# Style Transfer with Vision Transformer
Original Paper: https://arxiv.org/abs/2105.14576

### Environment Setup

In [1]:
# Import the drive module from Google Colab to access files stored in Google Drive
from google.colab import drive
import sys
drive.mount('/content/drive')  # Mount Google Drive

# Navigate to project folder (make sure you put the correct path here!)
%cd drive/MyDrive/Deep_Learning/StyleTransfer_ViT
sys.path.append("drive/MyDrive/Deep_Learning/StyleTransfer_ViT")  # Add to Python path--for sanity
%ls  # List folder contents

Mounted at /content/drive
/content/drive/MyDrive/Deep_Learning/StyleTransfer_ViT
[0m[01;34mbest_model[0m/  [01;34mlightning_logs_old[0m/  [01;34mpredictions2[0m/  training_pipeline_0.ipynb
[01;34mdata[0m/        [01;34mmodel[0m/               [01;34mraw_data[0m/      training_pipeline.ipynb
LICENSE      [01;34mpredictions[0m/         README.md      visualizing_results.ipynb


In [2]:
%%capture
# Install PyTorch Requirements (Colab only)
!pip install torch==2.5.1 torchvision==0.20.1 pytorch-lightning==2.0.2

# Standard imports
import os, glob, gc, random
from os.path import join
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

# PyTorch core
import torch

# PyTorch Lightning for easier training loops
import pytorch_lightning as pl  # High-level PyTorch interface
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint  # Useful training callbacks

# Progress bar
from tqdm.notebook import tqdm  # Pretty progress bars in notebooks

In [3]:
SEED = 42

os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)

torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

pl.seed_everything(SEED, workers=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

INFO:lightning_fabric.utilities.seed:Global seed set to 42


In [4]:
# DO YOU WISH TO TRAIN THE MODEL?
TRAIN_MODEL = True
TEST_MODEL = False
MODEL_VERSION = 3       # StyTR2 = 2; StyTR3 = 3

### Data Module Setup

In [5]:
## UPDATE THE BELOW PARAMETERS AS NEEDED ##
# Data dir
STYLE_PATH = "raw_data/style"
CONTENT_PATH = "raw_data/content"

# Setting up data
NUM_STYLES_PER_IMAGE = 10       # choose either 10 and 20 for simplicity

TVT_SPLIT = [0.7 - 1/(NUM_STYLES_PER_IMAGE), 0.3, 1/(NUM_STYLES_PER_IMAGE)]

# Dataloader
BATCH_SIZE = 8  # 8, as per paper. change accumulate gradients accordingly
NUM_WORKERS = 6
PREFETCH_FACTOR = 3
PIN_MEMORY = True

In [6]:
from data.data_module import StyleTransferDM

# Data Module Instance
if TRAIN_MODEL or TEST_MODEL:
    data_module = StyleTransferDM(
        STYLE_PATH,
        CONTENT_PATH,
        num_styles_per_image=NUM_STYLES_PER_IMAGE,
        train_val_test_split=TVT_SPLIT,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        prefetch_factor=PREFETCH_FACTOR,
        pin_memory=PIN_MEMORY
    )

### Model Setup

In [7]:
## UPDATE THE BELOW PARAMETERS AS NEEDED ##
# Universal
D_MODEL = 512

# VGG
IMG_HEIGHT = 224
IMG_WIDTH = 224
EXTRACTION_LAYERS=[8, 15, 20, 26, 31, 35]

# Patching
N = 18
S = 1
PATCH_SIZE = 8

# Content Encoder
N_HEAD_CONTENT = 8  # as per GH
DIM_FEEDFORWARD_CONTENT = 2048  # as per GH
DROPOUT_CONTENT = 0.1  # as per GH
N_LAYERS_CONTENT = 3  # as per GH

# Style Encoder
N_HEAD_STYLE = 8  # as per GH
DIM_FEEDFORWARD_STYLE = 2048  # as per GH
DROPOUT_STYLE = 0.1  # as per GH
N_LAYERS_STYLE = 3  # as per GH

# Decoder
N_HEAD_DEC = 8  # as per GH
DIM_FEEDFORWARD_DEC = 2048  # as per GH
DROPOUT_DEC = 0.1  # as per GH
N_LAYERS_DEC = 3  # as per GH

# Optimization
LR = 0.0005  # as per paper
if MODEL_VERSION == 2:
    LAMBDAS = [10, 7, 70, 1]  # as per paper: [10, 7, 50, 1]; as per GH: [10, 7, 70, 1]
elif MODEL_VERSION == 3:
    LAMBDAS = [10, 7, 70, 1, 10]    # additional lambda for sep_loss
    MARGIN = 2.0    # minimum gap between stylized and reverse stylized images
else:
    raise ValueError("Invalid model version")

# Optimizer
TRAINING_STYLE = "original"     # "original" or "plateau"
if TRAINING_STYLE == "plateau":
    LR_DECAY = 0.5
    LR_PATIENCE = 500
elif TRAINING_STYLE == "original":
    LR_DECAY = 0.05
    LR_PATIENCE = 1
BETAS = (0.9, 0.999)  # for Adam optimizer

# Predictions
if MODEL_VERSION == 2:
    RESULTS_PATH = "predictions"
elif MODEL_VERSION == 3:
    RESULTS_PATH = "predictions2"
else:
    raise ValueError("Invalid model version")
os.makedirs(RESULTS_PATH, exist_ok=True)

In [8]:
from re import M
if MODEL_VERSION == 2:
    from model.network import StyTR2 as StyTR
elif MODEL_VERSION == 3:
    from model.network2 import StyTR3 as StyTR
else:
    raise ValueError("Invalid model version")

if TRAIN_MODEL:
    # Model Instance
    if MODEL_VERSION == 2:
        model = StyTR(
            d_model=D_MODEL,
            img_height=IMG_HEIGHT,
            img_width=IMG_WIDTH,
            extraction_layers=EXTRACTION_LAYERS,
            n=N,
            s=S,
            patch_size_content=PATCH_SIZE,
            n_head_content=N_HEAD_CONTENT,
            dim_feedforward_content=DIM_FEEDFORWARD_CONTENT,
            dropout_content=DROPOUT_CONTENT,
            n_layers_content=N_LAYERS_CONTENT,
            patch_size_style=PATCH_SIZE,
            n_head_style=N_HEAD_STYLE,
            dim_feedforward_style=DIM_FEEDFORWARD_STYLE,
            dropout_style=DROPOUT_STYLE,
            n_layers_style=N_LAYERS_STYLE,
            n_head_dec=N_HEAD_DEC,
            dim_feedforward_dec=DIM_FEEDFORWARD_DEC,
            dropout_dec=DROPOUT_DEC,
            n_layers_dec=N_LAYERS_DEC,
            lr=LR,
            lambdas=LAMBDAS,
            lr_patience=LR_PATIENCE,
            lr_decay=LR_DECAY,
            betas=BETAS,
            results_path=RESULTS_PATH
        )
    elif MODEL_VERSION == 3:
        model = StyTR(
            d_model=D_MODEL,
            img_height=IMG_HEIGHT,
            img_width=IMG_WIDTH,
            extraction_layers=EXTRACTION_LAYERS,
            n=N,
            s=S,
            patch_size_content=PATCH_SIZE,
            n_head_content=N_HEAD_CONTENT,
            dim_feedforward_content=DIM_FEEDFORWARD_CONTENT,
            dropout_content=DROPOUT_CONTENT,
            n_layers_content=N_LAYERS_CONTENT,
            patch_size_style=PATCH_SIZE,
            n_head_style=N_HEAD_STYLE,
            dim_feedforward_style=DIM_FEEDFORWARD_STYLE,
            dropout_style=DROPOUT_STYLE,
            n_layers_style=N_LAYERS_STYLE,
            n_head_dec=N_HEAD_DEC,
            dim_feedforward_dec=DIM_FEEDFORWARD_DEC,
            dropout_dec=DROPOUT_DEC,
            n_layers_dec=N_LAYERS_DEC,
            lr=LR,
            lambdas=LAMBDAS,
            margin=MARGIN,
            lr_patience=LR_PATIENCE,
            lr_decay=LR_DECAY,
            betas=BETAS,
            results_path=RESULTS_PATH,
            training_style=TRAINING_STYLE
        )

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:02<00:00, 226MB/s]


### Model Training

In [9]:
## UPDATE THE BELOW PARAMETERS AS NEEDED ##
# Early Stopping and Checkpoints
MAX_EPOCHS = 30
EARLY_STOPPING_PATIENCE = 3
SAVE_TOP_K = 3
CHECKPOINT_FILENAME = 'model-{epoch:02d}-{val_loss:.4f}'

# Trainer
ACCELERATOR = 'gpu'  # choose from “cpu”, “gpu”, “tpu”, “hpu”, “mps”, “auto”
# more about strategy parameter here:
# https://pytorch-lightning.readthedocs.io/en/1.6.5/extensions/strategy.html
STRATEGY = "auto"
DEVICES = 1  # number of accelerators for distributed training
PRECISION = '32-true'   # '16-mixed'
ACCUMULATE_GRAD_BATCHES = 1
GRADIENT_CLIP_VAL = None

In [None]:
if TRAIN_MODEL:
    # Early Stopping
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=EARLY_STOPPING_PATIENCE,
        mode='min',
    )
    # Model Checkpoints
    checkpoint = ModelCheckpoint(
        monitor='val_loss',
        mode='min',
        save_top_k=SAVE_TOP_K,
        filename=CHECKPOINT_FILENAME,
        save_last=False
    )

    # PyTorch Lightning Trainer
    trainer = pl.Trainer(
        max_epochs=MAX_EPOCHS,
        callbacks=[early_stopping, checkpoint],
        accelerator=ACCELERATOR,
        strategy=STRATEGY,
        devices=DEVICES,
        precision=PRECISION,
        accumulate_grad_batches=ACCUMULATE_GRAD_BATCHES,
        gradient_clip_val=GRADIENT_CLIP_VAL
    )
    torch.set_float32_matmul_precision('high')  # Precision setting for matrix multiplications
    trainer.fit(model, datamodule=data_module)  # Train the model

    # Test the model and store results
    trainer.test(model=model, datamodule=data_module, ckpt_path="best")

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name            | Type                      | Params
--------------------------------------------------------------
0 | content_patcher | ContentPatching           | 98.8 K
1 | style_patcher   | StylePatching             | 98.8 K
2 | cape            | CAPE                      | 262 K 
3 | content_encoder | ContentTransformerEncoder | 12.6 M
4 | style_encoder   | StyleTransformerEncoder   | 12.6 M
5 | decoder         | TransformerDecoder        | 16.8 M
6 | cnn_decoder     | CNNDecoder                | 3.5 M 
7 | vgg_extractor   | VGGF

Sanity Checking: 0it [00:00, ?it/s]

[0] c:2.1413 s:10.2983 i1:0.5915 i2:21.0071 sep:10.4107


Training: 0it [00:00, ?it/s]

[0] c:2.1416 s:13.4728 i1:0.6002 i2:24.2436 sep:7.5880
[step 0] grad_norm: 1299.5956


In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

## Load Model from Checkpoint (if TRAIN_MODEL is False)

In [None]:
CHECKPOINT_FILEPATH = "lightning_logs/version_16/checkpoints/model-epoch=08-val_loss=43.1301.ckpt"

In [None]:
gc.collect()
torch.cuda.empty_cache()

if TEST_MODEL:
    # Load the trained model from the saved checkpoint
    model = StyTR.load_from_checkpoint(CHECKPOINT_FILEPATH)

    trainer = pl.Trainer(
        logger=False,
        accelerator=ACCELERATOR,
        devices=DEVICES
    )

    # Set the float32 matrix multiplication precision for higher efficiency
    torch.set_float32_matmul_precision('high')
    trainer.test(model=model, datamodule=data_module)

### Visualizing Results

In [None]:
NUM_SAMPLES = 20

In [None]:
content = np.load(join(RESULTS_PATH, 'content.npy'), mmap_mode="r+")
style = np.load(join(RESULTS_PATH, 'style.npy'), mmap_mode="r+")
stylized = np.load(join(RESULTS_PATH, 'stylized.npy'), mmap_mode="r+")
reverse_stylized = np.load(join(RESULTS_PATH, 'reverse_stylized.npy'), mmap_mode="r+")

In [None]:
indices = np.random.choice(content.shape[0], size=NUM_SAMPLES, replace=False)
fig, axes = plt.subplots(nrows=NUM_SAMPLES, ncols=4, figsize=(4*3, NUM_SAMPLES*3))
for row, i in enumerate(indices):
    c = np.transpose(content[i], axes=(1, 2, 0)).astype(np.uint8)
    s = np.transpose(style[i], axes=(1, 2, 0)).astype(np.uint8)
    t = np.transpose(stylized[i], axes=(1, 2, 0)).astype(np.uint8)
    r = np.transpose(reverse_stylized[i], axes=(1, 2, 0)).astype(np.uint8)
    axes[row][0].imshow(c)
    axes[row][0].set_title("Content")
    axes[row][1].imshow(s)
    axes[row][1].set_title("Style")
    axes[row][2].imshow(t)
    axes[row][2].set_title("Stylized")
    axes[row][3].imshow(r)
    axes[row][3].set_title("Reverse Stylized")
    for col in range(4):
        axes[row][col].axis("off")
plt.tight_layout()
plt.show()