# Prepare Data

We create the data loader of Coco dataset for training and tesing.

In [7]:
import os
import random
from PIL import Image

import math

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from torchvision import datasets, transforms
from torchvision.models import vgg19
from torchvision.models import vgg19
from torchvision.utils import save_image

import matplotlib.pyplot as plt

In [6]:
from google.colab import drive
drive.mount('/content/drive')

base_path = '/content/drive/My Drive/cs131_data/'

Mounted at /content/drive


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
# Image extensions
IMG_EXTENSIONS = set([
    '.jpg', '.jpeg', '.png', '.ppm', '.bmp'
])

def is_image_file(filename):
    return any(filename.lower().endswith(extension) for extension in IMG_EXTENSIONS)

def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname) and ("outfit" not in fname):
                path = os.path.join(root, fname)
                images.append(path)

    return sorted(images)

class AnimeDataset(Dataset):
    """
    A dataset class for loading anime and real images using PyTorch
    """
    def __init__(self, transform=None):
        self.real_image_paths = make_dataset(f'{base_path}data/d/trainA')
        self.anime_image_paths = make_dataset(f'{base_path}data/d/trainB')
        self.transform = transform or transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])

    @staticmethod
    def loader(path):
        # Open the image file, ensuring it's read in RGB format
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def __getitem__(self, index):
        real_image_path = self.real_image_paths[index]
        anime_image_path = self.anime_image_paths[index]

        real_image = self.loader(real_image_path)
        anime_image = self.loader(anime_image_path)

        if self.transform is not None:
            real_image = self.transform(real_image)
            anime_image = self.transform(anime_image)

        return real_image, anime_image

    def __len__(self):
        return min(len(self.anime_image_paths), len(self.real_image_paths))


def collate_fn(batch):
    # Only process the images and ignore the annotations
    images = [item[0] for item in batch]  # Extract images from the batch
    images = torch.stack(images, dim=0)   # Stack images into a single tensor
    return images

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize images to 256x256
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = AnimeDataset(transform=transform)

dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    drop_last =True
)

# Utils

In [None]:
def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std

def normal(feat, eps=1e-5):
    feat_mean, feat_std= calc_mean_std(feat, eps)
    normalized=(feat-feat_mean)/feat_std
    return normalized

# Build Model

## Embedding Layer

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_channels=3, emb_size=512):
        super(PatchEmbedding, self).__init__()
        self.patch_size = (patch_size, patch_size)
        self.projection = nn.Conv2d(in_channels, emb_size, kernel_size=self.patch_size, stride=self.patch_size)

    def forward(self, x):
        x = self.projection(x)
        return x

## Transformer

Here we use Content-Aware Positional Encoding [1].

[1] Deng, Yingying, et al. "[Stytr2: Image style transfer with transformers.](https://openaccess.thecvf.com/content/CVPR2022/html/Deng_StyTr2_Image_Style_Transfer_With_Transformers_CVPR_2022_paper.html)" Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.

In [None]:
class ContentAwarePositionalEncoding(nn.Module):
    def __init__(self, emb_size, n_positional_encoding=18):
        super().__init__()
        self.emb_size = emb_size
        self.n_positional_encoding = n_positional_encoding
        self.avgpool = nn.AdaptiveAvgPool2d((n_positional_encoding, n_positional_encoding))
        self.conv = nn.Conv2d(emb_size, emb_size, 1)

    def forward(self, feature_map):
        downsampled = self.avgpool(feature_map) # out: [B, emb_size, n, n]
        cape = self.conv(downsampled)
        interpolated_cape = F.interpolate(cape, size=(feature_map.shape[-2], feature_map.shape[-1]),
                                          mode='bilinear') # [B, emb_size, H', W']
        return feature_map + interpolated_cape

class TransformerEncoder(nn.Module):
    def __init__(self, emb_size=512, nhead=8, num_layers=3):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_size, nhead=nhead)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        return self.encoder(x)

class TransformerDecoder(nn.Module):
    def __init__(self, emb_size=512, nhead=8, num_layers=3):
        super().__init__()
        decoder_layer = nn.TransformerDecoderLayer(d_model=emb_size, nhead=nhead)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

    def forward(self, content, style):
        return self.decoder(content, style)

class TransformerModel(nn.Module):
    def __init__(self, emb_size=512, nhead=8, num_layers=3, max_len=5000):
        super().__init__()
        self.positional_encoding = ContentAwarePositionalEncoding(emb_size)
        self.transformer_encoder_content = TransformerEncoder(emb_size=emb_size, nhead=nhead, num_layers=num_layers)
        self.transformer_encoder_style = TransformerEncoder(emb_size=emb_size, nhead=nhead, num_layers=num_layers)
        self.transformer_decoder = TransformerDecoder(emb_size=emb_size, nhead=nhead, num_layers=num_layers)

    def forward(self, content_emb, style_emb):
        # Apply positional encoding to input [B, emb_size, H, W]
        content_emb = self.positional_encoding(content_emb)
        style_emb = self.positional_encoding(style_emb)

        # Flatten to [HW, B, emb_size] for transformer processing
        B, emb_size, H, W = content_emb.size()
        content_emb = content_emb.flatten(2).permute(2, 0, 1)
        style_emb = style_emb.flatten(2).permute(2, 0, 1)

        # Encode content and style separately
        content_encoded = self.transformer_encoder_content(content_emb)
        style_encoded = self.transformer_encoder_style(style_emb)

        # Decode the combined representation
        decoded = self.transformer_decoder(content_encoded, style_encoded)

        # Reshape back to [B, emb_size, H, W]
        decoded = decoded.permute(1, 2, 0).view(B, emb_size, H, W)
        # print('transformer_decoded - after permute \n', decoded)

        return decoded

## CNN Decoder

The CNN decoder will upsample the transformer output to the original image size.

In [None]:
class CNNDecoder(nn.Module):
    def __init__(self, img_size, patch_size, input_dim=512, output_channels=3):
        super(CNNDecoder, self).__init__()
        num_upsamples = int(torch.log2(torch.tensor(img_size // (patch_size))).item())

        # Initial convolution layer to prepare for upsampling
        current_dim = input_dim
        self.initial_conv = nn.Conv2d(current_dim, current_dim, kernel_size=3, padding=1)
        self.upsample_blocks = nn.ModuleList()

        # Create upsampling blocks
        for i in range(num_upsamples):
            next_dim = current_dim // 2
            self.upsample_blocks.append(
                nn.Sequential(
                    nn.Conv2d(current_dim, next_dim, kernel_size=3, padding=1),
                    nn.ReLU(inplace=True),
                    nn.Upsample(scale_factor=2, mode='nearest')
                )
            )
            current_dim = next_dim

        # Final convolution to get to the image size with the correct number of channels
        self.final_conv = nn.Conv2d(current_dim, output_channels, kernel_size=3, padding=1)

    def forward(self, x):
        # Pass through the initial conv layer
        x = self.initial_conv(x)

        # Pass through each upsample block
        for upsample_block in self.upsample_blocks:
            x = upsample_block(x)

        # Final convolution to get to the image size
        x = self.final_conv(x)

        return x

## Style Transfer Model

Now we combine all the components into a single model.

In [None]:
class StyleTransferModel(nn.Module):
    def __init__(self, img_size=256, patch_size=8, emb_size=512):
        super().__init__()

        self.enc_layers = [
            '1',    # relu1-1
            '6',    # relu2-1
            '11',   # relu3-1
            '20',   # relu4-1
            '29'    # relu5-1
        ]

        self.img_size = img_size
        self.patch_size = patch_size
        self.vgg = self.initialize_vgg()
        self.mse_loss = nn.MSELoss()

        self.patch_embedding = PatchEmbedding(img_size, patch_size, emb_size=emb_size)
        self.transformer = TransformerModel(emb_size=emb_size)
        self.cnn_decoder = CNNDecoder(img_size, patch_size, input_dim=emb_size)

    def initialize_vgg(self):
        vgg_model = vgg19(pretrained=True).features
        freeze_layers_until = int(self.enc_layers[-1])

        # Freeze the parameters for layers up to relu5_1
        for name, param in vgg_model.named_parameters():
            layer_index = int(name.split('.')[0])
            if layer_index <= freeze_layers_until:
                param.requires_grad_(False)

        return vgg_model

    def get_features(self, image):
        layers = set(self.enc_layers)
        features = []
        x = image
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in layers:
                features.append(x)
        return features

    def calc_content_loss(self, input, target):
        assert input.size() == target.size()
        assert target.requires_grad == False
        return self.mse_loss(input, target)

    def calc_style_loss(self, input, target): # no gram version
        assert input.size() == target.size()
        assert target.requires_grad == False
        input_mean, input_std = calc_mean_std(input)
        target_mean, target_std = calc_mean_std(target)
        return self.mse_loss(input_mean, target_mean) + self.mse_loss(input_std, target_std)

    def forward(self, content, style):
        # Embed images from [B, C=3, H=256, W=256] to [B, emb_size, H'=16, W'=16]
        content_emb = self.patch_embedding(content)
        style_emb = self.patch_embedding(style)

        # Use the TransformerModel for encoding and decoding
        decoded = self.transformer(content_emb, style_emb)

        # Pass through CNN Decoder [B=4, C=3, H=256, W=256]
        generated_image = self.cnn_decoder(decoded)

        # Extract features for loss calculation
        content_features = self.get_features(content)
        style_features = self.get_features(style)
        target_features = self.get_features(generated_image)

        # Calculate losses
        content_loss = self.calc_content_loss(normal(target_features[-1]), normal(content_features[-1])) + self.calc_content_loss(normal(target_features[-2]), normal(content_features[-2]))
        style_loss = sum([self.calc_style_loss(normal(target), normal(style)) for target, style in zip(target_features, style_features)])

        # Combine the losses
        total_loss = 3 * content_loss + 2 * style_loss

        # Return generated image and losses
        return generated_image, total_loss, content_loss, style_loss

# Train and Test

In [None]:
model = StyleTransferModel()
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Function to save checkpoint
def save_checkpoint(model, optimizer, step, losses, base_path):
    checkpoint_path = os.path.join(base_path, f'checkpoint_{step}.pth')
    latest_checkpoint_path = os.path.join(base_path, 'latest_checkpoint.pth')
    torch.save({
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'losses': losses,
    }, checkpoint_path)
    # Copy the latest checkpoint to a fixed filename to easily identify the most recent model
    torch.save({
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'losses': losses,
    }, latest_checkpoint_path)


# Function to load checkpoint
def load_checkpoint(model, optimizer, path):
    if os.path.isfile(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_step = checkpoint['step']
        losses = checkpoint.get('losses', [])
        print(f"Checkpoint loaded: Step {start_step}, Losses loaded")
        return start_step, losses
    else:
        print("No checkpoint found at:", path)
        return 0, []  # Start from the beginning if no checkpoint found

In [None]:
checkpoint_dir = f'{base_path}model/checkpoint/'
results_dir = f'{base_path}results'

model_save_path = f'{base_path}model/style_transfer_model.pth'
latest_checkpoint_path = os.path.join(checkpoint_dir, 'latest_checkpoint.pth')

# Ensure the directory exists
os.makedirs(os.path.dirname(checkpoint_dir), exist_ok=True)
os.makedirs(results_dir, exist_ok=True)

In [None]:
# Set maximum training steps
max_steps = 200000

# Load checkpoint if exists
start_step, losses = load_checkpoint(model, optimizer, latest_checkpoint_path)

# Training loop
current_step = start_step
while current_step < max_steps:
    for content_images, style_images in dataloader:
        if current_step >= max_steps:
            break

        content_images = content_images.to(device)
        style_images = style_images.to(device)

        # Forward pass
        optimizer.zero_grad()
        stylized_images, total_loss, content_loss, style_loss = model(content_images, style_images)

        # Track loss
        losses.append(total_loss.item())

        # Backward pass and optimization
        total_loss.backward()
        optimizer.step()

        # Save checkpoint and images
        if (current_step + 1) % 50 == 0:
            avg_loss = sum(losses[-100:]) / min(100, len(losses))
            save_checkpoint(model, optimizer, current_step + 1, losses, checkpoint_dir)
            save_image(stylized_images[0], os.path.join(results_dir, f'stylized_{current_step + 1}.png'))
            save_image(content_images[0], os.path.join(results_dir, f'content_{current_step + 1}.png'))
            save_image(style_images[0], os.path.join(results_dir, f'style_{current_step + 1}.png'))

        current_step += 1  # Increment global step count

# Save final model parameters after training
torch.save(model.state_dict(), model_save_path)


In [None]:
model = StyleTransferModel()
model.to(device)
load_checkpoint(model, optimizer, f'{latest_checkpoint_path}latest_checkpoint')
model.eval()

# Sample a content image and a style image
images_iter = iter(dataloader)
content_image, style_image = next(images_iter).to(device)

# Generate the stylized image
with torch.no_grad():
    stylized_image, _, _, _ = model(content_image, style_image.repeat(content_image.size(0), 1, 1, 1))

original image
 tensor([[[0.2275, 0.2314, 0.2471,  ..., 0.9451, 0.9529, 0.9569],
         [0.2235, 0.2235, 0.2510,  ..., 0.9451, 0.9529, 0.9569],
         [0.2314, 0.2235, 0.2431,  ..., 0.9412, 0.9529, 0.9569],
         ...,
         [0.1882, 0.2000, 0.2000,  ..., 0.2627, 0.2667, 0.2745],
         [0.1922, 0.2000, 0.2039,  ..., 0.2706, 0.2784, 0.2824],
         [0.1882, 0.2118, 0.2157,  ..., 0.2745, 0.2745, 0.2745]],

        [[0.2235, 0.2275, 0.2471,  ..., 0.9647, 0.9725, 0.9765],
         [0.2196, 0.2235, 0.2510,  ..., 0.9765, 0.9725, 0.9765],
         [0.2314, 0.2235, 0.2431,  ..., 0.9725, 0.9725, 0.9765],
         ...,
         [0.1529, 0.1608, 0.1608,  ..., 0.2235, 0.2353, 0.2353],
         [0.1569, 0.1608, 0.1647,  ..., 0.2392, 0.2471, 0.2471],
         [0.1529, 0.1725, 0.1765,  ..., 0.2471, 0.2471, 0.2431]],

        [[0.2431, 0.2431, 0.2471,  ..., 0.9412, 0.9490, 0.9529],
         [0.2353, 0.2353, 0.2510,  ..., 0.9490, 0.9490, 0.9529],
         [0.2392, 0.2314, 0.2471,  ..., 0.