In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='/content/drive/MyDrive/Abstract', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

# Hyperparameters
latent_dim = 100
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
num_epochs = 10

# Define the generator
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128 * 8 * 8),
            nn.ReLU(),
            nn.Unflatten(1, (128, 8, 8)),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128, momentum=0.78),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64, momentum=0.78),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img

# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.25),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ZeroPad2d((0, 1, 0, 1)),
            nn.BatchNorm2d(64, momentum=0.82),
            nn.LeakyReLU(0.25),
            nn.Dropout(0.25),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128, momentum=0.82),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.25),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256, momentum=0.8),
            nn.LeakyReLU(0.25),
            nn.Dropout(0.25),
            nn.Flatten(),
            nn.Linear(256 * 5 * 5, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        validity = self.model(img)
        return validity.view(img.size(0), -1)  # Flatten the output

# Instantiate the generator and discriminator
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

# Loss function
adversarial_loss = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

# Training loop
for epoch in range(num_epochs):
    for i, batch in enumerate(dataloader):
        real_images = batch[0].to(device)
        valid = torch.ones(real_images.size(0), 1, device=device)
        fake = torch.zeros(real_images.size(0), 1, device=device)

        # Train Discriminator
        optimizer_D.zero_grad()
        z = torch.randn(real_images.size(0), latent_dim, device=device)
        fake_images = generator(z)
        real_loss = adversarial_loss(discriminator(real_images), valid)
        fake_loss = adversarial_loss(discriminator(fake_images.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        gen_images = generator(z)
        g_loss = adversarial_loss(discriminator(gen_images), valid)
        g_loss.backward()
        optimizer_G.step()

        # Progress Monitoring
        if (i + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] Batch {i+1}/{len(dataloader)} "
                  f"Discriminator Loss: {d_loss.item():.4f} Generator Loss: {g_loss.item():.4f}")

    # Save generated images for every epoch
    if (epoch + 1) % 10 == 0:
        with torch.no_grad():
            z = torch.randn(16, latent_dim, device=device)
            generated = generator(z).detach().cpu()
            grid = torchvision.utils.make_grid(generated, nrow=4, normalize=True)
            plt.imshow(np.transpose(grid, (1, 2, 0)))
            plt.axis("off")
            plt.show()


Files already downloaded and verified
Epoch [1/10] Batch 100/1563 Discriminator Loss: 0.5082 Generator Loss: 1.1477
Epoch [1/10] Batch 200/1563 Discriminator Loss: 0.6927 Generator Loss: 0.8794
Epoch [1/10] Batch 300/1563 Discriminator Loss: 0.6862 Generator Loss: 1.0297
Epoch [1/10] Batch 400/1563 Discriminator Loss: 0.5986 Generator Loss: 1.1756
Epoch [1/10] Batch 500/1563 Discriminator Loss: 0.6212 Generator Loss: 1.1365
Epoch [1/10] Batch 600/1563 Discriminator Loss: 0.7207 Generator Loss: 0.7874
Epoch [1/10] Batch 700/1563 Discriminator Loss: 0.6555 Generator Loss: 0.9291
Epoch [1/10] Batch 800/1563 Discriminator Loss: 0.6734 Generator Loss: 0.9019
Epoch [1/10] Batch 900/1563 Discriminator Loss: 0.5815 Generator Loss: 1.0152
Epoch [1/10] Batch 1000/1563 Discriminator Loss: 0.7481 Generator Loss: 0.8303
Epoch [1/10] Batch 1100/1563 Discriminator Loss: 0.4935 Generator Loss: 1.2252
Epoch [1/10] Batch 1200/1563 Discriminator Loss: 0.7237 Generator Loss: 0.9293
Epoch [1/10] Batch 1300

In [None]:
import torchvision.utils

# Generate new data
with torch.no_grad():
    num_samples = 10  # You can adjust this based on your preference
    z = torch.randn(num_samples, latent_dim, device=device)
    generated_images = generator(z).detach().cpu()

# Visualize the original and generated images
real_images_batch = next(iter(dataloader))[0][:num_samples].cpu()

# Plot the original and generated images
fig, axes = plt.subplots(2, num_samples, figsize=(num_samples * 2, 4))

for i in range(num_samples):
    axes[0, i].imshow(np.transpose(real_images_batch[i], (1, 2, 0)))
    axes[0, i].axis('off')
    axes[0, i].set_title('Original')

    axes[1, i].imshow(np.transpose(generated_images[i], (1, 2, 0)))
    axes[1, i].axis('off')
    axes[1, i].set_title('Generated')

plt.show()
