In [None]:
from google.colab import drive
drive.mount('/content/drive')
# !pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121

In [None]:
# Importing libraries
import os
import random
import numpy as np
import pandas as pd
import pathlib
import shutil
import spacy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.nn.utils.spectral_norm as spectral_norm
from datetime import datetime
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

# Import for French text processing
from transformers import CamembertTokenizer, CamembertModel


In [None]:
# Set random seed for reproducibility
manual_seed = 999
random.seed(manual_seed)
np.random.seed(manual_seed)
torch.manual_seed(manual_seed)


In [None]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')


In [None]:
# Hyperparameters
num_epochs = 50
batch_size = 64
learning_rateD = 8e-5
learning_rateG = 2e-4
beta1 = 0.5
image_size = 256  # Adjust based on your hardware capabilities
nz = 128  # Size of z latent vector (i.e., the input to the generator)
ngf = 256  # Size of feature maps in generator
ndf = 32  # Size of feature maps in discriminator
nc = 3    # Number of channels in the training images (RGB)
embedding_size = 256  # CamemBERT base model output size


In [None]:
# Data directories
HOME_DIR = "/content/drive/MyDrive/art-data/pytorch_olive/"
image_dir = f"{HOME_DIR}/data/images/"
description_file = f"{HOME_DIR}/data/image_annotations.csv"
literature_text_file = f"{HOME_DIR}/data/corpus_texte.txt"


In [None]:
# Load SpaCy French language model
!python -m spacy download fr_core_news_sm
nlp = spacy.load('fr_core_news_sm')
nlp.max_length = 9000000


In [None]:
# Load CamemBERT tokenizer and model
tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
text_model = CamembertModel.from_pretrained('camembert-base').to(device)
text_model.eval()  # Set to evaluation mode

print("Hyperparameters set.")


In [None]:
# Function to process literature texts
def process_literature_texts(text_file):
    with open(text_file, 'r', encoding='utf-8') as f:
        text = f.read()
    # Use SpaCy to split text into sentences
    doc = nlp(text)
    sentences = [sent.text.strip() for sent in doc.sents if len(sent.text.strip()) > 10]
    return sentences


In [None]:
# Custom Dataset class
class DualConditionDataset(Dataset):
    def __init__(self, image_dir, description_file, literature_texts, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.literature_texts = literature_texts

        # List all image files in the directory
        self.image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
        self.image_files.sort()  # Optional: sort the list for consistency

        # Read the CSV file and create a mapping from filenames to descriptions
        descriptions_df = pd.read_csv(description_file)
        self.filename_to_description = pd.Series(descriptions_df.Description.values, index=descriptions_df.Filename).to_dict()

        # Default description for images without a matching description
        self.default_description = "Description non disponible."

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Get image filename
        img_filename = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_filename)

        # Load image
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        else:
            image = transforms.ToTensor()(image)

        # Get description; use default if not found
        description = self.filename_to_description.get(img_filename, self.default_description)

        # Randomly select a literature text
        literature_text = random.choice(self.literature_texts)

        # Tokenize and encode the image description
        with torch.no_grad():
            inputs_desc = tokenizer(description, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
            inputs_desc = {k: v.to(device) for k, v in inputs_desc.items()}
            outputs_desc = text_model(**inputs_desc)
            desc_embedding = outputs_desc.last_hidden_state.mean(dim=1).squeeze()

            # Tokenize and encode the literature text
            inputs_lit = tokenizer(literature_text, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
            inputs_lit = {k: v.to(device) for k, v in inputs_lit.items()}
            outputs_lit = text_model(**inputs_lit)
            lit_embedding = outputs_lit.last_hidden_state.mean(dim=1).squeeze()

        return image.to(device), desc_embedding.to(device), lit_embedding.to(device)

In [None]:
# Data augmentation and normalization
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * nc, [0.5] * nc)
])


In [None]:
# Prepare datasets
literature_texts = process_literature_texts(literature_text_file)
dataset = DualConditionDataset(image_dir, description_file, literature_texts, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

print("Data preparation complete.")


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(channels)
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.block(x)
        out += residual
        out = self.relu(out)
        return out

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, width, height = x.size()
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = torch.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

In [None]:
class ConditionalBatchNorm2d(nn.Module):
    def __init__(self, num_features, embedding_size):
        super(ConditionalBatchNorm2d, self).__init__()
        self.bn = nn.BatchNorm2d(num_features, affine=False)
        self.gamma_embed = nn.Linear(embedding_size, num_features)
        self.beta_embed = nn.Linear(embedding_size, num_features)

    def forward(self, x, y):
        out = self.bn(x)
        gamma = self.gamma_embed(y).unsqueeze(2).unsqueeze(3)
        beta = self.beta_embed(y).unsqueeze(2).unsqueeze(3)
        out = gamma * out + beta
        return out

In [None]:
class SequentialWithArgs(nn.Sequential):
    def forward(self, x, *args, **kwargs):
        for module in self._modules.values():
            if isinstance(module, ConditionalBatchNorm2d):
                x = module(x, *args, **kwargs)
            else:
                x = module(x)
        return x

In [None]:
# Define the Generator with dual conditioning
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.desc_proj = nn.Linear(768, embedding_size)  # From 768 to 256
        self.lit_proj = nn.Linear(768, embedding_size)
        self.desc_emb = nn.Linear(embedding_size, nz)
        self.lit_emb = nn.Linear(embedding_size, nz)

        self.main = SequentialWithArgs(
            nn.ConvTranspose2d(nz * 3, ngf * 16, 4, 1, 0, bias=False),  # Output: (ngf*16) x 4 x 4
            ConditionalBatchNorm2d(ngf * 16, embedding_size * 2),
            nn.ReLU(True),

            ResidualBlock(ngf * 16),
            ResidualBlock(ngf * 16),
            ResidualBlock(ngf * 16),

            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ngf * 16, ngf * 8, 3, 1, 1, bias=False),  # Output: (ngf*8) x 8 x 8
            ConditionalBatchNorm2d(ngf * 8, embedding_size * 2),
            nn.ReLU(True),
            SelfAttention(ngf * 8),

            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ngf * 8, ngf * 4, 3, 1, 1, bias=False),   # Output: (ngf*4) x 16 x 16
            ConditionalBatchNorm2d(ngf * 4, embedding_size * 2),
            nn.ReLU(True),
            SelfAttention(ngf * 4),

            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ngf * 4, ngf * 2, 3, 1, 1, bias=False),   # Output: (ngf*2) x 32 x 32
            ConditionalBatchNorm2d(ngf * 2, embedding_size * 2),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ngf * 2, ngf, 3, 1, 1, bias=False),       # Output: (ngf) x 64 x 64
            ConditionalBatchNorm2d(ngf, embedding_size * 2),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ngf, (ngf // 2), 3, 1, 1, bias=False),      # Output: (ngf//2) x 128 x 128
            ConditionalBatchNorm2d((ngf // 2), embedding_size * 2),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ngf // 2, nc, 3, 1, 1, bias=False),       # Output: (nc) x 256 x 256
            nn.Tanh()
        )

    def forward(self, noise, desc_embed, lit_embed):
        # Embed descriptions and literature texts
        desc_proj = self.desc_proj(desc_embed)  # Shape: (batch_size, 256)
        lit_proj = self.lit_proj(lit_embed)     # Shape: (batch_size, 256)

        # Transform embeddings to latent space
        desc_feat = self.desc_emb(desc_proj)    # Shape: (batch_size, nz)
        lit_feat = self.lit_emb(lit_proj)     # Shape: (batch_size, nz)

        # Concatenate noise and transformed embeddings
        combined_input = torch.cat(
            (noise, desc_feat.unsqueeze(2).unsqueeze(3), lit_feat.unsqueeze(2).unsqueeze(3)),
            dim=1
        )  # Shape: (batch_size, nz * 3, 1, 1)

        # Concatenate original embeddings for ConditionalBatchNorm2d
        combined_embed = torch.cat((desc_proj, lit_proj), dim=1)  # Shape: (batch_size, 512)

        return self.main(combined_input, combined_embed)



In [None]:
# Define the Discriminator with dual conditioning
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.desc_proj = nn.Linear(768, embedding_size)  # From 768 to 256
        self.lit_proj = nn.Linear(768, embedding_size)

        self.desc_emb = nn.Linear(embedding_size, image_size * image_size)
        self.lit_emb = nn.Linear(embedding_size, image_size * image_size)
        self.main = SequentialWithArgs(
            spectral_norm(nn.Conv2d(nc + 2, ndf, kernel_size=4, stride=2, padding=1, bias=False)),  # Output: ndf x 128 x 128
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            spectral_norm(nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False)),  # Output: ndf*2 x 64 x 64
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=False)),  # Output: ndf*4 x 32 x 32
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1, bias=False)),  # Output: ndf*8 x 16 x 16
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            spectral_norm(nn.Conv2d(ndf * 8, ndf * 16, kernel_size=4, stride=2, padding=1, bias=False)),  # Output: ndf*16 x 8 x 8
            nn.BatchNorm2d(ndf * 16),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            spectral_norm(nn.Conv2d(ndf * 16, ndf * 32, kernel_size=4, stride=2, padding=1, bias=False)),  # Output: ndf*32 x 4 x 4
            nn.BatchNorm2d(ndf * 32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            spectral_norm(nn.Conv2d(ndf * 32, ndf * 64, kernel_size=4, stride=2, padding=1, bias=False)),  # Output: ndf*64 x 2 x 2
            nn.BatchNorm2d(ndf * 64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            spectral_norm(nn.Conv2d(ndf * 64, 1, kernel_size=2, stride=1, padding=0, bias=False)),  # Output: 1 x 1 x 1

        )

    def forward(self, img, desc_embed, lit_embed):
        # Expand text embeddings to match image dimensions
        desc_proj = self.desc_proj(desc_embed)  # Shape: (batch_size, 256)
        lit_proj = self.lit_proj(lit_embed)
        desc_feat = self.desc_emb(desc_proj).view(-1, 1, image_size, image_size)
        lit_feat = self.lit_emb(lit_proj).view(-1, 1, image_size, image_size)

        # Concatenate tensors
        combined_input = torch.cat((img, desc_feat, lit_feat), 1)
        # Continue with forward pass
        output = self.main(combined_input)
        return output.view(-1)



In [None]:
# Initialize the models
netG = Generator().to(device)
netD = Discriminator().to(device)

print("Models initialized.")


In [None]:
# Initialize weights
def weights_init(m):
    classname = m.__class__.__name__

    # Initialize Convolutional and Linear layers
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.constant_(m.bias.data, 0)

    # Initialize BatchNorm layers (excluding ConditionalBatchNorm2d)
    elif isinstance(m, nn.BatchNorm2d):
        if m.affine:  # Only initialize if affine parameters exist
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    # Initialize ConditionalBatchNorm2d's internal Linear layers
    elif isinstance(m, ConditionalBatchNorm2d):
        if hasattr(m, 'gamma_embed') and m.gamma_embed is not None:
            nn.init.normal_(m.gamma_embed.weight.data, 0.0, 0.02)
            nn.init.constant_(m.gamma_embed.bias.data, 0)
        if hasattr(m, 'beta_embed') and m.beta_embed is not None:
            nn.init.normal_(m.beta_embed.weight.data, 0.0, 0.02)
            nn.init.constant_(m.beta_embed.bias.data, 0)


netG.apply(weights_init)
netD.apply(weights_init)

In [None]:
# Loss function
criterion = nn.BCEWithLogitsLoss()


In [None]:
# Optimizers
optimizerG = optim.RMSprop(netG.parameters(), lr=learning_rateG)
optimizerD = optim.RMSprop(netD.parameters(), lr=learning_rateD)

In [None]:
# Fixed noise and embeddings for sample generation
fixed_noise = torch.randn(15, nz, 1, 1, device=device)

# Select fixed descriptions and literature texts for consistent samples
fixed_descriptions = [
    "La chambre de Cécile est située au premier étage, donnant sur le jardin, avec une vue dégagée sur les allées ombragées et les parterres fleuris.",
    "Ce château, situé au milieu d'une forêt épaisse, était entouré de fossés profonds remplis d'eau ; on n'y entrait que par un pont-levis, et les murs en étaient si élevés, qu'il était impossible de les escalader.",
    "Paris est aussi grand qu'Ispahan : les maisons y sont si hautes qu'on jugerait qu'elles ne sont habitées que par des astrologues. Tu juges bien qu'une ville bâtie en l'air, qui a six ou sept maisons les unes sur les autres, est extrêmement peuplée, et que, quand tout le monde est descendu dans la rue, il s'y fait un bel embarras.",
    "Madame de Saint-Ange, femme d'environ trente-six ans, dont les charmes enivraient encore plus qu'ils ne séduisaient, était grande, belle, bien faite ; ses yeux bleus, languissants, mais pleins de feu, la bouche petite, les dents éblouissantes, la peau de lis et de rose ; c'était enfin une de ces créatures que la volupté semblait n'avoir formée que pour elle, et dont tous les mouvements annonçaient ce que leur âme cachait.",
    "On a mis auprès de Virginie, au pied des mêmes roseaux, son ami Paul, et autour d'eux leurs tendres mères et leurs fidèles serviteurs. Elle représente une allée de bambous qui conduit vers la mer ; elle est éclairée par les derniers rayons du soleil couchant : on aperçoit, entre quatre gerbes de ces bambous, trois tombes rustiques.",
    "À l'ombre des bois épais qui couvrent la montagne, on trouve de frais ruisseaux, bordés de fleurs inconnues en Europe, et où les oiseaux, d'une variété infinie de plumages, viennent boire en chantant. Ces vallées fertiles sont plantées de cannes à sucre, de caféiers, et de citronniers qui embaument l'air ; et, le soir, les brises de la mer, chargées d'une fraîcheur humide, répandent partout la vie et la volupté.",
    "Quelle contrée charmante ! Jamais les bords du Lignon, ni les rives de l'Eurotas, ni la vallée de Tempé ne me parurent si délicieux. Ces montagnes, ces forêts, ces rochers, ces torrents, tout ce qui ici parle à mes sens et à mon cœur, semble unir le sublime de la nature sauvage à l'élégance des paysages cultivés.",
    "Ici, tout respire la paix et l'innocence ; les bois, les prairies, la fraîcheur des eaux, tout invite au calme et à la rêverie. Mais, dans ce lieu tranquille, où la nature semble n'avoir rien laissé à désirer, combien de passions orageuses se sont développées, combien de cœurs ont été troublés !",
    "Elle portait une robe légère de mousseline, qui flottait autour de sa taille fine avec une grâce naturelle. Une écharpe de soie, négligemment jetée sur ses épaules, laissait entrevoir un cou d'une blancheur éclatante. Son chapeau, orné de rubans pastel, était posé légèrement de côté, découvrant des boucles brunes qui encadraient son visage rayonnant de jeunesse. Chaque détail de sa parure semblait choisi non pour briller, mais pour accentuer encore la douceur et le charme de sa personne.",
    "Venise, cette ville unique suspendue entre ciel et eau, semble flotter sur un miroir infini. Ses palais, ses églises, et ses dômes se reflètent dans les eaux calmes de la lagune, tandis que ses canaux, véritables rues liquides, résonnent du bruit des rames des gondoles. Nulle part ailleurs le mariage de l'art et de la nature n'est si parfait, et chaque pierre de cette cité raconte une histoire de grandeur passée.",
    "À Venise, chaque instant est une peinture, chaque rue d'eau une scène unique. Ses palais aux marbres somptueux s'élèvent comme des fantômes au-dessus des lagunes, et les gondoles qui glissent silencieusement semblent des ombres fuyant sous la lune. Il n'est point d'autre ville où l'art, l'eau, et le ciel se confondent ainsi dans un spectacle perpétuel.",
    "M. Boucher est le peintre des Grâces. Ses figures sont des nymphes, des déesses, des amours ; tout cela folâtre, badine, se poursuit, se fuit, se joue, se caresse, se pelote, se renverse, se relève, s'embrasse, se mord, se couche, se relève, s'endort, s'éveille, se cache, se montre, se cache encore, et tout cela sans fin.",
    "Que voyez-vous dans ses compositions ? Des Vénus sur des nuages, des Cupidon, des bergères, des femmes nues ; toujours des nudités, des chairs, des postures lascives, des draperies transparentes. Rien n'est vrai, tout est faux, tout est pour les yeux, rien pour l'esprit, rien pour le cœur.",
    "Là, des anges sereins, aux ailes éclatantes, Sur des nuées d'azur, dans des postures flottantes, Semblent guider des rayons de lumière, Tandis qu'en bas les Grâces légères, Couronnent d'immortelles fleurs, les héros et leurs mystères.",
    "Le plafond était orné de peintures représentant des nymphes et des chérubins, voltigeant dans un ciel rose et azuré. Vénus, entourée des Grâces, semblait distribuer des fleurs aux Amours ailés, tandis que, plus loin, Apollon sur son char traversait les nuages dorés, jetant une lumière divine sur cette scène enchantée."
]
fixed_descriptions = (fixed_descriptions * ((15 + len(fixed_descriptions) - 1) // len(fixed_descriptions)))[:15]

fixed_literature_texts = random.sample(literature_texts, 15)


In [None]:
# Generate embeddings for fixed descriptions and literature texts
with torch.no_grad():
    fixed_desc_embeddings = []
    for desc in fixed_descriptions:
        inputs_desc = tokenizer(desc, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
        inputs_desc = {k: v.to(device) for k, v in inputs_desc.items()}
        outputs_desc = text_model(**inputs_desc)
        desc_embedding = outputs_desc.last_hidden_state.mean(dim=1).squeeze()
        fixed_desc_embeddings.append(desc_embedding)
    fixed_desc_embeddings = torch.stack(fixed_desc_embeddings)

    fixed_lit_embeddings = []
    for lit in fixed_literature_texts:
        inputs_lit = tokenizer(lit, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
        inputs_lit = {k: v.to(device) for k, v in inputs_lit.items()}
        outputs_lit = text_model(**inputs_lit)
        lit_embedding = outputs_lit.last_hidden_state.mean(dim=1).squeeze()
        fixed_lit_embeddings.append(lit_embedding)
    fixed_lit_embeddings = torch.stack(fixed_lit_embeddings)


In [None]:
# Training Loop
print("Backup of the code and hyperparameters...")
now_fmt = datetime.now().strftime("%Y%m%d_%H%M%S")
OUTPUT_DIR = f"{HOME_DIR}/data_generated/{now_fmt}"
os.makedirs(OUTPUT_DIR, exist_ok=True)

accumulation_steps = 2  # Number of batches to accumulate gradients over

print("Starting Training Loop with Gradient Accumulation...")
for epoch in range(num_epochs):
    netG.train()
    netD.train()
    progress_bar = tqdm(dataloader, desc=f'Epoch [{epoch+1}/{num_epochs}]', leave=False)
    for i, data in enumerate(progress_bar, 0):
        real_images, desc_embeddings, lit_embeddings = data

        batch_size_i = real_images.size(0)

        # Move data to device
        real_images = real_images.to(device)
        desc_embeddings = desc_embeddings.to(device)
        lit_embeddings = lit_embeddings.to(device)

        # Create labels
        real_labels = torch.full((batch_size_i,), 0.9, dtype=torch.float, device=device)
        fake_labels = torch.full((batch_size_i,), 0.1, dtype=torch.float, device=device)

        ############################
        # (1) Update Discriminator
        ###########################
        # Accumulate gradients for the Discriminator
        netD.zero_grad(set_to_none=True)

        # Forward pass real images
        output_real = netD(real_images, desc_embeddings, lit_embeddings)
        errD_real = criterion(output_real, real_labels)
        D_x = output_real.mean().item()

        # Generate fake images
        noise = torch.randn(batch_size_i, nz, 1, 1, device=device)
        fake_images = netG(noise, desc_embeddings, lit_embeddings)

        # Forward pass fake images
        output_fake = netD(fake_images.detach(), desc_embeddings, lit_embeddings)
        errD_fake = criterion(output_fake, fake_labels)
        D_G_z1 = output_fake.mean().item()

        # Combine losses
        errD = errD_real + errD_fake
        errD = errD / accumulation_steps  # Normalize loss
        errD.backward()

        if (i + 1) % accumulation_steps == 0:
            optimizerD.step()
            netD.zero_grad(set_to_none=True)

        ############################
        # (2) Update Generator
        ###########################
        # Accumulate gradients for the Generator
        netG.zero_grad(set_to_none=True)

        # Generate fake images (reuse if possible)
        noise = torch.randn(batch_size_i, nz, 1, 1, device=device)
        fake_images = netG(noise, desc_embeddings, lit_embeddings)

        output = netD(fake_images, desc_embeddings, lit_embeddings)
        errG = criterion(output, real_labels)
        D_G_z2 = output.mean().item()

        errG = errG / accumulation_steps  # Normalize loss
        errG.backward()

        if (i + 1) % accumulation_steps == 0:
            optimizerG.step()
            netG.zero_grad(set_to_none=True)

        # Update progress bar
        progress_bar.set_postfix({'Loss_D': f'{errD.item():.4f}', 'Loss_G': f'{errG.item():.4f}'})


    # Generate samples after each epoch
    with torch.no_grad():
        fake_samples = netG(fixed_noise, fixed_desc_embeddings, fixed_lit_embeddings).detach().cpu()
    img_grid = vutils.make_grid(fake_samples, padding=2, normalize=True)
    os.makedirs(f"{OUTPUT_DIR}/images", exist_ok=True)
    vutils.save_image(img_grid, f"{OUTPUT_DIR}/images/epoch_{epoch+1}.png")

    print(f"Epoch [{epoch+1}/{num_epochs}]  Loss_D: {errD.item():.4f}  Loss_G: {errG.item():.4f}  D(x): {D_x:.4f}  D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}")

    # Save model checkpoints every 10 epochs
    if (epoch + 1) % 50 == 0:
        os.makedirs(f"{OUTPUT_DIR}/model_checkpoints", exist_ok=True)
        torch.save(netG.state_dict(), f"{OUTPUT_DIR}/model_checkpoints/netG_epoch_{epoch+1}.pth")
        torch.save(netD.state_dict(), f"{OUTPUT_DIR}/model_checkpoints/netD_epoch_{epoch+1}.pth")


In [None]:
# Save final models
print("Training complete. Saving final models...")
torch.save(netG.state_dict(), f"{OUTPUT_DIR}/model_checkpoints/netG_final.pth")
torch.save(netD.state_dict(), f"{OUTPUT_DIR}/model_checkpoints/netD_final.pth")
from google.colab import runtime
runtime.unassign()

