In [None]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2


In [None]:
TRAIN_PATH = 'train2017' # coco 2017 images

In [None]:
# create dataset

images = pd.DataFrame({
    "images_paths": [os.path.join(TRAIN_PATH, img_path) for img_path in os.listdir('train2017')]
})

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
                 use_activation=True,
                use_instance_norm=True):
        
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, padding_mode='reflect'),
            nn.InstanceNorm2d(out_channels, affine=True) if use_instance_norm else nn.Identity(),
            nn.ReLU() if use_activation else nn.Identity()
        
        )
        
    def forward(self, x):
        return self.conv_block(x)
        
        

class ResBlock(nn.Module):
    def __init__(self, in_channels, kernel_size, stride, padding):
        super().__init__()
        self.res_block = nn.Sequential(
        ConvBlock(in_channels, in_channels, kernel_size, stride, padding),
        ConvBlock(in_channels, in_channels, kernel_size, stride, padding, use_activation=False)    
        
        )
        
    def forward(self, x):
        return x + self.res_block(x)
        


class TransformerNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            
            ConvBlock(3, 32, (9, 9), 1, 4), #32x256x256
            ConvBlock(32, 64, (3, 3), 2, 1), #64x128x128
            ConvBlock(64, 128, (3, 3), 2, 1), #128x64x64
            
            *[ResBlock(128, (3, 3), 1, 1) for i in range(5)], #128x64x64
            
            nn.Upsample(scale_factor=2), #128x128x128
            ConvBlock(128, 64, (3, 3), 1, 1), #64x128x128
 
            nn.Upsample(scale_factor=2), #64x256x256
            ConvBlock(64, 32, (3, 3), 1, 1), #32x256x256
            
            ConvBlock(32, 3, (9, 9), 1,  4, use_activation=False, use_instance_norm=False), #3x256x256
            
        
        )
        
    def forward(self, x):
        return self.model(x)

In [None]:

def gram_matrix(y):
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t) / (ch * h * w)
    return gram


def normalize_batch(batch):
    # normalize using imagenet mean and std
    mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    batch = batch.div_(255.0)
    return (batch - mean) / std


def tv_loss(img):
    return 0.5 * (torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :]).mean() +
                    (torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1]).mean()))



class Vgg16(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg16 = torchvision.models.vgg16(pretrained=True).features
        
        self.style_layers = ['3', '8', '15', '22']
        self.content_layers = ['15']
        
        
        for param in self.vgg16.parameters():
            param.requires_grad = False
            
    def forward(self, x):
        layer_features = {}
        for name, layer in self.vgg16.named_children():
            x = layer(x)
            if name in self.style_layers + self.content_layers:
                layer_features[name] = x
                
            if name == '22':
                break
                
        return layer_features

In [None]:
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, data, gray=False):
        self.data = data
        self.gray = gray
        
    def __getitem__(self, index):
        image_arr = load_image(self.data.iloc[index][0], gray=self.gray)
        image_arr = torch.from_numpy(image_arr)
        image_arr = image_arr.permute((2,0,1))
#         image_arr = preprcoess_image(image_arr)
#         print(image_arr.dtype)
        return image_arr.float()
    
    def __len__(self):
        return self.data.shape[0]
    
    
def load_image(image_path: str, gray=False):
    '''loads an image using opencv and converts it to RGB isntead of opencv's BGR then returns it.
    
    Parameters:
        image_path (str): path to the image
        
    Returns:
        image_arr (numpy array)
    '''
    image_arr = cv2.imread(image_path)
    if gray:
        image_arr = cv2.cvtColor(image_arr, cv2.COLOR_BGR2GRAY)
        image_arr = np.stack([image_arr, image_arr, image_arr], axis=2)
    else: 
        image_arr = cv2.cvtColor(image_arr, cv2.COLOR_BGR2RGB)
    image_arr = cv2.resize(image_arr, (256, 256))
    
    return image_arr


def denormalize_img(img_arr):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    x = img_arr.cpu().detach()
    x = (x*std) + mean
    return x.permute(1, 2, 0).numpy()

In [None]:
content_weight = 1e4
style_weight = 5e9
tv_weight = 8.5e-5

STYLE_IMAGE = 'scream.jpg'

device = torch.device("cuda")

train_dataset = ImageDataset(images)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4)

transformer = TransformerNet().to(device)
optimizer = torch.optim.Adam(transformer.parameters(), 1e-3)

vgg = Vgg16().to(device)

style = torch.from_numpy(load_image(f'style/{STYLE_IMAGE}.jpg')).permute(2, 0, 1).unsqueeze(0).float()
style = style.repeat(4, 1, 1, 1).to(device)

features_style = vgg(normalize_batch(style))
gram_style = [gram_matrix(y) for y in features_style.values()]


for e in range(1):
    
    transformer.train()

    count = 0
    for batch_id, x in enumerate(train_loader):
        n_batch = len(x)
        optimizer.zero_grad()
        x = x.to(device)
        y = transformer(x)
        y_temp = y.detach().clone()

        y = normalize_batch(y)
        x = normalize_batch(x)

        features_y = vgg(y)
        features_x = vgg(x)
  
        content_loss = content_weight * F.mse_loss(features_y['8'], features_x['8'])
    

        style_loss = 0.

        for ft_y, gm_s in zip(features_y.values(), gram_style):
            gm_y = gram_matrix(ft_y)
            style_loss += F.mse_loss(gm_y, gm_s)
        style_loss *= style_weight
        
    
        
        total_variation_loss = tv_loss(y)
        total_variation_loss *= tv_weight

        total_loss = content_loss + style_loss + total_variation_loss
                
        
        total_loss.backward()
        optimizer.step()

        print(f'{batch_id}, {total_loss.item()}')
        if batch_id%10 ==0 :
            fig, axs = plt.subplots(1, 2, figsize=(10, 7))
            axs[0].imshow(denormalize_img(y[0]))
            axs[1].imshow(denormalize_img(x[0]))
            axs[0].axis('off')
            axs[1].axis('off')
            
            
            plt.show()


In [None]:
torch.save(transformer.state_dict(), f'models/{STYLE_IMAGE}.pth')