### Import Libraries and Set Device python Copy code


In [2]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Set device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cpu


### Text Processing Utilities

In [3]:
# Simple text tokenizer and encoder utilities

class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(TextEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
    def forward(self, indices):
        emb = self.embedding(indices)  # (batch, seq_len, embedding_dim)
        emb_mean = emb.mean(dim=1)      # (batch, embedding_dim)
        return emb_mean

def build_vocab(texts, min_freq=1):
    """Build a vocabulary dictionary from a list of texts."""
    freq = {}
    for txt in texts:
        for word in txt.split():
            freq[word] = freq.get(word, 0) + 1
    vocab = {"<pad>": 0, "<unk>": 1}
    idx = 2
    for word, count in freq.items():
        if count >= min_freq:
            vocab[word] = idx
            idx += 1
    return vocab

def tokenize(text, vocab, max_len=20):
    """Convert a text string into a list of token indices."""
    tokens = [vocab.get(word, vocab["<unk>"]) for word in text.split()]
    if len(tokens) < max_len:
        tokens.extend([vocab["<pad>"]] * (max_len - len(tokens)))
    else:
        tokens = tokens[:max_len]
    return tokens


In [4]:
class StackGANDataset(Dataset):
    def __init__(self, root_dir, vocab=None, max_text_len=20, transform=None):
        """
        Expects root_dir structure:
            root_dir/
                subdir1/
                    image1.jpg, image1.txt, ...
                subdir2/
                    ...
        """
        self.image_paths = []
        self.text_paths = []
        self.transform = transform
        self.max_text_len = max_text_len

        subdirs = sorted(os.listdir(root_dir))
        texts = []
        # First pass: gather texts for vocab building
        for sub in subdirs:
            sub_path = os.path.join(root_dir, sub)
            if os.path.isdir(sub_path):
                txt_files = glob.glob(os.path.join(sub_path, "*.txt"))
                for txt_file in txt_files:
                    with open(txt_file, "r", encoding="utf-8") as f:
                        content = f.read().strip().lower()
                        texts.append(content)
        if vocab is None:
            self.vocab = build_vocab(texts)
        else:
            self.vocab = vocab

        # Second pass: gather image and text file pairs
        for sub in subdirs:
            sub_path = os.path.join(root_dir, sub)
            if os.path.isdir(sub_path):
                for file in os.listdir(sub_path):
                    if file.endswith((".jpg", ".png", ".jpeg")):
                        base = os.path.splitext(file)[0]
                        img_path = os.path.join(sub_path, file)
                        txt_path = os.path.join(sub_path, base + ".txt")
                        if os.path.exists(txt_path):
                            self.image_paths.append(img_path)
                            self.text_paths.append(txt_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        else:
            transform = transforms.Compose(
                [
                    transforms.Resize((64, 64)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]
            )
            img = transform(img)
        with open(self.text_paths[idx], "r", encoding="utf-8") as f:
            text = f.read().strip().lower()
        tokens = tokenize(text, self.vocab, max_len=self.max_text_len)
        tokens = torch.tensor(tokens, dtype=torch.long)
        return img, tokens

### Model Architectures for Stage-I and Stage-II

In [5]:
# Stage-I Generator
class Stage1Generator(nn.Module):
    def __init__(self, noise_dim, text_embedding_dim, ngf=64, output_channels=3):
        super(Stage1Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim + text_embedding_dim, 128 * 8 * 4 * 4),
            nn.BatchNorm1d(128 * 8 * 4 * 4),
            nn.ReLU(True)
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128 * 8, 128 * 4, 4, 2, 1), 
            nn.BatchNorm2d(128 * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(128 * 4, 128 * 2, 4, 2, 1),
            nn.BatchNorm2d(128 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(128 * 2, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, output_channels, 4, 2, 1),
            nn.Tanh()
        )
    def forward(self, noise, text_embedding):
        x = torch.cat((noise, text_embedding), dim=1)
        x = self.fc(x)
        x = x.view(x.size(0), 128 * 8, 4, 4)
        img = self.deconv(x)
        return img

# Stage-I Discriminator
class Stage1Discriminator(nn.Module):
    def __init__(self, text_embedding_dim, ndf=64, input_channels=3):
        super(Stage1Discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, ndf, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.fc = nn.Sequential(
            nn.Linear(ndf*8*4*4 + text_embedding_dim, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )
    def forward(self, img, text_embedding):
        x = self.conv(img)
        x = x.view(x.size(0), -1)
        x = torch.cat((x, text_embedding), dim=1)
        validity = self.fc(x)
        return validity

# Stage-II Generator (Refinement network)
class Stage2Generator(nn.Module):
    def __init__(self, text_embedding_dim, noise_dim, ngf=64, output_channels=3):
        super(Stage2Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim + text_embedding_dim, 128 * 8 * 4 * 4),
            nn.BatchNorm1d(128 * 8 * 4 * 4),
            nn.ReLU(True)
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128 * 8, 128 * 4, 4, 2, 1),
            nn.BatchNorm2d(128 * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(128 * 4, 128 * 2, 4, 2, 1),
            nn.BatchNorm2d(128 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(128 * 2, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, output_channels, 4, 2, 1),
            nn.Tanh()
        )
    def forward(self, noise, text_embedding, stage1_img):
        x = torch.cat((noise, text_embedding), dim=1)
        x = self.fc(x)
        x = x.view(x.size(0), 128 * 8, 4, 4)
        refined_img = self.deconv(x)
        return refined_img

# Stage-II Discriminator
class Stage2Discriminator(nn.Module):
    def __init__(self, text_embedding_dim, ndf=64, input_channels=3):
        super(Stage2Discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, ndf, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.fc = nn.Sequential(
            nn.Linear(ndf*8*4*4 + text_embedding_dim, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )
    def forward(self, img, text_embedding):
        x = self.conv(img)
        x = x.view(x.size(0), -1)
        x = torch.cat((x, text_embedding), dim=1)
        validity = self.fc(x)
        return validity


### Training and Evaluation Functions

In [6]:
def train_epoch(
    dataloader,
    text_encoder,
    G1,
    D1,
    G2,
    D2,
    optimizer_G,
    optimizer_D,
    criterion,
    noise_dim,
    epoch,
):
    G1.train()
    D1.train()
    G2.train()
    D2.train()
    running_loss_G = 0.0
    running_loss_D = 0.0
    for i, (imgs, texts) in enumerate(dataloader):
        batch_size = imgs.size(0)
        real_imgs = imgs.to(device)
        texts = texts.to(device)

        text_emb = text_encoder(texts)
        valid = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)

        # Train Generators
        optimizer_G.zero_grad()
        noise = torch.randn(batch_size, noise_dim).to(device)
        gen_img1 = G1(noise, text_emb)
        noise2 = torch.randn(batch_size, noise_dim).to(device)
        gen_img2 = G2(noise2, text_emb, gen_img1)

        validity1 = D1(gen_img1, text_emb)
        validity2 = D2(gen_img2, text_emb)
        loss_G1 = criterion(validity1, valid)
        loss_G2 = criterion(validity2, valid)
        loss_G = loss_G1 + loss_G2
        loss_G.backward()
        optimizer_G.step()

        # Train Discriminators
        optimizer_D.zero_grad()
        real_loss1 = criterion(D1(real_imgs, text_emb), valid)
        real_loss2 = criterion(D2(real_imgs, text_emb), valid)
        fake_loss1 = criterion(D1(gen_img1.detach(), text_emb), fake)
        fake_loss2 = criterion(D2(gen_img2.detach(), text_emb), fake)
        loss_D1 = (real_loss1 + fake_loss1) / 2
        loss_D2 = (real_loss2 + fake_loss2) / 2
        loss_D = loss_D1 + loss_D2
        loss_D.backward()
        optimizer_D.step()

        running_loss_G += loss_G.item()
        running_loss_D += loss_D.item()

        if i % 20 == 0:
            print(
                f"Epoch [{epoch}] Batch [{i}/{len(dataloader)}] Loss G: {loss_G.item():.4f}, Loss D: {loss_D.item():.4f}"
            )

    return running_loss_G / len(dataloader), running_loss_D / len(dataloader)


def evaluate_model(dataloader, text_encoder, D1, D2, criterion, noise_dim):
    D1.eval()
    D2.eval()
    all_preds = []
    all_labels = []
    total_loss = 0.0
    with torch.no_grad():
        for imgs, texts in dataloader:
            batch_size = imgs.size(0)
            real_imgs = imgs.to(device)
            texts = texts.to(device)
            text_emb = text_encoder(texts)
            valid = torch.ones(batch_size, 1).to(device)

            pred1 = D1(real_imgs, text_emb)
            pred2 = D2(real_imgs, text_emb)
            preds = (pred1 + pred2) / 2.0
            loss = criterion(preds, valid)
            total_loss += loss.item()
            preds_bin = (preds > 0.5).long().cpu().numpy()
            labels = torch.ones(batch_size, 1).long().cpu().numpy()
            all_preds.extend(preds_bin.flatten())
            all_labels.extend(labels.flatten())
    avg_loss = total_loss / len(dataloader)
    cm = confusion_matrix(all_labels, all_preds)
    return avg_loss, cm

### Main Training Routine

In [None]:
def main():
    # Hyperparameters
    num_epochs = 10
    batch_size = 32
    noise_dim = 100
    text_embedding_dim = 256
    lr = 0.0002
    max_text_len = 20

    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # Directory paths for training, validation, and testing
    train_dir = "./data/dataset/train"
    valid_dir = "./data/dataset/valid"
    test_dir  = "./data/dataset/test"
    
    # Create training dataset and build vocabulary
    train_dataset = StackGANDataset(train_dir, transform=transform, max_text_len=max_text_len)
    vocab = train_dataset.vocab
    valid_dataset = StackGANDataset(valid_dir, vocab=vocab, transform=transform, max_text_len=max_text_len)
    test_dataset  = StackGANDataset(test_dir, vocab=vocab, transform=transform, max_text_len=max_text_len)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    # Initialize models
    text_encoder = TextEncoder(vocab_size=len(vocab), embedding_dim=text_embedding_dim).to(device)
    G1 = Stage1Generator(noise_dim, text_embedding_dim).to(device)
    D1 = Stage1Discriminator(text_embedding_dim).to(device)
    G2 = Stage2Generator(text_embedding_dim, noise_dim).to(device)
    D2 = Stage2Discriminator(text_embedding_dim).to(device)
    
    criterion = nn.BCELoss()
    optimizer_G = optim.Adam(list(G1.parameters()) + list(G2.parameters()) + list(text_encoder.parameters()), lr=lr, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(list(D1.parameters()) + list(D2.parameters()), lr=lr, betas=(0.5, 0.999))
    
    G_losses = []
    D_losses = []
    valid_losses = []
    
    for epoch in range(1, num_epochs+1):
        loss_G, loss_D = train_epoch(train_loader, text_encoder, G1, D1, G2, D2,
                                     optimizer_G, optimizer_D, criterion, noise_dim, epoch)
        G_losses.append(loss_G)
        D_losses.append(loss_D)
        print(f"Epoch [{epoch}/{num_epochs}] Generator Loss: {loss_G:.4f} Discriminator Loss: {loss_D:.4f}")
        
        # Validation evaluation
        val_loss, cm = evaluate_model(valid_loader, text_encoder, D1, D2, criterion, noise_dim)
        valid_losses.append(val_loss)
        print(f"Validation Loss: {val_loss:.4f}")
        
        # Plot confusion matrix for the validation set
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[0, 1])
        disp.plot(cmap=plt.cm.Blues)
        plt.title(f"Validation Confusion Matrix - Epoch {epoch}")
        plt.show()
    
    # Save the trained models
    torch.save(G1.state_dict(), "G1.pth")
    torch.save(G2.state_dict(), "G2.pth")
    torch.save(D1.state_dict(), "D1.pth")
    torch.save(D2.state_dict(), "D2.pth")
    torch.save(text_encoder.state_dict(), "text_encoder.pth")
    
    # Plot training curves
    plt.figure(figsize=(10, 5))
    plt.plot(G_losses, label="Generator Loss")
    plt.plot(D_losses, label="Discriminator Loss")
    plt.plot(valid_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Losses")
    plt.legend()
    plt.show()
    
    # Testing: Generate images using the test dataset
    text_encoder.eval()
    G1.eval()
    G2.eval()
    with torch.no_grad():
        for imgs, texts in test_loader:
            batch_size = imgs.size(0)
            texts = texts.to(device)
            text_emb = text_encoder(texts)
            noise = torch.randn(batch_size, noise_dim).to(device)
            stage1_img = G1(noise, text_emb)
            noise2 = torch.randn(batch_size, noise_dim).to(device)
            gen_img = G2(noise2, text_emb, stage1_img)
            gen_img = gen_img.cpu() * 0.5 + 0.5  # Denormalize
            np_img = gen_img[0].permute(1, 2, 0).numpy()
            plt.imshow(np_img)
            plt.title("Generated Image from Test Set")
            plt.axis("off")
            plt.show()
            break  # Display one batch for demonstration

if __name__ == '__main__':
    main()
