In [1]:
import numpy as np
import pandas as pd

import os

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import tqdm

in this project, we just use part of the coco dataset, which is the test set

In [2]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = torch.device("mps")
print(device)

cuda


In [26]:
class MS_COCO(Dataset):
    def __init__(self, img_dir, transform=None):
        """
        Args:
            img_dir (str): Path to the directory containing images.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.img_dir = img_dir
        self.image_paths = [os.path.join(img_dir, img_name) for img_name in os.listdir(img_dir)]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image

img_size = [256, 256]
# define the transform
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])
transforms = torchvision.transforms.Compose([
             torchvision.transforms.Resize(img_size),
             torchvision.transforms.ToTensor(),
             torchvision.transforms.Normalize(mean=rgb_mean,
                                         std=rgb_std)
                ])

dataset = MS_COCO("/kaggle/input/coco-2014-dataset-for-yolov3/coco2014/images/test2014", transform=transforms)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [27]:
pretrained_net = torchvision.models.vgg16(weights='VGG16_Weights.IMAGENET1K_V1')

In [28]:
# then we need to modified the model
net = nn.Sequential(*[pretrained_net.features[i] for i in range(24)])
net = net.to(device)

In [29]:
# we define the model for "generate" an image
def ConvLayer(in_channels, out_channels, kernel_size=3, stride=1, 
    upsample=None, instance_norm=True, relu=True):
    layers = []
    if upsample:
        layers.append(nn.Upsample(mode='nearest', scale_factor=upsample))
    layers.append(nn.ReflectionPad2d(kernel_size // 2))
    layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride))
    if instance_norm:
        layers.append(nn.InstanceNorm2d(out_channels))
    if relu:
        layers.append(nn.ReLU())
    return layers

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv = nn.Sequential(
            *ConvLayer(channels, channels, kernel_size=3, stride=1), 
            *ConvLayer(channels, channels, kernel_size=3, stride=1, relu=False)
        )

    def forward(self, x):
        return self.conv(x) + x

class TransformNet(nn.Module):
    def __init__(self, base=32):
        super(TransformNet, self).__init__()
        self.downsampling = nn.Sequential(
            *ConvLayer(3, base, kernel_size=9), 
            *ConvLayer(base, base*2, kernel_size=3, stride=2), 
            *ConvLayer(base*2, base*4, kernel_size=3, stride=2), 
        )
        self.residuals = nn.Sequential(*[ResidualBlock(base*4) for i in range(5)])
        self.upsampling = nn.Sequential(
            *ConvLayer(base*4, base*2, kernel_size=3, upsample=2),
            *ConvLayer(base*2, base, kernel_size=3, upsample=2),
            *ConvLayer(base, 3, kernel_size=9, instance_norm=False, relu=False),
        )
    
    def forward(self, X):
        y = self.downsampling(X)
        y = self.residuals(y)
        y = self.upsampling(y)
        return y

In [30]:
tf_net = TransformNet()
tf_net = tf_net.to(device)

handle the style image

In [31]:
style_path = '/kaggle/input/style-image/the_scream.jpg'
style_image = Image.open(style_path).convert('RGB')
style_tensor = transforms(style_image)
style_tensor = style_tensor.unsqueeze(0)
style_tensor = style_tensor.to(device)
# style_image

### Extract content and style

In [32]:
def extract_feature(images, vgg_model, content_layer, style_layers):
    styles = []
    x = images
    for i in range(len(vgg_model)):
        x = vgg_model[i](x)
        if i in style_layers:
            styles.append(x)
        if i in content_layer:
            content = x
    return content, styles

### Calculate loss
#### Calculate content loss
#### Calculate style loss
#### Calculate TV loss

In [33]:
# content_loss_fn = nn.MSELoss()
# content_weight = 1
# content_loss = content_weight * content_loss_fn(real_content_out.repeat(fake_content_out.size(0),1,1,1), 
#                                                 fake_content_out)
def calculate_content_loss(content_weight, content_loss_fn, 
                           transformed_content, original_content):
    """
    # since we just select one layer as the content layer, so no need to consider the loops
    content_weight:
    content_loss_fn: MSELoss
    transformed_content: the output of content layer of 'fake' image, torch.Tensor
    original_content: the output of content layer of style image, torch.Tensor
    """
    content_loss = content_weight * content_loss_fn(original_content, transformed_content)
    return content_loss

def calculate_tv_loss(tv_weight, transfromed_img):
    """
    tv_weight:
    transformed_img: output from transform network
    """
    return tv_weight * (torch.abs(transfromed_img[:, :, 1:, :] - transfromed_img[:, :, :-1, :]).mean() +
                        torch.abs(transfromed_img[:, :, :, 1:] - transfromed_img[:, :, :, :-1]).mean())

def gram_matrix(y):
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t) / (ch * h * w)
    return gram

def calculate_style_loss(style_weight, style_loss_fn, 
                         transformed_style, original_style_gram):
    """
    # the style layers is not a single layer, so we need to consider loop
    style_weight: 
    sytle_loss_fn: MSELoss
    transformed_style: the outputs of sytle layers of 'fake' image, list[torch.Tensor]
    original_style_gram: the gram matrix outputs of sytle layers of original image, list[torch.Tensor]
    Note:
        since this project is for a fixed style transform, therefore we don't calcualte the style and gram of target style while training.
        in order to save some computational resource
    """
    style_loss = 0
    transformed_grams = [gram_matrix(x) for x in transformed_style]
    for transformed_gram, original_style_gram in zip(transformed_grams, original_style_gram):
        original_style_gram = original_style_gram.detach()
        style_loss += style_weight * style_loss_fn(transformed_gram, 
                                                   original_style_gram.expand_as(transformed_gram))
    return style_loss

In [None]:
# train:
# now we have all the components, then we just need put them togather
# test one epoch

content_layer = [23]
style_layers = [4, 9, 16, 23]
lr = 0.01
epochs = 5

optimizer = torch.optim.Adam(tf_net.parameters(), lr=lr)
# so first we need to get the style gram matrix of the original style image
_, target_style = extract_feature(style_tensor, net, content_layer, style_layers)
# target_style = target_style.to(device)
# calcualte target gram matrix
target_gram = [gram_matrix(x).to(device) for x in target_style]

for epoch in range(epochs):
    content_losses, style_losses, Tv_losses = [], [], []
#     with tqdm(enumerate(data_loader), total=n_batch) as pbar
    for i, content_img in enumerate(tqdm.tqdm(dataloader)):
        # keep track the losses
        c_losses, s_losses, tv_losses = [], [], []

        # clear the grad
        optimizer.zero_grad()

        # move to device
        content_img = content_img.to(device)

        # first we get the transformed image
        transformed_img = tf_net(content_img)

        # then we need to get the content of content_image
        target_content, _ = extract_feature(content_img, net, content_layer, style_layers)
        # content and style of the transformed image
        transformed_content, transformed_styles = extract_feature(transformed_img, net, content_layer, style_layers) 

        # then we need to calcualte the content loss
        c_loss = calculate_content_loss(1, nn.MSELoss(), transformed_content, target_content)
        c_losses.append(c_loss.item())

        # calculate the style loss
        s_loss = calculate_style_loss(1e3, nn.MSELoss(), transformed_styles, target_gram)
        s_losses.append(s_loss.item())

        # calculate the tv loss
        tv_loss = calculate_tv_loss(1, transformed_img)
        tv_losses.append(tv_loss.item)

        loss = c_loss + s_loss + tv_loss
        # then we need to update the parameter of the tf_net
        loss.backward()
        optimizer.step()
    print(sum(c_losses))
    print(sum(s_losses))
    print(sum(tv_losses))
    content_losses.append(sum(c_losses))
    style_losses.append(sum(s_losses))
    Tv_losses.append(sum(tv_losses))

 16%|█▌        | 402/2549 [04:29<24:01,  1.49it/s]