# Style Transfer with Vision Transformer
Original Paper: https://arxiv.org/abs/2105.14576

In [1]:
# Standard imports
import os, glob, gc, random
from os.path import join
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

# PyTorch core
import torch

# PyTorch Lightning for easier training loops
import pytorch_lightning as pl  # High-level PyTorch interface
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint  # Useful training callbacks

# Progress bar
from tqdm.notebook import tqdm  # Pretty progress bars in notebooks

### Data Module Setup

In [2]:
## UPDATE THE BELOW PARAMETERS AS NEEDED ##
# Data dir
STYLE_PATH = "raw_data/style"
CONTENT_PATH = "raw_data/content"

# Setting up data
NUM_STYLES_PER_IMAGE = 100
TVT_SPLIT = [0.6, 0.3, 0.1]

# Dataloader
BATCH_SIZE = 8  # as per paper
NUM_WORKERS = 4
PREFETCH_FACTOR = 2
PIN_MEMORY = True

In [3]:
from data.data_module import StyleTransferDM

# Data Module Instance
data_module = StyleTransferDM(
    STYLE_PATH,
    CONTENT_PATH,
    num_styles_per_image=NUM_STYLES_PER_IMAGE,
    train_val_test_split=TVT_SPLIT,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    prefetch_factor=PREFETCH_FACTOR,
    pin_memory=PIN_MEMORY      
)

### Model Setup

In [4]:
## UPDATE THE BELOW PARAMETERS AS NEEDED ##
# Universal
D_MODEL = 512
IMG_HEIGHT = 224
IMG_WIDTH = 224

# Patching
N = 8
S = 1
PATCH_SIZE = 8

# Content Encoder
N_HEAD_CONTENT = 8  # as per GH
DIM_FEEDFORWARD_CONTENT = 2048  # as per GH
DROPOUT_CONTENT = 0.1  # as per GH
N_LAYERS_CONTENT = 3  # as per GH

# Style Encoder
N_HEAD_STYLE = 8  # as per GH
DIM_FEEDFORWARD_STYLE = 2048  # as per GH
DROPOUT_STYLE = 0.1  # as per GH
N_LAYERS_STYLE = 3  # as per GH

# Decoder
N_HEAD_DEC = 8  # as per GH
DIM_FEEDFORWARD_DEC = 2048  # as per GH
DROPOUT_DEC = 0.1  # as per GH
N_LAYERS_DEC = 3  # as per GH

# Optimization
LR = 0.0005  # as per paper
LAMBDAS = [10, 7, 50, 1]  # as per paper
LR_PATIENCE = 5
LR_DECAY = 0.1
BETAS = (0.9, 0.999)  # for Adam optimizer

In [7]:
from model.network import StyTR2

# Model Instance
model = StyTR2(
    d_model=D_MODEL,
    img_height=IMG_HEIGHT,
    img_width=IMG_WIDTH,
    n=N,
    s=S,
    patch_size_content=PATCH_SIZE,
    n_head_content=N_HEAD_CONTENT,
    dim_feedforward_content=DIM_FEEDFORWARD_CONTENT,
    dropout_content=DROPOUT_CONTENT,
    n_layers_content=N_LAYERS_CONTENT,
    patch_size_style=PATCH_SIZE,
    n_head_style=N_HEAD_STYLE,
    dim_feedforward_style=DIM_FEEDFORWARD_STYLE,
    dropout_style=DROPOUT_STYLE,
    n_layers_style=N_LAYERS_STYLE,
    n_head_dec=N_HEAD_DEC,
    dim_feedforward_dec=DIM_FEEDFORWARD_DEC,
    dropout_dec=DROPOUT_DEC,
    n_layers_dec=N_LAYERS_DEC,
    lr=LR,
    lambdas=LAMBDAS,
    lr_patience=LR_PATIENCE,
    lr_decay=LR_DECAY,
    betas=BETAS
)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /Users/chai/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|████████████████████████████████████████| 548M/548M [00:23<00:00, 24.2MB/s]


### Model Training

In [8]:
## UPDATE THE BELOW PARAMETERS AS NEEDED ##
MAX_EPOCHS = 100
EARLY_STOPPING_PATIENCE = 10
SAVE_TOP_K = 3
CHECKPOINT_FILENAME = 'model-{epoch:02d}-{val_loss:.4f}'
ACCELERATOR = 'auto'  # choose from “cpu”, “gpu”, “tpu”, “hpu”, “mps”, “auto”
DEVICES = 1  # number of accelerators for distributed training

In [10]:
# Early Stopping
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=EARLY_STOPPING_PATIENCE,
    mode='min',
)
# Model Checkpoints
checkpoint = ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=SAVE_TOP_K,
    filename=CHECKPOINT_FILENAME,
    save_last=False
)
# PyTorch Lightning Trainer
trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    callbacks=[early_stopping, checkpoint],
    accelerator=ACCELERATOR,
    devices=DEVICES
)

torch.set_float32_matmul_precision('high')  # Precision setting for matrix multiplications
trainer.fit(model, datamodule=data_module)  # Train the model

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name            | Type                      | Params | Mode 
----------------------------------------------------------------------
0 | content_patcher | ContentPatching           | 98.8 K | train
1 | style_patcher   | StylePatching             | 98.8 K | train
2 | cape            | CAPE                      | 262 K  | train
3 | content_encoder | ContentTransformerEncoder | 22.1 M | train
4 | style_encoder   | StyleTransformerEncoder   | 22.1 M | train
5 | decoder         | TransformerDecoder        | 29.4 M | train
6 | cnn_decoder     | CNNDecoder                | 4.7 M  | train
7 | vgg_extractor   | VGGFeatureExtractor       | 20.0 M | train
----------------------------------------------------------------------
78.8 M    Trainable params
20.0 M    Non-trainable params
98.8 M    Total params
395.119   Total estimated model params size (MB)
306       Modules in train 

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

Training: |                                               | 0/? [00:00<?, ?it/s]

RuntimeError: MPS backend out of memory (MPS allocated: 43.26 GB, other allocations: 2.36 GB, max allowed: 45.90 GB). Tried to allocate 600.25 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).