In [None]:

# from google.colab import files
# files.upload()
# !mkdir -p ~/.kaggle
# !mv kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json
# !kaggle datasets download -d "ikarus777/best-artworks-of-all-time"
# !kaggle datasets download -dgopalbhattrai/pascal-voc-2012-dataset

In [None]:
# colab
# !unzip best-artworks-of-all-time.zip
# !unzip pascal-voc-2012-dataset.zip
# kaggle
import os
import shutil
src1 = '/kaggle/input/best-artworks-of-all-time'
# src2 = '/kaggle/input/best-artworks-of-all-time/resized' #don't use
src3 = '/kaggle/input/pascal-voc-2012-dataset'
dst = '.'
# for folder in os.listdir(src1):
#     folder_path = os.path.join(src1, folder)
#     if os.path.isdir(folder_path):
#         shutil.copytree(folder_path, os.path.join(dst, folder), dirs_exist_ok=True)
        
# for folder in os.listdir(src2):
#     folder_path = os.path.join(src2, folder)
#     if os.path.isdir(folder_path):
#         shutil.copytree(folder_path, os.path.join(dst, folder), dirs_exist_ok=True)
        
# for folder in os.listdir(src3):
#     folder_path = os.path.join(src3, folder)
#     if os.path.isdir(folder_path):
#         shutil.copytree(folder_path, os.path.join(dst, folder), dirs_exist_ok=True)
# !cp /kaggle/input/best-artworks-of-all-time/artists.csv ./artists.csv

In [3]:
import torch
from torch import nn
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision import transforms
from torchvision import  models
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import matplotlib.pyplot as plt
import numpy as np
import os

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

Using cuda device


In [5]:
class AdaIN(nn.Module):
  def __init__(self,eps):
    super().__init__()
    self.eps = eps
  def forward(self,x,y):
    if x.shape != y.shape:
      raise ValueError(f'x and y must have the same shape x.shape {x.shape} y.shape {y.shape}')

    normal_x = (x - x.mean(axis = (2,3),keepdim=True)) / (x.std(axis = (2,3),keepdim=True) + self.eps)
    y_mean = y.mean(axis = (2,3),keepdim=True) #for each sample and each channel
    y_std = y.std(axis = (2,3),keepdim=True) #for each sample and each channel
    return normal_x * y_std + y_mean
class ContentLoss(nn.Module):
  def __init__(self,encoder):
    super().__init__()
    self.encoder = encoder
  def forward(self,y_gen,t):
    y_gen = self.encoder(y_gen)
    return ((y_gen-t)**2).mean()

class StyleLoss(nn.Module):
  def __init__(self,encoder,to_layer:int):
    super().__init__()
    self.encoder = encoder
    self.layers_range = range(0,to_layer)
  def forward(self,y_gen,y_style):
    loss = 0.0
    for i,layer in enumerate(self.encoder):
      if i in self.layers_range:
        y_gen = layer(y_gen)
        y_style = layer(y_style)
        mean_g = y_gen.mean(dim=(2, 3))
        mean_s = y_style.mean(dim=(2, 3))
        std_g = y_gen.std(dim=(2, 3))
        std_s = y_style.std(dim=(2, 3))
        # loss += ((y_gen.mean(dim = (2,3)) - y_style.mean(dim = (2,3)))**2) + ((y_gen.std(dim = (2,3)) - y_style.std(dim = (2,3)))**2)
        loss += ((mean_g - mean_s) ** 2 + (std_g - std_s) ** 2).sum()
    return loss



In [None]:
# !gdown --fuzzy https://drive.google.com/file/d/11lWUMPPMinaxDBSQKgphFBGOeQHJbHOB/view?usp=sharing
# !mv ./AdaIN.pth ./model.pth

In [6]:
import copy
from collections import OrderedDict

vgg19 = models.vgg19(weights = models.VGG19_Weights.IMAGENET1K_V1)

encoder = copy.deepcopy(vgg19.features[:22])

decoder = nn.Sequential(
    nn.Conv2d(512, 256, 3, 1, 1), nn.ReLU(inplace=False),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=False),
    nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=False),
    nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=False),
    nn.Conv2d(256, 128, 3, 1, 1), nn.ReLU(inplace=False),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=False),
    nn.Conv2d(128, 64, 3, 1, 1), nn.ReLU(inplace=False),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(inplace=False),
    nn.Conv2d(64, 3, 3, 1, 1),
)
ada = AdaIN(.000001)

class AdaINModel(nn.Module):
  """
  Note freeze encoder before starting
  """
  def __init__(self,encoder,decoder,ada):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.ada = ada
  def forward(self,x_content,x_style):
    x_content = self.encoder(x_content)
    x_style = self.encoder(x_style)
    ada_out = self.ada(x_content,x_style)
    x_gen = self.decoder(ada_out)

    return {
        'x_gen':x_gen,
        'ada_out':ada_out,
    }

c_loss,s_loss = ContentLoss(encoder),StyleLoss(encoder,to_layer=len(encoder))
model = AdaINModel(encoder,decoder,ada)
for i, layer in enumerate(model.encoder):
    if isinstance(layer, nn.ReLU):
        model.encoder[i] = nn.ReLU(inplace=False)

model = torch.compile(model) # u forget this idiot
model.load_state_dict(torch.load('model_weightsV5.pth'))
c_loss,s_loss = torch.compile(c_loss),torch.compile(s_loss)

c_loss.to(device)
s_loss.to(device)
model.to(device)


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:08<00:00, 64.6MB/s] 


OptimizedModule(
  (_orig_mod): AdaINModel(
    (encoder): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU()
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU()
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU()
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU()
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU()
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU()
      (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)

In [None]:
import pandas as pd
df = pd.read_csv('./artists.csv')

In [None]:
df.hist('paintings',bins=50)

In [7]:
import random
random.seed(42)
style_paths = {}
for root,folders,files in os.walk('resized/resized'):
  for file in files:
    artist = file.rindex('_')
    artist = file[:artist]
    if file not in style_paths:
      style_paths[artist] = style_paths.get(artist,[])+[os.path.join(root,file)]

max_images = 250
li = []
for artist,images in style_paths.items():
  if len(images) > max_images:
    images = random.sample(images,max_images)
  li += images

style_images = li
random.shuffle(style_images)
len(style_images)

7045

In [8]:
train_image_paths = []
for root,folders,files in os.walk('./VOC2012_train_val/VOC2012_train_val/JPEGImages'):
  for file in files:
    train_image_paths.append(os.path.join(root,file))
random.shuffle(train_image_paths)
val_image_paths = train_image_paths[:1000]
train_image_paths = train_image_paths[1000:]
len(train_image_paths)

16125

In [9]:
from torchvision.transforms import ToTensor,Resize,Compose,RandomCrop,Lambda
from PIL import Image
import copy

class ImagesDataset(Dataset):
  def __init__(self,image_paths,transform=None):
    self.image_paths = image_paths
    self.transform = transform
  def __len__(self):
    return len(self.image_paths)

  def __getitem__(self, index):
    image = self.image_paths[index]
    image = Image.open(image).convert("RGB")
    if self.transform:
      image = Compose(self.transform)(image)
    return image

def resizeWithAspectRatio(image):
  # print('shape',image.shape)
  width,height = image.shape[-1:-3:-1]
  short_d = -1 if width < height else -2
  long_d = -2 if short_d == -1 else -1
  # print(f'short {short_d}, long {long_d}')
  ratio = image.shape[long_d]/image.shape[short_d]
  long_new_size = int(ratio*image.shape[long_d])
  short_new_size = 256
  if long_new_size < 256:
    long_new_size = 256
    short_new_size = 256 # todo keep aspect ratio
  new_shape = list(image.shape[-2::1])
  new_shape[long_d] = long_new_size
  new_shape[short_d] = short_new_size
  # print(f'newshape {new_shape}')

  return Resize(new_shape)(image)
transform = [
    ToTensor(),
    Lambda(resizeWithAspectRatio),
    RandomCrop((224,224)),
]

style_dataset = ImagesDataset(style_images,transform)
train_dataset = ImagesDataset(train_image_paths,transform)
sample = ImagesDataset(train_image_paths[:1000],transform)
val_dataset = ImagesDataset(val_image_paths,transform)



In [10]:
batch_size = 32
epochs = 50

train_data = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,drop_last=True)
trainSample_data = DataLoader(sample,batch_size=batch_size,shuffle=True,drop_last=True)
val_data = DataLoader(val_dataset,batch_size=batch_size,shuffle=False,drop_last=True)
style_data = DataLoader(style_dataset,batch_size=batch_size,shuffle=True,drop_last=True)

In [None]:
optmizer = torch.optim.Adam(model.parameters(),lr=1e-3,weight_decay=.01)
scheduler = torch.optim.lr_scheduler.LinearLR(optmizer,.1,.001,total_iters= epochs*len(train_data))


In [None]:
# pip install line-profiler

In [11]:
from line_profiler import LineProfiler
from queue import Queue
from accelerate import Accelerator

accelerator = Accelerator()

def train(data,y_styles,alpha,val_data = None,evaluate_every = None,accum_loss = 1):
  import tqdm
  model.train()
  model.encoder.requires_grad_(False)
  model.ada.requires_grad_(False)
  # loop = tqdm.tqdm(range(epochs))
  styles_iter = iter(y_styles)
  evaluate_every = evaluate_every or len(data)
  for epoch in range(epochs):
    loop = tqdm.tqdm(data, desc=f"Epoch {epoch+1}/{epochs}", unit="batch")
    con_loss,sty_loss,eloss = '','',''
    for i,images in enumerate(loop):
      try:

        y_style = next(styles_iter)
        images = images.to(device)
        y_style = y_style.to(device)

        y = model(images,y_style)
        y_gen,ada_out = y['x_gen'],y['ada_out']

        content_loss = c_loss(y_gen,ada_out)/accum_loss
        style_loss = s_loss(y_gen,y_style)/accum_loss
        loss = content_loss + alpha*style_loss
        loss.backward()

        if i % accum_loss == 0:
            optmizer.step()
            if scheduler:
                for i in range(accum_loss):
                    scheduler.step()
            optmizer.zero_grad()
            
        if val_data and ((i+1) % evaluate_every == 0):
          con_loss,sty_loss,eloss = evaluate(val_data,y_styles,alpha)

        # loop.set_description(f'Epoch {epoch}/{epochs} Batch {i/len(data)}')
        loop.set_postfix(content_loss = content_loss.item(),style_loss = f'{style_loss.item()*alpha}/ {style_loss.item()}',loss=loss.item(),val_loss = f'{con_loss},{sty_loss},{eloss}')
        # loop.set_postfix(content_loss = sum(closs_queue.queue)/queue_size,style_loss = sum(sloss_queue.queue)/queue_size,loss= sum(loss_queue.queue)/queue_size)
      except StopIteration:
        styles_iter = iter(y_styles)
        continue
def evaluate(data,y_styles,alpha):
  model.eval()
  loss = 0.0
  con_loss = 0.0
  sty_loss = 0.0
  styles_iter = iter(y_styles)
  with torch.no_grad():

    for i,images in enumerate(data):
      try:
        y_style = next(styles_iter)
        images = images.to(device)
        y_style = y_style.to(device)

        y = model(images,y_style)
        y_gen,ada_out = y['x_gen'],y['ada_out']
        content_loss = c_loss(y_gen,ada_out)
        style_loss = s_loss(y_gen,y_style)
        loss += content_loss.item() + alpha*style_loss.item()
        con_loss += content_loss.item()
        sty_loss += style_loss.item()

      except StopIteration:
        styles_iter = iter(y_styles)
        continue
    loss = loss/len(data)
    con_loss = con_loss/len(data)
    sty_loss = sty_loss/len(data)
    print(f'content loss {con_loss}, style loss {sty_loss}, total loss {loss}')
    return con_loss,sty_loss,loss



In [None]:
# #warmup
def warmup_lr(step):
    total_steps = epochs * len(train_data)
    return step / total_steps
epochs = 5
optmizer = torch.optim.Adam(model.parameters(),lr=1e-3,weight_decay=0.01)

# scheduler = torch.optim.lr_scheduler.LambdaLR(optmizer,lr_lambda=warmup_lr)
scheduler = torch.optim.lr_scheduler.LinearLR(optmizer,.001,1,total_iters= epochs*len(train_data))


train(train_data,style_data,10,val_data,len(train_data)//8,2)


Epoch 1/50:   0%|          | 0/503 [00:00<?, ?batch/s]W0930 21:28:27.902000 36 torch/_inductor/utils.py:1137] [0/0] Not enough SMs to use max_autotune_gemm mode
Epoch 1/50:  12%|█▏        | 62/503 [04:43<3:26:08, 28.05s/batch, content_loss=7.89, loss=1.48e+5, style_loss=148051.5625/ 14805.15625, val_loss=17.28585756978681,34221.23708417339,342229.65669930365]

content loss 17.28585756978681, style loss 34221.23708417339, total loss 342229.65669930365


Epoch 1/50:  25%|██▍       | 124/503 [07:39<1:42:54, 16.29s/batch, content_loss=8.84, loss=1.74e+5, style_loss=174011.484375/ 17401.1484375, val_loss=16.99770579799529,34346.89598034274,343485.9575092254]     

content loss 16.99770579799529, style loss 34346.89598034274, total loss 343485.9575092254


Epoch 1/50:  37%|███▋      | 186/503 [10:34<1:26:02, 16.29s/batch, content_loss=9.62, loss=1.86e+5, style_loss=185580.8984375/ 18558.08984375, val_loss=16.977176050986014,33164.559475806454,331662.5719341155]

content loss 16.977176050986014, style loss 33164.559475806454, total loss 331662.5719341155


Epoch 1/50:  49%|████▉     | 248/503 [13:27<1:09:12, 16.28s/batch, content_loss=8.25, loss=1.52e+5, style_loss=151761.689453125/ 15176.1689453125, val_loss=17.125981576981083,34162.474609375,341641.872075327]  

content loss 17.125981576981083, style loss 34162.474609375, total loss 341641.872075327


Epoch 1/50:  62%|██████▏   | 310/503 [16:21<52:04, 16.19s/batch, content_loss=7.86, loss=1.46e+5, style_loss=145918.154296875/ 14591.8154296875, val_loss=17.064253622485744,33565.509135584674,335672.1556094693]

content loss 17.064253622485744, style loss 33565.509135584674, total loss 335672.1556094693


Epoch 1/50:  74%|███████▍  | 372/503 [19:16<35:28, 16.25s/batch, content_loss=9.05, loss=1.75e+5, style_loss=175228.53515625/ 17522.853515625, val_loss=17.318275882351784,33901.33083417339,339030.6266176162]   

content loss 17.318275882351784, style loss 33901.33083417339, total loss 339030.6266176162


Epoch 1/50:  86%|████████▋ | 434/503 [22:11<18:41, 16.25s/batch, content_loss=8.1, loss=1.45e+5, style_loss=145053.02734375/ 14505.302734375, val_loss=16.72066786981398,31760.455708165322,317621.27774952305]  

content loss 16.72066786981398, style loss 31760.455708165322, total loss 317621.27774952305


Epoch 1/50:  99%|█████████▊| 496/503 [25:04<01:53, 16.18s/batch, content_loss=8.66, loss=1.6e+5, style_loss=159773.486328125/ 15977.3486328125, val_loss=17.60108200196297,34319.155367943546,343209.15476143744] 

content loss 17.60108200196297, style loss 34319.155367943546, total loss 343209.15476143744


Epoch 1/50: 100%|██████████| 503/503 [25:18<00:00,  3.02s/batch, content_loss=9.31, loss=2.15e+5, style_loss=215447.96875/ 21544.796875, val_loss=17.60108200196297,34319.155367943546,343209.15476143744]        
Epoch 2/50:  12%|█▏        | 62/503 [02:54<1:58:24, 16.11s/batch, content_loss=8.38, loss=1.55e+5, style_loss=155468.45703125/ 15546.845703125, val_loss=17.071331147224672,33356.63942792339,333583.4656103811]

content loss 17.071331147224672, style loss 33356.63942792339, total loss 333583.4656103811


Epoch 2/50:  24%|██▎       | 119/503 [04:52<13:24,  2.10s/batch, content_loss=8.94, loss=1.79e+5, style_loss=178660.0/ 17866.0, val_loss=17.071331147224672,33356.63942792339,333583.4656103811]                  

In [None]:
#train schedula
epochs = 5
# optmizer = torch.optim.Adam(model.parameters(),lr=1e-3,weight_decay=0.005)
# scheduler = torch.optim.lr_scheduler.LinearLR(optmizer,.1,.001,total_iters= epochs*len(train_data))
train(train_data,style_data,10,val_data,len(train_data)//8,2)


In [None]:
 evaluate(val_data,style_data,18)


In [None]:
torch.save(model.state_dict(), "model_weightsV5.pth")

In [None]:
from google.colab import drive
drive.mount('./drive')

In [None]:
import shutil

src = '/content/model_weights.pth'
dst = '/content/drive/My Drive/models/AdaIN.pth'

shutil.copy(src, dst)
print("File uploaded successfully!")


In [None]:
import torch

def crop_to_same_size(img1: torch.Tensor, img2: torch.Tensor):
    """
    Crop two images to have exactly the same size (largest crop possible).
    Assumes input images are tensors of shape [1, C, H, W].
    Crops centered.
    """
    _, _, h1, w1 = img1.shape
    _, _, h2, w2 = img2.shape

    target_h = min(h1, h2)
    target_w = min(w1, w2)

    def center_crop(img, target_h, target_w):
        _, _, h, w = img.shape
        top = (h - target_h) // 2
        left = (w - target_w) // 2
        return img[:, :, top:top+target_h, left:left+target_w]

    img1_cropped = center_crop(img1, target_h, target_w)
    img2_cropped = center_crop(img2, target_h, target_w)

    return img1_cropped, img2_cropped

# Usage



In [None]:
model.to(device)
content = ToTensor()(Image.open('./content.jpg').convert("RGB"))
style = ToTensor()(Image.open('./style.jpg').convert("RGB"))
orignal_image = copy.deepcopy(content).permute(1, 2, 0).numpy()
content = torch.unsqueeze(content,0)
style = torch.unsqueeze(style,0)
content_cropped, style_cropped = crop_to_same_size(content, style)
content = content_cropped.to(device)
style = style_cropped.to(device)
y_gen = model(content, style)

In [None]:
print(torch.cuda.memory_allocated() / 1024**2, "MB allocated")
print(torch.cuda.memory_reserved() / 1024**2, "MB reserved")


In [None]:
import gc
# optmizer.zero_grad(set_to_none=True)
gc.collect()

torch.cuda.empty_cache()


In [None]:
import torch
from torchvision.transforms import ToTensor, Resize, CenterCrop
from PIL import Image
import glob
model.eval()
model.to(device)
# Load content image
content = ToTensor()(Image.open('/kaggle/input/test-data/content.jpg').convert("RGB"))
content_c, content_h, content_w = content.shape
content = content.unsqueeze(0)  # [1, C, H, W]

# Load style images
style_files = sorted(glob.glob('/kaggle/input/test-data/style*.jpg'))
style_list = []

for f in style_files:
    img = Image.open(f).convert("RGB")
    style_h, style_w = img.size[1], img.size[0]

    # If style smaller than content, resize to content size (ignore aspect ratio)
    if style_h < content_h or style_w < content_w:
        img = img.resize((content_w, content_h))
    # If style larger than content, center crop
    elif style_h > content_h or style_w > content_w:
        img = CenterCrop((content_h, content_w))(img)

    style_list.append(ToTensor()(img))

# Stack into batch
style_batch = torch.stack(style_list, dim=0)  # [n, C, H, W]

# Broadcast content to match batch size
content_batch = content.repeat(style_batch.size(0), 1, 1, 1)

# Move to device
style_iter = iter(style_data)
content_batch = next( iter(train_data) )
style_batch = next(style_iter)

style_batch = style_batch.to(device)
content_batch = content_batch.to(device)

# Generate stylized images
y = None
content_loss,style_loss = None, None
with torch.no_grad():
    y = model(content_batch, style_batch)
    y_gen,ada_out = y['x_gen'],y['ada_out']
    
content_loss = c_loss(y_gen,ada_out)
style_loss = s_loss(y_gen,style_batch)
    

In [None]:
content_loss,style_loss

In [None]:
n = 16
plt.figure(figsize=(15, 5 * n))  # Wider for side-by-side

for i in range(n):
    # Content image
    content_img = content_batch[i].permute(1, 2, 0).detach().cpu().numpy()
    content_img = (content_img * 255).astype(np.uint8)

    # Style image
    style_img = style_batch[i].permute(1, 2, 0).detach().cpu().numpy()
    style_img = (style_img * 255).astype(np.uint8)

    # Generated image
    gen_img = y_gen[i].permute(1, 2, 0).detach().cpu().numpy()
    gen_img = (gen_img * 255).astype(np.uint8)

    # Plot row: content | style | generated
    plt.subplot(n, 3, i * 3 + 1)
    plt.imshow(content_img)
    plt.title("Content")
    plt.axis("off")

    plt.subplot(n, 3, i * 3 + 2)
    plt.imshow(style_img)
    plt.title("Style")
    plt.axis("off")

    plt.subplot(n, 3, i * 3 + 3)
    plt.imshow(gen_img)
    plt.title("Generated")
    plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
import math
plt.figure(figsize=(20,20))
plt.subplot(math.ceil(len(style_files)/4)+1,2,1)
# plt.imshow(original_image)

for i in range(2,len(content_batch)+1):
  image = y_gen[i-1].permute(1, 2, 0).detach().cpu().numpy()
  image = (image*255).astype(np.uint8)
  plt.subplot(math.ceil(len(style_files)/4)+1,2,i)
  plt.imshow(image)

In [None]:
math.ceil(len(style_files)/4)