### Vanilla GAN(Generative Adverserial Network) implementation using pytorch

#### Imports and dataset loading for mnist dataset,

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, transforms
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # GPU is not wirking, so using CPU
device = 'cpu'
print(device)

cpu


In [16]:
mnist_data = datasets.FashionMNIST(
    root= '../data',
    train= True,
    download= True,
    transform= ToTensor(),
    target_transform= Lambda(lambda y: torch.zeros(10, dtype= torch.float).scatter_(0, torch.tensor(y), value=1))
)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.FashionMNIST(root='../data', train= True, download= True, transform= transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size= 128, shuffle= True)


#### Model Generation code
* decalre the model with model

In [19]:
class Generator(nn.Module):
    def __init__(self, input_dim= 100, output_dim= 784):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, output_dim),
            nn.Tanh() #Because we normalized image to [-1, 1] range
        )

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, input_dim= 784):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid() # Output probability
        )
        
    def forward(self, x):
        return self.model(x)
    

In [20]:
noise_dim = 100
image_dim = 28*28
#model initialization
generator = Generator(input_dim=noise_dim, output_dim=image_dim).to(device)
discriminator = Discriminator(input_dim=image_dim).to(device)

# Loss and Optimizers
g_optimizer = torch.optim.Adam(generator.parameters(), lr= 0.0002)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr= 0.0002)

# Loss function
criterion = nn.BCELoss()

In [22]:
def show_generated_images(epoch, generator, fixed_noise):
    generator.eval()
    with torch.no_grad():
        fake_images = generator(fixed_noise).view(-1, 1, 28, 28)
        fake_images = fake_images * 0.5 + 0.5  # Rescale to [0, 1]
    grid = torchvision.utils.make_grid(fake_images, nrow= 8)
    plt.figure(figsize=(10, 10))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.title(f'Epoch {epoch}')
    plt.axis('off')
    plt.show()

In [None]:
def train_generator(generator, train_loader, noise_dim, fixed_noise= None):
    if not  fixed_noise:
        fixed_noise = torch.randn(64, noise_dim).to(device)
    batch_size = train_loader.batch_size
    z = torch.randn()


In [None]:
def train_gan(generator, discriminator, data_loader, mode, num_epochs= 50, batch_size= 128, lr= 0.0002):
    fixed_noise = torch.randn(64, noise_dim).to(device)  # Fixed noise for generating images
    
    for epoch in range(num_epochs):
        for batch_idx, (real_images, _) in enumerate(data_loader):
            batch_size = real_images.size(0)
            real = real.view(batch_size, -1).to(device)  # Flatten images

            # create real and fake labels
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            ### ======================================================
            ### Train Disciminator
            ### ======================================================

            # real images 
            outputs = discriminator(real)
            d_loss_real = criterion(outputs, real_labels)
            real_score = outputs

            # fake images
            z = torch.randn(batch_size, noise_dim).to(device)
            fake = generator(z)
            outputs = discriminator(fake.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            fake_score = outputs

            # Total discriminator loss
            d_loss = d_loss_real + d_loss_fake

            discriminator.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            ### ======================================================
            ### Train Generator
            ### ======================================================
            z = torch.randn(batch_size, noise_dim).to(device)
            fake = generator(z)
            outputs = discriminator(fake)

            g_loss = criterion(outputs, real_labels)

            generator.zero_grad()
            g_loss.backward()
            g_optimizer.step()
            if (batch_idx + 1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(data_loader)}], '
                      f'D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}, '
                      f'D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}')
            
            ### ======================================================
            ### 3. Modify according to mode
            ### ======================================================

            if mode == 'five_gen_one_disc':
                # four times generator train
                for _ in range(4):
                    z = torch.randn(batch_size, noise_dim).to(device)
                    fake = generator(z)
                    outputs = discriminator(fake)

                    g_loss = criterion(outputs, real_labels)
                    
                    generator.zero_grad()
                    g_loss.backward()
                    g_optimizer.step()
                
                # one time discriminator train
            elif mode == 'five_disc_one_gen':
                for _ in range(4):
                    real_labels = torch.ones(batch_size, 1).to(device)
                    outputs = discriminator(real)
                    real_score = outputs

                    d_loss_real = criterion(outputs, real_labels)

                    # fake
                    z = torch.randn(batch_size, noise_dim).to(device)
                    fake = generator(z)
                    fake_labels = torch.zeros(batch_size, 1).to(device)
                    outputs = discriminator(fake)
                    d_loss_fake = criterion(outputs, fake_labels)
                    fake_score = d_loss_fake

                    d_loss  = real_score + fake_score
                    discriminator.zero_grad()
                    d_loss.backward()
                    discriminator.step()
                else:
                    print('default mode selected')
                
            if (epoch + 1) % 5 ==0:
                print(f'Epoch[{epoch+1}/{num_epochs}], D_loss: {d_loss.items():.4f}, G_Loss: {g_loss.item():.4f}:')
                show_generated_images(epoch+1, generator, fixed_noise)









In [25]:
## standard training process
print("Training: 1-step gen 1-step disc")
train_gan(generator, discriminator, train_loader, mode= 'default', num_epochs= 50)

print('train 5-gen, 1-step discriminator')
train_gan(generator, discriminator, train_loader, mode='five_gen_one_disc', num_epochs=50)

print('train 5-disc, 1-gen')
train_gan(generator, discriminator, train_loader, mode='five_disc_one_gen', num_epochs=50)


Training: 1-step gen 1-step disc


UnboundLocalError: cannot access local variable 'real' where it is not associated with a value