In [80]:
import wandb
import torch
import functools
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.utils as utils
import torchvision.utils as vu
import torch.nn.functional as F
import torchvision.datasets as Datasets
import torchvision.transforms as T
from torchvision.utils import save_image
from torch.cuda.amp import GradScaler, autocast


from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

import helpMe
import layers
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] =1


## Configrations

In [81]:
model_name = "Cgan_mnist"
image_size = 32
batch_size = 100
z_dim = 128
# DATA_DIR = './imageNet_lp/torch_image_folder/mnt/volume_sfo3_01/imagenet-lt/ImageDataset/train'
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
channels =1
epochs = 11

In [82]:
wandb.init(
    # set the wandb project where this run will be logged
    project="Cgan_mnist",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.0002,
    "architecture": "Gan",
    "dataset": "MNIST",
    "epochs": 10,
    }
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33maruntd008[0m ([33maruntd08[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888884685, max=1.0…

## Dataset

In [83]:
# transforms = T.Compose([T.Resize(image_size), T.CenterCrop(image_size), T.ToTensor(), T.Normalize(*stats)])
transforms = T.Compose([
    T.Resize(image_size),
    T.RandomRotation(10),  # Random rotation within 10 degrees
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,))
])
dataset = Datasets.MNIST(root='./Datasxts/MNIST', train=True, download=True,transform=transforms) # cifar-10
# dataset = D.ImageFolder(DATA_DIR, transform=transforms) # imageNet_lp
dataset

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./Datasxts/MNIST
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=32, interpolation=bilinear, max_size=None, antialias=True)
               RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
               ToTensor()
               Normalize(mean=(0.5,), std=(0.5,))
           )

In [84]:
class_list=dataset.classes
print('No of Classes:',len(class_list))
print('Class Names: ',class_list)

No of Classes: 10
Class Names:  ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']


In [85]:
device = helpMe.get_default_device()
device

device(type='cuda')

# Model Architecture

In [86]:
class GBlock(nn.Module):
    def __init__(self, in_channels, out_channels,
                 which_conv=nn.Conv2d, which_bn=layers.bn, activation=None,
                 upsample=None, channel_ratio=4):
        super(GBlock, self).__init__()

        self.in_channels, self.out_channels = in_channels, out_channels
        self.hidden_channels = self.in_channels // channel_ratio
        self.which_conv, self.which_bn = which_conv, which_bn
        self.activation = activation
        # Conv layers
        self.conv1 = self.which_conv(self.in_channels, self.hidden_channels,
                                     kernel_size=1, padding=0)
        self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels)
        self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels)
        self.conv4 = self.which_conv(self.hidden_channels, self.out_channels,
                                     kernel_size=1, padding=0)
        # Batchnorm layers
        self.bn1 = self.which_bn(self.in_channels)
        self.bn2 = self.which_bn(self.hidden_channels)
        self.bn3 = self.which_bn(self.hidden_channels)
        self.bn4 = self.which_bn(self.hidden_channels)
        # upsample layers
        self.upsample = upsample

    def forward(self, x, y):
        # Project down to channel ratio
        h = self.conv1(self.activation(self.bn1(x, y)))
        # Apply next BN-ReLU
        h = self.activation(self.bn2(h, y))
        # Drop channels in x if necessary
        if self.in_channels != self.out_channels:
            x = x[:, :self.out_channels]
            # Upsample both h and x at this point
        if self.upsample:
            h = self.upsample(h)
            x = self.upsample(x)
        # 3x3 convs
        h = self.conv2(h)
        h = self.conv3(self.activation(self.bn3(h, y)))
        # Final 1x1 conv
        h = self.conv4(self.activation(self.bn4(h, y)))
        return h + x


def G_arch(ch=64, attention='64'):
    arch = {}
    arch[256] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2]],
                 'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1]],
                 'upsample': [True] * 6,
                 'resolution': [8, 16, 32, 64, 128, 256],
                 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])               #{8: False, 16: False, 32: False, 64: True, 128: False, 256: False}
                               for i in range(3, 9)}}                                                         #This dictionary indicates which of the powers of 2 from 2^3 to 2^8 are present in the attention string.
    arch[128] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]],
                 'out_channels': [ch * item for item in [16, 8, 4, 2, 1]],
                 'upsample': [True] * 5,
                 'resolution': [8, 16, 32, 64, 128],
                 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
                               for i in range(3, 8)}}
    arch[64] = {'in_channels': [ch * item for item in [16, 16, 8, 4]],
                'out_channels': [ch * item for item in [16, 8, 4, 2]],
                'upsample': [True] * 4,
                'resolution': [8, 16, 32, 64],
                'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
                              for i in range(3, 7)}}
    arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]],
                'out_channels': [ch * item for item in [4, 4, 4]],
                'upsample': [True] * 3,
                'resolution': [8, 16, 32],
                'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
                              for i in range(3, 6)}}

    return arch

In [87]:
class Generator(nn.Module):
    def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=32,
                 G_kernel_size=3, G_attn='64', n_classes=1000,
                 G_activation=nn.ReLU(inplace=False),
                 G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
                 skip_init=False,
                 **kwargs):
        super(Generator, self).__init__()
        # Channel width mulitplier
        self.ch = G_ch
        # Number of resblocks per stage
        self.G_depth = G_depth
        # Dimensionality of the latent space
        self.dim_z = dim_z
        # The initial spatial dimensions
        self.bottom_width = bottom_width
        # Resolution of the output
        self.resolution = resolution
        # Kernel size?
        self.kernel_size = G_kernel_size
        # Attention?
        self.attention = G_attn
        # number of classes
        self.n_classes = n_classes
        # Dimensionality of the shared embedding
        self.shared_dim = dim_z
        # nonlinearity for residual blocks
        self.activation = G_activation

    
        # Architecture dict
        self.arch = G_arch(self.ch, self.attention)[resolution]
        
        # Which convs, batchnorms, and linear layers to use
        self.which_conv = functools.partial(layers.SNConv2d,kernel_size=3, padding=1)
        self.which_linear = functools.partial(layers.SNLinear)


        # We use a non-spectral-normed embedding here regardless;
        # For some reason applying SN to G's embedding seems to randomly cripple G
        self.which_embedding = nn.Embedding
        bn_linear = (functools.partial(self.which_linear, bias=False))
        self.which_bn = functools.partial(layers.ccbn,
                                          which_linear=bn_linear,
                                          input_size=(self.shared_dim + self.dim_z))

        # Prepare model
        

        self.shared = (self.which_embedding(n_classes, self.shared_dim))
        # First linear layer
        self.linear = self.which_linear(self.dim_z + self.shared_dim,
                                        self.arch['in_channels'][0] * (self.bottom_width ** 2))

        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        # while the inner loop is over a given block
        
        # arch[128] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]],
        #          'out_channels': [ch * item for item in   [16, 8, 4, 2, 1]],
        #          'upsample': [True] * 5,
        #          'resolution': [8, 16, 32, 64, 128],
        #          'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
        #                        for i in range(3, 8)}}
        self.blocks = []
        for index in range(len(self.arch['out_channels'])):
            self.blocks += [[GBlock(in_channels=self.arch['in_channels'][index],
                                    out_channels=self.arch['in_channels'][index] if g_index == 0 else
                                    self.arch['out_channels'][index],
                                    which_conv=self.which_conv,
                                    which_bn=self.which_bn,
                                    activation=self.activation,
                                    upsample=(functools.partial(F.interpolate, scale_factor=2)
                                              if self.arch['upsample'][index] and g_index == (
                                                self.G_depth - 1) else None))]
                            for g_index in range(self.G_depth)]

            # If attention on this block, attach it to the end
            if self.arch['attention'][self.arch['resolution'][index]]:
                print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
                self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]

        # Turn self.blocks into a ModuleList so that it's all properly registered.
        self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])

        # output layer: batchnorm-relu-conv.
        # Consider using a non-spectral conv here
        self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1]),
                                          self.activation,
                                          self.which_conv(self.arch['out_channels'][-1], channels))

        # Initialize weights. Optionally skip init for testing.
        if not skip_init:
            self.init_weights()

        # Set up optimizer

        self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps

        self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
                                betas=(self.B1, self.B2), weight_decay=0,
                                eps=self.adam_eps)

    # Initialize
    def init_weights(self):
        self.param_count = 0
        for module in self.modules():
            if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Embedding)):

                init.orthogonal_(module.weight)

                self.param_count += sum([p.data.nelement() for p in module.parameters()])
        print('Param count for G''s initialized parameters: %d' % self.param_count)

    # Note on this forward function: we pass in a y vector which has
    # already been passed through G.shared to enable easy class-wise
    # interpolation later. If we passed in the one-hot and then ran it through
    # G.shared in this forward function, it would be harder to handle.
    # NOTE: The z vs y dichotomy here is for compatibility with not-y
    def forward(self, z, y):
        # concatenate zs and ys
        # print('Shape of z',z.shape)
        # print('Shape of y',y.shape)
        z = torch.cat([y, z], 1)
        y = z
        # First linear layer
        # print('Shape of cont z',z.shape)
        h = self.linear(z)
        # Reshape
        h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
        # Loop over blocks
        for index, blocklist in enumerate(self.blocks):
            # Second inner loop in case block has multiple layers
            for block in blocklist:
                h = block(h, y)

        # Apply batchnorm-relu-conv-tanh at output
        return torch.tanh(self.output_layer(h))

In [88]:
class DBlock(nn.Module):
    def __init__(self, in_channels, out_channels, which_conv=layers.SNConv2d,
                 preactivation=True, activation=None, downsample=None,
                 channel_ratio=4):
        super(DBlock, self).__init__()
        self.in_channels, self.out_channels = in_channels, out_channels
        self.hidden_channels = self.out_channels // channel_ratio
        self.which_conv = which_conv
        self.preactivation = preactivation
        self.activation = activation
        self.downsample = downsample

        # Conv layers
        self.conv1 = self.which_conv(self.in_channels, self.hidden_channels,
                                     kernel_size=1, padding=0)
        self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels)
        self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels)
        self.conv4 = self.which_conv(self.hidden_channels, self.out_channels,
                                     kernel_size=1, padding=0)

        self.learnable_sc = True if (in_channels != out_channels) else False
        if self.learnable_sc:
            self.conv_sc = self.which_conv(in_channels, out_channels - in_channels,
                                           kernel_size=1, padding=0)

    def shortcut(self, x):
        if self.downsample:
            x = self.downsample(x)
        if self.learnable_sc:
            x = torch.cat([x, self.conv_sc(x)], 1)
        return x

    def forward(self, x):
        # 1x1 bottleneck conv
        h = self.conv1(F.relu(x))
        # 3x3 convs
        h = self.conv2(self.activation(h))
        h = self.conv3(self.activation(h))
        # relu before downsample
        h = self.activation(h)
        # downsample
        if self.downsample:
            h = self.downsample(h)
            # final 1x1 conv
        h = self.conv4(h)
        return h + self.shortcut(x)


# Discriminator architecture, same paradigm as G's above
def D_arch(ch=64, attention='64'):
    arch = {}
    arch[256] = {'in_channels': [item * ch for item in [1, 2, 4, 8, 8, 16]],
                 'out_channels': [item * ch for item in [2, 4, 8, 8, 16, 16]],
                 'downsample': [True] * 6 + [False],
                 'resolution': [128, 64, 32, 16, 8, 4, 4],
                 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
                               for i in range(2, 8)}}
    arch[128] = {'in_channels': [item * ch for item in [1, 2, 4, 8, 16]],
                 'out_channels': [item * ch for item in [2, 4, 8, 16, 16]],
                 'downsample': [True] * 5 + [False],
                 'resolution': [64, 32, 16, 8, 4, 4],
                 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
                               for i in range(2, 8)}}
    arch[64] = {'in_channels': [item * ch for item in [1, 2, 4, 8]],
                'out_channels': [item * ch for item in [2, 4, 8, 16]],
                'downsample': [True] * 4 + [False],
                'resolution': [32, 16, 8, 4, 4],
                'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
                              for i in range(2, 7)}}
    arch[32] = {'in_channels': [item * ch for item in [4, 4, 4]],
                'out_channels': [item * ch for item in [4, 4, 4]],
                'downsample': [True, True, False, False],
                'resolution': [16, 16, 16, 16],
                'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
                              for i in range(2, 6)}}
    return arch



In [89]:
class Discriminator(nn.Module):

    def __init__(self, D_ch=64, D_wide=True, D_depth=2, resolution=32,
                 D_kernel_size=3, D_attn='64', n_classes=1000,
                 num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
                 D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8,
                 SN_eps=1e-12, output_dim=1, skip_init=False, **kwargs):
        super(Discriminator, self).__init__()
        # Width multiplier
        self.ch = D_ch
        # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
        self.D_wide = D_wide
        # How many resblocks per stage?
        self.D_depth = D_depth
        # Resolution
        self.resolution = resolution
        # Kernel size
        self.kernel_size = D_kernel_size
        # Attention?
        self.attention = D_attn
        # Number of classes
        self.n_classes = n_classes
        # Activation
        self.activation = D_activation

        # Architecture
        self.arch = D_arch(self.ch, self.attention)[resolution]

        # Which convs, batchnorms, and linear layers to use
        # No option to turn off SN in D right now

        self.which_conv = functools.partial(layers.SNConv2d, kernel_size=3, padding=1,)
        self.which_linear = functools.partial(layers.SNLinear)
        self.which_embedding = functools.partial(layers.SNEmbedding)

        # Prepare model
        # Stem convolution
        self.input_conv = self.which_conv(channels, self.arch['in_channels'][0])
        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        self.blocks = []
        for index in range(len(self.arch['out_channels'])):
            self.blocks += [[DBlock(
                in_channels=self.arch['in_channels'][index] if d_index == 0 else self.arch['out_channels'][index],
                out_channels=self.arch['out_channels'][index],
                which_conv=self.which_conv,
                activation=self.activation,
                preactivation=True,
                downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] and d_index == 0 else None))
                             for d_index in range(self.D_depth)]]
            # If attention on this block, attach it to the end
            if self.arch['attention'][self.arch['resolution'][index]]:
                print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
                self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
                                                     self.which_conv)]
        # Turn self.blocks into a ModuleList so that it's all properly registered.
        self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
        # Linear output layer. The output dimension is typically 1, but may be
        # larger if we're e.g. turning this into a VAE with an inference output
        self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
        # Embedding for projection discrimination
        self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])

        # Initialize weights
        if not skip_init:
            self.init_weights()

        # Set up optimizer
        self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps

        self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
                                betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)


    # Initialize
    def init_weights(self):
        self.param_count = 0
        for module in self.modules():
            if (isinstance(module, nn.Conv2d)
                    or isinstance(module, nn.Linear)
                    or isinstance(module, nn.Embedding)):

                init.orthogonal_(module.weight)
               
                self.param_count += sum([p.data.nelement() for p in module.parameters()])
        print('Param count for D''s initialized parameters: %d' % self.param_count)

    def forward(self, x, y=None):
        # Run input conv
        h = self.input_conv(x)
        # Loop over blocks
        for index, blocklist in enumerate(self.blocks):
            for block in blocklist:
                h = block(h)
        # Apply global sum pooling as in SN-GAN
        h = torch.sum(self.activation(h), [2, 3])
        # Get initial class-unconditional output
        out = self.linear(h)
        # Get projection of final featureset onto class vectors and add to evidence
        out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
        return out

## Training

In [90]:

import os

sample_dir = 'generated'
os.makedirs(sample_dir, exist_ok=True)

def save_samples(index, fake_images, show=False):
    fake_fname = 'generated-images-{0:0=4d}.png'.format(index)
    save_image(helpMe.denorm(fake_images[:100]), os.path.join(sample_dir, fake_fname), nrow=10)
    print('Saving', fake_fname)
    if show:
        fig, ax = plt.subplots(figsize=(4, 10))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(vu.make_grid(fake_images.cpu().detach(), nrow=10).permute(1, 2, 0))

In [91]:
import numpy as np

def score(real, fake):
  real = real.cpu().detach().numpy()
  fake = fake.cpu().detach().numpy()

  # Ensure NumPy arrays for efficient vectorized operations
  real = np.asarray(real)
  fake = np.asarray(fake)

  # Count elements using vectorized comparison and summation
  real_count = np.sum(real >= 1)
  fake_count = np.sum(fake <= -1)

  # Calculate percentages with handling for empty lists or arrays
  real_score = real_count / (len(real) if len(real) > 0 else 1)
  fake_score = fake_count / (len(fake) if len(fake) > 0 else 1)

  return real_score, fake_score



In [92]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch.nn.functional as F
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler

# Define the loss function (hinge loss in this example)
def loss_hinge_dis(dis_fake, dis_real):
    loss_real = torch.mean(F.relu(1. - dis_real))
    loss_fake = torch.mean(F.relu(1. + dis_fake))
    return loss_real, loss_fake

def loss_hinge_gen(dis_fake):
    loss = -torch.mean(dis_fake)
    return loss

# Create a mixed precision optimizer for each network
scaler_G = GradScaler()
scaler_D = GradScaler()

# Define the gradient accumulation steps
accumulation_steps = 4

# Training function with tqdm
def train_gan(generator, discriminator, dataloader, num_epochs, z_dim, batch_size, checkpoint_dir=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    discriminator.to(device)
    
    start_epoch = 1
    if checkpoint_dir:
        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pth')
        if os.path.exists(checkpoint_path):
            checkpoint = torch.load(checkpoint_path, map_location=device)
            generator.load_state_dict(checkpoint['generator_state_dict'])
            discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
            generator.optim.load_state_dict(checkpoint['optimizer_G_state_dict'])
            discriminator.optim.load_state_dict(checkpoint['optimizer_D_state_dict'])
            scaler_G.load_state_dict(checkpoint['scaler_G'])
            scaler_D.load_state_dict(checkpoint['scaler_D'])
            start_epoch = checkpoint['epoch'] + 1
            print(f"Resuming training from epoch {start_epoch}.")

    for epoch in range(start_epoch, num_epochs):
        total_d_loss = 0.0
        total_g_loss = 0.0
        
        with tqdm(enumerate(dataloader), total=len(dataloader)) as t:
            for i, (real_images, labels) in t:
                real_images = real_images.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                
                # Train discriminator
                discriminator.optim.zero_grad()
                for a in range(accumulation_steps):
                    # Generate fake images
                    z = torch.randn(batch_size, z_dim, device=device)
                    fake_images = generator(z, generator.shared(labels))
                    
                    # Combine real and fake images
                    all_images = torch.cat([real_images, fake_images], dim=0)
                    all_labels = torch.cat([labels, labels], dim=0)
                    
                    with autocast():
                        dis_output = discriminator(all_images, all_labels)
                        dis_real, dis_fake = torch.chunk(dis_output, 2)
                        loss_real, loss_fake = loss_hinge_dis(dis_fake, dis_real)
                        d_loss = (loss_real + loss_fake) / accumulation_steps
                        # print("rel",a,dis_real)
                        # print("fake",a,dis_fake)
                        
                    scaler_D.scale(d_loss).backward()
                # print("disrel",dis_real)
                # print("disfake",dis_fake)
                real_score, fake_score = score(dis_real, dis_fake)
                # print("real score",real_score)
                # print("fake score",fake_score)
                # return
                scaler_D.step(discriminator.optim)
                scaler_D.update()
                
                total_d_loss += d_loss.item() * accumulation_steps
                
                # Train generator
                generator.optim.zero_grad()
                for _ in range(accumulation_steps):
                    z = torch.randn(batch_size, z_dim, device=device)
                    fake_images = generator(z, generator.shared(labels))
                    
                    with autocast():
                        dis_fake = discriminator(fake_images, labels)
                        g_loss = loss_hinge_gen(dis_fake) / accumulation_steps
                        
                    scaler_G.scale(g_loss).backward()
                    
                scaler_G.step(generator.optim)
                scaler_G.update()
                
                total_g_loss += g_loss.item() * accumulation_steps
                D_loss = total_d_loss / ((i + 1) * accumulation_steps)
                G_loss = total_g_loss / ((i + 1) * accumulation_steps)
                # Print losses
                t.set_description(f'Epoch [{epoch}/{num_epochs}]')
                t.set_postfix({'D_loss': f'{D_loss:.3f}',
                               'G_loss': f'{G_loss:.3f}',
                               'Real_score': f'{real_score:.3f}',
                               'Fake_score': f'{fake_score:.3f}'})
                
                wandb.log({'Loss/Gen': G_loss,
                   'Loss/Dis': D_loss,
                   'Score/Real': real_score,
                   'Score/Fake': fake_score
                   })
                # Clear unused variables to free up memory
                del real_images, labels, fake_images, all_images, all_labels, z, dis_output, dis_real, dis_fake, real_score, fake_score, D_loss, G_loss, a
                torch.cuda.empty_cache()

        # Save generated images
        # save_generated_images(generator, epoch, batch_size, z_dim, checkpoint_dir, device)
        # Save the model
        # save_model(generator, discriminator, epoch, checkpoint_dir)

def save_generated_images(generator, epoch, batch_size, z_dim, path, device):
    os.makedirs(f"{path}Generated", exist_ok=True)
    with torch.no_grad():
        fixed_latent = torch.randn(batch_size, z_dim, device=device)
        fixed_labels = torch.tensor([i % 10 for i in range(batch_size)], device=device)
        fake_images = generator(fixed_latent, generator.shared(fixed_labels))
    torchvision.utils.save_image(fake_images.detach(), f"{path}Generated/generated_images_epoch_{epoch}.png", normalize=True,nrow=10)

def save_model(generator, discriminator, epoch, checkpoint_dir):
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pth')
    torch.save({
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_G_state_dict': generator.optim.state_dict(),
        'optimizer_D_state_dict': discriminator.optim.state_dict(),
        'scaler_G': scaler_G.state_dict(),
        'scaler_D': scaler_D.state_dict()
    }, checkpoint_path)


train_dl = DataLoader(dataset, batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

G = Generator(resolution=image_size, n_classes=len(class_list))
D = Discriminator(resolution=image_size, n_classes=len(class_list))

print(G)
print(D)
print('Number of params in G: {} D: {}'.format(
*[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]]))
checkpoint_dir = f"Models/{model_name}/"  # Directory to save or load the checkpoint
train_gan(G, D, train_dl, num_epochs=epochs, z_dim=z_dim, batch_size=batch_size, checkpoint_dir=checkpoint_dir)


Param count for Gs initialized parameters: 3074177
Param count for Ds initialized parameters: 647041
Generator(
  (activation): ReLU()
  (shared): Embedding(10, 128)
  (linear): SNLinear(in_features=256, out_features=4096, bias=True)
  (blocks): ModuleList(
    (0-5): 6 x ModuleList(
      (0): GBlock(
        (activation): ReLU()
        (conv1): SNConv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
        (conv2): SNConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv3): SNConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv4): SNConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (bn1): ccbn(
          out: 256, in: 256, cross_replica=False
          (gain): SNLinear(in_features=256, out_features=256, bias=False)
          (bias): SNLinear(in_features=256, out_features=256, bias=False)
        )
        (bn2): ccbn(
          out: 64, in: 256, cross_replica=False
          (gain): SNLinear(in_features=256, out_features=64,

Epoch [10/11]:   3%|▎         | 16/600 [00:22<13:29,  1.39s/it, D_loss=0.484, G_loss=0.040, Real_score=0.010, Fake_score=0.010]


KeyboardInterrupt: 

In [None]:
wandb.finish()