In [9]:
#### import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import tqdm
import os
import wandb


# Hyperparameters
mb_size = 64 
Z_dim = 1000
h_dim = 128
lr = 1e-3

# Load MNIST data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))  # Flatten the 28x28 image to 784
])

train_dataset = datasets.MNIST(root='../MNIST', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=mb_size, shuffle=True)

X_dim = 784  # 28 x 28

# Xavier Initialization
def xavier_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

# Generator
class Generator(nn.Module):
    def __init__(self, z_dim, h_dim, x_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(z_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, x_dim)
        self.apply(xavier_init)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        out = torch.sigmoid(self.fc2(h))
        return out

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, x_dim, h_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(x_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, 1)
        self.apply(xavier_init)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        out = torch.sigmoid(self.fc2(h))
        return out



# Training
def cGANTraining(G, D, loss_fn, train_loader):
    G.train()
    D.train()

    D_loss_real_total = 0
    D_loss_fake_total = 0
    G_loss_total = 0
    t = tqdm.tqdm(train_loader)
    
    for it, (X_real, labels) in enumerate(t):
        # Prepare real data
        X_real = X_real.float().to(device)

        # Sample noise and labels
        z = torch.randn(X_real.size(0), Z_dim).to(device)
        ones_label = torch.ones(X_real.size(0), 1).to(device)
        zeros_label = torch.zeros(X_real.size(0), 1).to(device)

        # ================= Train Discriminator =================
        G_sample = G(z)
        D_real = D(X_real)
        D_fake = D(G_sample.detach())

        D_loss_real = loss_fn(D_real, ones_label)
        D_loss_fake = loss_fn(D_fake, zeros_label)
        D_loss = D_loss_real + D_loss_fake #high value would mean that the discriminator is bad
        D_loss_real_total += D_loss_real.item() #
        D_loss_fake_total += D_loss_fake.item()

        D_solver.zero_grad()
        D_loss.backward()
        D_solver.step()

        # ================= Train Generator ====================
        z = torch.randn(X_real.size(0), Z_dim).to(device)
        G_sample = G(z)
        D_fake = D(G_sample) #if this value is high ths mean that the discriminatior gues this picture was fake

        G_loss = loss_fn(D_fake, ones_label) # if the value is low means that the generator can fool the discriminator
        G_loss_total += G_loss.item()

        G_solver.zero_grad()
        G_loss.backward()
        G_solver.step()

    # ================= Logging =================
    D_loss_real_avg = D_loss_real_total / len(train_loader)
    D_loss_fake_avg = D_loss_fake_total / len(train_loader)
    D_loss_avg = D_loss_real_avg + D_loss_fake_avg
    G_loss_avg = G_loss_total / len(train_loader)

    wandb.log({
        "D_loss_real": D_loss_real_avg,
        "D_loss_fake": D_loss_fake_avg,
        "D_loss": D_loss_avg,
        "G_loss": G_loss_avg
    })

    return G, D, G_loss_avg, D_loss_avg
    


def save_sample(G, epoch, mb_size, Z_dim):
    out_dir = "out_vanila_GAN2"
    G.eval()
    with torch.no_grad():
        z = torch.randn(mb_size, Z_dim).to(device)
        samples = G(z).detach().cpu().numpy()[:16]

    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    if not os.path.exists(f'{out_dir}'):
        os.makedirs(f'{out_dir}')

    plt.savefig(f'{out_dir}/{str(epoch).zfill(3)}.png', bbox_inches='tight')
    plt.close(fig)



########################### Main #######################################
wandb_log = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Instantiate models
G = Generator(Z_dim, h_dim, X_dim).to(device)
D = Discriminator(X_dim, h_dim).to(device)

# Optimizers
G_solver = optim.Adam(G.parameters(), lr=lr)
D_solver = optim.Adam(D.parameters(), lr=lr)

# Loss function
def my_bce_loss(preds, targets):
    return F.binary_cross_entropy(preds, targets)

#loss_fn = nn.BCEWithLogitsLoss()
loss_fn = my_bce_loss

if wandb_log: 
    wandb.init(project="conditional-gan-mnist")

    # Log hyperparameters
    wandb.config.update({
        "batch_size": mb_size,
        "Z_dim": Z_dim,
        "X_dim": X_dim,
        "h_dim": h_dim,
        "lr": lr,
    })

best_g_loss = float('inf')  # Initialize best generator loss
save_dir = 'checkpoints'
os.makedirs(save_dir, exist_ok=True)

#Train epochs
epochs = 60

for epoch in range(epochs):
    G, D, G_loss_avg, D_loss_avg= cGANTraining(G, D, loss_fn, train_loader)

    print(f'epoch{epoch}; D_loss: {D_loss_avg:.4f}; G_loss: {G_loss_avg:.4f}')

    if G_loss_avg < best_g_loss:
        best_g_loss = G_loss_avg
        torch.save(G.state_dict(), os.path.join(save_dir, 'G_best.pth'))
        torch.save(D.state_dict(), os.path.join(save_dir, 'D_best.pth'))
        print(f"Saved Best Models at epoch {epoch} | G_loss: {best_g_loss:.4f}")

    save_sample(G, epoch, mb_size, Z_dim)


# Inference    
# G.load_state_dict(torch.load('checkpoints/G_best.pth'))
# G.eval()

# save_sample(G, "best", mb_size, Z_dim)

0,1
D_loss,▁▂▂▃▄▆▆▇▇██████████████▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆
D_loss_fake,▁▁▁▂▃▆▇▇███████████████▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▆
D_loss_real,▂▁▂▂▄▅▆▇████████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆
G_loss,██▇▆▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂

0,1
D_loss,0.49075
D_loss_fake,0.21669
D_loss_real,0.27406
G_loss,3.01146


100%|██████████| 938/938 [00:11<00:00, 81.91it/s]


epoch0; D_loss: 0.0626; G_loss: 6.0155
Saved Best Models at epoch 0 | G_loss: 6.0155


100%|██████████| 938/938 [00:11<00:00, 83.89it/s]


epoch1; D_loss: 0.0526; G_loss: 5.1583
Saved Best Models at epoch 1 | G_loss: 5.1583


100%|██████████| 938/938 [00:09<00:00, 102.21it/s]


epoch2; D_loss: 0.0782; G_loss: 5.3637


100%|██████████| 938/938 [00:11<00:00, 78.99it/s]


epoch3; D_loss: 0.1325; G_loss: 5.3420


100%|██████████| 938/938 [00:14<00:00, 64.86it/s]


epoch4; D_loss: 0.2380; G_loss: 4.8234
Saved Best Models at epoch 4 | G_loss: 4.8234


100%|██████████| 938/938 [00:11<00:00, 82.32it/s]


epoch5; D_loss: 0.3587; G_loss: 4.3161
Saved Best Models at epoch 5 | G_loss: 4.3161


100%|██████████| 938/938 [00:10<00:00, 89.28it/s] 


epoch6; D_loss: 0.4744; G_loss: 3.6588
Saved Best Models at epoch 6 | G_loss: 3.6588


100%|██████████| 938/938 [00:10<00:00, 85.48it/s]


epoch7; D_loss: 0.5932; G_loss: 3.4041
Saved Best Models at epoch 7 | G_loss: 3.4041


100%|██████████| 938/938 [00:10<00:00, 89.16it/s] 


epoch8; D_loss: 0.7043; G_loss: 3.0389
Saved Best Models at epoch 8 | G_loss: 3.0389


100%|██████████| 938/938 [00:10<00:00, 89.50it/s] 


epoch9; D_loss: 0.7517; G_loss: 2.7463
Saved Best Models at epoch 9 | G_loss: 2.7463


100%|██████████| 938/938 [00:11<00:00, 84.85it/s] 


epoch10; D_loss: 0.7271; G_loss: 2.6301
Saved Best Models at epoch 10 | G_loss: 2.6301


100%|██████████| 938/938 [00:10<00:00, 88.22it/s] 


epoch11; D_loss: 0.7121; G_loss: 2.4783
Saved Best Models at epoch 11 | G_loss: 2.4783


100%|██████████| 938/938 [00:11<00:00, 81.94it/s]


epoch12; D_loss: 0.7306; G_loss: 2.3877
Saved Best Models at epoch 12 | G_loss: 2.3877


100%|██████████| 938/938 [00:11<00:00, 83.27it/s]


epoch13; D_loss: 0.7321; G_loss: 2.4003


100%|██████████| 938/938 [00:10<00:00, 87.56it/s] 


epoch14; D_loss: 0.7401; G_loss: 2.3195
Saved Best Models at epoch 14 | G_loss: 2.3195


100%|██████████| 938/938 [00:11<00:00, 82.34it/s]


epoch15; D_loss: 0.7508; G_loss: 2.2868
Saved Best Models at epoch 15 | G_loss: 2.2868


100%|██████████| 938/938 [00:11<00:00, 81.74it/s]


epoch16; D_loss: 0.7542; G_loss: 2.2101
Saved Best Models at epoch 16 | G_loss: 2.2101


100%|██████████| 938/938 [00:11<00:00, 84.53it/s]


epoch17; D_loss: 0.7517; G_loss: 2.1412
Saved Best Models at epoch 17 | G_loss: 2.1412


100%|██████████| 938/938 [00:12<00:00, 73.19it/s]


epoch18; D_loss: 0.7314; G_loss: 2.1189
Saved Best Models at epoch 18 | G_loss: 2.1189


100%|██████████| 938/938 [00:12<00:00, 74.76it/s]


epoch19; D_loss: 0.7429; G_loss: 2.1504


100%|██████████| 938/938 [00:11<00:00, 83.64it/s]


epoch20; D_loss: 0.7289; G_loss: 2.1572


100%|██████████| 938/938 [00:11<00:00, 81.40it/s]


epoch21; D_loss: 0.7340; G_loss: 2.1642


100%|██████████| 938/938 [00:11<00:00, 81.43it/s]


epoch22; D_loss: 0.7240; G_loss: 2.1226


100%|██████████| 938/938 [00:10<00:00, 86.00it/s]


epoch23; D_loss: 0.7113; G_loss: 2.1728


100%|██████████| 938/938 [00:10<00:00, 90.74it/s] 


epoch24; D_loss: 0.7140; G_loss: 2.1888


100%|██████████| 938/938 [00:16<00:00, 56.92it/s]


epoch25; D_loss: 0.7042; G_loss: 2.2290


100%|██████████| 938/938 [00:17<00:00, 54.68it/s]


epoch26; D_loss: 0.6931; G_loss: 2.2686


100%|██████████| 938/938 [00:17<00:00, 53.56it/s]


epoch27; D_loss: 0.6910; G_loss: 2.3061


100%|██████████| 938/938 [00:17<00:00, 54.64it/s]


epoch28; D_loss: 0.6876; G_loss: 2.3347


100%|██████████| 938/938 [00:17<00:00, 54.74it/s]


epoch29; D_loss: 0.6840; G_loss: 2.3210


100%|██████████| 938/938 [00:17<00:00, 54.26it/s]


epoch30; D_loss: 0.6796; G_loss: 2.3659


100%|██████████| 938/938 [00:14<00:00, 64.42it/s]


epoch31; D_loss: 0.6667; G_loss: 2.4256


100%|██████████| 938/938 [00:17<00:00, 54.70it/s]


epoch32; D_loss: 0.6733; G_loss: 2.3941


100%|██████████| 938/938 [00:17<00:00, 54.61it/s]


epoch33; D_loss: 0.6601; G_loss: 2.4652


100%|██████████| 938/938 [00:17<00:00, 53.56it/s]


epoch34; D_loss: 0.6608; G_loss: 2.4564


100%|██████████| 938/938 [00:17<00:00, 54.38it/s]


epoch35; D_loss: 0.6555; G_loss: 2.4692


100%|██████████| 938/938 [00:16<00:00, 55.23it/s]


epoch36; D_loss: 0.6472; G_loss: 2.5100


100%|██████████| 938/938 [00:17<00:00, 54.66it/s]


epoch37; D_loss: 0.6413; G_loss: 2.5649


100%|██████████| 938/938 [00:14<00:00, 66.12it/s] 


epoch38; D_loss: 0.6378; G_loss: 2.5840


100%|██████████| 938/938 [00:16<00:00, 55.30it/s]


epoch39; D_loss: 0.6265; G_loss: 2.5851


100%|██████████| 938/938 [00:17<00:00, 54.26it/s]


epoch40; D_loss: 0.6266; G_loss: 2.6056


100%|██████████| 938/938 [00:16<00:00, 58.01it/s]


epoch41; D_loss: 0.6204; G_loss: 2.6278


100%|██████████| 938/938 [00:16<00:00, 57.27it/s]


epoch42; D_loss: 0.6147; G_loss: 2.6426


100%|██████████| 938/938 [00:17<00:00, 54.10it/s]


epoch43; D_loss: 0.6095; G_loss: 2.6594


100%|██████████| 938/938 [00:16<00:00, 55.34it/s]


epoch44; D_loss: 0.6000; G_loss: 2.6731


100%|██████████| 938/938 [00:13<00:00, 67.17it/s] 


epoch45; D_loss: 0.6016; G_loss: 2.7145


100%|██████████| 938/938 [00:17<00:00, 53.98it/s]


epoch46; D_loss: 0.5897; G_loss: 2.7487


100%|██████████| 938/938 [00:17<00:00, 55.02it/s]


epoch47; D_loss: 0.5849; G_loss: 2.7793


100%|██████████| 938/938 [00:17<00:00, 54.26it/s]


epoch48; D_loss: 0.5806; G_loss: 2.7664


100%|██████████| 938/938 [00:17<00:00, 53.41it/s]


epoch49; D_loss: 0.5784; G_loss: 2.8078


100%|██████████| 938/938 [00:17<00:00, 54.09it/s]


epoch50; D_loss: 0.5681; G_loss: 2.8234


100%|██████████| 938/938 [00:17<00:00, 54.50it/s]


epoch51; D_loss: 0.5682; G_loss: 2.8245


100%|██████████| 938/938 [00:14<00:00, 64.63it/s]


epoch52; D_loss: 0.5617; G_loss: 2.8529


100%|██████████| 938/938 [00:17<00:00, 54.23it/s]


epoch53; D_loss: 0.5549; G_loss: 2.8775


100%|██████████| 938/938 [00:17<00:00, 54.76it/s]


epoch54; D_loss: 0.5446; G_loss: 2.9150


100%|██████████| 938/938 [00:18<00:00, 52.08it/s]


epoch55; D_loss: 0.5434; G_loss: 2.9272


100%|██████████| 938/938 [00:17<00:00, 54.46it/s]


epoch56; D_loss: 0.5369; G_loss: 2.9444


100%|██████████| 938/938 [00:17<00:00, 53.88it/s]


epoch57; D_loss: 0.5337; G_loss: 2.9527


100%|██████████| 938/938 [00:16<00:00, 55.65it/s]


epoch58; D_loss: 0.5305; G_loss: 2.9863


100%|██████████| 938/938 [00:14<00:00, 65.27it/s]


epoch59; D_loss: 0.5251; G_loss: 2.9832
