In [1]:
# Install necessary libraries if not already installed
!pip install torch torchvision matplotlib scipy numpy


Defaulting to user installation because normal site-packages is not writeable


In [2]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Load and preprocess CIFAR-10 dataset
def preprocess_data():
    transform = transforms.Compose([
        transforms.Resize(64),  # Resize to 64x64
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
    ])
    dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
    return dataloader

dataloader = preprocess_data()



Files already downloaded and verified


In [3]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, noise_dim, text_dim):
        super(Generator, self).__init__()
        input_dim = noise_dim + text_dim
        self.main = nn.Sequential(
            nn.Linear(input_dim, 4 * 4 * 512),
            nn.ReLU(True),
            nn.Unflatten(1, (512, 4, 4)),
            
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(64, 3, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
        )

    def forward(self, noise, text):
        x = torch.cat((noise, text), dim=1)
        return self.main(x)


In [4]:
class Discriminator(nn.Module):
    def __init__(self, text_dim):
        super(Discriminator, self).__init__()
        self.text_embed = nn.Sequential(
            nn.Linear(text_dim, 256),
            nn.ReLU(True),
            nn.Unflatten(1, (256, 1, 1))
        )
        
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.final = nn.Sequential(
            nn.Conv2d(512 + 256, 512, kernel_size=1),
            nn.Flatten(),
            nn.Linear(512 * 4 * 4, 1),
            nn.Sigmoid()
        )

    def forward(self, image, text):
        text_emb = self.text_embed(text).expand(-1, -1, 4, 4)
        image_features = self.main(image)
        combined = torch.cat((image_features, text_emb), dim=1)
        return self.final(combined)


In [5]:
import torch.optim as optim

def train(generator, discriminator, dataloader, epochs, noise_dim, text_dim, device):
    # Optimizers and loss
    gen_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    disc_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    criterion = nn.BCELoss()

    generator.to(device)
    discriminator.to(device)

    for epoch in range(epochs):
        for images, _ in dataloader:
            images = images.to(device)
            batch_size = images.size(0)

            # Labels
            real_labels = torch.full((batch_size, 1), 0.9, device=device)
            fake_labels = torch.full((batch_size, 1), 0.1, device=device)


            # Train Discriminator
            noise = torch.randn(batch_size, noise_dim).to(device)
            text = torch.randn(batch_size, text_dim).to(device)
            fake_images = generator(noise, text)

            disc_real = discriminator(images, text)
            disc_fake = discriminator(fake_images.detach(), text)

            disc_loss_real = criterion(disc_real, real_labels)
            disc_loss_fake = criterion(disc_fake, fake_labels)
            disc_loss = disc_loss_real + disc_loss_fake

            disc_optimizer.zero_grad()
            disc_loss.backward()
            disc_optimizer.step()

            # Train Generator
            fake_images = generator(noise, text)
            disc_fake = discriminator(fake_images, text)
            gen_loss = criterion(disc_fake, real_labels)

            gen_optimizer.zero_grad()
            gen_loss.backward()
            gen_optimizer.step()

        print(f"Epoch [{epoch+1}/{epochs}] | Gen Loss: {gen_loss.item()} | Disc Loss: {disc_loss.item()}")


In [6]:
from torchvision.models import inception_v3
import numpy as np
from scipy.linalg import sqrtm

# Inception Score
def inception_score(images, splits=10):
    inception_model = inception_v3(pretrained=True, transform_input=False).eval().to(images.device)
    images_resized = torch.nn.functional.interpolate(images, size=(299, 299), mode='bilinear')

    preds = []
    with torch.no_grad():
        for img_batch in torch.split(images_resized, 32):
            pred = inception_model(img_batch)
            preds.append(torch.nn.functional.softmax(pred, dim=1))
    preds = torch.cat(preds, dim=0)

    split_scores = []
    N = preds.size(0)
    for k in range(splits):
        part = preds[k * (N // splits): (k + 1) * (N // splits), :]
        py = torch.mean(part, dim=0)
        scores = torch.exp(torch.sum(part * torch.log(part / py[None, :]), dim=1))
        split_scores.append(scores.mean().item())
    return np.mean(split_scores), np.std(split_scores)

# Frechet Inception Distance
def calculate_fid(real_images, fake_images):
    inception_model = inception_v3(pretrained=True, transform_input=False).eval().to(real_images.device)

    def get_activations(images):
        images_resized = torch.nn.functional.interpolate(images, size=(299, 299), mode='bilinear')
        with torch.no_grad():
            activations = []
            for img_batch in torch.split(images_resized, 32):
                act = inception_model(img_batch).detach()
                activations.append(act)
        return torch.cat(activations, dim=0).cpu().numpy()

    real_activations = get_activations(real_images)
    fake_activations = get_activations(fake_images)

    mu_real, sigma_real = real_activations.mean(axis=0), np.cov(real_activations, rowvar=False)
    mu_fake, sigma_fake = fake_activations.mean(axis=0), np.cov(fake_activations, rowvar=False)

    diff = mu_real - mu_fake
    covmean = sqrtm(sigma_real.dot(sigma_fake))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    return float(diff.dot(diff) + np.trace(sigma_real + sigma_fake - 2 * covmean))


In [7]:
import matplotlib.pyplot as plt

def generate_and_save_images(generator, epoch, noise_dim, text_dim, device, num_examples=10):
    generator.eval()
    noise = torch.randn(num_examples, noise_dim).to(device)
    text = torch.randn(num_examples, text_dim).to(device)
    fake_images = generator(noise, text).cpu().detach()

    fake_images = (fake_images + 1) / 2  # Rescale to [0, 1]

    plt.figure(figsize=(10, 2))
    for i in range(num_examples):
        plt.subplot(1, num_examples, i + 1)
        plt.imshow(fake_images[i].permute(1, 2, 0).numpy())
        plt.axis("off")
    plt.savefig(f"generated_images_epoch_{epoch}.png")
    plt.show()


In [8]:
# Hyperparameters
noise_dim = 100
text_dim = 119
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models
generator = Generator(noise_dim, text_dim)
discriminator = Discriminator(text_dim)

# Train models
train(generator, discriminator, dataloader, epochs, noise_dim, text_dim, device)

# Generate images and evaluate
num_samples = 1000
real_images = next(iter(dataloader))[0][:num_samples].to(device)
noise = torch.randn(num_samples, noise_dim).to(device)
text = torch.randn(num_samples, text_dim).to(device)
fake_images = generator(noise, text)

# Evaluation metrics
mean_is, std_is = inception_score(fake_images)
print(f"Inception Score: {mean_is} ± {std_is}")

fid_score = calculate_fid(real_images, fake_images)
print(f"FID Score: {fid_score}")


Epoch [1/50] | Gen Loss: 2.62575626373291 | Disc Loss: 0.7526130676269531
Epoch [2/50] | Gen Loss: 2.7964208126068115 | Disc Loss: 0.7040140628814697
Epoch [3/50] | Gen Loss: 2.083345413208008 | Disc Loss: 0.6548295021057129
Epoch [4/50] | Gen Loss: 2.061908721923828 | Disc Loss: 0.6509606838226318
Epoch [5/50] | Gen Loss: 2.0748696327209473 | Disc Loss: 0.6652166843414307
Epoch [6/50] | Gen Loss: 1.9977624416351318 | Disc Loss: 0.6566369533538818
Epoch [7/50] | Gen Loss: 2.1267926692962646 | Disc Loss: 0.6509142518043518
Epoch [8/50] | Gen Loss: 1.266461968421936 | Disc Loss: 1.1185492277145386
Epoch [9/50] | Gen Loss: 1.9870193004608154 | Disc Loss: 0.6522029638290405
Epoch [10/50] | Gen Loss: 2.0760927200317383 | Disc Loss: 0.650847315788269
Epoch [11/50] | Gen Loss: 2.0695438385009766 | Disc Loss: 0.6502892971038818
Epoch [12/50] | Gen Loss: 2.1052074432373047 | Disc Loss: 0.6511006355285645
Epoch [13/50] | Gen Loss: 2.037440776824951 | Disc Loss: 0.6516380906105042
Epoch [14/50] |



Inception Score: 2.4239993572235106 ± 0.23913216295525017
FID Score: 994.4297109825263
