In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision.models import vgg19, VGG19_Weights
from torchvision.transforms import transforms
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import os
from datasets import load_dataset
from tqdm import tqdm

In [17]:
low_res_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

content_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

style_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


## Generator

In [18]:
class AdaINLayer(nn.Module):
    def forward(self, content, style):
        style_mean, style_std = style.mean([2, 3], keepdim=True), style.std([2, 3], keepdim=True)
        content_mean, content_std = content.mean([2, 3], keepdim=True), content.std([2, 3], keepdim=True)
        normalized_content = (content - content_mean) / content_std
        return normalized_content * style_std + style_mean

In [19]:
class ResidualBlockAdaIN(nn.Module):
    def __init__(self, channels):
        super(ResidualBlockAdaIN, self).__init__()
        self.conv1 = nn.Conv2d(channels,
                               channels,
                               kernel_size=3, 
                               stride=1, 
                               padding=1)
        
        self.adain1 = AdaINLayer()
        
        self.conv2 = nn.Conv2d(channels, 
                               channels, 
                               kernel_size=3, 
                               stride=1, 
                               padding=1)
        
        self.adain2 = AdaINLayer()

    def forward(self, x, style_features):
        # print(f'In Resnet class, self.conv1 shape is: {self.conv1(x).shape} while the style shape is: {style_features.shape}')
        res = x
        x = F.relu(self.adain1(self.conv1(x), style_features))
        x = self.adain2(self.conv2(x), style_features)
        return x + res

In [None]:
class GeneratorB2A(nn.Module):
    def __init__(self):
        super(GeneratorB2A, self).__init__()

        self.initial_conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 3, kernel_size=4, stride=1, padding=1),
            nn.Tanh()
        )


    def forward(self, x):
        return self.initial_conv(x)

In [20]:
class GeneratorA2B(nn.Module):
    def __init__(self):
        super(GeneratorA2B, self).__init__()
        
        self.initial_conv = nn.Sequential(
            nn.Conv2d(in_channels=3, 
                      out_channels=64, 
                      kernel_size=7, 
                      stride=1, 
                      padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
    
        self.down1 = nn.Sequential(
            nn.Conv2d(in_channels=64, 
                      out_channels=128, 
                      kernel_size=3, 
                      stride=2, 
                      padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.down2 = nn.Sequential(
            nn.Conv2d(in_channels=128, 
                      out_channels=256, 
                      kernel_size=3, 
                      stride=2, 
                      padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.res_blocks = nn.Sequential(
            *[ResidualBlockAdaIN(256) for _ in range(6)]
        )

        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, 
                               out_channels=128,
                                kernel_size=3, 
                                stride=2, 
                                padding=1, 
                                output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, 
                               out_channels=64, 
                               kernel_size=3, 
                               stride=2, 
                               padding=1, 
                               output_padding=1),
            nn.InstanceNorm2d(64), 
            nn.ReLU(inplace=True)
        )

        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64,
                               out_channels=32,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.InstanceNorm2d(32),
            nn.ReLU(inplace=True)
        )

        self.up4 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32,
                               out_channels=32,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.InstanceNorm2d(32),
            nn.ReLU(inplace=True)
        )


        self.final_conv = nn.Sequential(
            nn.Conv2d(in_channels=32,
                      out_channels=3, 
                      kernel_size=7, 
                      stride=1, 
                      padding=3),
            nn.Tanh()
        )

    def forward(self, x, style_features):
        print(f'This if from inside of the Generator, shape of x before passing through convs is: {x.shape}')
        x = self.initial_conv(x)
        x = self.down1(x)
        x = self.down2(x)
        for block in self.res_blocks:
            x = block(x, style_features)
            print(f'This if from inside of the Generator, shape of x after passing through resnet is: {x.shape}')
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        x = self.up4(x)
        print(f'This if from inside of the Generator, shape of x after passing through upconvs is: {x.shape}')
        
        return self.final_conv(x)


+## Discriminator

In [21]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=0, padding_mode='reflect'):
        super(CNNBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,
                      out_channels,
                      kernel_size,
                      stride,
                      padding,
                      padding_mode='reflect'
                      ),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


In [22]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super(Discriminator, self).__init__()
        self.initial_layer = nn.Sequential(
            nn.Conv2d(in_channels=3,
                      out_channels=64,
                      kernel_size=4,
                      stride=2,
                      padding=1), 
            nn.LeakyReLU(0.2, inplace=True)
        )

        layers = list()

        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNNBlock(
                    in_channels,
                    out_channels=feature,
                    kernel_size=4,
                    stride=1 if feature == features[-1] else 2,
                    padding=0
                )
            )
            in_channels = feature

        layers.append(
            nn.Conv2d(in_channels,
                      out_channels=1,
                      kernel_size=4,
                      stride=1,
                      padding=1,
                      padding_mode='reflect')
        )

        self.model = nn.Sequential(*layers)


    def forward(self, x):
        x = self.initial_layer(x)
        return torch.sigmoid(self.model(x))
        

def test():
    x = torch.randn((5, 3, 256, 256))
    model = Discriminator(in_channels=3)
    preds = model(x)
    print(preds.shape)

if __name__ == '__main__':
    test()

torch.Size([5, 1, 26, 26])


### WikiArt dataset

In [23]:
class CombinedDataset(Dataset):
    def __init__(self, div2k_dir, wikiart_dir, low_res_transform, content_transform, style_transform):
        self.div2k_dir = div2k_dir
        self.wikiart_dir = wikiart_dir
        self.div2k_filenames = [f for f in os.listdir(self.div2k_dir) if os.path.isfile(os.path.join(self.div2k_dir, f))]
        self.wikiart_filenames = [f for f in os.listdir(wikiart_dir) if os.path.isfile(os.path.join(wikiart_dir, f))]

        # self.wikiart = wikiart_dataset
        self.low_res_transform = low_res_transform
        self.content_transform = content_transform
        self.style_transform = style_transform

    def __len__(self):
        # Return the smaller of the two dataset lengths to avoid mismatch
        return min(len(self.div2k_filenames), len(self.wikiart_filenames))

    def __getitem__(self, idx):
        # Get low-resolution image from DIV2K
        div2k_path = os.path.join(self.div2k_dir, self.div2k_filenames[idx])
        wikiart_path = os.path.join(self.wikiart_dir, self.wikiart_filenames[idx % len(self.wikiart_filenames)])

        div2k_img = Image.open(div2k_path).convert('RGB')
        wikiart_img = Image.open(wikiart_path).convert('RGB')

        # Get high-resolution stylized image from WikiArt
        low_res_image = self.low_res_transform(div2k_img)
        content_image = self.content_transform(div2k_img)  # Assuming key is "image"
        style_image = self.style_transform(wikiart_img)  # Style image from WikiArt

        return {
            "low_res": low_res_image,
            "content": content_image,
            "style": style_image,
        }


In [24]:
wikiart_dataset_path = "/home/mehran/Projects/SRStyle/Dataset/wikiart_shuffled/"

div2k_dataset_path = "/home/mehran/Projects/SRStyle/Dataset/div2k_x2_train/"
combined_dataset = CombinedDataset(div2k_dataset_path, wikiart_dataset_path, low_res_transform, content_transform, style_transform)
dataloader = DataLoader(combined_dataset, batch_size=4, shuffle=True, num_workers=4)

In [25]:
dataloader.dataset[0]

{'low_res': tensor([[[ 0.9294,  0.9294,  0.9294,  ...,  0.8745,  0.8824,  0.8824],
          [ 0.9294,  0.9216,  0.9216,  ...,  0.8667,  0.8667,  0.8667],
          [ 0.9059,  0.9373,  0.9294,  ...,  0.8431,  0.8353,  0.8275],
          ...,
          [-0.8431, -0.7020, -0.6392,  ..., -0.2314, -0.3412, -0.4431],
          [-0.7647, -0.6706, -0.6549,  ..., -0.3882, -0.4824, -0.5373],
          [-0.7255, -0.7255, -0.7098,  ..., -0.5059, -0.5608, -0.5529]],
 
         [[ 0.9059,  0.9137,  0.9059,  ...,  0.8588,  0.8510,  0.8510],
          [ 0.9059,  0.8980,  0.8980,  ...,  0.8431,  0.8353,  0.8353],
          [ 0.8980,  0.9216,  0.9137,  ...,  0.8196,  0.8196,  0.8118],
          ...,
          [-0.7961, -0.7020, -0.7020,  ..., -0.2078, -0.3490, -0.4196],
          [-0.7255, -0.6706, -0.6863,  ..., -0.4039, -0.4980, -0.5451],
          [-0.6706, -0.7333, -0.6941,  ..., -0.5216, -0.5294, -0.5216]],
 
         [[ 0.9373,  0.9373,  0.9451,  ...,  0.8745,  0.8745,  0.8745],
          [ 0.952

In [26]:
# === VGG Feature Extractor ===
class VGGFeatures(nn.Module):
    def __init__(self, layers_to_extract):
        super(VGGFeatures, self).__init__()
        vgg = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features
        layer_map = {
            "conv1_1": 0,
            "conv2_1": 5,
            "conv3_1": 10,
            "conv4_1": 19,
            "conv5_1": 28
        }

        self.required_layers = sorted([layer_map[layer] for layer in layers_to_extract])
        self.model = vgg[:self.required_layers[-1] + 1]


    def forward(self, x):
        features = {}
        for i, layer in enumerate(self.model):
            x = layer(x)  # Pass input through each layer
            if i in self.required_layers:
                features[f"layer_{i}"] = x
        return features

In [27]:
# === Loss Functions ===
def content_loss(content_features, generated_features):
    return nn.MSELoss()(generated_features, content_features)


def adversarial_loss(prediction, target_is_real):
    target = torch.ones_like(prediction) if target_is_real else torch.zeros_like(prediction)
    return nn.BCEWithLogitsLoss()(prediction, target)


def style_loss(style_features, generated_features):
    loss = 0
    for sf, gf in zip(style_features.values(), generated_features.values()):
        if sf.size() != gf.size():
            raise RuntimeError("Style and generated features have different shapes.")

        # Compute MSE between feature maps directly
        loss += torch.nn.functional.mse_loss(gf, sf)
    return loss

In [29]:
# === Training Loop ===
def train_model(generator_a2b, generator_b2a, discriminator_a, discriminator_b, vgg_features, optimizer_g,
                optimizer_d_a, optimizer_d_b, dataloader, device, epochs):
    for epoch in range(epochs):
        epoch_progress = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch [{epoch + 1}/{epochs}]")
        for i, data in epoch_progress:
            low_res_batch = data['low_res']
            style_batch = data['style']
            content_batch = data['content']
            low_res = low_res_batch.to(device)
            style = style_batch.to(device)
            content = content_batch.to(device)


            # Extract VGG features for content and style
            content_features = vgg_features(content)
            style_features = vgg_features(style)

            content_feature = content_features["layer_10"]
            style_feature = style_features["layer_10"]

            # === Generator A2B Training ===
            optimizer_g.zero_grad()
            generated_high_res = generator_a2b(low_res, style_feature)

            generated_high_res_transform = transforms.Compose([
            transforms.CenterCrop((224, 224)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
            generated_features = vgg_features(generated_high_res_transform(generated_high_res))['layer_10']
            c_loss = content_loss(content_feature, generated_features)
            s_loss = style_loss(style_feature.to_sparse(), generated_features.to_sparse())
            g_loss_adv = adversarial_loss(discriminator_b(generated_high_res), True)
            loss_g_a2b = c_loss + s_loss + 0.1 * g_loss_adv  # Combine losses
            loss_g_a2b.backward()
            optimizer_g.step()

            print('------------- Generator B2A Training... -----------------')
            # === Generator B2A Training (Cycle Consistency) ===
            optimizer_g.zero_grad()
            reconstructed_low_res = generator_b2a(generated_high_res)
            cycle_consistency_loss = content_loss(low_res, reconstructed_low_res)
            g_loss_adv_b2a = adversarial_loss(discriminator_a(reconstructed_low_res), True)
            loss_g_b2a = cycle_consistency_loss + 0.1 * g_loss_adv_b2a  # Combine losses

            loss_g_b2a.backward()
            optimizer_g.step()

            print('$$$$$$$$$$$$$$$$$$ Discriminator A Training... $$$$$$$$$$$$$$$$$$$$')
            # === Discriminator A Training ===
            optimizer_d_a.zero_grad()
            real_loss_a = adversarial_loss(discriminator_a(low_res), True)
            fake_loss_a = adversarial_loss(discriminator_a(reconstructed_low_res.detach()), False)
            loss_d_a = 0.5 * (real_loss_a + fake_loss_a)

            loss_d_a.backward()
            optimizer_d_b.step()

            print('##################### Discriminator B Training... ######################')
            # === Discriminator B Training ===
            optimizer_d_b.zero_grad()
            real_loss_b = adversarial_loss(discriminator_b(content), True)
            fake_loss_b = adversarial_loss(discriminator_b(generated_high_res.detach()), False)
            loss_d_b = 0.5 * (real_loss_b + fake_loss_b)

            loss_d_b.backward()
            optimizer_d_b.step()

            # Update loop description
            epoch_progress.set_postfix({
                "Loss_G_A2B": loss_g_a2b.item(),
                "Loss_G_B2A": loss_g_b2a.item(),
                "Loss_D_A": loss_d_a.item(),
                "Loss_D_B": loss_d_b.item()
            })
        print('()()()()()()()() Loop Description completed successfully!')
        torch.save(generator_a2b.state_dict(), "generator_a2b.pth")
        torch.save(generator_b2a.state_dict(), "generator_b2a.pth")
        torch.save(discriminator_a.state_dict(), "discriminator_a.pth")
        torch.save(discriminator_b.state_dict(), "discriminator_b.pth")


In [None]:
device = 'cuda'
# Initialize the VGG feature extractor for content and style losses
vgg_features = VGGFeatures(layers_to_extract=["conv1_1", "conv2_1", "conv3_1", "conv4_1", "conv5_1"]).to(device)

generator_a2b = GeneratorA2B()
generator_b2a = GeneratorB2A()
discriminator_a = Discriminator()
discriminator_b = Discriminator()

optimizer_g = torch.optim.Adam(
    list(generator_a2b.parameters()) + list(generator_b2a.parameters()), lr=2e-4, betas=(0.5, 0.999)
)
optimizer_d_a = torch.optim.Adam(discriminator_a.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizer_d_b = torch.optim.Adam(discriminator_b.parameters(), lr=2e-4, betas=(0.5, 0.999))

# Move models to the device
generator_a2b = generator_a2b.to(device)
generator_b2a = generator_b2a.to(device)
discriminator_a = discriminator_a.to(device)
discriminator_b = discriminator_b.to(device)


# Call the training loop
train_model(
    generator_a2b=generator_a2b,
    generator_b2a=generator_b2a,
    discriminator_a=discriminator_a,
    discriminator_b=discriminator_b,
    vgg_features=vgg_features,
    optimizer_g=optimizer_g,
    optimizer_d_a=optimizer_d_a,
    optimizer_d_b=optimizer_d_b,
    dataloader=dataloader,
    device=device,
    epochs=10  # Adjust the number of epochs
)