#Optimizer

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils
import math
from tqdm import tqdm

In [None]:

class CustomAdam(optim.Optimizer):
    def __init__(self, params, lr=2e-4, betas=(0.5, 0.9), eps=1e-8):
        defaults = dict(lr=lr, betas=betas, eps=eps)
        super(CustomAdam, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                beta1, beta2 = group['betas']

                state['step'] += 1

                # Update biased first moment estimate
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                # Update biased second moment estimate
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

                # Bias correction
                bias_correction = 1 - beta2 ** state['step']
                denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction)).add_(group['eps'])

                # Update parameters
                step_size = group['lr']
                p.data.addcdiv_(-step_size, exp_avg, denom)

        return loss

In [None]:
# Set random seed for reproducibility
torch.manual_seed(42)

# Hyperparameters
batch_size = 64
num_generator_updates = 1
adam_beta1 = 0.5
adam_beta2 = 0.9
gradient_penalty_weight = 10
lr_generator = 2e-5
lr_discriminator = 2e-4
beta_ema = 0.9999

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = datasets.CIFAR10(root='./data', download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Generator
class Generator(nn.Module):
    def __init__(self, z_dim=128):
        super(Generator, self).__init__()

        self.linear = nn.Linear(z_dim, 128 * 4 * 4)
        self.resblocks = nn.Sequential(
            ResBlock(128),
            ResBlock(128),
            ResBlock(128)
        )
        self.batch_norm = nn.BatchNorm2d(128)
        self.relu = nn.ReLU()

        # Add additional transposed convolution layers for upsampling
        self.transposed_conv1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.transposed_conv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        self.transposed_conv3 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1)

        self.tanh = nn.Tanh()

    def forward(self, z):
        x = self.linear(z)
        x = x.view(-1, 128, 4, 4)
        x = self.resblocks(x)
        x = self.batch_norm(x)
        x = self.relu(x)

        # Upsample to 32x32
        x = self.transposed_conv1(x)
        x = self.transposed_conv2(x)
        x = self.transposed_conv3(x)

        x = self.tanh(x)
        return x

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.resblocks = nn.Sequential(
            ResBlock(3),
            ResBlock(128),
            ResBlock(128),
            ResBlock(128)
        )
        self.linear = nn.Linear(128, 1)

    def forward(self, x):
        x = self.resblocks(x)
        x = x.mean(dim=(2, 3))  # Global average pooling
        x = self.linear(x)
        return x

# Residual Block
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels=128, stride=1):
        super(ResBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = self.shortcut(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += residual
        out = self.relu(out)

        return out

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Optimizers
'''
optimizer_generator = optim.Adam(generator.parameters(), lr=lr_generator, betas=(adam_beta1, adam_beta2))
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=lr_discriminator, betas=(adam_beta1, adam_beta2))
'''
optimizer_generator = CustomAdam(generator.parameters(), lr_generator, betas=(adam_beta1, adam_beta2))
optimizer_discriminator = CustomAdam(discriminator.parameters(), lr=lr_discriminator,betas=(adam_beta1, adam_beta2))


import torch

loss_function = torch.nn.BCELoss()

def compute_discriminator_loss(output_discriminator, all_samples_labels):
    # print(output_discriminator)
    # print(all_samples_labels)
    return loss_function(output_discriminator, all_samples_labels)

def compute_generator_loss(output_discriminator_generated, all_samples_labels):
    return loss_function(output_discriminator_generated, all_samples_labels)

def compute_gradient_penalty(real_data, fake_data, discriminator):
    # Interpolation
    alpha = torch.rand(real_data.size(0), 1, 1, 1).to(real_data.device)
    print(alpha.shape)
    print(real_data.shape)
    print(fake_data.shape)
    interpolates = alpha * real_data + (1 - alpha) * fake_data
    interpolates.requires_grad_(True)

    # Calculate discriminator scores
    disc_interpolates = discriminator(interpolates)

    # Compute gradients
    gradients = torch.autograd.grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones(disc_interpolates.size()).to(real_data.device),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    # Compute gradient penalty
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * gradient_penalty_weight

    return gradient_penalty

# Assuming ema_model_params is a list that stores the EMA parameters
ema_model_params = [p.data.clone() for p in generator.parameters()]

# Training loop
for update in tqdm(range(num_generator_updates), desc="Generator Updates"):

    for real_data in tqdm(dataloader, desc="Dataloader", leave=False):

        # Train discriminator
        discriminator.zero_grad()

        # Generate fake data
        fake_data = generator(torch.randn(batch_size, 128))  # Assuming z_dim is 128

        # Compute WGAN loss with gradient penalty
        loss_discriminator = compute_discriminator_loss(real_data[0], fake_data, discriminator)
        loss_discriminator.backward()
        optimizer_discriminator.step()

        # Train generator
        generator.zero_grad()

        # Generate fake data
        fake_data = generator(torch.randn(batch_size, 128))  # Assuming z_dim is 128

        # Compute generator loss
        loss_generator = compute_generator_loss(fake_data, discriminator)
        loss_generator.backward()
        optimizer_generator.step()

        # Gradient penalty
        # gradient_penalty = compute_gradient_penalty(real_data[0], fake_data, discriminator)
        # gradient_penalty.backward(retain_graph=True)
        # optimizer_discriminator.step()

        # Update EMA parameters
        # update_ema(generator.parameters(), beta_ema)

    # Print losses or save images for evaluation
    if update % 100 == 0:
        print(f'Update: {update}, Generator Loss: {loss_generator.item()}, Discriminator Loss: {loss_discriminator.item()}')

    if update % 500 == 0:
        with torch.no_grad():
            generator.eval()
            fake_samples = generator(torch.randn(64, 128))  # Assuming z_dim is 128
            vutils.save_image(fake_samples, f'generated_samples_{update}.png', normalize=True)

# Evaluate the generator (generate samples)
with torch.no_grad():
    generator.eval()
    fake_samples = generator(torch.randn(64, 128))  # Assuming z_dim is 128
    vutils.save_image(fake_samples, 'generated_samples.png', normalize=True)

# Save the trained models
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')


Files already downloaded and verified


Generator Updates:   0%|          | 0/1 [00:00<?, ?it/s]
Dataloader:   0%|          | 0/782 [00:00<?, ?it/s][A