In [None]:
# %% [code] Stage II: Imports & Global Settings
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

# Update paths – adjust these to your local directories
STAGE1_DIR = './stage1_outputs/'  # Folder where Stage I generated images are saved
REAL_DIR = './data/coco2014/train2014/'  # Real high-res images folder
TEXT_EMB_FILE = './data/coco/filenames.pickle'  # (Or a dedicated text embeddings pickle)
MODEL_SAVE_DIR = './models/'

# Device selection
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
## Stage II: Dataset Definition
class Stage2Dataset(Dataset):
    def __init__(self, stage1_dir, real_dir, text_embedding_file, transform=None):
        self.stage1_dir = stage1_dir
        self.real_dir = real_dir
        self.transform = transform
        
        # Load text embeddings – assumed to be a dictionary mapping filename to embedding
        with open(text_embedding_file, 'rb') as f:
            self.text_embeddings = pickle.load(f)
        
        # Use keys from text_embeddings as filenames for both Stage I and real images
        self.filenames = list(self.text_embeddings.keys())
    
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        filename = self.filenames[idx]
        stage1_path = os.path.join(self.stage1_dir, filename)
        real_path = os.path.join(self.real_dir, filename)
        
        stage1_img = Image.open(stage1_path).convert('RGB')
        real_img = Image.open(real_path).convert('RGB')
        
        if self.transform:
            stage1_img = self.transform(stage1_img)
            real_img = self.transform(real_img)
        
        # Convert text embedding to tensor
        text_embedding = torch.tensor(self.text_embeddings[filename], dtype=torch.float)
        return stage1_img, text_embedding, real_img

# Example transforms – adjust the image size as required
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Create the dataset and dataloader for Stage II
dataset2 = Stage2Dataset(STAGE1_DIR, REAL_DIR, TEXT_EMB_FILE, transform=transform)
dataloader2 = DataLoader(dataset2, batch_size=32, shuffle=True)


In [None]:
# Stage II: Model Definitions
# Residual block used in the Stage II Generator
class ResBlock(nn.Module):
    def __init__(self, channel_num):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channel_num, channel_num, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(channel_num),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel_num, channel_num, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(channel_num)
        )
    
    def forward(self, x):
        return x + self.block(x)

# Stage II Generator: Refines the coarse image using the text embedding.
class Stage2Generator(nn.Module):
    def __init__(self, ngf=64, nef=128, n_residual=4):
        super(Stage2Generator, self).__init__()
        self.ngf = ngf
        # Encoder: Downsample the coarse image from Stage I
        self.encoder = nn.Sequential(
            nn.Conv2d(3, ngf, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True)
        )
        
        # Combine image features with text embedding
        self.joint_conv = nn.Sequential(
            nn.Conv2d(ngf * 4 + nef, ngf * 4, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True)
        )
        
        # Residual blocks to refine features
        res_blocks = [ResBlock(ngf * 4) for _ in range(n_residual)]
        self.res_blocks = nn.Sequential(*res_blocks)
        
        # Upsample to obtain the high-resolution image
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh()  # Output in range [-1, 1]
        )
        
    def forward(self, stage1_img, text_embedding):
        x = self.encoder(stage1_img)
        # Expand text embedding spatially to match image feature dimensions
        batch_size, _, h, w = x.size()
        text_embedding = text_embedding.view(batch_size, -1, 1, 1).repeat(1, 1, h, w)
        x = torch.cat((x, text_embedding), 1)
        x = self.joint_conv(x)
        x = self.res_blocks(x)
        x = self.upsample(x)
        return x

# Stage II Discriminator: Evaluates high-res images conditioned on text.
class Stage2Discriminator(nn.Module):
    def __init__(self, ndf=64, nef=128):
        super(Stage2Discriminator, self).__init__()
        # Image encoder: progressively downsample the input image
        self.image_encoder = nn.Sequential(
            nn.Conv2d(3, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Combine image features with text embedding and output a probability score
        self.joint_conv = nn.Sequential(
            nn.Conv2d(ndf * 8 + nef, ndf * 8, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4),
            nn.Sigmoid()
        )
        
    def forward(self, img, text_embedding):
        x = self.image_encoder(img)
        batch_size, _, h, w = x.size()
        text_embedding = text_embedding.view(batch_size, -1, 1, 1).repeat(1, 1, h, w)
        x = torch.cat((x, text_embedding), 1)
        out = self.joint_conv(x)
        return out.view(-1)

# Instantiate models and move them to device
netG = Stage2Generator().to(device)
netD = Stage2Discriminator().to(device)
print(netG)
print(netD)


In [None]:
# Stage II: Training Loop
# Hyperparameters
num_epochs_stage2 = 50 
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

for epoch in range(num_epochs_stage2):
    for i, (stage1_imgs, text_embeddings, real_imgs) in enumerate(dataloader2):
        stage1_imgs = stage1_imgs.to(device)
        text_embeddings = text_embeddings.to(device)
        real_imgs = real_imgs.to(device)
        
        batch_size = real_imgs.size(0)
        real_labels = torch.ones(batch_size, device=device)
        fake_labels = torch.zeros(batch_size, device=device)
        
        # --- Train Discriminator ---
        netD.zero_grad()
        outputs_real = netD(real_imgs, text_embeddings)
        d_loss_real = criterion(outputs_real, real_labels)
        
        fake_imgs = netG(stage1_imgs, text_embeddings)
        outputs_fake = netD(fake_imgs.detach(), text_embeddings)
        d_loss_fake = criterion(outputs_fake, fake_labels)
        
        d_loss_total = d_loss_real + d_loss_fake
        d_loss_total.backward()
        optimizerD.step()
        
        # --- Train Generator ---
        netG.zero_grad()
        outputs_fake_for_G = netD(fake_imgs, text_embeddings)
        g_loss = criterion(outputs_fake_for_G, real_labels)
        g_loss.backward()
        optimizerG.step()
        
        if i % 50 == 0:
            print(f"Stage II Epoch [{epoch+1}/{num_epochs_stage2}], Batch [{i}/{len(dataloader2)}], D_loss: {d_loss_total.item():.4f}, G_loss: {g_loss.item():.4f}")
    
    # Save Stage II models periodically
    torch.save(netG.state_dict(), os.path.join(MODEL_SAVE_DIR, f"netG2_{epoch+1}.pt"))
    torch.save(netD.state_dict(), os.path.join(MODEL_SAVE_DIR, f"netD2_{epoch+1}.pt"))


In [None]:
# Stage II: Inference
netG.eval()  # Stage II generator to evaluation mode

with torch.no_grad():
    # For demonstration, loading a batch from the Stage II dataloader
    stage1_imgs, text_embeddings, _ = next(iter(dataloader2))
    stage1_imgs = stage1_imgs.to(device)
    text_embeddings = text_embeddings.to(device)
    fake_highres = netG(stage1_imgs, text_embeddings)

# Helper function to convert a tensor to a PIL image for visualization
def tensor_to_pil(tensor):
    tensor = tensor.cpu().clone().detach()
    tensor = (tensor + 1) / 2  # Scale tensor from [-1,1] to [0,1]
    array = tensor.squeeze(0).permute(1, 2, 0).numpy()
    return Image.fromarray((array * 255).astype('uint8'))

# Display the first generated high-resolution image
generated_image = tensor_to_pil(fake_highres[0])
generated_image.show()
