### Task: Neural Style Transfer 

#### For this task, I have selected two datasets. First, the dataset of best artworks of all time, and other is the dataset of Images of Dragon Ball Z characters. I will choose a style image from the art dataset and train the CNN network to transfer its style to the anime characters' images.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from PIL import Image
import matplotlib.pyplot as plt
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import cv2

  from .autonotebook import tqdm as notebook_tqdm


### Creating the Dataset

In [2]:
image_size=224

In [3]:
class ContentDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform

        files = os.listdir(folder_path)
        image_files = [file for file in files if file.lower().endswith(('.png', '.jpg', '.jpeg'))]

        self.images = [Image.open(os.path.join(folder_path, image_file)) for image_file in image_files]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]

        if self.transform:
            image = self.transform(image)

        return image

# Define a transformation to be applied to the images
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
]) 

# Create a custom dataset
dataset = ContentDataset(folder_path="C:\\Users\\awast\\Downloads\\archive_3", transform=transform)

In [4]:
dataset_size=len(dataset)
dataset_size

3145

#### Importing Pretrained VGG network

In [5]:
import torchvision.models as models


In [6]:
# print(model)

In [7]:
## we are going to select a few layers to calculate the content and style loss
#layers_list=['0','5','10','19','28']

class VGG(nn.Module):
    def __init__(self):
        super(VGG,self).__init__()
        
        self.model=models.vgg19(pretrained=True).features[:29]
        self.layers_list=['0','5','10','19','28']
        
    def forward(self, x):
        features=[]
        
        for layer_numb,layer in enumerate(self.model):
            x=layer(x)
            if str(layer_numb) in self.layers_list:
                features.append(x)
        
        return features
        

In [8]:
loader = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  (Original Normalization constansts for the VGG network)
    ]
)

def plot_image(tensor):
#     if not torch.jit.is_scripting() and not torch.jit.is_tracing():
#         _log_api_usage_once(save_image)
#     grid = make_grid(tensor, **kwargs)
    # Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer
    ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    im = Image.fromarray(ndarr)
    plt.imshow(im)

def load_image(image):
    image = loader(image).unsqueeze(0)
    return image.to(device)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# image_size = 224



# original_img = dataset[0]
# style_img = load_image(styles[2])

#### I have used an encoder decoder architecture, which will take the original image as the input and output the stylized image as the output

In [9]:
import torch
import torch.nn as nn

class EncodeDecoder(nn.Module):
   
    def __init__(self):
        super(EncodeDecoder, self).__init__()
        self.ConvBlock = nn.Sequential(
            ConvLayer(3, 32, 9, 1),
            nn.ReLU(),
            ConvLayer(32, 64, 3, 2),
            nn.ReLU(),
            ConvLayer(64, 128, 3, 2),
            nn.ReLU()
        )
        self.ResidualBlock = nn.Sequential(
            ResidualLayer(128, 3), 
            ResidualLayer(128, 3), 
            ResidualLayer(128, 3), 
            ResidualLayer(128, 3), 
            ResidualLayer(128, 3)
        )
        self.DeconvBlock = nn.Sequential(
            DeconvLayer(128, 64, 3, 2, 1),
            nn.ReLU(),
            DeconvLayer(64, 32, 3, 2, 1),
            nn.ReLU(),
            ConvLayer(32, 3, 9, 1, norm="None")
        )

    def forward(self, x):
        x = self.ConvBlock(x)
        x = self.ResidualBlock(x)
        out = self.DeconvBlock(x)
        return out


class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, norm="instance"):
        super(ConvLayer, self).__init__()
        # Padding Layers
        padding_size = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(padding_size)

        # Convolution Layer
        self.conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

        # Normalization Layers
        self.norm_type = norm
        if (norm=="instance"):
            self.norm_layer = nn.InstanceNorm2d(out_channels, affine=True)
        elif (norm=="batch"):
            self.norm_layer = nn.BatchNorm2d(out_channels, affine=True)

    def forward(self, x):
        x = self.reflection_pad(x)
        x = self.conv_layer(x)
        if (self.norm_type=="None"):
            out = x
        else:
            out = self.norm_layer(x)
        return out

class ResidualLayer(nn.Module):
    
    def __init__(self, channels=128, kernel_size=3):
        super(ResidualLayer, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size, stride=1)
        self.relu = nn.ReLU()
        self.conv2 = ConvLayer(channels, channels, kernel_size, stride=1)

    def forward(self, x):
        identity = x                     # preserve residual
        out = self.relu(self.conv1(x))   # 1st conv layer + activation
        out = self.conv2(out)            # 2nd conv layer
        out = out + identity             # add residual
        return out

class DeconvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding, norm="instance"):
        super(DeconvLayer, self).__init__()

        # Transposed Convolution 
        padding_size = kernel_size // 2
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding_size, output_padding)

        # Normalization Layers
        self.norm_type = norm
        if (norm=="instance"):
            self.norm_layer = nn.InstanceNorm2d(out_channels, affine=True)
        elif (norm=="batch"):
            self.norm_layer = nn.BatchNorm2d(out_channels, affine=True)

    def forward(self, x):
        x = self.conv_transpose(x)
        if (self.norm_type=="None"):
            out = x
        else:
            out = self.norm_layer(x)
        return out

In [10]:
class VGG16(nn.Module):
    def __init__(self, vgg_path="vgg16-00b39a1b.pth"):
        super(VGG16, self).__init__()
        # Load VGG Skeleton, Pretrained Weights
        vgg16_features = models.vgg16(pretrained=False)
        vgg16_features.load_state_dict(torch.load(vgg_path), strict=False)
        self.features = vgg16_features.features

        # Turn-off Gradient History
        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, x):
        layers = {'3': 'relu1_2', '8': 'relu2_2', '15': 'relu3_3', '22': 'relu4_3'}
        features = {}
        for name, layer in self.features._modules.items():
            x = layer(x)
            if name in layers:
                features[layers[name]] = x
                if (name=='22'):
                    break

        return features

In [11]:
# train_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

#### Calculating Gram Matrix to quantify style loss

In [12]:
# Gram Matrix
def gram(tensor):
    B, C, H, W = tensor.shape
    x = tensor.view(B, C, H*W)
    x_t = x.transpose(1, 2)
    return  torch.bmm(x, x_t) / (C*H*W)

In [13]:

# Show image
def show(img):
    # Convert from BGR to RGB
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # imshow() only accepts float [0,1] or int [0,255]
    img = np.array(img/255).clip(0,1)
    
    plt.figure(figsize=(10, 5))
    plt.imshow(img)
    plt.show()

def saveimg(img, image_path):
    img = img.clip(0, 255)
    cv2.imwrite(image_path, img)

# Preprocessing ~ Image to Tensor
def itot(img, max_size=None):
    # Rescale the image
    if (max_size==None):
        itot_t = transforms.Compose([
            #transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])    
    else:
        H, W, C = img.shape
        image_size = tuple([int((float(max_size) / max([H,W]))*x) for x in [H, W]])
        itot_t = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])

    # Convert image to tensor
    tensor = itot_t(img)

    # Add the batch_size dimension
    tensor = tensor.unsqueeze(dim=0)
    return tensor

# Preprocessing ~ Tensor to Image
def ttoi(tensor):
    
    tensor = tensor.squeeze()
    #img = ttoi_t(tensor)
    img = tensor.cpu().numpy()
    
    # Transpose from [C, H, W] -> [H, W, C]
    img = img.transpose(1, 2, 0)
    return img


In [14]:
# plt.imshow(styles[3])

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import random
import numpy as np
import time


TRAIN_IMAGE_SIZE = 224
DATASET_PATH = "C:\\Users\\awast\\Downloads\\dragon_ball"
NUM_EPOCHS = 5
# STYLE_IMAGE_PATH = "images/mosaic.jpg"
style_datapath="C:\\Users\\awast\\Downloads\\artworks"
files = os.listdir(style_datapath)
image_files = [file for file in files if file.lower().endswith(('.png', '.jpg', '.jpeg'))]

styles = [Image.open(os.path.join(style_datapath, image_file)) for image_file in image_files]


BATCH_SIZE = 4 
CONTENT_WEIGHT = 10
STYLE_WEIGHT = 1
ADAM_LR = 0.001
SAVE_MODEL_PATH = "C:/Users/awast/dashtoon/models/"
SAVE_IMAGE_PATH = "C:/Users/awast/dashtoon/images/out/"
SAVE_MODEL_EVERY = 100 # 2,000 Images with batch size 4
SEED = 35
PLOT_LOSS = 1

def train():
    # Seeds
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    # Device
    device = ("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset and Dataloader
    transform = transforms.Compose([
        transforms.Resize(TRAIN_IMAGE_SIZE),
        transforms.CenterCrop(TRAIN_IMAGE_SIZE),
        transforms.ToTensor(),
#         transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(DATASET_PATH, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    # Load networks
    TransformerNetwork1 = EncodeDecoder().to(device)
    VGG = VGG16().to(device)

    imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).to(device)
    style_image = styles[4]
    style_tensor = itot(style_image).to(device)
    style_tensor = style_tensor.add(imagenet_neg_mean)
    B, C, H, W = style_tensor.shape
    style_features = VGG(style_tensor.expand([BATCH_SIZE, C, H, W]))
    style_gram = {}
    for key, value in style_features.items():
        style_gram[key] = gram(value)

    # Optimizer settings
    optimizer = optim.Adam(TransformerNetwork1.parameters(), lr=ADAM_LR)

    # Loss trackers
    content_loss_history = []
    style_loss_history = []
    total_loss_history = []
    batch_content_loss_sum = 0
    batch_style_loss_sum = 0
    batch_total_loss_sum = 0

    # Optimization/Training Loop
    batch_count = 1
    start_time = time.time()
    for epoch in range(NUM_EPOCHS):
        print("========Epoch {}/{}========".format(epoch+1, NUM_EPOCHS))
        for content_batch, _ in train_loader:
            # Get current batch size in case of odd batch sizes
            curr_batch_size = content_batch.shape[0]

            # Free-up unneeded cuda memory
            torch.cuda.empty_cache()

            # Zero-out Gradients
            optimizer.zero_grad()

            # Generate images and get features
            content_batch = content_batch[:,[2,1,0]].to(device)
            generated_batch = TransformerNetwork1(content_batch)
            content_features = VGG(content_batch.add(imagenet_neg_mean))
            generated_features = VGG(generated_batch.add(imagenet_neg_mean))

            # Content Loss
            MSELoss = nn.MSELoss().to(device)
            content_loss = CONTENT_WEIGHT * MSELoss(generated_features['relu2_2'], content_features['relu2_2'])            
            batch_content_loss_sum += content_loss

            # Style Loss
            style_loss = 0
            for key, value in generated_features.items():
                s_loss = MSELoss(gram(value), style_gram[key][:curr_batch_size])
                style_loss += s_loss
            style_loss *= STYLE_WEIGHT
            batch_style_loss_sum += style_loss.item()

            # Total Loss
            total_loss = content_loss + style_loss
            print(style_loss)
            batch_total_loss_sum += total_loss.item()

            # Backprop and Weight Update
            total_loss.backward()
            optimizer.step()

            # Save Model and Print Losses
            if (((batch_count-1)%SAVE_MODEL_EVERY == 0) or (batch_count==NUM_EPOCHS*len(train_loader))):
                # Print Losses
                print("========Iteration {}/{}========".format(batch_count, NUM_EPOCHS*len(train_loader)))
                print("\tContent Loss:\t{:.2f}".format(batch_content_loss_sum/batch_count))
                print("\tStyle Loss:\t{:.2f}".format(batch_style_loss_sum/batch_count))
                print("\tTotal Loss:\t{:.2f}".format(batch_total_loss_sum/batch_count))
                print("Time elapsed:\t{} seconds".format(time.time()-start_time))

                # Save Model
                checkpoint_path = SAVE_MODEL_PATH + "checkpoint_" + str(batch_count-1) + ".pth"
                torch.save(TransformerNetwork1.state_dict(), checkpoint_path)
                print("Saved TransformerNetwork checkpoint file at {}".format(checkpoint_path))

                # Save sample generated image
                sample_tensor = generated_batch[0].clone().detach().unsqueeze(dim=0)
                sample_image = sample_tensor.clone().detach().squeeze().cpu().numpy().transpose(1,2,0)
                sample_image_path = SAVE_IMAGE_PATH + "sample0_" + str(batch_count-1) + ".png"
#                 np.save(sample_image, sample_image_path)
#                 utils.saveimg(sample_image, sample_image_path)
                sample_image = sample_image.clip(0, 255)
                cv2.imwrite(sample_image_path, sample_image)
                print("Saved sample tranformed image at {}".format(sample_image_path))

                # Save loss histories
                content_loss_history.append(batch_total_loss_sum/batch_count)
                style_loss_history.append(batch_style_loss_sum/batch_count)
                total_loss_history.append(batch_total_loss_sum/batch_count)

            # Iterate Batch Counter
            batch_count+=1


    # Save TransformerNetwork weights
    TransformerNetwork1.eval()
    TransformerNetwork1.cpu()
    final_path = SAVE_MODEL_PATH + "transformer_weight.pth"
    print("Saving TransformerNetwork weights at {}".format(final_path))
    torch.save(TransformerNetwork1.state_dict(), final_path)


In [None]:
train()



tensor(3084150.2500, device='cuda:0', grad_fn=<MulBackward0>)
	Content Loss:	628.96
	Style Loss:	3084150.25
	Total Loss:	3084779.25
Time elapsed:	0.6199982166290283 seconds
Saved TransformerNetwork checkpoint file at C:/Users/awast/dashtoon/models/checkpoint_0.pth
Saved sample tranformed image at C:/Users/awast/dashtoon/images/out/sample0_0.png
tensor(3050835., device='cuda:0', grad_fn=<MulBackward0>)
tensor(3012108.5000, device='cuda:0', grad_fn=<MulBackward0>)
tensor(2969441.5000, device='cuda:0', grad_fn=<MulBackward0>)
tensor(2926210.2500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(2890251.2500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(2844592., device='cuda:0', grad_fn=<MulBackward0>)
tensor(2800557., device='cuda:0', grad_fn=<MulBackward0>)
tensor(2731735.2500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(2700549.2500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(2687741., device='cuda:0', grad_fn=<MulBackward0>)
tensor(2681563., device='cuda:0', grad_fn=<MulBac

tensor(1987772.7500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(2018502.5000, device='cuda:0', grad_fn=<MulBackward0>)
tensor(2025779.6250, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1986368.1250, device='cuda:0', grad_fn=<MulBackward0>)
tensor(2007494., device='cuda:0', grad_fn=<MulBackward0>)
tensor(1965761.5000, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1971101.2500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1951400.8750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1980376.3750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(2012253., device='cuda:0', grad_fn=<MulBackward0>)
tensor(1986196.5000, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1959255.8750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1991935.1250, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1983224.2500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1981140.1250, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1996721.1250, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1967741.5

tensor(1791236.8750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1708350.5000, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1743488.2500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1769300.3750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1748281.1250, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1741592.2500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1777139.2500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1739145.7500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1727909.8750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1735427.7500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1762792., device='cuda:0', grad_fn=<MulBackward0>)
tensor(1740666.5000, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1706101.7500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1718616.7500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1714055.8750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1730870.2500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(17392

tensor(1601839.3750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1592740.1250, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1599640., device='cuda:0', grad_fn=<MulBackward0>)
tensor(1622984.3750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1643923.6250, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1636341.5000, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1595251.8750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1617130.7500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1604843.8750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1639093.1250, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1629182.1250, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1600806.8750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1578248.3750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1607383.2500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1588572.5000, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1596396.3750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(16157

tensor(1571141.7500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1574780.3750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1583520.3750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1562287.2500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1548270.2500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1558366.5000, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1535390.8750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1579666.5000, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1582967.5000, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1550784., device='cuda:0', grad_fn=<MulBackward0>)
tensor(1571886.8750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1571067., device='cuda:0', grad_fn=<MulBackward0>)
tensor(1579177.6250, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1567447.7500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1574515.2500, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1562440.8750, device='cuda:0', grad_fn=<MulBackward0>)
tensor(1578631.5

In [None]:
# train(total_steps=10000,alpha=0.1,beta=0.01,optimizer=optimizer)