# GANDhani: CycleGAN for Cultural Style Transfer Translating Bandhani Textile Motifs onto Contemporary Apparel

### Import Required Libraries

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import make_grid, save_image

### Data Proprocessing

### Discriminator Network

In [7]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, base_features=64):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, base_features, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_features, base_features * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(base_features * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_features * 2, base_features * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(base_features * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_features * 4, base_features * 8, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(base_features * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_features * 8, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)

### Generator Network

In [8]:
class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()

        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0),
            nn.InstanceNorm2d(dim),
            nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0),
            nn.InstanceNorm2d(dim),
        )

    def forward(self, x):
        return x + self.block(x) # Skip Connections

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()

        self.g1 = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, features, kernel_size=7, stride=1, padding=0),
            nn.InstanceNorm2d(features),
            nn.ReLU(True),
        )

        self.g2 = nn.Sequential(
            nn.Conv2d(features, features*2, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(features*2),
            nn.ReLU(True),
        )

        self.g3 = nn.Sequential(
            nn.Conv2d(features*2, features*4, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(features*4),
            nn.ReLU(True),
        )

        res_blocks = []

        for _ in range(9):
            res_blocks.append(ResidualBlock(features*4))
        self.res_blocks = nn.Sequential(*res_blocks)

        self.g4 = nn.Sequential(
            nn.ConvTranspose2d(features*4, features*2, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(features*2),
            nn.ReLU(True),
        )

        self.g5 = nn.Sequential(
            nn.ConvTranspose2d(features*2, features, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(features),
            nn.ReLU(True),
        )

        self.g6 = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(features, out_channels, kernel_size=7, stride=1, padding=0),
            nn.Tanh()
        )

    def forward(self, x):
        g1 = self.g1(x)
        g2 = self.g2(g1)
        g3 = self.g3(g2)
        res = self.res_blocks(g3)
        g4 = self.g4(res)
        g5 = self.g5(g4)
        
        return self.g6(g5)

### Discriminator Training

In [None]:
def train_discriminator(discriminator_A, discriminator_B,
                        generator_A, generator_B,
                        real_A, real_B, 
                        fake_A, fake_B,
                        opt_d, criterion_GAN):
    
    discriminator_A.train()
    discriminator_B.train()

    # Clear discriminator gradients
    opt_d.zero_grad()

    # --- Train D_A ---

    # ——— Real pairs ———
    # D(map, real) should predict “real” → target=1
    real_preds_A = discriminator_A(real_A)
    real_targets_A = torch.ones_like(real_preds_A)
    loss_D_A_real = criterion_GAN(real_preds_A, real_targets_A)
    real_score_A = real_preds_A.mean().item()

    # ——— Fake pairs ———
    # Generate fake images
    # G(map) → fake; detach so G’s grad isn’t updated here
    fake_preds_A = discriminator_A(fake_A.detach())
    fake_targets_A = torch.zeros_like(fake_preds_A)
    loss_D_A_fake = criterion_GAN(fake_preds_A, fake_targets_A)
    fake_score_A  = fake_preds_A.mean().item()

    loss_D_A = 0.5 * (loss_D_A_real + loss_D_A_fake)


    # --- Train D_B ---

    # ——— Real pairs ———
    # D(map, real) should predict “real” → target=1
    real_preds_B = discriminator_B(real_B)
    real_targets_B = torch.ones_like(real_preds_B)
    loss_D_B_real = criterion_GAN(real_preds_B, real_targets_B)
    real_score_B = real_preds_B.mean().item()

    # ——— Fake pairs ———
    # Generate fake images
    # G(map) → fake; detach so G’s grad isn’t updated here
    fake_preds_B = discriminator_B(fake_B.detach())
    fake_targets_B = torch.zeros_like(fake_preds_B)
    loss_D_B_fake = criterion_GAN(fake_preds_B, fake_targets_B)
    fake_score_B  = fake_preds_B.mean().item()

    loss_D_B = 0.5 * (loss_D_B_real + loss_D_B_fake)


    # --- Total Discriminator Loss ---
    loss_D = loss_D_A + loss_D_B
    loss_D.backward()
    opt_d.step()
    
    return {
        'loss_D_A': loss_D_A.item(),
        'loss_D_B': loss_D_B.item(),
        'real_A_score': real_score_A,
        'fake_A_score': fake_score_A,
        'real_B_score': real_score_B,
        'fake_B_score': fake_score_B
    }

### Generator Training

In [None]:
def train_generator(discriminator_A, discriminator_B,
                        generator_A, generator_B,
                        real_A, real_B, 
                        lambda_a, lambda_b, lambda_id,
                        opt_g, criterion_GAN):
    
    # Clear generator gradients                                     
    opt_g.zero_grad()

    fake_B = generator_A(real_A)
    fake_A = generator_B(real_B)

    # --- Train D_A ---

    # 1) Adverserial Loss
    preds_fake_a = discriminator_A(fake_A)
    targets_a = torch.ones_like(preds_fake_a)
    adv_total_b = criterion_GAN(preds_fake_a, targets_a)

    preds_fake_b = discriminator_B(fake_B)
    targets_b = torch.ones_like(preds_fake_b)
    adv_total_a = criterion_GAN(preds_fake_b, targets_b)

    # 2) Cycle Losses
    rec_A = generator_B(fake_B)
    rec_B = generator_A(fake_A)

    loss_cycle_A = F.l1_loss(rec_A, real_A)
    loss_cycle_B = F.l1_loss(rec_B, real_B)

    # 3) Identity Loss
    idt_B = generator_A(real_B)
    idt_A = generator_B(real_A)
    identity_a = F.l1_loss(idt_B, real_B)
    identity_b = F.l1_loss(idt_A, real_A)

    loss_a = adv_total_a + (lambda_a * loss_cycle_A) + (lambda_id * lambda_b * identity_a)
    loss_b = adv_total_b + (lambda_b * loss_cycle_B) + (lambda_id * lambda_a * identity_b)

    loss_total = loss_a + loss_b
    
    loss_total.backward()
    opt_g.step()

    return {
        'loss_total': loss_total.item(),
        'G_A_loss':   loss_a.item(),
        'G_B_loss':   loss_b.item(),
        'adv_A':      adv_total_a.item(),
        'adv_B':      adv_total_b.item(),
        'cycle_A':    loss_cycle_A.item(),
        'cycle_B':    loss_cycle_B.item(),
        'idt_A':      identity_a.item(),
        'idt_B':      identity_b.item()
    }

In [None]:
sample_dir = 'generated'
os.makedirs(sample_dir, exist_ok=True)

### Saving Generated Samples

In [None]:
# Denormalize from [-1,1] back to [0,1]
def denorm(imgs):
    return imgs * 0.5 + 0.5

In [None]:
def save_cycle_samples(
    epoch: int,
    real_A: torch.Tensor,
    real_B: torch.Tensor,
    generator_A: nn.Module,
    generator_B: nn.Module,
    denorm,
    sample_dir: str = "generated",
    nrow: int = 8
):
    generator_A.eval()
    generator_B.eval()

    with torch.no_grad():
        fake_B = generator_A(real_A.to(next(generator_A.parameters()).device))
        fake_A = generator_B(real_B.to(next(generator_B.parameters()).device))

    # bring back to [0,1]
    real_A_vis = denorm(real_A.cpu())
    real_B_vis = denorm(real_B.cpu())
    fake_A_vis = denorm(fake_A.cpu())
    fake_B_vis = denorm(fake_B.cpu())

    # make grids
    grid_A2B = make_grid(
        torch.cat([real_A_vis, fake_B_vis], dim=0),
        nrow=nrow,
        padding=2,
        normalize=False
    )
    grid_B2A = make_grid(
        torch.cat([real_B_vis, fake_A_vis], dim=0),
        nrow=nrow,
        padding=2,
        normalize=False
    )

    # save
    save_image(grid_A2B, os.path.join(sample_dir, f"A2B_epoch{epoch:03d}.png"))
    save_image(grid_B2A, os.path.join(sample_dir, f"B2A_epoch{epoch:03d}.png"))

### Full Training Loop

In [None]:
class ReplayBuffer():
    def __init__(self, max_size=50):
        self.data = []
        self.max_size = max_size

    def push_and_pop(self, images):
        out = []
        for img in images:
            img = img.unsqueeze(0)
            if len(self.data) < self.max_size:
                self.data.append(img)
                out.append(img)
            else:
                if random.random() < 0.5:
                    i = random.randint(0, self.max_size-1)
                    out.append(self.data[i].clone())
                    self.data[i] = img
                else:
                    out.append(img)
        return torch.cat(out)

In [1]:
epochs = 200

def lambda_rule(epoch):
    # 1.0 for epoch ∈ [0, epochs/2), then linearly to 0 by epoch=epochs
    return 1.0 - max(0, epoch - epochs//2) / float(epochs//2)

In [None]:
def init_weights(m):
    if hasattr(m, "weight") and m.weight is not None:
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight, 0.0, 0.02)
        elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
            nn.init.normal_(m.weight, 1.0, 0.02)
            
    if hasattr(m, "bias") and m.bias is not None:
        nn.init.constant_(m.bias, 0)

In [None]:
def fit(
    discriminator_A: nn.Module,
    discriminator_B: nn.Module,
    generator_A:     nn.Module,
    generator_B:     nn.Module,
    train_dl:        DataLoader,
    denorm,
    device:          torch.device,
    epochs:          int      = 200,
    lr:              float    = 2e-4,
    lambda_a:        float    = 10.0,
    lambda_b:        float    = 10.0,
    lambda_id:       float    = 0.5,
    pool_size:       int      = 50,
    sample_dir:      str      = "generated",
    nrow:            int      = 8
):
    opt_G = torch.optim.Adam(
        list(generator_A.parameters()) + list(generator_B.parameters()),
        lr=lr, betas=(0.5, 0.999)
    )
    opt_D = torch.optim.Adam(
        list(discriminator_A.parameters()) + list(discriminator_B.parameters()),
        lr=lr, betas=(0.5, 0.999)
    )

    sched_G = torch.optim.lr_scheduler.LambdaLR(opt_G, lr_lambda=lambda_rule)
    sched_D = torch.optim.lr_scheduler.LambdaLR(opt_D, lr_lambda=lambda_rule)

    # grab one fixed batch for visualization 
    fixed_A, fixed_B = next(iter(train_dl))
    fixed_A, fixed_B = fixed_A.to(device), fixed_B.to(device)

    criterion_GAN = nn.MSELoss() # LSGAN
    buffer_A = ReplayBuffer(pool_size)  # for fake_A
    buffer_B = ReplayBuffer(pool_size)  # for fake_B

    # ——— history ———
    history = {
        'G_A': [], 'G_B': [],
        'D_A': [], 'D_B': [],
        'cycle_A': [], 'cycle_B': [],
        'idt_A': [],   'idt_B': [],
        'adv_A': [],   'adv_B': []
    }

    for epoch in range(1, epochs + 1):
        for real_A, real_B in train_dl:
            real_A, real_B = real_A.to(device), real_B.to(device)

            gen_metrics = train_generator(
                discriminator_A, discriminator_B,
                generator_A, generator_B,
                real_A, real_B,
                lambda_a, lambda_b, lambda_id,
                opt_G, criterion_GAN
            )
            
            # produce fresh fakes (for the buffer)
            fake_B = generator_A(real_A).detach()
            fake_A = generator_B(real_B).detach()

            # ——— 2) Discriminators ———
            # pull from buffer
            fake_A_buf = buffer_A.push_and_pop(fake_A)
            fake_B_buf = buffer_B.push_and_pop(fake_B)

            disc_metrics = train_discriminator(
                discriminator_A, discriminator_B,
                real_A, real_B,
                fake_A_buf, fake_B_buf,
                opt_D, criterion_GAN
            )

            # Log losses & scores
            history['G_A'].append(gen_metrics['G_A_loss'])
            history['G_B'].append(gen_metrics['G_B_loss'])
            history['adv_A'].append(gen_metrics['adv_A'])
            history['adv_B'].append(gen_metrics['adv_B'])
            history['cycle_A'].append(gen_metrics['cycle_A'])
            history['cycle_B'].append(gen_metrics['cycle_B'])
            history['idt_A'].append(gen_metrics['idt_A'])
            history['idt_B'].append(gen_metrics['idt_B'])

            history['D_A'].append(disc_metrics['loss_D_A'])
            history['D_B'].append(disc_metrics['loss_D_B'])

        # Step the schedulers each epoch
        sched_G.step()
        sched_D.step()

        if epoch % 5 == 0:
            torch.save(generator_A.state_dict(), f"checkpoint_gen_a_epoch{epoch}.pth")
            torch.save(generator_B.state_dict(), f"checkpoint_gen_b_epoch{epoch}.pth")

        # every epoch dump sample grids A→B and B→A
        save_cycle_samples(
            epoch,
            fixed_A, fixed_B,
            generator_A, generator_B,
            denorm,
            sample_dir=sample_dir,
            nrow=nrow
        )

        print(f"Epoch {epoch}/{epochs}  "
              f"G_A: {gen_metrics['G_A_loss']:.3f}, "
              f"G_B: {gen_metrics['G_B_loss']:.3f}, "
              f"D_A: {disc_metrics['loss_D_A']:.3f}, "
              f"D_B: {disc_metrics['loss_D_B']:.3f}")

    return history


SyntaxError: expected ':' (3572880838.py, line 1)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

discriminator_a = Discriminator().to(device)
discriminator_b = Discriminator().to(device)
generator_a     = Generator().to(device)
generator_b     = Generator().to(device)

generator_a.apply(init_weights)
generator_b.apply(init_weights)
discriminator_a.apply(init_weights)
discriminator_b.apply(init_weights)

history = fit(
    discriminator_a=discriminator_a,
    discriminator_b=discriminator_b,
    generator_a=generator_a,
    generator_b=generator_b,
    train_dl=train_dl,
    denorm=denorm,
    device=device
)

In [None]:
losses_G_A = history['G_A']
losses_G_B = history['G_B']
losses_D_A = history['D_A']
losses_D_B = history['D_B']
adv_A      = history['adv_A']
adv_B      = history['adv_B']
cycle_A    = history['cycle_A']
cycle_B    = history['cycle_B']
idt_A      = history['idt_A']
idt_B      = history['idt_B']