In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import tqdm
import os
import wandb

ImportError: cannot import name 'TypeAliasType' from 'typing_extensions' (C:\Users\elias\Documents\Anaconda\envs\ADL\lib\site-packages\typing_extensions.py)

In [None]:
wandb.login()

wandb.init(project="amazon-books-reviews-sentiment", name="T6-6class", reinit=True)

In [15]:
# Hyperparameters
mb_size = 64
Z_dim = 1000
h_dim = 128
c_dim = 10 # class 
lr = 1e-3

In [16]:
# Load MNIST data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))  # Flatten the 28x28 image to 784
])

train_dataset = datasets.MNIST(root='../MNIST', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=mb_size, shuffle=True)

X_dim = 784  # 28 x 28

In [17]:
# Xavier Initialization
def xavier_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)


In [18]:
# To onehot
def to_onehot(x, num_classes=10):
    if isinstance(x, int):
        x = torch.tensor([x], dtype=torch.long)
    elif not isinstance(x, torch.Tensor):
        raise TypeError("Input must be an int or a torch.Tensor")

    x = x.long()  # ensure it's LongTensor for one-hot indexing
    one_hot = torch.zeros(x.size(0), num_classes, device=x.device)
    one_hot.scatter_(1, x.view(-1, 1), 1)
    return one_hot

In [19]:
# Generator
class Generator(nn.Module):
    def __init__(self, z_dim, h_dim, x_dim, c_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(z_dim + c_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, x_dim)
        self.apply(xavier_init)

    def forward(self, z,c):
        v = torch.cat((z, c), 1)
        h = F.relu(self.fc1(v))
        out = torch.sigmoid(self.fc2(h))
        return out


In [20]:
# Discriminator
class Discriminator(nn.Module):
    def __init__(self, x_dim, h_dim, c_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(x_dim + c_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, 1)
        self.apply(xavier_init)

    def forward(self, x, c):
        v = torch.cat((x, c), 1)
        h = F.relu(self.fc1(v))
        out = torch.sigmoid(self.fc2(h))
        return out


In [32]:
# Training
def cGANTraining(G, D, loss_fn, train_loader):
    G.train()
    D.train()

    D_loss_real_total = 0
    D_loss_fake_total = 0
    G_loss_total = 0
    t = tqdm.tqdm(train_loader)
    
    for it, (X_real, labels) in enumerate(t):
        # Prepare real data
        X_real = X_real.float().to(device)
        c = to_onehot(labels).to(device)
        
        # Sample noise and labels
        z = torch.randn(X_real.size(0), Z_dim).to(device)
        ones_label = torch.ones(X_real.size(0), 1).to(device)
        zeros_label = torch.zeros(X_real.size(0), 1).to(device)

        # ================= Train Discriminator =================
        G_sample = G(z,c)
        D_real = D(X_real, c)
        D_fake = D(G_sample.detach(),c)

        D_loss_real = loss_fn(D_real, ones_label)
        D_loss_fake = loss_fn(D_fake, zeros_label)
        D_loss = D_loss_real + D_loss_fake
        D_loss_real_total += D_loss_real.item()
        D_loss_fake_total += D_loss_fake.item()

        D_solver.zero_grad()
        D_loss.backward()
        D_solver.step()

        # ================= Train Generator ====================
        z = torch.randn(X_real.size(0), Z_dim).to(device)
        G_sample = G(z,c)
        D_fake = D(G_sample,c)

        G_loss = loss_fn(D_fake, ones_label)
        G_loss_total += G_loss.item()

        G_solver.zero_grad()
        G_loss.backward()
        G_solver.step()

    # ================= Logging =================
    D_loss_real_avg = D_loss_real_total / len(train_loader)
    D_loss_fake_avg = D_loss_fake_total / len(train_loader)
    D_loss_avg = D_loss_real_avg + D_loss_fake_avg
    G_loss_avg = G_loss_total / len(train_loader)

#    wandb.log({
#        "D_loss_real": D_loss_real_avg,
#        "D_loss_fake": D_loss_fake_avg,
#        "D_loss": D_loss_avg,
#        "G_loss": G_loss_avg
#    })

    return G, D, G_loss_avg, D_loss_avg
    


def save_sample(G, epoch, mb_size, Z_dim):
    out_dir = "out_vanila_GAN2"
    G.eval()
    with torch.no_grad():
        num_classes = 10
        z = torch.randn(num_classes, Z_dim).to(device)

        # One-hot encoded class conditions
        c = torch.zeros(num_classes, num_classes).to(device)
        c[torch.arange(num_classes), torch.arange(num_classes)] = 1

        samples = G(z, c).detach().cpu().numpy()

    fig = plt.figure(figsize=(10, 1))
    for i in range(num_classes):
        plt.subplot(1, num_classes, i + 1)
        plt.axis('off')
        plt.imshow(samples[i].reshape(28, 28), cmap='Greys_r')

    os.makedirs(out_dir, exist_ok=True)
    plt.savefig(f'{out_dir}/{str(epoch).zfill(3)}.png', bbox_inches='tight')
    plt.close(fig)



########################### Main #######################################
wandb_log = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Instantiate models
G = Generator(Z_dim, h_dim, X_dim, c_dim).to(device)
D = Discriminator(X_dim, h_dim, c_dim).to(device)

# Optimizers
G_solver = optim.Adam(G.parameters(), lr=lr)
D_solver = optim.Adam(D.parameters(), lr=lr)

# Loss function
def my_bce_loss(preds, targets):
    return F.binary_cross_entropy(preds, targets)

#loss_fn = nn.BCEWithLogitsLoss()
loss_fn = my_bce_loss

#if wandb_log: 
#    wandb.init(project="conditional-gan-mnist")
#
#    # Log hyperparameters
#    wandb.config.update({
#        "batch_size": mb_size,
#        "Z_dim": Z_dim,
#        "X_dim": X_dim,
#        "h_dim": h_dim,
#        "lr": lr,
#    })
#
best_g_loss = float('inf')  # Initialize best generator loss
save_dir = 'checkpoints'
os.makedirs(save_dir, exist_ok=True)

#Train epochs
epochs = 100

for epoch in range(epochs):
    G, D, G_loss_avg, D_loss_avg= cGANTraining(G, D, loss_fn, train_loader)

    print(f'epoch{epoch}; D_loss: {D_loss_avg:.4f}; G_loss: {G_loss_avg:.4f}')

    if G_loss_avg < best_g_loss:
        best_g_loss = G_loss_avg
        torch.save(G.state_dict(), os.path.join(save_dir, 'G_best.pth'))
        torch.save(D.state_dict(), os.path.join(save_dir, 'D_best.pth'))
        print(f"Saved Best Models at epoch {epoch} | G_loss: {best_g_loss:.4f}")

    save_sample(G, epoch, mb_size, Z_dim)


# Inference    
# G.load_state_dict(torch.load('checkpoints/G_best.pth'))
# G.eval()

# save_sample(G, "best", mb_size, Z_dim)

100%|██████████| 938/938 [00:07<00:00, 122.12it/s]


epoch0; D_loss: 0.0471; G_loss: 6.8148
Saved Best Models at epoch 0 | G_loss: 6.8148


100%|██████████| 938/938 [00:07<00:00, 120.08it/s]


epoch1; D_loss: 0.0135; G_loss: 7.4256


100%|██████████| 938/938 [00:07<00:00, 122.39it/s]


epoch2; D_loss: 0.0217; G_loss: 7.0546


100%|██████████| 938/938 [00:07<00:00, 127.36it/s]


epoch3; D_loss: 0.0402; G_loss: 6.4778
Saved Best Models at epoch 3 | G_loss: 6.4778


100%|██████████| 938/938 [00:06<00:00, 134.77it/s]


epoch4; D_loss: 0.0961; G_loss: 6.0634
Saved Best Models at epoch 4 | G_loss: 6.0634


100%|██████████| 938/938 [00:06<00:00, 136.79it/s]


epoch5; D_loss: 0.1670; G_loss: 5.4971
Saved Best Models at epoch 5 | G_loss: 5.4971


100%|██████████| 938/938 [00:07<00:00, 128.26it/s]


epoch6; D_loss: 0.2671; G_loss: 5.0769
Saved Best Models at epoch 6 | G_loss: 5.0769


100%|██████████| 938/938 [00:07<00:00, 121.61it/s]


epoch7; D_loss: 0.3576; G_loss: 4.4312
Saved Best Models at epoch 7 | G_loss: 4.4312


100%|██████████| 938/938 [00:07<00:00, 125.71it/s]


epoch8; D_loss: 0.4883; G_loss: 3.8133
Saved Best Models at epoch 8 | G_loss: 3.8133


100%|██████████| 938/938 [00:07<00:00, 126.30it/s]


epoch9; D_loss: 0.5624; G_loss: 3.3175
Saved Best Models at epoch 9 | G_loss: 3.3175


100%|██████████| 938/938 [00:07<00:00, 133.97it/s]


epoch10; D_loss: 0.5918; G_loss: 3.1490
Saved Best Models at epoch 10 | G_loss: 3.1490


100%|██████████| 938/938 [00:07<00:00, 129.94it/s]


epoch11; D_loss: 0.6316; G_loss: 2.8862
Saved Best Models at epoch 11 | G_loss: 2.8862


100%|██████████| 938/938 [00:07<00:00, 121.35it/s]


epoch12; D_loss: 0.6381; G_loss: 2.6561
Saved Best Models at epoch 12 | G_loss: 2.6561


100%|██████████| 938/938 [00:07<00:00, 122.33it/s]


epoch13; D_loss: 0.6760; G_loss: 2.4757
Saved Best Models at epoch 13 | G_loss: 2.4757


100%|██████████| 938/938 [00:07<00:00, 126.08it/s]


epoch14; D_loss: 0.7077; G_loss: 2.4017
Saved Best Models at epoch 14 | G_loss: 2.4017


100%|██████████| 938/938 [00:07<00:00, 125.67it/s]


epoch15; D_loss: 0.7295; G_loss: 2.2649
Saved Best Models at epoch 15 | G_loss: 2.2649


100%|██████████| 938/938 [00:07<00:00, 123.97it/s]


epoch16; D_loss: 0.7278; G_loss: 2.2733


100%|██████████| 938/938 [00:07<00:00, 125.61it/s]


epoch17; D_loss: 0.7333; G_loss: 2.2005
Saved Best Models at epoch 17 | G_loss: 2.2005


100%|██████████| 938/938 [00:07<00:00, 124.58it/s]


epoch18; D_loss: 0.7260; G_loss: 2.1935
Saved Best Models at epoch 18 | G_loss: 2.1935


100%|██████████| 938/938 [00:07<00:00, 126.26it/s]


epoch19; D_loss: 0.7266; G_loss: 2.1545
Saved Best Models at epoch 19 | G_loss: 2.1545


100%|██████████| 938/938 [00:06<00:00, 138.96it/s]


epoch20; D_loss: 0.7289; G_loss: 2.1901


100%|██████████| 938/938 [00:06<00:00, 138.84it/s]


epoch21; D_loss: 0.7314; G_loss: 2.1397
Saved Best Models at epoch 21 | G_loss: 2.1397


100%|██████████| 938/938 [00:07<00:00, 124.81it/s]


epoch22; D_loss: 0.7269; G_loss: 2.1538


100%|██████████| 938/938 [00:06<00:00, 138.28it/s]


epoch23; D_loss: 0.7251; G_loss: 2.1585


100%|██████████| 938/938 [00:07<00:00, 124.85it/s]


epoch24; D_loss: 0.7237; G_loss: 2.1401


100%|██████████| 938/938 [00:07<00:00, 123.39it/s]


epoch25; D_loss: 0.7230; G_loss: 2.1367
Saved Best Models at epoch 25 | G_loss: 2.1367


100%|██████████| 938/938 [00:07<00:00, 127.20it/s]


epoch26; D_loss: 0.7279; G_loss: 2.1597


100%|██████████| 938/938 [00:07<00:00, 120.26it/s]


epoch27; D_loss: 0.7197; G_loss: 2.1543


100%|██████████| 938/938 [00:07<00:00, 132.58it/s]


epoch28; D_loss: 0.7213; G_loss: 2.1536


100%|██████████| 938/938 [00:07<00:00, 130.05it/s]


epoch29; D_loss: 0.7186; G_loss: 2.1695


100%|██████████| 938/938 [00:07<00:00, 120.81it/s]


epoch30; D_loss: 0.7169; G_loss: 2.1591


100%|██████████| 938/938 [00:07<00:00, 124.62it/s]


epoch31; D_loss: 0.7156; G_loss: 2.1582


100%|██████████| 938/938 [00:07<00:00, 122.92it/s]


epoch32; D_loss: 0.7175; G_loss: 2.1800


100%|██████████| 938/938 [00:07<00:00, 122.63it/s]


epoch33; D_loss: 0.7155; G_loss: 2.1843


100%|██████████| 938/938 [00:07<00:00, 124.83it/s]


epoch34; D_loss: 0.7187; G_loss: 2.1749


100%|██████████| 938/938 [00:07<00:00, 122.38it/s]


epoch35; D_loss: 0.7110; G_loss: 2.1588


100%|██████████| 938/938 [00:07<00:00, 124.71it/s]


epoch36; D_loss: 0.7098; G_loss: 2.1671


100%|██████████| 938/938 [00:07<00:00, 123.57it/s]


epoch37; D_loss: 0.7087; G_loss: 2.1644


100%|██████████| 938/938 [00:07<00:00, 122.85it/s]


epoch38; D_loss: 0.7076; G_loss: 2.1738


100%|██████████| 938/938 [00:06<00:00, 137.91it/s]


epoch39; D_loss: 0.7050; G_loss: 2.1626


100%|██████████| 938/938 [00:07<00:00, 126.28it/s]


epoch40; D_loss: 0.7048; G_loss: 2.1895


100%|██████████| 938/938 [00:07<00:00, 123.42it/s]


epoch41; D_loss: 0.7031; G_loss: 2.1681


100%|██████████| 938/938 [00:07<00:00, 125.34it/s]


epoch42; D_loss: 0.7015; G_loss: 2.2011


100%|██████████| 938/938 [00:07<00:00, 125.54it/s]


epoch43; D_loss: 0.6991; G_loss: 2.2100


100%|██████████| 938/938 [00:07<00:00, 121.47it/s]


epoch44; D_loss: 0.6968; G_loss: 2.2111


100%|██████████| 938/938 [00:06<00:00, 134.14it/s]


epoch45; D_loss: 0.6972; G_loss: 2.2140


100%|██████████| 938/938 [00:06<00:00, 138.94it/s]


epoch46; D_loss: 0.6946; G_loss: 2.2385


100%|██████████| 938/938 [00:07<00:00, 123.62it/s]


epoch47; D_loss: 0.6865; G_loss: 2.2316


100%|██████████| 938/938 [00:07<00:00, 129.62it/s]


epoch48; D_loss: 0.6905; G_loss: 2.2458


100%|██████████| 938/938 [00:07<00:00, 124.25it/s]


epoch49; D_loss: 0.6880; G_loss: 2.2246


100%|██████████| 938/938 [00:07<00:00, 124.01it/s]


epoch50; D_loss: 0.6832; G_loss: 2.2351


100%|██████████| 938/938 [00:07<00:00, 126.59it/s]


epoch51; D_loss: 0.6815; G_loss: 2.2502


100%|██████████| 938/938 [00:07<00:00, 125.43it/s]


epoch52; D_loss: 0.6809; G_loss: 2.2555


100%|██████████| 938/938 [00:07<00:00, 123.93it/s]


epoch53; D_loss: 0.6793; G_loss: 2.2492


100%|██████████| 938/938 [00:07<00:00, 124.04it/s]


epoch54; D_loss: 0.6758; G_loss: 2.2566


100%|██████████| 938/938 [00:07<00:00, 123.07it/s]


epoch55; D_loss: 0.6762; G_loss: 2.2690


100%|██████████| 938/938 [00:07<00:00, 123.06it/s]


epoch56; D_loss: 0.6742; G_loss: 2.2649


100%|██████████| 938/938 [00:07<00:00, 129.61it/s]


epoch57; D_loss: 0.6758; G_loss: 2.2624


100%|██████████| 938/938 [00:07<00:00, 123.64it/s]


epoch58; D_loss: 0.6671; G_loss: 2.2571


100%|██████████| 938/938 [00:07<00:00, 127.92it/s]


epoch59; D_loss: 0.6697; G_loss: 2.2750


100%|██████████| 938/938 [00:07<00:00, 128.14it/s]


epoch60; D_loss: 0.6602; G_loss: 2.2510


100%|██████████| 938/938 [00:07<00:00, 125.74it/s]


epoch61; D_loss: 0.6653; G_loss: 2.2674


100%|██████████| 938/938 [00:07<00:00, 126.30it/s]


epoch62; D_loss: 0.6640; G_loss: 2.2716


100%|██████████| 938/938 [00:07<00:00, 132.70it/s]


epoch63; D_loss: 0.6591; G_loss: 2.2842


100%|██████████| 938/938 [00:07<00:00, 127.14it/s]


epoch64; D_loss: 0.6588; G_loss: 2.2703


100%|██████████| 938/938 [00:07<00:00, 124.12it/s]


epoch65; D_loss: 0.6591; G_loss: 2.2735


100%|██████████| 938/938 [00:07<00:00, 124.49it/s]


epoch66; D_loss: 0.6535; G_loss: 2.2751


100%|██████████| 938/938 [00:07<00:00, 124.61it/s]


epoch67; D_loss: 0.6495; G_loss: 2.2858


100%|██████████| 938/938 [00:07<00:00, 124.26it/s]


epoch68; D_loss: 0.6542; G_loss: 2.2888


100%|██████████| 938/938 [00:07<00:00, 123.70it/s]


epoch69; D_loss: 0.6518; G_loss: 2.2651


100%|██████████| 938/938 [00:07<00:00, 122.56it/s]


epoch70; D_loss: 0.6502; G_loss: 2.2794


100%|██████████| 938/938 [00:07<00:00, 123.32it/s]


epoch71; D_loss: 0.6512; G_loss: 2.2866


 80%|████████  | 755/938 [00:06<00:01, 121.62it/s]


KeyboardInterrupt: 