In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import zipfile
import os
import shutil
from pathlib import Path
import torch.nn as nn

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Implementing Paper Image Style Transfer

This code is for reacreating the [Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization](https://arxiv.org/abs/1703.06868)


I will try to re-implement the things they followed with the Painter by Numbers Dataset.

Later in the same code I will try distributed training

One more Interesting thing in this paper is I dont need a ground truth. All we need is a bunch Image Pairs

In [None]:
img_size = 512  # Resize all images to this size
batch_size = 8
ori_lr = 1e-4
max_iter = 70000
content_weight = 1.0
style_weight = 10.0
save_iter = 10000
lr_decay=5e-5

content_zip = "/kaggle/input/painter-by-numbers/train_1.zip"
style_zip = "/kaggle/input/painter-by-numbers/test.zip"

content_dir_name = "content"
style_dir_name = "style"

content_dir = r"/kaggle/input/coco-2014-dataset-for-yolov3/coco2014/images/test2014"
style_dir = os.path.join(os.getcwd(), style_dir_name)
save_dir=Path(Path.cwd() / 'save')

In [None]:
#Get the images out of zip file

def extract_images_from_zip(zip_path,output_dir):
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Open and extract the ZIP file
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        # Get list of all files in ZIP
        file_list = zip_ref.namelist() 

        # Filter image files (common image extensions)
        image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp')
        image_files = [f for f in file_list if f.lower().endswith(image_extensions)]
        
        # Extract only image files
        for image_file in image_files:
            zip_ref.extract(image_file, output_dir)
            extracted_path = os.path.join(output_dir, image_file)
            final_path = os.path.join(output_dir, os.path.basename(image_file))
            os.rename(extracted_path, final_path)

    print(f"\nExtraction complete! Images saved to: {output_dir}")

# extract_images_from_zip(content_zip,content_dir_name)
extract_images_from_zip(style_zip,style_dir_name)

In [None]:
#transforms
from torchvision import transforms
def train_transform():
    transform_list = [
        transforms.Resize(size=(512, 512)),
        transforms.RandomCrop(256),
        transforms.ToTensor()
    ]
    return transforms.Compose(transform_list)
content_tf=train_transform()
style_tf=train_transform()

In [None]:
#first we need to process the Dataset using torch.utils.Dataset
import glob
import torch
from torch.utils.data import Dataset
from PIL import Image

class Style_Dataset(Dataset):
    def __init__(self, img_dir, transform):
        self.transform = transform
        self.img_files = [f for f in glob.glob(f"{img_dir}/*") if os.path.isfile(f)]  # Only use valid image files
        img_path="/kaggle/input/coco-2014-dataset-for-yolov3/coco2014/images/train2014/COCO_train2014_000000000009.jpg"
        self.extra_image=Image.open(img_path).convert("RGB")
        self.extra_image=self.transform(self.extra_image)
    def __getitem__(self, idx):
        img_path = self.img_files[idx]
        try:
            with Image.open(img_path) as img:
                if img.size[0] * img.size[1] > 89_478_485:  # Pillow's safety limit
                    print(f"Skipping large image: {img_path}, size: {img.size}")
                    return self.extra_image
                img = img.convert("RGB")
                img = self.transform(img)
                return img
        except Exception as e:
            print(f"Skipping corrupt file: {img_path}, Error: {e}")
            return self.extra_image

    def __len__(self):
        return len(self.img_files)

In [None]:
import torch

def collate_fn(batch):
    batch = [img for img in batch if img is not None]  # Remove None values
    if len(batch) == 0:  
        return None  # Handle empty batch case
    return torch.stack(batch)  # Stack images into a tensor batch

In [None]:
#testing the dataset
import matplotlib.pyplot as plt
content_dataset = Style_Dataset(content_dir,content_tf)
style_dataset = Style_Dataset(style_dir,style_tf)
content_img= content_dataset[0]
content_img = content_img.permute(1, 2, 0).numpy() 
style_img=style_dataset[0]
style_img=style_img.permute(1, 2, 0).numpy()
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(content_img)
plt.title("Content Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(style_img)
plt.title("Style Image")
plt.axis("off")

In [None]:
len(content_dataset)

In [None]:
#Infinite Sampler which will be used in the dataloader
from torch.utils.data.sampler import Sampler

# An Infinite Sampler is a type of PyTorch Sampler that 
# provides an endless stream of shuffled data indices. Unlike
# standard samplers that stop after one full pass through the dataset
# ,an infinite sampler continuously reshuffles and reuses the dataset
# without explicitly restarting an epoch.

def InfiniteSampler(n):
    order = np.random.permutation(n)
    i = 0
    while True:
        yield order[i]
        i += 1
        if i >= n:
            order = np.random.permutation(n)
            i = 0

class InfiniteSamplerWrapper(Sampler):
    def __init__(self, data_source):
        self.num_samples = len(data_source)

    def __iter__(self):
        return iter(InfiniteSampler(self.num_samples))

    def __len__(self):
        return 2**31  # Large value to indicate infinity


In [None]:
#dataloader
from torch.utils.data import DataLoader
# we need two loaders one for content and other for style 
content_iter=iter(DataLoader(content_dataset,
                          batch_size=batch_size,
                          sampler=InfiniteSamplerWrapper(content_dataset),
                            collate_fn=collate_fn))
style_iter=iter(DataLoader(style_dataset,
                        batch_size=batch_size,
                        sampler=InfiniteSamplerWrapper(style_dataset),
                          collate_fn=collate_fn))

In [None]:
#Adaptive Instance Normalization
def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std


def adain(content_feat, style_feat):
    assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)

In [None]:
#vgg-16 Structure
vgg = nn.Sequential(
    nn.Conv2d(3, 3, (1, 1)),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(3, 64, (3, 3)),
    nn.ReLU(),  # relu1-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),  # relu1-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 128, (3, 3)),
    nn.ReLU(),  # relu2-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),  # relu2-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 256, (3, 3)),
    nn.ReLU(),  # relu3-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 512, (3, 3)),
    nn.ReLU(),  # relu4-1, this is the last layer used
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU()  # relu5-4
)
vgg.load_state_dict(torch.load("/kaggle/input/vgg_normalised/pytorch/default/1/vgg_normalised.pth"))
vgg = nn.Sequential(*list(vgg.children())[:31])

In [None]:
decoder = nn.Sequential(
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 256, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 128, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 64, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 3, (3, 3)),
)
# decoder.load_state_dict(torch.load("/kaggle/input/decoder_for_image_stle_transfer/pytorch/default/1/decoder.pth"))

In [None]:
#Entire model with encoder and decoder
import torch.optim.swa_utils as swa_utils

class network(nn.Module):
    def __init__(self, encoder, decoder):
        super(network, self).__init__()
        enc_layers = list(encoder.children())
        self.enc_1 = nn.Sequential(*enc_layers[:4])  # input -> relu1_1
        self.enc_2 = nn.Sequential(*enc_layers[4:11])  # relu1_1 -> relu2_1
        self.enc_3 = nn.Sequential(*enc_layers[11:18])  # relu2_1 -> relu3_1
        self.enc_4 = nn.Sequential(*enc_layers[18:31])  # relu3_1 -> relu4_1
        self.decoder = decoder
        self.mse_loss = nn.MSELoss()

        # fix the encoder
        for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
            for param in getattr(self, name).parameters():
                param.requires_grad = False

    # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
    def encode_with_intermediate(self, input):
        results = [input]
        for i in range(4):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
        return results[1:]

    # extract relu4_1 from input image
    def encode(self, input):
        for i in range(4):
            input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
        return input

    def calc_content_loss(self, input, target):
        assert (input.size() == target.size())
        assert (target.requires_grad is False)
        return self.mse_loss(input, target)

    def calc_style_loss(self, input, target):
        assert (input.size() == target.size())
        assert (target.requires_grad is False)
        input_mean, input_std = calc_mean_std(input)
        target_mean, target_std = calc_mean_std(target)
        return self.mse_loss(input_mean, target_mean) + \
               self.mse_loss(input_std, target_std)

    def forward(self, content, style, alpha=1.0):
        assert 0 <= alpha <= 1
        style_feats = self.encode_with_intermediate(style)
        content_feat = self.encode(content)
        t = adain(content_feat, style_feats[-1])
        t = alpha * t + (1 - alpha) * content_feat

        g_t = self.decoder(t)
        g_t_feats = self.encode_with_intermediate(g_t)

        loss_c = self.calc_content_loss(g_t_feats[-1], t)
        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
        for i in range(1, 4):
            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
        return loss_c, loss_s

In [None]:
#initialize the model
decoder = decoder.to(device)
vgg=vgg.to(device)
model=network(vgg,decoder).to(device)

In [None]:
#optimizer
import torch.optim.lr_scheduler as lr_scheduler

optimizer = torch.optim.Adam(model.decoder.parameters(), lr=ori_lr)
def adjust_learning_rate(optimizer, iteration_count):
    """Imitating the original implementation"""
    lr = ori_lr / (1.0 + lr_decay * iteration_count)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
# Store losses for visualization
content_losses = []
style_losses = []
total_losses = []
iterations = []

In [None]:
#Tried to use Gradient Clipping Since the loss is oscillating too much
def adaptive_clip_grad(parameters, clip_factor=0.01, epsilon=1e-3):
    for p in parameters:
        if p.grad is None:
            continue
        param_norm = torch.norm(p, p=2).clamp(min=epsilon)  # Avoid div by zero
        grad_norm = torch.norm(p.grad, p=2)
        clip_val = clip_factor * param_norm
        p.grad *= (clip_val / grad_norm).clamp(max=1.0)  # Scale if needed

In [None]:
import torch
import random
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

def tensor_to_image(tensor):
    # Move tensor to CPU and convert to NumPy
    image = tensor.detach().clone().cpu().numpy()

    # Reshape the image to (C, H, W) from (1, C, H, W)
    image = image.squeeze(0)

    # Denormalize the image
    image = image.transpose(1, 2, 0)
    return image
    
def save_one_test(model,i):
    # Define preprocessing transforms
    test_tf = transforms.Compose([
        transforms.Resize((256, 256)),  # Resize to (512, 512)
        transforms.ToTensor(),  # Convert image to tensor
    ])
    
    
    
    # Select two different random images
    content_idx = random.randint(0, len(content_dataset) - 1)
    style_idx = random.randint(0, len(style_dataset) - 1)
    
    
    content_image = content_dataset[content_idx]  # Get content image
    style_image = style_dataset[style_idx]  # Get style image
    
    # Move images to device (CPU/GPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    content_image = content_image.unsqueeze(0).to(device)  # Add batch dimension
    style_image = style_image.unsqueeze(0).to(device)
    
    # Perform inference with no gradient computation
    with torch.no_grad():
        output = model.decoder(adain(model.encode(content_image), model.encode(style_image)))
        
    # Convert tensors to images
    content_img = tensor_to_image(content_image)
    style_img = tensor_to_image(style_image)
    output_img = tensor_to_image(output)
    
    # Display all images side by side
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.imshow(content_img)
    plt.title("Content Image")
    plt.axis("off")
    
    plt.subplot(1, 3, 2)
    plt.imshow(style_img)
    plt.title("Style Image")
    plt.axis("off")
    
    plt.subplot(1, 3, 3)
    plt.imshow(output_img)
    plt.title("Stylized Output")
    plt.axis("off")
    
    plt.savefig(f"output{i}.png", dpi=300, bbox_inches='tight')
    plt.close()
save_one_test(model,1)

In [None]:
#train loop
from tqdm import tqdm
model.load_state_dict(torch.load("/kaggle/input/trained_1.3l/pytorch/default/1/network_130000.pth"))
model.train()
for i in range(max_iter+1):
    optimizer.zero_grad()
    i=i+130000
    adjust_learning_rate(optimizer, iteration_count=i)
    content_image=next(content_iter).to(device)
    style_image=next(style_iter).to(device)
    content_loss,style_loss=model(content_image,style_image)
    loss=content_weight*content_loss+style_weight*style_loss
    loss.backward()
    # adaptive_clip_grad(model.decoder.parameters())
    optimizer.step()
    # model.update_ema()
    
    if i%50==0:
        # Store loss values
        content_losses.append(content_loss.item())
        style_losses.append(style_loss.item())
        total_losses.append(loss.item())
        iterations.append(i)
    
    if i%100==0:
        print(f'Iteration: {i}, Content Loss: {content_loss.item()}, Style Loss: {style_loss.item()}')
        save_one_test(model,i)
    if i%save_iter==0 or i+1==max_iter:
        torch.save(model.state_dict(), f'network_{i}.pth')



In [None]:
# Set model to evaluation mode
model.eval()
def show_one_test(model,i):
    # Define preprocessing transforms
    test_tf = transforms.Compose([
        transforms.Resize((256, 256)),  # Resize to (512, 512)
        transforms.ToTensor(),  # Convert image to tensor
    ])
    
    
    
    # Select two different random images
    content_idx = random.randint(0, len(content_dataset) - 1)
    style_idx = random.randint(0, len(style_dataset) - 1)
    
    
    content_image = content_dataset[content_idx]  # Get content image
    style_image = style_dataset[style_idx]  # Get style image
    
    # Move images to device (CPU/GPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    content_image = content_image.unsqueeze(0).to(device)  # Add batch dimension
    style_image = style_image.unsqueeze(0).to(device)
    
    # Perform inference with no gradient computation
    with torch.no_grad():
        output = model.decoder(adain(model.encode(content_image), model.encode(style_image)))
        
    # Convert tensors to images
    content_img = tensor_to_image(content_image)
    style_img = tensor_to_image(style_image)
    output_img = tensor_to_image(output)
    
    # Display all images side by side
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.imshow(content_img)
    plt.title("Content Image")
    plt.axis("off")
    
    plt.subplot(1, 3, 2)
    plt.imshow(style_img)
    plt.title("Style Image")
    plt.axis("off")
    
    plt.subplot(1, 3, 3)
    plt.imshow(output_img)
    plt.title("Stylized Output")
    plt.axis("off")
    
    plt.show()
    plt.close()
show_one_test(model,-1)

In [None]:
#plotting the loss function

# Plot Loss Graph
plt.figure(figsize=(10, 5))
plt.plot(iterations, content_losses, label="Content Loss", color='blue')
plt.plot(iterations, style_losses, label="Style Loss", color='red')
plt.plot(iterations, total_losses, label="Total Loss", color='green')
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.title("Training Losses Over Iterations")
plt.legend()
plt.grid()
plt.show()