In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from torchvision.utils import save_image

# Define Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(100, 512 * 4 * 4),
            nn.BatchNorm1d(512 * 4 * 4),
            nn.ReLU(True)
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 512, 4, 4)
        return self.deconv(x)

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.fc = nn.Sequential(
            nn.Linear(512 * 2 * 2, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Training parameters
num_epochs = 100
batch_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transformation: Convert to grayscale, Resize to 32x32, normalize to [-1, 1]
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Define dataset path
dataset_path = '/kaggle/input/gujarati-ocr/Gujarati/Train'
output_path = './generated_images'  # Path to save generated images
os.makedirs(output_path, exist_ok=True)

# Load full dataset
full_dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
class_names = full_dataset.classes

# Iterate over all classes
for target_class_label, class_name in enumerate(class_names):
    # Filter and load images for the current class
    target_class_indices = [i for i, (_, label) in enumerate(full_dataset) if label == target_class_label]
    target_class_dataset = Subset(full_dataset, target_class_indices)
    train_loader = DataLoader(target_class_dataset, batch_size=batch_size, shuffle=True)

    # Initialize GAN models, criterion, and optimizers for each class
    generator = Generator().to(device)
    discriminator = Discriminator().to(device)
    criterion = nn.BCELoss()
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    # Labels for real and fake data
    real_label = 1.0
    fake_label = 0.0

    # Train the GAN for the current class
    for epoch in range(num_epochs):
        for i, (images, _) in enumerate(train_loader):
            images = images.to(device)

            # Train Discriminator
            optimizer_D.zero_grad()
            output_real = discriminator(images).view(-1)
            loss_real = criterion(output_real, torch.full((images.size(0),), real_label, device=device))
            loss_real.backward()

            noise = torch.randn(images.size(0), 100, device=device)
            fake_images = generator(noise)
            output_fake = discriminator(fake_images.detach()).view(-1)
            loss_fake = criterion(output_fake, torch.full((images.size(0),), fake_label, device=device))
            loss_fake.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            output_fake = discriminator(fake_images).view(-1)
            loss_G = criterion(output_fake, torch.full((images.size(0),), real_label, device=device))
            loss_G.backward()
            optimizer_G.step()

            # Log progress
            if i % 50 == 0:
                print(f"Class: {class_name}, Epoch [{epoch}/{num_epochs}], Step [{i}/{len(train_loader)}], D Loss: {loss_real.item() + loss_fake.item():.4f}, G Loss: {loss_G.item():.4f}")

    # Generate and save 100 images for the current class
    generator.eval()
    save_class_path = os.path.join(output_path, class_name)
    os.makedirs(save_class_path, exist_ok=True)
    with torch.no_grad():
        for img_num in range(100):
            noise = torch.randn(1, 100, device=device)
            fake_image = generator(noise)
            save_image(fake_image, os.path.join(save_class_path, f"{img_num+1}.png"), normalize=True)

    print(f"Generated images for class '{class_name}' saved in '{save_class_path}'")


Class: A, Epoch [0/100], Step [0/3], D Loss: 1.4437, G Loss: 2.1200
Class: A, Epoch [1/100], Step [0/3], D Loss: 0.1678, G Loss: 4.5438
Class: A, Epoch [2/100], Step [0/3], D Loss: 0.0629, G Loss: 5.6918
Class: A, Epoch [3/100], Step [0/3], D Loss: 0.0215, G Loss: 5.9926
Class: A, Epoch [4/100], Step [0/3], D Loss: 0.0159, G Loss: 6.1338
Class: A, Epoch [5/100], Step [0/3], D Loss: 0.0067, G Loss: 6.5560
Class: A, Epoch [6/100], Step [0/3], D Loss: 0.0070, G Loss: 6.4354
Class: A, Epoch [7/100], Step [0/3], D Loss: 0.0066, G Loss: 6.5628
Class: A, Epoch [8/100], Step [0/3], D Loss: 0.0045, G Loss: 6.8400
Class: A, Epoch [9/100], Step [0/3], D Loss: 0.0045, G Loss: 6.7123
Class: A, Epoch [10/100], Step [0/3], D Loss: 0.0026, G Loss: 7.5535
Class: A, Epoch [11/100], Step [0/3], D Loss: 0.0016, G Loss: 7.5607
Class: A, Epoch [12/100], Step [0/3], D Loss: 0.0023, G Loss: 7.3801
Class: A, Epoch [13/100], Step [0/3], D Loss: 0.0016, G Loss: 7.5515
Class: A, Epoch [14/100], Step [0/3], D Loss