In [1]:
!pip install torch torchvision



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils

# Set random seed for reproducibility
torch.manual_seed(42)

# Hyperparameters
batch_size = 64
num_generator_updates = 500000
adam_beta1 = 0.5
adam_beta2 = 0.9
gradient_penalty_weight = 10
lr_generator = 2e-5
lr_discriminator = 2e-4
beta_ema = 0.9999

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = datasets.CIFAR10(root='./data', download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Generator
class Generator(nn.Module):
    def __init__(self, z_dim=128):
        super(Generator, self).__init__()

        self.linear = nn.Linear(z_dim, 128 * 4 * 4)
        self.resblocks = nn.Sequential(
            ResBlock(128),
            ResBlock(128),
            ResBlock(128)
        )
        self.batch_norm = nn.BatchNorm2d(128)
        self.relu = nn.ReLU()
        self.transposed_conv = nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, z):
        x = self.linear(z)
        x = x.view(-1, 128, 4, 4)
        x = self.resblocks(x)
        x = self.batch_norm(x)
        x = self.relu(x)
        x = self.transposed_conv(x)
        x = self.tanh(x)
        return x

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.resblocks = nn.Sequential(
            ResBlock(3),
            ResBlock(128),
            ResBlock(128),
            ResBlock(128)
        )
        self.linear = nn.Linear(128, 1)

    def forward(self, x):
        x = self.resblocks(x)
        x = x.mean(dim=(2, 3))  # Global average pooling
        x = self.linear(x)
        return x

# Residual Block
class ResBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, 128, kernel_size=3, stride=1, padding=1)
        self.batch_norm1 = nn.BatchNorm2d(128)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(128, in_channels, kernel_size=3, stride=1, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(in_channels)

    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.batch_norm1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.batch_norm2(x)
        x += residual
        x = self.relu(x)
        return x

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Optimizers
optimizer_generator = optim.Adam(generator.parameters(), lr=lr_generator, betas=(adam_beta1, adam_beta2))
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=lr_discriminator, betas=(adam_beta1, adam_beta2))

import torch

def compute_discriminator_loss(real_data, fake_data, discriminator):
    # Real data loss
    output_real = discriminator(real_data)
    loss_real = -torch.mean(output_real)

    # Fake data loss
    output_fake = discriminator(fake_data.detach())
    loss_fake = torch.mean(output_fake)

    # Total loss
    loss_discriminator = loss_real + loss_fake

    return loss_discriminator

def compute_generator_loss(fake_data, discriminator):
    # Generator loss
    output_fake = discriminator(fake_data)
    loss_generator = -torch.mean(output_fake)

    return loss_generator

def compute_gradient_penalty(real_data, fake_data, discriminator):
    # Interpolation
    alpha = torch.rand(real_data.size(0), 1, 1, 1).to(real_data.device)
    interpolates = alpha * real_data + (1 - alpha) * fake_data
    interpolates.requires_grad_(True)

    # Calculate discriminator scores
    disc_interpolates = discriminator(interpolates)

    # Compute gradients
    gradients = torch.autograd.grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones(disc_interpolates.size()).to(real_data.device),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    # Compute gradient penalty
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * gradient_penalty_weight

    return gradient_penalty

# EMA update function
def update_ema(model_params, beta_ema=0.9999):
    ema_model_params = [ema_param * beta_ema + (1.0 - beta_ema) * param
                        for ema_param, param in zip(ema_model_params, model_params)]

# Assuming ema_model_params is a list that stores the EMA parameters
ema_model_params = [p.data.clone() for p in generator.parameters()]

# Training loop
for update in range(num_generator_updates):
    for real_data in dataloader:
        # Train discriminator
        discriminator.zero_grad()

        # Generate fake data
        fake_data = generator(torch.randn(batch_size, 128))  # Assuming z_dim is 128

        # Compute WGAN loss with gradient penalty
        loss_discriminator = compute_discriminator_loss(real_data, fake_data, discriminator)
        loss_discriminator.backward()
        optimizer_discriminator.step()

        # Train generator
        generator.zero_grad()

        # Generate fake data
        fake_data = generator(torch.randn(batch_size, 128))  # Assuming z_dim is 128

        # Compute generator loss
        loss_generator = compute_generator_loss(fake_data, discriminator)
        loss_generator.backward()
        optimizer_generator.step()

        # Gradient penalty
        gradient_penalty = compute_gradient_penalty(real_data, fake_data, discriminator)
        gradient_penalty.backward()
        optimizer_discriminator.step()

        # Update EMA parameters
        update_ema(generator.parameters(), beta_ema)

    # Print losses or save images for evaluation
    if update % 100 == 0:
        print(f'Update: {update}, Generator Loss: {loss_generator.item()}, Discriminator Loss: {loss_discriminator.item()}')

    if update % 500 == 0:
        with torch.no_grad():
            generator.eval()
            fake_samples = generator(torch.randn(64, 128))  # Assuming z_dim is 128
            vutils.save_image(fake_samples, f'generated_samples_{update}.png', normalize=True)

# Evaluate the generator (generate samples)
with torch.no_grad():
    generator.eval()
    fake_samples = generator(torch.randn(64, 128))  # Assuming z_dim is 128
    vutils.save_image(fake_samples, 'generated_samples.png', normalize=True)

# Save the trained models
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 74463688.13it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data




TypeError: ignored

#Optimizer

In [None]:
import torch
from torch.optim import Optimizer

class CustomAdam(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0 or not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameters")

        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
        super(CustomAdam, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                # Sample new mini-batch and compute stochastic gradient
                g = p.grad.data

                # Update estimate of first moment
                group['exp_avg'] = group['betas'][0] * group['exp_avg'] + (1 - group['betas'][0]) * g

                # Update estimate of second moment
                group['exp_avg_sq'] = group['betas'][1] * group['exp_avg_sq'] + (1 - group['betas'][1]) * g ** 2

                # Perform update step
                if group['amsgrad']:
                    # Maintain the maximum of all second moment running averages
                    group['max_exp_avg_sq'] = torch.max(group['max_exp_avg_sq'], group['exp_avg_sq'])
                    denom = group['max_exp_avg_sq'].sqrt() + group['eps']
                else:
                    denom = group['exp_avg_sq'].sqrt() + group['eps']

                p.data.addcdiv_(-group['lr'], group['exp_avg'], denom)

        return loss

# Usage example
optimizer = CustomAdam(model.parameters(), lr=0.001, betas=(0.9, 0.999))