In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import os
import numpy as np

In [2]:
# Create output directory if it doesn't exist
os.makedirs('output', exist_ok=True)

In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)


In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create the generator and discriminator
netG = Generator().to(device)
netD = Discriminator().to(device)


In [6]:
# Initialize the weights
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG.apply(weights_init)
netD.apply(weights_init)

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Dropout(p=0.3, inplace=False)
    (3): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Dropout(p=0.3, inplace=False)
    (7): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): LeakyReLU(negative_slope=0.2, inplace=True)
    (10): Dropout(p=0.3, inplace=False)
    (11): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (12): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): LeakyReLU(negative_slope=0.2, inplace=True)
    (14)

In [7]:
# Loss and optimizers
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0001, betas=(0.5, 0.999))

schedulerD = torch.optim.lr_scheduler.StepLR(optimizerD, step_size=10, gamma=0.1)
schedulerG = torch.optim.lr_scheduler.StepLR(optimizerG, step_size=10, gamma=0.1)

num_epochs = 50
batch_size = 64
latent_vector_size = 100

In [8]:
# Data transformation for CIFAR-10
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


In [9]:
# Load CIFAR-10 dataset
dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:15<00:00, 11155265.45it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [10]:
# Function to convert RGB images to grayscale
def rgb_to_grayscale(batch):
    return batch.mean(dim=1, keepdim=True)


In [11]:
# Training loop
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        # Update Discriminator
        netD.zero_grad()
        real_images, _ = data
        real_images = real_images.to(device)

        # Convert to grayscale
        grayscale_images = rgb_to_grayscale(real_images)

        batch_size = real_images.size(0)
        real_labels = torch.full((batch_size,), 0.9, dtype=torch.float, device=device)
        fake_labels = torch.full((batch_size,), 0.1, dtype=torch.float, device=device)

        output = netD(real_images).view(-1)
        errD_real = criterion(output, real_labels)
        errD_real.backward()

        noise = torch.randn(batch_size, latent_vector_size, 1, 1, device=device)
        fake_images = netG(noise)
        output = netD(fake_images.detach()).view(-1)
        errD_fake = criterion(output, fake_labels)
        errD_fake.backward()
        optimizerD.step()

        # Update Generator
        netG.zero_grad()
        real_labels.fill_(1)
        output = netD(fake_images).view(-1)
        errG = criterion(output, real_labels)
        errG.backward()
        optimizerG.step()

        if i % 100 == 0:
            print(f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] Loss_D: {errD_real.item() + errD_fake.item()} Loss_G: {errG.item()}')

    # Save generated images every epoch
    with torch.no_grad():
        noise = torch.randn(64, latent_vector_size, 1, 1, device=device)
        fake_images = netG(noise).detach().cpu()
        save_image(fake_images, f'output/fake_images_epoch_{epoch}.png', normalize=True)

    # Adjust learning rate
    schedulerD.step()
    schedulerG.step()

print('Training finished.')

[0/50][0/782] Loss_D: 1.638987421989441 Loss_G: 1.3647303581237793
[0/50][100/782] Loss_D: 0.8063732087612152 Loss_G: 2.396559000015259
[0/50][200/782] Loss_D: 0.8122829794883728 Loss_G: 3.898791790008545
[0/50][300/782] Loss_D: 0.8122454881668091 Loss_G: 2.5928597450256348
[0/50][400/782] Loss_D: 0.8059958219528198 Loss_G: 3.5814950466156006
[0/50][500/782] Loss_D: 0.9128357768058777 Loss_G: 3.64278507232666
[0/50][600/782] Loss_D: 0.8553335070610046 Loss_G: 5.585565567016602
[0/50][700/782] Loss_D: 0.8134957551956177 Loss_G: 3.0699892044067383


  return F.conv_transpose2d(


[1/50][0/782] Loss_D: 0.7672998011112213 Loss_G: 2.7860302925109863
[1/50][100/782] Loss_D: 0.9705862402915955 Loss_G: 2.52182674407959
[1/50][200/782] Loss_D: 0.7904083728790283 Loss_G: 3.916564464569092
[1/50][300/782] Loss_D: 0.866606742143631 Loss_G: 2.238795518875122
[1/50][400/782] Loss_D: 0.76784548163414 Loss_G: 2.450723171234131
[1/50][500/782] Loss_D: 0.7976526021957397 Loss_G: 2.8430893421173096
[1/50][600/782] Loss_D: 0.9038259387016296 Loss_G: 3.626544237136841
[1/50][700/782] Loss_D: 0.7813024520874023 Loss_G: 2.3806233406066895
[2/50][0/782] Loss_D: 0.9762203097343445 Loss_G: 3.2030699253082275
[2/50][100/782] Loss_D: 0.9232001006603241 Loss_G: 1.9121179580688477
[2/50][200/782] Loss_D: 0.9463008344173431 Loss_G: 4.486671447753906
[2/50][300/782] Loss_D: 0.7438027858734131 Loss_G: 3.7400026321411133
[2/50][400/782] Loss_D: 0.9862045645713806 Loss_G: 1.543668508529663
[2/50][500/782] Loss_D: 0.9402974247932434 Loss_G: 3.53641939163208
[2/50][600/782] Loss_D: 0.82357224822

In [12]:
!zip -r output.zip output

  adding: output/ (stored 0%)
  adding: output/fake_images_epoch_7.png (deflated 0%)
  adding: output/fake_images_epoch_33.png (deflated 0%)
  adding: output/fake_images_epoch_29.png (deflated 0%)
  adding: output/fake_images_epoch_47.png (deflated 0%)
  adding: output/fake_images_epoch_38.png (deflated 0%)
  adding: output/fake_images_epoch_24.png (deflated 0%)
  adding: output/fake_images_epoch_39.png (deflated 0%)
  adding: output/fake_images_epoch_32.png (deflated 0%)
  adding: output/fake_images_epoch_42.png (deflated 0%)
  adding: output/fake_images_epoch_5.png (deflated 0%)
  adding: output/fake_images_epoch_26.png (deflated 0%)
  adding: output/fake_images_epoch_6.png (deflated 0%)
  adding: output/fake_images_epoch_40.png (deflated 0%)
  adding: output/fake_images_epoch_35.png (deflated 0%)
  adding: output/fake_images_epoch_30.png (deflated 0%)
  adding: output/fake_images_epoch_11.png (deflated 0%)
  adding: output/fake_images_epoch_28.png (deflated 0%)
  adding: output/fake

In [13]:
from google.colab import files
files.download('output.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [15]:
# Assuming netG is your trained Generator model
torch.save(netG.state_dict(), 'saved_generator_weights.pth')

# Download the saved weights
from google.colab import files
files.download('saved_generator_weights.pth')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>