In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Base paths
GDRIVE_BASE = "/content/drive/MyDrive/NST Dataset/filter_preproc"

Mounted at /content/drive


In [None]:
!pip install torch torchvision pillow tqdm



In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import os
from tqdm import tqdm

In [None]:
def calc_mean_std(feat, eps=1e-5):
    size = feat.size()
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std

def adain(content_feat, style_feat):
    c_mean, c_std = calc_mean_std(content_feat)
    s_mean, s_std = calc_mean_std(style_feat)
    normalized = (content_feat - c_mean) / c_std
    return normalized * s_std + s_mean


In [6]:
# VGG encoder (truncated at relu4_1)
# from torchvision.models import vgg19

# vgg = vgg19(pretrained=False)
# vgg.load_state_dict(torch.load("vgg_normalised.pth"))
# encoder = nn.Sequential(*list(vgg.features.children())[:21])

# decoder = torch.load("decoder.pth")

# device = "cuda" if torch.cuda.is_available() else "cpu"
# encoder = encoder.to(device).eval()
# decoder = decoder.to(device).eval()

In [7]:
def build_vgg_encoder():
    layers = []
    in_channels = 3
    cfg = [64, 64, 'M',
           128, 128, 'M',
           256, 256, 256, 256, 'M',
           512, 512, 512, 512]  # stop at relu4_1

    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv, nn.ReLU(inplace=True)]
            in_channels = v

    return nn.Sequential(*layers)


In [9]:
# encoder = build_vgg_encoder()
# state_dict = torch.load("vgg_normalised.pth")

# encoder.load_state_dict(state_dict)
# encoder.eval()

In [15]:
decoder = nn.Sequential(
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 256, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 128, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 64, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 3, (3, 3)),
)

vgg = nn.Sequential(
    nn.Conv2d(3, 3, (1, 1)),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(3, 64, (3, 3)),
    nn.ReLU(),  # relu1-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),  # relu1-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 128, (3, 3)),
    nn.ReLU(),  # relu2-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),  # relu2-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 256, (3, 3)),
    nn.ReLU(),  # relu3-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 512, (3, 3)),
    nn.ReLU(),  # relu4-1, this is the last layer used
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU()  # relu5-4
)


In [16]:
state_dict = torch.load("vgg_normalised.pth")
vgg.load_state_dict(state_dict)

<All keys matched successfully>

In [17]:
encoder = nn.Sequential(*list(vgg.children())[:31])
encoder.eval()

Sequential(
  (0): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (1): ReflectionPad2d((1, 1, 1, 1))
  (2): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
  (4): ReflectionPad2d((1, 1, 1, 1))
  (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), dilation=1, ceil_mode=True)
  (8): ReflectionPad2d((1, 1, 1, 1))
  (9): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (10): ReLU()
  (11): ReflectionPad2d((1, 1, 1, 1))
  (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
  (13): ReLU()
  (14): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), dilation=1, ceil_mode=True)
  (15): ReflectionPad2d((1, 1, 1, 1))
  (16): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
  (17): ReLU()
  (18): ReflectionPad2d((1, 1, 1, 1))
  (19): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
  (20): ReLU()
  (21): ReflectionPad2d((1, 1, 1, 1))
  (22): Conv2d(256, 256, kernel_size=(

In [18]:
decoder.load_state_dict(torch.load("decoder.pth"))
decoder.eval()

Sequential(
  (0): ReflectionPad2d((1, 1, 1, 1))
  (1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1))
  (2): ReLU()
  (3): Upsample(scale_factor=2.0, mode='nearest')
  (4): ReflectionPad2d((1, 1, 1, 1))
  (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): ReflectionPad2d((1, 1, 1, 1))
  (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
  (9): ReLU()
  (10): ReflectionPad2d((1, 1, 1, 1))
  (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
  (12): ReLU()
  (13): ReflectionPad2d((1, 1, 1, 1))
  (14): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))
  (15): ReLU()
  (16): Upsample(scale_factor=2.0, mode='nearest')
  (17): ReflectionPad2d((1, 1, 1, 1))
  (18): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
  (19): ReLU()
  (20): ReflectionPad2d((1, 1, 1, 1))
  (21): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
  (22): ReLU()
  (23): Upsample(scale_factor=2.0, mode='nearest')
  (24): ReflectionPad2d((1, 1, 1, 1))
  (25): Conv2d(64

### Sanity check

In [19]:
x = torch.randn(1, 3, 256, 256)

with torch.no_grad():
    f = encoder(x)
    y = decoder(f)

print(f.shape)  # should be [1, 512, 32, 32]
print(y.shape)  # should be [1, 3, 256, 256]


torch.Size([1, 512, 32, 32])
torch.Size([1, 3, 256, 256])


In [20]:
device = "cuda" if torch.cuda.is_available() else "cpu"
encoder = encoder.to(device).eval()
decoder = decoder.to(device).eval()

### Image Transform

In [21]:
transform = transforms.Compose([
    transforms.Resize(512),
    transforms.CenterCrop(512),
    transforms.ToTensor()
])

def load_image(path):
    img = Image.open(path).convert("RGB")
    return transform(img).unsqueeze(0).to(device)


In [22]:
@torch.no_grad()
def stylize(content_path, style_path, alpha=1.0):
    content = load_image(content_path)
    style = load_image(style_path)

    content_feat = encoder(content)
    style_feat = encoder(style)

    t = adain(content_feat, style_feat)
    t = alpha * t + (1 - alpha) * content_feat

    out = decoder(t)
    return out.clamp(0, 1)

In [24]:
def collect_image_files(root, exts={".jpg", ".jpeg", ".png"}):
    files = []
    for dirpath, _, filenames in os.walk(root):
        for fn in filenames:
            if os.path.splitext(fn)[1].lower() in exts:
                files.append(os.path.join(dirpath, fn))
    return sorted(files)


In [25]:
content_dir = GDRIVE_BASE + "/preproc/images_jpeg"
style_dir = GDRIVE_BASE + "/styles"
out_dir = GDRIVE_BASE + "/stylized"

os.makedirs(out_dir, exist_ok=True)

# content_files = sorted(os.listdir(content_dir))[:20]
# style_files   = sorted(os.listdir(style_dir))
content_files = sorted([
    f for f in os.listdir(content_dir)
    if f.lower().endswith((".jpg", ".png", ".jpeg"))
])[:20]

style_files = collect_image_files(style_dir)


NUM_STYLES_PER_IMAGE = 2  # configurable

for cf in tqdm(content_files, desc="AdaIN stylization"):
    content_path = os.path.join(content_dir, cf)

    for k in range(NUM_STYLES_PER_IMAGE):
        # style_name = style_files[(hash(cf) + k) % len(style_files)]
        # style_path = os.path.join(style_dir, style_name)
        style_path = style_files[(hash(cf) + k) % len(style_files)]

        out = stylize(content_path, style_path, alpha=1.0)

        out_name = cf.replace(".jpg", f"_s{k}.jpg")
        out_path = os.path.join(out_dir, out_name)

        transforms.ToPILImage()(out.squeeze(0)).save(out_path)


AdaIN stylization: 100%|██████████| 20/20 [05:50<00:00, 17.52s/it]
