In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [None]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
# Set random seed for reproducibility
torch.manual_seed(42)

# Define the generator network
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, output_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.fc(x)

In [None]:
# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.fc(x)

In [None]:
# Hyperparameters
batch_size = 128
input_dim = 100
output_dim = 784
num_epochs = 100
learning_rate = 0.0002

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Initialize generator and discriminator
generator = Generator(input_dim, output_dim)
discriminator = Discriminator(output_dim)

# Define loss function and optimizers
criterion = nn.BCELoss()
gen_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)

# Training loop
total_step = len(train_loader)

for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        batch_size = images.size(0)

        # Flatten the images
        images = images.view(batch_size, -1)

        # Create labels for real and fake images
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        # Train the discriminator
        discriminator.zero_grad()
        outputs = discriminator(images)
        real_loss = criterion(outputs, real_labels)
        real_score = outputs

        z = torch.randn(batch_size, input_dim)
        fake_images = generator(z)
        outputs = discriminator(fake_images.detach())
        fake_loss = criterion(outputs, fake_labels)
        fake_score = outputs

        disc_loss = real_loss + fake_loss
        disc_loss.backward()
        disc_optimizer.step()

        # Train the generator
        generator.zero_grad()
        z = torch.randn(batch_size, input_dim)
        fake_images = generator(z)
        outputs = discriminator(fake_images)
        gen_loss = criterion(outputs, real_labels)

        gen_loss.backward()
        gen_optimizer.step()

        # Print training progress
        if (i+1) % 200 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], d_loss: {disc_loss.item():.4f}, g_loss: {gen_loss.item():.4f}, D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}")


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 112660618.22it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 12326823.43it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 25510763.29it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 17820887.53it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Epoch [1/100], Step [200/469], d_loss: 0.2171, g_loss: 4.4883, D(x): 0.94, D(G(z)): 0.13
Epoch [1/100], Step [400/469], d_loss: 1.2546, g_loss: 0.5103, D(x): 0.93, D(G(z)): 0.68
Epoch [2/100], Step [200/469], d_loss: 0.3712, g_loss: 1.9822, D(x): 0.86, D(G(z)): 0.19
Epoch [2/100], Step [400/469], d_loss: 1.0824, g_loss: 1.0854, D(x): 0.66, D(G(z)): 0.45
Epoch [3/100], Step [200/469], d_loss: 0.3570, g_loss: 2.7302, D(x): 0.84, D(G(z)): 0.14
Epoch [3/100], Step [400/469], d_loss: 0.5669, g_loss: 2.8812, D(x): 0.81, D(G(z)): 0.22
Epoch [4/100], Step [200/469], d_loss: 1.4973, g_loss: 1.4512, D(x): 0.56, D(G(z)): 0.40
Epoch [4/100], Step [400/469], d_loss: 0.6208, g_loss: 1.6831, D(x): 0.87, D(G(z)): 0.32
Epoch [5/100], Step [200/469], d_loss: 0.8895, g_loss: 1.4811, D(x): 0.78, D(G(z)): 0.33
Epoch [5/100], Step [400/469], d_loss: 0.5546, g_loss: 1.9755, D(x): 0.84, D(G(z)): 0.24
Epoch [6/100], Step [200/469], d_loss: 

In [None]:
# Generate fake digits
num_samples = 1

for i in range(500):
    z = torch.randn(1, input_dim)
    generated_images = generator(z)
    generated_images = generated_images.view(num_samples, 1, 28, 28)
    torchvision.utils.save_image(generated_images, 'Fake_Digits/'+str(i) + '.png', nrow=1, normalize=True)
    with open ('Fake_Digits/'+str(i) + '.txt','w') as fo:
        fo.write('['+ ','.join([str(i) for i in z.tolist()[0]]) + ']')

In [None]:
# Save Models using pickle
torch.save(generator.state_dict(), 'Models/G_dict.pkl')
torch.save(discriminator.state_dict(), 'Models/D_dict.pkl')
torch.save(generator, 'Models/G.pkl')
torch.save(discriminator, 'Models/D.pkl')

In [None]:
# Generate 10x10 (100) images
num_samples = 100
z = torch.randn(num_samples, input_dim)
generated_images = generator(z)
generated_images = generated_images.view(num_samples, 1, 28, 28)
torchvision.utils.save_image(generated_images, 'generated_images.png', nrow=10, normalize=True)