# GANDhani: CycleGAN for Cultural Style Transfer Translating Bandhani Textile Motifs onto Contemporary Apparel

### Import Required Libraries

In [6]:
import torch
import torch.nn as nn

### Data Proprocessing

### Discriminator Network

In [7]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, base_features=64):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, base_features, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_features, base_features * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(base_features * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_features * 2, base_features * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(base_features * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_features * 4, base_features * 8, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(base_features * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_features * 8, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)

### Generator Network

In [8]:
class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()

        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0),
            nn.InstanceNorm2d(dim),
            nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0),
            nn.InstanceNorm2d(dim),
        )

    def forward(self, x):
        return x + self.block(x) # Skip Connections

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()

        self.g1 = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, features, kernel_size=7, stride=1, padding=0),
            nn.InstanceNorm2d(features),
            nn.ReLU(True),
        )

        self.g2 = nn.Sequential(
            nn.Conv2d(features, features*2, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(features*2),
            nn.ReLU(True),
        )

        self.g3 = nn.Sequential(
            nn.Conv2d(features*2, features*4, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(features*4),
            nn.ReLU(True),
        )

        res_blocks = []

        for _ in range(9):
            res_blocks.append(ResidualBlock(features*4))
        self.res_blocks = nn.Sequential(*res_blocks)

        self.g4 = nn.Sequential(
            nn.ConvTranspose2d(features*4, features*2, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(features*2),
            nn.ReLU(True),
        )

        self.g5 = nn.Sequential(
            nn.ConvTranspose2d(features*2, features, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(features),
            nn.ReLU(True),
        )

        self.g6 = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(features, out_channels, kernel_size=7, stride=1, padding=0),
            nn.Tanh()
        )

    def forward(self, x):
        g1 = self.g1(x)
        g2 = self.g2(g1)
        g3 = self.g3(g2)
        res = self.res_blocks(g3)
        g4 = self.g4(res)
        g5 = self.g5(g4)
        
        return self.g6(g5)

### Discriminator Training

### Generator Training

### Full Training Loop