# Style Transfer with Vision Transformer
Original Paper: https://arxiv.org/abs/2105.14576

### Environment Setup

In [3]:
# Import the drive module from Google Colab to access files stored in Google Drive
from google.colab import drive
import sys
drive.mount('/content/drive')  # Mount Google Drive

# Navigate to project folder (make sure you put the correct path here!)
%cd drive/MyDrive/Deep_Learning/StyleTransfer_ViT
sys.path.append("drive/MyDrive/Deep_Learning/StyleTransfer_ViT")  # Add to Python path--for sanity
%ls  # List folder contents

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[Errno 2] No such file or directory: 'drive/MyDrive/Deep_Learning/StyleTransfer_ViT'
/content/drive/MyDrive/Deep_Learning/StyleTransfer_ViT
[0m[01;34mdata[0m/            [01;34mmodel[0m/        README.md
LICENSE          [01;34mpredictions[0m/  training_pipeline.ipynb
[01;34mlightning_logs[0m/  [01;34mraw_data[0m/     visualizing_results.ipynb


In [4]:
%%capture
# Install PyTorch Requirements (Colab only)
!pip install torch==2.5.1 torchvision==0.20.1 pytorch-lightning==2.0.2

# Standard imports
import os, glob, gc, random
from os.path import join
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

# PyTorch core
import torch

# PyTorch Lightning for easier training loops
import pytorch_lightning as pl  # High-level PyTorch interface
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint  # Useful training callbacks

# Progress bar
from tqdm.notebook import tqdm  # Pretty progress bars in notebooks

### Data Module Setup

In [5]:
## UPDATE THE BELOW PARAMETERS AS NEEDED ##
# Data dir
STYLE_PATH = "raw_data/style"
CONTENT_PATH = "raw_data/content"

# Setting up data
NUM_STYLES_PER_IMAGE = 10
TVT_SPLIT = [0.6, 0.3, 0.1]

# Dataloader
BATCH_SIZE = 8  # 8, as per paper. change accumulate gradients accordingly
NUM_WORKERS = 6
PREFETCH_FACTOR = 3
PIN_MEMORY = True

In [6]:
from data.data_module import StyleTransferDM

# Data Module Instance
data_module = StyleTransferDM(
    STYLE_PATH,
    CONTENT_PATH,
    num_styles_per_image=NUM_STYLES_PER_IMAGE,
    train_val_test_split=TVT_SPLIT,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    prefetch_factor=PREFETCH_FACTOR,
    pin_memory=PIN_MEMORY
)

### Model Setup

In [7]:
## UPDATE THE BELOW PARAMETERS AS NEEDED ##
# Universal
D_MODEL = 512

# VGG
IMG_HEIGHT = 224
IMG_WIDTH = 224

# Patching
N = 8
S = 1
PATCH_SIZE = 8

# Content Encoder
N_HEAD_CONTENT = 8  # as per GH
DIM_FEEDFORWARD_CONTENT = 2048  # as per GH
DROPOUT_CONTENT = 0.1  # as per GH
N_LAYERS_CONTENT = 3  # as per GH

# Style Encoder
N_HEAD_STYLE = 8  # as per GH
DIM_FEEDFORWARD_STYLE = 2048  # as per GH
DROPOUT_STYLE = 0.1  # as per GH
N_LAYERS_STYLE = 3  # as per GH

# Decoder
N_HEAD_DEC = 8  # as per GH
DIM_FEEDFORWARD_DEC = 2048  # as per GH
DROPOUT_DEC = 0.1  # as per GH
N_LAYERS_DEC = 3  # as per GH

# Optimization
LR = 0.0005  # as per paper
LAMBDAS = [10, 7, 70, 1]  # as per paper: [10, 7, 50, 1]; as per GH: [10, 7, 70, 1]
LR_PATIENCE = 2
LR_DECAY = 0.5
BETAS = (0.9, 0.999)  # for Adam optimizer

# Predictions
RESULTS_PATH = "predictions"

In [8]:
from model.network import StyTR2

# Model Instance
model = StyTR2(
    d_model=D_MODEL,
    img_height=IMG_HEIGHT,
    img_width=IMG_WIDTH,
    n=N,
    s=S,
    patch_size_content=PATCH_SIZE,
    n_head_content=N_HEAD_CONTENT,
    dim_feedforward_content=DIM_FEEDFORWARD_CONTENT,
    dropout_content=DROPOUT_CONTENT,
    n_layers_content=N_LAYERS_CONTENT,
    patch_size_style=PATCH_SIZE,
    n_head_style=N_HEAD_STYLE,
    dim_feedforward_style=DIM_FEEDFORWARD_STYLE,
    dropout_style=DROPOUT_STYLE,
    n_layers_style=N_LAYERS_STYLE,
    n_head_dec=N_HEAD_DEC,
    dim_feedforward_dec=DIM_FEEDFORWARD_DEC,
    dropout_dec=DROPOUT_DEC,
    n_layers_dec=N_LAYERS_DEC,
    lr=LR,
    lambdas=LAMBDAS,
    lr_patience=LR_PATIENCE,
    lr_decay=LR_DECAY,
    betas=BETAS,
    results_path=RESULTS_PATH
)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:02<00:00, 226MB/s]


### Model Training

In [9]:
## UPDATE THE BELOW PARAMETERS AS NEEDED ##
# Early Stopping and Checkpoints
MAX_EPOCHS = 30
EARLY_STOPPING_PATIENCE = 3
SAVE_TOP_K = 3
CHECKPOINT_FILENAME = 'model-{epoch:02d}-{val_loss:.4f}'

# Trainer
ACCELERATOR = 'gpu'  # choose from “cpu”, “gpu”, “tpu”, “hpu”, “mps”, “auto”
# more about strategy parameter here:
# https://pytorch-lightning.readthedocs.io/en/1.6.5/extensions/strategy.html
STRATEGY = "auto"
DEVICES = 1  # number of accelerators for distributed training
PRECISION = '32-true'   # '16-mixed'
ACCUMULATE_GRAD_BATCHES = 1
GRADIENT_CLIP_VAL = 8e5

In [None]:
# Early Stopping
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=EARLY_STOPPING_PATIENCE,
    mode='min',
)
# Model Checkpoints
checkpoint = ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=SAVE_TOP_K,
    filename=CHECKPOINT_FILENAME,
    save_last=False
)

# PyTorch Lightning Trainer
trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    callbacks=[early_stopping, checkpoint],
    accelerator=ACCELERATOR,
    strategy=STRATEGY,
    devices=DEVICES,
    precision=PRECISION,
    accumulate_grad_batches=ACCUMULATE_GRAD_BATCHES,
    gradient_clip_val=GRADIENT_CLIP_VAL
)

torch.set_float32_matmul_precision('high')  # Precision setting for matrix multiplications
trainer.fit(model, datamodule=data_module)  # Train the model

# Test the model and store results
trainer.test(model=model, datamodule=data_module, ckpt_path="best")

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name            | Type                      | Params
--------------------------------------------------------------
0 | content_patcher | ContentPatching           | 98.8 K
1 | style_patcher   | StylePatching             | 98.8 K
2 | cape            | CAPE                      | 262 K 
3 | content_encoder | ContentTransformerEncoder | 12.6 M
4 | style_encoder   | StyleTransformerEncoder   | 12.6 M
5 | decoder         | TransformerDecoder        | 16.8 M
6 | cnn_decoder     | CNNDecoder                | 3.5 M 
7 | vgg_extractor   | VGGF

Sanity Checking: 0it [00:00, ?it/s]

[0] Content Loss: 3.0006, Style Loss: 101084.5469, Identity Loss 1: 35850.1016, Identity Loss 2: 189315.3125


Training: 0it [00:00, ?it/s]

[0] Content Loss: 3.0436, Style Loss: 84715.7734, Identity Loss 1: 34737.4453, Identity Loss 2: 172971.7656
[100] Content Loss: 2.6158, Style Loss: 66399.8906, Identity Loss 1: 3631.1555, Identity Loss 2: 129232.1719
[200] Content Loss: 2.5037, Style Loss: 68239.3047, Identity Loss 1: 5238.9922, Identity Loss 2: 140675.9844
[300] Content Loss: 2.4842, Style Loss: 46038.1289, Identity Loss 1: 3238.2234, Identity Loss 2: 103346.2812
[400] Content Loss: 2.2532, Style Loss: 27333.6152, Identity Loss 1: 2156.6729, Identity Loss 2: 71992.2344
[500] Content Loss: 2.0541, Style Loss: 28971.7441, Identity Loss 1: 1378.4948, Identity Loss 2: 65152.2656
[600] Content Loss: 2.3712, Style Loss: 27900.9336, Identity Loss 1: 2112.4897, Identity Loss 2: 88313.8828
[700] Content Loss: 2.1083, Style Loss: 51566.1953, Identity Loss 1: 2300.3787, Identity Loss 2: 92465.0859
[800] Content Loss: 2.5344, Style Loss: 22086.3809, Identity Loss 1: 1946.3456, Identity Loss 2: 76115.8281
[900] Content Loss: 2.294

Validation: 0it [00:00, ?it/s]

[0] Content Loss: 2.3390, Style Loss: 19898.6934, Identity Loss 1: 671.9366, Identity Loss 2: 47488.9180
[100] Content Loss: 2.4804, Style Loss: 16194.6992, Identity Loss 1: 847.2848, Identity Loss 2: 58657.7461
[200] Content Loss: 2.4601, Style Loss: 17450.5859, Identity Loss 1: 954.2519, Identity Loss 2: 70139.5781
[300] Content Loss: 2.3705, Style Loss: 39942.4062, Identity Loss 1: 1029.9069, Identity Loss 2: 71418.6484
[400] Content Loss: 2.7014, Style Loss: 13911.2490, Identity Loss 1: 843.4259, Identity Loss 2: 53813.3789
[500] Content Loss: 2.5531, Style Loss: 33065.9141, Identity Loss 1: 1277.7793, Identity Loss 2: 90757.0078
[600] Content Loss: 2.5503, Style Loss: 47010.6094, Identity Loss 1: 1053.3079, Identity Loss 2: 81503.6094
[700] Content Loss: 2.6819, Style Loss: 14802.6191, Identity Loss 1: 1194.1908, Identity Loss 2: 80662.2031
[800] Content Loss: 2.5890, Style Loss: 19250.0859, Identity Loss 1: 972.9879, Identity Loss 2: 71302.0938
[900] Content Loss: 2.4460, Style L

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

### Visualizing Results

In [None]:
content = np.load(join(RESULTS_PATH, 'content.npy'), mmap_mode="r+")
style = np.load(join(RESULTS_PATH, 'style.npy'), mmap_mode="r+")
stylized = np.load(join(RESULTS_PATH, 'stylized.npy'), mmap_mode="r+")

In [None]:
indices = np.random.choice(content.shape[0], size=5, replace=False)
fig, axes = plt.subplots(nrows=5, ncols=3, figsize=(15, 25))
for row, i in enumerate(indices):
    c = np.transpose(content[i], axes=(1, 2, 0)).astype(np.uint8)
    s = np.transpose(style[i], axes=(1, 2, 0)).astype(np.uint8)
    t = np.transpose(stylized[i], axes=(1, 2, 0)).astype(np.uint8)
    axes[row][0].imshow(c)
    axes[row][0].set_title("Content")
    axes[row][1].imshow(s)
    axes[row][1].set_title("Style")
    axes[row][2].imshow(t)
    axes[row][2].set_title("Stylized")
    for col in range(3):
        axes[row][col].axis("off")
plt.tight_layout()
plt.show()