In [86]:
import pandas as pd
import numpy as np
import torch
from PIL import Image
import math
import os

In [151]:
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T

class HW_Dataset(Dataset):
    
    def __init__(self, data_root, transform=None, max_seq_len=50):
        self.data_root = data_root
        self.transform = transform
        self.max_seq_len = max_seq_len
        
        self.samples = self.load_dataset()
        self.char_to_int = self.build_vocab()
        
    def load_dataset(self):
        labelled_pairs = []
        with open(f"{self.data_root}/train.txt", 'r', encoding='utf-8') as file:
            file.seek(0)
            lines = file.readlines()
            for line in lines:
                line = line.split()
                labelled_pairs.append((line[0], line[1]))
        return labelled_pairs
    
    def build_vocab(self):
        unique_chars = []
        with open(f"{self.data_root}/hindi_vocab.txt", 'r', encoding='utf-8') as file:
            file.seek(0)
            lines = file.readlines()
            for line in lines:
                for char in line:
                    if char not in unique_chars:
                        unique_chars.append(char)
        char_to_int = {'<PAD>': 0, '<UNK>': 1}
        for i in range(len(unique_chars)):
            char_to_int[unique_chars[i]] = i + 2
        return char_to_int
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, text = self.samples[idx]
        
        # load image
        img = Image.open(os.path.join(self.data_root, img_path)).convert('L')
        if self.transform:
            img_tensor = self.transform(img)
        
        # Process text
        text_ids = [self.char_to_int.get(c, self.char_to_int['<UNK>']) for c in text]
        padded_text_ids = torch.zeros(self.max_seq_len, dtype=torch.long)
        padded_text_ids[:len(text_ids)] = torch.tensor(text_ids)
        
        return img_tensor, padded_text_ids

In [71]:
class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, output_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # Use an LSTM to process the sequence of characters
        self.rnn = nn.LSTM(embedding_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, text_ids):
        # text_ids shape: (Batch_size, Max_seq_len)
        embedded = self.embedding(text_ids)
        # Pass through RNN/LSTM
        _, (hidden, _) = self.rnn(embedded)
        # Use the final hidden state and project to the desired output size
        # hidden shape: (1, Batch_size, Hidden_size)
        condition = self.fc(hidden.squeeze(0))
        return condition

In [140]:
class Generator(nn.Module):
    def __init__(self, z_dim, condition_dim, img_channels, img_size_h, img_size_w):
        super().__init__()
        self.img_size_h = img_size_h
        self.img_size_w = img_size_w
        self.img_channels = img_channels
        
        # We start by projecting and reshaping the combined input
        # Input: z_dim + condition_dim
        self.fc = nn.Sequential(
            nn.Linear(z_dim + condition_dim, 1024 * (img_size_h // 16) * (img_size_w // 16)),
            nn.BatchNorm1d(1024 * (img_size_h // 16) * (img_size_w // 16)),
            nn.ReLU()
        )
        
        # Now, we upsample using ConvTranspose2d
        # We'll go from (4x16) -> (8x32) -> (16x64) -> (32x128) -> (64x256)
        self.gen = nn.Sequential(
            # Input: (1024, 4, 16)
            self._block(1024, 512, 4, 2, 1),  # -> (512, 8, 32)
            self._block(512, 256, 4, 2, 1),   # -> (256, 16, 64)
            self._block(256, 128, 4, 2, 1),   # -> (128, 32, 128)
            
            # Final layer to get to the target size and channels
            nn.ConvTranspose2d(
                128, img_channels, kernel_size=4, stride=2, padding=1
            ),
            # Output: (img_channels, 64, 256)
            nn.Tanh() # Normalize output to [-1, 1], matching data normalization
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, z, condition):
        # z shape: (N, z_dim)
        # condition shape: (N, condition_dim)
        
        # Combine noise and condition
        combined_input = torch.cat([z, condition], dim=1) # (N, z_dim + condition_dim)
        
        # Project and reshape
        x = self.fc(combined_input)
        # Reshape to (N, 1024, H/16, W/16) -> (N, 1024, 4, 16)
        x = x.view(-1, 1024, self.img_size_h // 16, self.img_size_w // 16)
        
        # Pass through the generator blocks
        return self.gen(x)

In [141]:
class Discriminator(nn.Module):
    def __init__(self, condition_dim, img_channels, img_size_h, img_size_w):
        super().__init__()
        
        # CNN blocks to process the image
        # Input: (img_channels, 64, 256)
        self.disc = nn.Sequential(
            # -> (128, 32, 128)
            self._block(img_channels, 128, 4, 2, 1, use_norm=False), 
            # -> (256, 16, 64)
            self._block(128, 256, 4, 2, 1),
            # -> (512, 8, 32)
            self._block(256, 512, 4, 2, 1),
            # -> (1024, 4, 16)
            self._block(512, 1024, 4, 2, 1),
        )
        
        # Flatten and combine with condition
        # Output of disc: (N, 1024, 4, 16)
        # Flattened size: 1024 * 4 * 16 = 65536
        self.fc = nn.Sequential(
            nn.Linear(1024 * (img_size_h // 16) * (img_size_w // 16) + condition_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1)
            # No Sigmoid here! We'll use BCEWithLogitsLoss for stability.
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding, use_norm=True):
        layers = [
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False
            )
        ]
        if use_norm:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)

    def forward(self, x, condition):
        # x shape: (N, C, H, W)
        # condition shape: (N, condition_dim)
        
        x = self.disc(x) # (N, 1024, 4, 16)
        
        # Flatten and concatenate condition
        x_flat = x.view(x.shape[0], -1) # (N, 65536)
        combined = torch.cat([x_flat, condition], dim=1) # (N, 65536 + condition_dim)
        
        # Classify
        return self.fc(combined)

In [137]:
from torch.nn.utils.rnn import pad_sequence
from torchvision.transforms import functional as F

class FixedHeightResize:
    """Resizes an image to a fixed height while preserving the aspect ratio."""
    def __init__(self, target_height=64):
        self.target_height = target_height

    def __call__(self, img):
        # Get original dimensions (PIL image returns width, height)
        original_width, original_height = img.size
        # Calculate the new width to maintain the aspect ratio
        aspect_ratio = original_width / original_height
        new_width = math.ceil(aspect_ratio * self.target_height)
        # Resize the image
        resized_img = img.resize((new_width, self.target_height), Image.BICUBIC)
        return resized_img

class PadToWidth:
    """Pads an image (PIL or Tensor) to a fixed width while maintaining height."""
    def __init__(self, target_width, fill_color=255): 
        # For standard handwriting data on a white background, 255 (white) is best.
        self.target_width = target_width
        self.fill_color = fill_color 

    def __call__(self, img):
        if not isinstance(img, Image.Image):
             # Assumes we are operating on a PIL Image before ToTensor()
             raise TypeError("Input must be a PIL Image.")

        current_width = img.width
        
        if current_width >= self.target_width:
            # If the image is already wide enough (or too wide), we just center-crop it.
            # This handles outliers, though you should choose max_width carefully.
            return F.center_crop(img, (img.height, self.target_width)) 
        
        # Calculate padding needed (only on the right)
        padding_needed = self.target_width - current_width
        
        # Pad with the background color (left, top, right, bottom)
        # We only pad on the right (right padding = padding_needed)
        padding = (0, 0, padding_needed, 0) 
        
        return F.pad(img, padding, fill=self.fill_color)

### Testing for the time being
Move to another notebook once all looks good

In [162]:
import torch.optim as optim
import torchvision.utils as vutils
from tqdm import tqdm

if torch.backends.mps.is_available():
    DEVICE = "mps"
elif torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"
    
print(f"Using {DEVICE}")

# Data setup
IMG_H = 64
IMG_W = 256
IMG_CHANNELS = 1
MAX_SEQ_LEN = 50
BATCH_SIZE = 32

# Model params
VOCAB_SIZE = len(dataset.char_to_int)
EMBEDDING_DIM = 256
CONDITION_DIM = 128 # The output size of TextEncoder
Z_DIM = 100 # Noise dimension
LR = 2e-4
BETA1 = 0.5
NUM_EPOCHS = 100 # Add more later

data_transforms = T.Compose([
    FixedHeightResize(target_height=64),
    PadToWidth(target_width=256),
    T.Grayscale(num_output_channels=1),
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,))
])

# Load dataset
dataset = HW_Dataset(data_root='IIIT-HW-Hindi_v1', transform=data_transforms)

train_loader = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
)

# Initialize Models
text_encoder = TextEncoder(VOCAB_SIZE, EMBEDDING_DIM, CONDITION_DIM, CONDITION_DIM).to(DEVICE)
gen = Generator(Z_DIM, CONDITION_DIM, IMG_CHANNELS, IMG_H, IMG_W).to(DEVICE)
disc = Discriminator(CONDITION_DIM, IMG_CHANNELS, IMG_H, IMG_W).to(DEVICE)

# Initialize weights
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

gen.apply(weights_init)
disc.apply(weights_init)
text_encoder.apply(weights_init)

# Loss and Optimizers
criterion = nn.BCEWithLogitsLoss()
opt_disc = optim.Adam(disc.parameters(), lr=LR, betas=(BETA1, 0.999))
opt_gen = optim.Adam(
    list(gen.parameters()) + list(text_encoder.parameters()), 
    lr=LR, 
    betas=(BETA1, 0.999)
)

# We'll use a fixed batch of noise and text to see G's progress
fixed_noise = torch.randn(BATCH_SIZE, Z_DIM).to(DEVICE)
fixed_batch_data = next(iter(train_loader))
fixed_real_images, fixed_text_ids = fixed_batch_data
fixed_text_ids = fixed_text_ids.to(DEVICE)
fixed_real_images = fixed_real_images.to(DEVICE)

if not os.path.exists("outputs"):
    os.makedirs("outputs")

# Save the fixed real batch for comparison
vutils.save_image(fixed_real_images, "outputs/real_samples.png", normalize=True)

print("Starting Training...")

progress_bar = tqdm(total=dataset.__len__())

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real_images, text_ids) in enumerate(train_loader):
        real_images = real_images.to(DEVICE)
        text_ids = text_ids.to(DEVICE)
        
        # Get the condition vector (c)
        condition = text_encoder(text_ids)
        
        # --- Train Discriminator ---
        opt_disc.zero_grad()
        
        # Train with real images
        # Use condition.detach() so text encoder is not updated
        d_real_output = disc(real_images, condition.detach()).reshape(-1)
        d_real_loss = criterion(d_real_output, torch.ones_like(d_real_output))
        d_real_loss.backward()
        
        # Train with fake images
        noise = torch.randn(real_images.size(0), Z_DIM).to(DEVICE)
        fake_images = gen(noise, condition.detach())
        # Use condition.detach() so generator is not updated
        d_fake_output = disc(fake_images.detach(), condition.detach()).reshape(-1)
        d_fake_loss = criterion(d_fake_output, torch.zeros_like(d_fake_output))
        d_fake_loss.backward()
        
        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2
        opt_disc.step()
        
        # --- Train Generator ---
        opt_gen.zero_grad()
        
        # Get fresh condition vector (no detaching so encoder is updated)
        condition_gen = text_encoder(text_ids)
        
        # Generate new fake images
        fake_images_gen = gen(noise, condition_gen)
        
        # See what the discriminator thinks (no detaching)
        g_output = disc(fake_images_gen, condition_gen).reshape(-1)
        
        # Calculate loss (Generator wants discriminator to think they are real)
        g_loss = criterion(g_output, torch.ones_like(g_output))
        
        # Backprop (updates both generatir and text encoder)
        g_loss.backward()
        opt_gen.step()
        
        # Logging
        if (batch_idx + 1) % 100 == 0:
            print(
                f"[Epoch {epoch+1}/{NUM_EPOCHS}] [Batch {batch_idx+1}/{len(train_loader)}] "
                f"D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}"
            )
        progress_bar.update(1)
            
    # Save generated images at the end of each epoch
    with torch.no_grad():
        fixed_condition = text_encoder(fixed_text_ids)
        fake_samples = gen(fixed_noise, fixed_condition)
        vutils.save_image(
            fake_samples,
            f"outputs/fake_samples_epoch_{epoch+1}.png",
            normalize=True
        )

print("Training finished.")

Using cpu
Starting Training...


  0%|                                     | 14/69853 [00:52<73:28:06,  3.79s/it]

KeyboardInterrupt: 