In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import tqdm
import os
import wandb


# Hyperparameters
mb_size = 64
Z_dim = 1000
h_dim = 128
lr = 1e-3

# Load MNIST data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))  # Flatten the 28x28 image to 784
])

train_dataset = datasets.MNIST(root='../MNIST', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=mb_size, shuffle=True)

X_dim = 784  # 28 x 28

# Xavier Initialization
def xavier_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

# Generator
class Generator(nn.Module):
    def __init__(self, z_dim, h_dim, x_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(z_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, x_dim)
        self.apply(xavier_init)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        out = torch.sigmoid(self.fc2(h))
        return out

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, x_dim, h_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(x_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, 1)
        self.apply(xavier_init)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        out = torch.sigmoid(self.fc2(h))
        return out



# Training
def cGANTraining(G, D, loss_fn, train_loader):
    G.train()
    D.train()

    D_loss_real_total = 0
    D_loss_fake_total = 0
    G_loss_total = 0
    t = tqdm.tqdm(train_loader)

    for it, (X_real, labels) in enumerate(t):
        # Prepare real data
        X_real = X_real.float().to(device)

        # Sample noise and labels
        z = torch.randn(X_real.size(0), Z_dim).to(device)
        ones_label = torch.ones(X_real.size(0), 1).to(device)
        zeros_label = torch.zeros(X_real.size(0), 1).to(device)

        # ================= Train Discriminator =================
        G_sample = G(z)
        D_real = D(X_real)
        D_fake = D(G_sample.detach())

        D_loss_real = loss_fn(D_real, ones_label)
        D_loss_fake = loss_fn(D_fake, zeros_label)
        D_loss = D_loss_real + D_loss_fake
        D_loss_real_total += D_loss_real.item()
        D_loss_fake_total += D_loss_fake.item()

        D_solver.zero_grad()
        D_loss.backward()
        D_solver.step()

        # ================= Train Generator ====================
        z = torch.randn(X_real.size(0), Z_dim).to(device)
        G_sample = G(z)
        D_fake = D(G_sample)

        G_loss = loss_fn(D_fake, ones_label)
        G_loss_total += G_loss.item()

        G_solver.zero_grad()
        G_loss.backward()
        G_solver.step()

    # ================= Logging =================
    D_loss_real_avg = D_loss_real_total / len(train_loader)
    D_loss_fake_avg = D_loss_fake_total / len(train_loader)
    D_loss_avg = D_loss_real_avg + D_loss_fake_avg
    G_loss_avg = G_loss_total / len(train_loader)

    wandb.log({
        "D_loss_real": D_loss_real_avg,
        "D_loss_fake": D_loss_fake_avg,
        "D_loss": D_loss_avg,
        "G_loss": G_loss_avg
    })

    return G, D, G_loss_avg, D_loss_avg



def save_sample(G, epoch, mb_size, Z_dim):
    out_dir = "out_vanila_GAN2"
    G.eval()
    with torch.no_grad():
        z = torch.randn(mb_size, Z_dim).to(device)
        samples = G(z).detach().cpu().numpy()[:16]

    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    if not os.path.exists(f'{out_dir}'):
        os.makedirs(f'{out_dir}')

    plt.savefig(f'{out_dir}/{str(epoch).zfill(3)}.png', bbox_inches='tight')
    plt.close(fig)



########################### Main #######################################
wandb_log = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Instantiate models
G = Generator(Z_dim, h_dim, X_dim).to(device)
D = Discriminator(X_dim, h_dim).to(device)

# Optimizers
G_solver = optim.Adam(G.parameters(), lr=lr)
D_solver = optim.Adam(D.parameters(), lr=lr)

# Loss function
def my_bce_loss(preds, targets):
    return F.binary_cross_entropy(preds, targets)

#loss_fn = nn.BCEWithLogitsLoss()
loss_fn = my_bce_loss

if wandb_log:
    wandb.init(project="conditional-gan-mnist")

    # Log hyperparameters
    wandb.config.update({
        "batch_size": mb_size,
        "Z_dim": Z_dim,
        "X_dim": X_dim,
        "h_dim": h_dim,
        "lr": lr,
    })

best_g_loss = float('inf')  # Initialize best generator loss
save_dir = 'checkpoints'
os.makedirs(save_dir, exist_ok=True)

#Train epochs
epochs = 5
print("5 EPOCHS: ")
for epoch in range(epochs):
    G, D, G_loss_avg, D_loss_avg= cGANTraining(G, D, loss_fn, train_loader)

    print(f'epoch{epoch}; D_loss: {D_loss_avg:.4f}; G_loss: {G_loss_avg:.4f}')

    if G_loss_avg < best_g_loss:
        best_g_loss = G_loss_avg
        torch.save(G.state_dict(), os.path.join(save_dir, 'G_best.pth'))
        torch.save(D.state_dict(), os.path.join(save_dir, 'D_best.pth'))
        print(f"Saved Best Models at epoch {epoch} | G_loss: {best_g_loss:.4f}")

    save_sample(G, epoch, mb_size, Z_dim)

print("DONE")
# Inference
# G.load_state_dict(torch.load('checkpoints/G_best.pth'))
# G.eval()

# save_sample(G, "best", mb_size, Z_dim)

100%|██████████| 9.91M/9.91M [00:00<00:00, 42.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.13MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 8.65MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.44MB/s]


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33memisde-2[0m ([33mertveh-4-lule-university-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


5 EPOCHS: 


100%|██████████| 938/938 [00:13<00:00, 71.60it/s]


epoch0; D_loss: 0.0624; G_loss: 5.9588
Saved Best Models at epoch 0 | G_loss: 5.9588


100%|██████████| 938/938 [00:12<00:00, 73.13it/s]


epoch1; D_loss: 0.0299; G_loss: 6.2204


100%|██████████| 938/938 [00:13<00:00, 72.01it/s]


epoch2; D_loss: 0.0535; G_loss: 5.5215
Saved Best Models at epoch 2 | G_loss: 5.5215


100%|██████████| 938/938 [00:12<00:00, 72.95it/s]


epoch3; D_loss: 0.0894; G_loss: 5.6593


100%|██████████| 938/938 [00:12<00:00, 72.57it/s]


epoch4; D_loss: 0.1733; G_loss: 4.8484
Saved Best Models at epoch 4 | G_loss: 4.8484
DONE


In [None]:
best_g_loss = float('inf')  # Initialize best generator loss
save_dir = 'checkpoints'
os.makedirs(save_dir, exist_ok=True)

#Train epochs
epochs = 10
print("10 EPOCHS: ")
for epoch in range(epochs):
    G, D, G_loss_avg, D_loss_avg= cGANTraining(G, D, loss_fn, train_loader)

    print(f'epoch{epoch}; D_loss: {D_loss_avg:.4f}; G_loss: {G_loss_avg:.4f}')

    if G_loss_avg < best_g_loss:
        best_g_loss = G_loss_avg
        torch.save(G.state_dict(), os.path.join(save_dir, 'G_best.pth'))
        torch.save(D.state_dict(), os.path.join(save_dir, 'D_best.pth'))
        print(f"Saved Best Models at epoch {epoch} | G_loss: {best_g_loss:.4f}")

    save_sample(G, epoch, mb_size, Z_dim)

print("DONE")
best_g_loss = float('inf')  # Initialize best generator loss
save_dir = 'checkpoints'
os.makedirs(save_dir, exist_ok=True)

10 EPOCHS: 


100%|██████████| 938/938 [00:25<00:00, 36.13it/s]


epoch0; D_loss: 0.3020; G_loss: 4.3416
Saved Best Models at epoch 0 | G_loss: 4.3416


100%|██████████| 938/938 [00:18<00:00, 51.72it/s]


epoch1; D_loss: 0.3841; G_loss: 3.7982
Saved Best Models at epoch 1 | G_loss: 3.7982


100%|██████████| 938/938 [00:19<00:00, 47.99it/s]


epoch2; D_loss: 0.4585; G_loss: 3.5628
Saved Best Models at epoch 2 | G_loss: 3.5628


100%|██████████| 938/938 [00:19<00:00, 48.68it/s]


epoch3; D_loss: 0.5386; G_loss: 3.3502
Saved Best Models at epoch 3 | G_loss: 3.3502


100%|██████████| 938/938 [00:19<00:00, 48.61it/s]


epoch4; D_loss: 0.6063; G_loss: 3.1493
Saved Best Models at epoch 4 | G_loss: 3.1493


100%|██████████| 938/938 [00:18<00:00, 52.04it/s]


epoch5; D_loss: 0.6383; G_loss: 2.9762
Saved Best Models at epoch 5 | G_loss: 2.9762


100%|██████████| 938/938 [00:18<00:00, 49.59it/s]


epoch6; D_loss: 0.6645; G_loss: 2.7660
Saved Best Models at epoch 6 | G_loss: 2.7660


100%|██████████| 938/938 [00:18<00:00, 51.74it/s]


epoch7; D_loss: 0.6965; G_loss: 2.6393
Saved Best Models at epoch 7 | G_loss: 2.6393


100%|██████████| 938/938 [00:18<00:00, 50.57it/s]


epoch8; D_loss: 0.6895; G_loss: 2.5483
Saved Best Models at epoch 8 | G_loss: 2.5483


100%|██████████| 938/938 [00:19<00:00, 48.51it/s]


epoch9; D_loss: 0.7135; G_loss: 2.4641
Saved Best Models at epoch 9 | G_loss: 2.4641
DONE


In [None]:
best_g_loss = float('inf')  # Initialize best generator loss
save_dir = 'checkpoints'
os.makedirs(save_dir, exist_ok=True)

epochs = 50
print("50 EPOCHS: ")
for epoch in range(epochs):
    G, D, G_loss_avg, D_loss_avg= cGANTraining(G, D, loss_fn, train_loader)

    print(f'epoch{epoch}; D_loss: {D_loss_avg:.4f}; G_loss: {G_loss_avg:.4f}')

    if G_loss_avg < best_g_loss:
        best_g_loss = G_loss_avg
        torch.save(G.state_dict(), os.path.join(save_dir, 'G_best.pth'))
        torch.save(D.state_dict(), os.path.join(save_dir, 'D_best.pth'))
        print(f"Saved Best Models at epoch {epoch} | G_loss: {best_g_loss:.4f}")

    save_sample(G, epoch, mb_size, Z_dim)

print("DONE")

50 EPOCHS: 


100%|██████████| 938/938 [00:21<00:00, 44.66it/s]


epoch0; D_loss: 0.7093; G_loss: 2.2676
Saved Best Models at epoch 0 | G_loss: 2.2676


100%|██████████| 938/938 [00:19<00:00, 48.77it/s]


epoch1; D_loss: 0.7099; G_loss: 2.2622
Saved Best Models at epoch 1 | G_loss: 2.2622


100%|██████████| 938/938 [00:18<00:00, 50.19it/s]


epoch2; D_loss: 0.7020; G_loss: 2.2078
Saved Best Models at epoch 2 | G_loss: 2.2078


100%|██████████| 938/938 [00:18<00:00, 49.96it/s]


epoch3; D_loss: 0.6940; G_loss: 2.2463


100%|██████████| 938/938 [00:19<00:00, 48.73it/s]


epoch4; D_loss: 0.6981; G_loss: 2.2387


100%|██████████| 938/938 [00:18<00:00, 50.23it/s]


epoch5; D_loss: 0.6978; G_loss: 2.2604


100%|██████████| 938/938 [00:19<00:00, 47.77it/s]


epoch6; D_loss: 0.6927; G_loss: 2.2728


100%|██████████| 938/938 [00:18<00:00, 49.70it/s]


epoch7; D_loss: 0.6975; G_loss: 2.2704


100%|██████████| 938/938 [00:20<00:00, 46.86it/s]


epoch8; D_loss: 0.6896; G_loss: 2.2899


100%|██████████| 938/938 [00:18<00:00, 49.42it/s]


epoch9; D_loss: 0.6844; G_loss: 2.2876


100%|██████████| 938/938 [00:20<00:00, 46.29it/s]


epoch10; D_loss: 0.6824; G_loss: 2.2945


100%|██████████| 938/938 [00:19<00:00, 49.02it/s]


epoch11; D_loss: 0.6774; G_loss: 2.3128


100%|██████████| 938/938 [00:20<00:00, 45.89it/s]


epoch12; D_loss: 0.6723; G_loss: 2.3350


100%|██████████| 938/938 [00:19<00:00, 48.12it/s]


epoch13; D_loss: 0.6693; G_loss: 2.3519


100%|██████████| 938/938 [00:20<00:00, 45.64it/s]


epoch14; D_loss: 0.6662; G_loss: 2.3853


100%|██████████| 938/938 [00:19<00:00, 48.18it/s]


epoch15; D_loss: 0.6603; G_loss: 2.4030


100%|██████████| 938/938 [00:20<00:00, 45.19it/s]


epoch16; D_loss: 0.6507; G_loss: 2.4274


100%|██████████| 938/938 [00:19<00:00, 48.42it/s]


epoch17; D_loss: 0.6481; G_loss: 2.4113


100%|██████████| 938/938 [00:21<00:00, 44.25it/s]


epoch18; D_loss: 0.6427; G_loss: 2.4589


100%|██████████| 938/938 [00:20<00:00, 46.49it/s]


epoch19; D_loss: 0.6307; G_loss: 2.4875


100%|██████████| 938/938 [00:19<00:00, 47.07it/s]


epoch20; D_loss: 0.6379; G_loss: 2.5151


100%|██████████| 938/938 [00:20<00:00, 45.34it/s]


epoch21; D_loss: 0.6285; G_loss: 2.5054


100%|██████████| 938/938 [00:19<00:00, 48.05it/s]


epoch22; D_loss: 0.6244; G_loss: 2.5362


100%|██████████| 938/938 [00:20<00:00, 45.42it/s]


epoch23; D_loss: 0.6269; G_loss: 2.5589


100%|██████████| 938/938 [00:19<00:00, 48.01it/s]


epoch24; D_loss: 0.6127; G_loss: 2.5838


100%|██████████| 938/938 [00:20<00:00, 45.32it/s]


epoch25; D_loss: 0.6138; G_loss: 2.5817


100%|██████████| 938/938 [00:19<00:00, 47.70it/s]


epoch26; D_loss: 0.6040; G_loss: 2.6363


100%|██████████| 938/938 [00:20<00:00, 45.40it/s]


epoch27; D_loss: 0.5991; G_loss: 2.6694


100%|██████████| 938/938 [00:19<00:00, 47.77it/s]


epoch28; D_loss: 0.5901; G_loss: 2.6736


100%|██████████| 938/938 [00:20<00:00, 45.28it/s]


epoch29; D_loss: 0.5827; G_loss: 2.6887


100%|██████████| 938/938 [00:20<00:00, 46.72it/s]


epoch30; D_loss: 0.5779; G_loss: 2.7251


100%|██████████| 938/938 [00:20<00:00, 46.38it/s]


epoch31; D_loss: 0.5739; G_loss: 2.7300


100%|██████████| 938/938 [00:20<00:00, 45.88it/s]


epoch32; D_loss: 0.5678; G_loss: 2.7560


100%|██████████| 938/938 [00:19<00:00, 47.15it/s]


epoch33; D_loss: 0.5587; G_loss: 2.7865


100%|██████████| 938/938 [00:20<00:00, 45.55it/s]


epoch34; D_loss: 0.5604; G_loss: 2.7937


100%|██████████| 938/938 [00:19<00:00, 47.49it/s]


epoch35; D_loss: 0.5507; G_loss: 2.7933


100%|██████████| 938/938 [00:22<00:00, 42.23it/s]


epoch36; D_loss: 0.5488; G_loss: 2.8057


100%|██████████| 938/938 [00:19<00:00, 47.50it/s]


epoch37; D_loss: 0.5397; G_loss: 2.8088


100%|██████████| 938/938 [00:21<00:00, 44.47it/s]


epoch38; D_loss: 0.5362; G_loss: 2.8597


100%|██████████| 938/938 [00:19<00:00, 46.94it/s]


epoch39; D_loss: 0.5297; G_loss: 2.8245


100%|██████████| 938/938 [00:20<00:00, 45.25it/s]


epoch40; D_loss: 0.5280; G_loss: 2.8464


100%|██████████| 938/938 [00:20<00:00, 45.02it/s]


epoch41; D_loss: 0.5143; G_loss: 2.8857


100%|██████████| 938/938 [00:20<00:00, 46.85it/s]


epoch42; D_loss: 0.5178; G_loss: 2.8759


100%|██████████| 938/938 [00:21<00:00, 44.53it/s]


epoch43; D_loss: 0.5104; G_loss: 2.8887


100%|██████████| 938/938 [00:20<00:00, 46.78it/s]


epoch44; D_loss: 0.5087; G_loss: 2.8779


100%|██████████| 938/938 [00:21<00:00, 43.83it/s]


epoch45; D_loss: 0.5009; G_loss: 2.8859


100%|██████████| 938/938 [00:20<00:00, 45.27it/s]


epoch46; D_loss: 0.5002; G_loss: 2.9532


100%|██████████| 938/938 [00:21<00:00, 44.05it/s]


epoch47; D_loss: 0.4926; G_loss: 2.9395


100%|██████████| 938/938 [00:21<00:00, 42.88it/s]


epoch48; D_loss: 0.4902; G_loss: 2.9586


100%|██████████| 938/938 [00:20<00:00, 45.33it/s]


epoch49; D_loss: 0.4861; G_loss: 2.9701
DONE


In [5]:
#With BCE Logit loss as Loss function
import torch.nn.functional as F
def my_BCEL_loss(preds, targets):
    return F.binary_cross_entropy_with_logits(preds, targets)

#loss_fn = nn.BCEWithLogitsLoss()
loss_fn = my_BCEL_loss

best_g_loss = float('inf')  # Initialize best generator loss
save_dir = 'checkpoints'
os.makedirs(save_dir, exist_ok=True)

#Train epochs
epochs = 5
print("5 EPOCHS: ")
for epoch in range(epochs):
    G, D, G_loss_avg, D_loss_avg= cGANTraining(G, D, loss_fn, train_loader)

    print(f'epoch{epoch}; D_loss: {D_loss_avg:.4f}; G_loss: {G_loss_avg:.4f}')

    if G_loss_avg < best_g_loss:
        best_g_loss = G_loss_avg
        torch.save(G.state_dict(), os.path.join(save_dir, 'G_best.pth'))
        torch.save(D.state_dict(), os.path.join(save_dir, 'D_best.pth'))
        print(f"Saved Best Models at epoch {epoch} | G_loss: {best_g_loss:.4f}")

    save_sample(G, epoch, mb_size, Z_dim)

print("DONE")

5 EPOCHS: 


100%|██████████| 938/938 [00:25<00:00, 36.57it/s]


epoch0; D_loss: 1.0119; G_loss: 0.6916
Saved Best Models at epoch 0 | G_loss: 0.6916


100%|██████████| 938/938 [00:19<00:00, 48.96it/s]


epoch1; D_loss: 1.0073; G_loss: 0.6929


100%|██████████| 938/938 [00:21<00:00, 43.71it/s]


epoch2; D_loss: 1.0068; G_loss: 0.6930


100%|██████████| 938/938 [00:20<00:00, 45.04it/s]


epoch3; D_loss: 1.0066; G_loss: 0.6931


100%|██████████| 938/938 [00:21<00:00, 43.61it/s]


epoch4; D_loss: 1.0065; G_loss: 0.6931
DONE


In [4]:
best_g_loss = float('inf')  # Initialize best generator loss
save_dir = 'checkpoints'
os.makedirs(save_dir, exist_ok=True)

#Train epochs
epochs = 10
print("10 EPOCHS: ")
for epoch in range(epochs):
    G, D, G_loss_avg, D_loss_avg= cGANTraining(G, D, loss_fn, train_loader)

    print(f'epoch{epoch}; D_loss: {D_loss_avg:.4f}; G_loss: {G_loss_avg:.4f}')

    if G_loss_avg < best_g_loss:
        best_g_loss = G_loss_avg
        torch.save(G.state_dict(), os.path.join(save_dir, 'G_best.pth'))
        torch.save(D.state_dict(), os.path.join(save_dir, 'D_best.pth'))
        print(f"Saved Best Models at epoch {epoch} | G_loss: {best_g_loss:.4f}")

    save_sample(G, epoch, mb_size, Z_dim)

print("DONE")

10 EPOCHS: 


100%|██████████| 938/938 [00:12<00:00, 72.71it/s]


epoch0; D_loss: 0.2931; G_loss: 4.3655
Saved Best Models at epoch 0 | G_loss: 4.3655


100%|██████████| 938/938 [00:13<00:00, 72.07it/s]


epoch1; D_loss: 0.3556; G_loss: 3.8379
Saved Best Models at epoch 1 | G_loss: 3.8379


100%|██████████| 938/938 [00:12<00:00, 72.51it/s]


epoch2; D_loss: 0.3788; G_loss: 3.7475
Saved Best Models at epoch 2 | G_loss: 3.7475


100%|██████████| 938/938 [00:13<00:00, 71.56it/s]


epoch3; D_loss: 0.4597; G_loss: 3.3984
Saved Best Models at epoch 3 | G_loss: 3.3984


100%|██████████| 938/938 [00:13<00:00, 70.82it/s]


epoch4; D_loss: 0.5640; G_loss: 3.0592
Saved Best Models at epoch 4 | G_loss: 3.0592


100%|██████████| 938/938 [00:13<00:00, 69.78it/s]


epoch5; D_loss: 0.5895; G_loss: 2.7915
Saved Best Models at epoch 5 | G_loss: 2.7915


100%|██████████| 938/938 [00:13<00:00, 70.10it/s]


epoch6; D_loss: 0.6245; G_loss: 2.8057


100%|██████████| 938/938 [00:13<00:00, 70.91it/s]


epoch7; D_loss: 0.6239; G_loss: 2.6591
Saved Best Models at epoch 7 | G_loss: 2.6591


100%|██████████| 938/938 [00:13<00:00, 71.66it/s]


epoch8; D_loss: 0.6614; G_loss: 2.5551
Saved Best Models at epoch 8 | G_loss: 2.5551


100%|██████████| 938/938 [00:12<00:00, 72.30it/s]


epoch9; D_loss: 0.6725; G_loss: 2.4337
Saved Best Models at epoch 9 | G_loss: 2.4337
DONE


In [6]:
best_g_loss = float('inf')  # Initialize best generator loss
save_dir = 'checkpoints'
os.makedirs(save_dir, exist_ok=True)

#Train epochs
epochs = 50
print("50 EPOCHS: ")
for epoch in range(epochs):
    G, D, G_loss_avg, D_loss_avg= cGANTraining(G, D, loss_fn, train_loader)

    print(f'epoch{epoch}; D_loss: {D_loss_avg:.4f}; G_loss: {G_loss_avg:.4f}')

    if G_loss_avg < best_g_loss:
        best_g_loss = G_loss_avg
        torch.save(G.state_dict(), os.path.join(save_dir, 'G_best.pth'))
        torch.save(D.state_dict(), os.path.join(save_dir, 'D_best.pth'))
        print(f"Saved Best Models at epoch {epoch} | G_loss: {best_g_loss:.4f}")

    save_sample(G, epoch, mb_size, Z_dim)

print("DONE")

50 EPOCHS: 


100%|██████████| 938/938 [00:19<00:00, 47.58it/s]


epoch0; D_loss: 1.0065; G_loss: 0.6931
Saved Best Models at epoch 0 | G_loss: 0.6931


100%|██████████| 938/938 [00:20<00:00, 46.18it/s]


epoch1; D_loss: 1.0065; G_loss: 0.6931


100%|██████████| 938/938 [00:20<00:00, 46.21it/s]


epoch2; D_loss: 1.0065; G_loss: 0.6931


100%|██████████| 938/938 [00:19<00:00, 47.11it/s]


epoch3; D_loss: 1.0065; G_loss: 0.6931


100%|██████████| 938/938 [00:20<00:00, 46.56it/s]


epoch4; D_loss: 1.0240; G_loss: 0.6888
Saved Best Models at epoch 4 | G_loss: 0.6888


100%|██████████| 938/938 [00:19<00:00, 47.38it/s]


epoch5; D_loss: 1.0678; G_loss: 0.6770
Saved Best Models at epoch 5 | G_loss: 0.6770


100%|██████████| 938/938 [00:20<00:00, 46.07it/s]


epoch6; D_loss: 1.0785; G_loss: 0.6744
Saved Best Models at epoch 6 | G_loss: 0.6744


100%|██████████| 938/938 [00:19<00:00, 47.76it/s]


epoch7; D_loss: 1.0784; G_loss: 0.6733
Saved Best Models at epoch 7 | G_loss: 0.6733


100%|██████████| 938/938 [00:20<00:00, 44.95it/s]


epoch8; D_loss: 1.0815; G_loss: 0.6732
Saved Best Models at epoch 8 | G_loss: 0.6732


100%|██████████| 938/938 [00:19<00:00, 47.95it/s]


epoch9; D_loss: 1.0807; G_loss: 0.6730
Saved Best Models at epoch 9 | G_loss: 0.6730


100%|██████████| 938/938 [00:21<00:00, 44.03it/s]


epoch10; D_loss: 1.0364; G_loss: 0.6851


100%|██████████| 938/938 [00:19<00:00, 47.00it/s]


epoch11; D_loss: 1.0618; G_loss: 0.6787


100%|██████████| 938/938 [00:21<00:00, 44.55it/s]


epoch12; D_loss: 1.0808; G_loss: 0.6736


100%|██████████| 938/938 [00:20<00:00, 46.56it/s]


epoch13; D_loss: 1.0601; G_loss: 0.6786


100%|██████████| 938/938 [00:21<00:00, 43.23it/s]


epoch14; D_loss: 1.0394; G_loss: 0.6844


100%|██████████| 938/938 [00:22<00:00, 42.41it/s]


epoch15; D_loss: 1.0297; G_loss: 0.6868


100%|██████████| 938/938 [00:21<00:00, 43.66it/s]


epoch16; D_loss: 1.0323; G_loss: 0.6859


100%|██████████| 938/938 [00:23<00:00, 40.17it/s]


epoch17; D_loss: 1.0077; G_loss: 0.6927


100%|██████████| 938/938 [00:22<00:00, 41.95it/s]


epoch18; D_loss: 1.0562; G_loss: 0.6800


100%|██████████| 938/938 [00:22<00:00, 42.58it/s]


epoch19; D_loss: 1.0430; G_loss: 0.6837


100%|██████████| 938/938 [00:22<00:00, 41.23it/s]


epoch20; D_loss: 1.0637; G_loss: 0.6772


100%|██████████| 938/938 [00:24<00:00, 38.84it/s]


epoch21; D_loss: 1.0217; G_loss: 0.6889


100%|██████████| 938/938 [00:21<00:00, 42.68it/s]


epoch22; D_loss: 1.0841; G_loss: 0.6703
Saved Best Models at epoch 22 | G_loss: 0.6703


100%|██████████| 938/938 [00:23<00:00, 40.67it/s]


epoch23; D_loss: 1.0357; G_loss: 0.6852


100%|██████████| 938/938 [00:24<00:00, 38.45it/s]


epoch24; D_loss: 1.0164; G_loss: 0.6903


100%|██████████| 938/938 [00:25<00:00, 36.88it/s]


epoch25; D_loss: 1.0233; G_loss: 0.6884


100%|██████████| 938/938 [00:23<00:00, 40.73it/s]


epoch26; D_loss: 1.0916; G_loss: 0.6683
Saved Best Models at epoch 26 | G_loss: 0.6683


100%|██████████| 938/938 [00:23<00:00, 39.19it/s]


epoch27; D_loss: 1.0462; G_loss: 0.6814


100%|██████████| 938/938 [00:22<00:00, 42.30it/s]


epoch28; D_loss: 1.0756; G_loss: 0.6737


100%|██████████| 938/938 [00:24<00:00, 38.58it/s]


epoch29; D_loss: 1.0439; G_loss: 0.6826


100%|██████████| 938/938 [00:23<00:00, 39.18it/s]


epoch30; D_loss: 1.0662; G_loss: 0.6746


100%|██████████| 938/938 [00:23<00:00, 40.09it/s]


epoch31; D_loss: 1.0717; G_loss: 0.6742


100%|██████████| 938/938 [00:25<00:00, 37.44it/s]


epoch32; D_loss: 1.0085; G_loss: 0.6926


100%|██████████| 938/938 [00:24<00:00, 38.12it/s]


epoch33; D_loss: 1.0447; G_loss: 0.6823


100%|██████████| 938/938 [00:23<00:00, 39.18it/s]


epoch34; D_loss: 1.0578; G_loss: 0.6789


100%|██████████| 938/938 [00:25<00:00, 37.49it/s]


epoch35; D_loss: 1.0273; G_loss: 0.6870


100%|██████████| 938/938 [00:26<00:00, 35.34it/s]


epoch36; D_loss: 1.0080; G_loss: 0.6927


100%|██████████| 938/938 [00:27<00:00, 34.25it/s]


epoch37; D_loss: 1.0071; G_loss: 0.6929


100%|██████████| 938/938 [00:29<00:00, 32.25it/s]


epoch38; D_loss: 1.0069; G_loss: 0.6930


100%|██████████| 938/938 [00:30<00:00, 31.15it/s]


epoch39; D_loss: 1.0216; G_loss: 0.6900


100%|██████████| 938/938 [00:25<00:00, 37.23it/s]


epoch40; D_loss: 1.0622; G_loss: 0.6773


100%|██████████| 938/938 [00:24<00:00, 38.92it/s]


epoch41; D_loss: 1.0633; G_loss: 0.6767


100%|██████████| 938/938 [00:25<00:00, 37.31it/s]


epoch42; D_loss: 1.0648; G_loss: 0.6753


100%|██████████| 938/938 [00:24<00:00, 37.81it/s]


epoch43; D_loss: 1.0601; G_loss: 0.6778


100%|██████████| 938/938 [00:24<00:00, 38.56it/s]


epoch44; D_loss: 1.0592; G_loss: 0.6773


100%|██████████| 938/938 [00:24<00:00, 38.71it/s]


epoch45; D_loss: 1.0628; G_loss: 0.6763


100%|██████████| 938/938 [00:24<00:00, 38.40it/s]


epoch46; D_loss: 1.0529; G_loss: 0.6792


100%|██████████| 938/938 [00:23<00:00, 40.54it/s]


epoch47; D_loss: 1.0675; G_loss: 0.6746


100%|██████████| 938/938 [00:23<00:00, 39.22it/s]


epoch48; D_loss: 1.0672; G_loss: 0.6751


100%|██████████| 938/938 [00:24<00:00, 38.84it/s]


epoch49; D_loss: 1.0698; G_loss: 0.6752
DONE
