
# 🎨 Neural Style Transfer with VGG19 — *Interactive Notebook* (PyTorch, CUDA/MPS)



> Optional installation (uncomment & run if needed):
```bash
# !pip install ipywidgets
# !jupyter nbextension enable --py widgetsnbextension
```


In [None]:

import os, math, time, numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

try:
    import ipywidgets as widgets
except Exception:
    widgets = None

def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

DEVICE = get_device()
print("Using device:", DEVICE)
if DEVICE.type == "cuda":
    print("CUDA:", torch.cuda.get_device_name(0))
elif DEVICE.type == "mps":
    print("Apple Metal (MPS) backend active")

torch.set_grad_enabled(False)


In [None]:

CONTENT_PATH = None
STYLE_PATH   = None

def make_placeholder(size=512):
    w=h=size
    x = np.linspace(0,1,w); y = np.linspace(0,1,h)
    X, Y = np.meshgrid(x, y)
    r = (np.sin(2*np.pi*X*3)+1)/2
    g = (np.cos(2*np.pi*Y*3)+1)/2
    b = (np.sin(2*np.pi*(X+Y)*2)+1)/2
    img_np = np.stack([r,g,b], axis=2)
    img_np[40:110, 40:160, :] = [1.0, 0.6, 0.2]
    cy, cx, rad = 180, 190, 35
    yy, xx = np.ogrid[:h, :w]
    mask = (yy-cy)**2 + (xx-cx)**2 <= rad**2
    img_np[mask] = [0.2, 0.7, 1.0]
    return Image.fromarray((img_np*255).astype(np.uint8))

def load_pil(path, fallback_size=512):
    if path and os.path.exists(path):
        return Image.open(path).convert("RGB")
    return make_placeholder(fallback_size)

content_disp = load_pil(CONTENT_PATH, 512)
style_disp   = load_pil(STYLE_PATH,   512)

fig, ax = plt.subplots(1, 2, figsize=(8,4))
ax[0].imshow(content_disp); ax[0].set_title("Content"); ax[0].axis("off")
ax[1].imshow(style_disp);   ax[1].set_title("Style");   ax[1].axis("off")
plt.show()


In [None]:

try:
    weights = models.VGG19_Weights.DEFAULT
    vgg_features = models.vgg19(weights=weights).features.eval().to(DEVICE)
    imagenet_mean = torch.tensor(weights.meta["mean"]).view(1,3,1,1).to(DEVICE)
    imagenet_std  = torch.tensor(weights.meta["std"]).view(1,3,1,1).to(DEVICE)
    print("Loaded VGG19 pretrained on ImageNet.")
except Exception as e:
    print("WARNING: Using random init; results will be poor.\n", e)
    vgg_features = models.vgg19(weights=None).features.eval().to(DEVICE)
    imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(DEVICE)
    imagenet_std  = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(DEVICE)

def make_preprocess(imsize):
    return transforms.Compose([
        transforms.Resize(imsize, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(imsize),
        transforms.ToTensor(),
        transforms.Normalize(mean=imagenet_mean.squeeze().tolist(), std=imagenet_std.squeeze().tolist()),
    ])

def pil_to_tensor(img, imsize):
    pre = make_preprocess(imsize)
    return pre(img).unsqueeze(0).to(DEVICE)

def tensor_to_pil(t):
    x = t.detach().cpu().clone()
    x = x * imagenet_std.cpu() + imagenet_mean.cpu()
    x = x.clamp(0,1)
    return transforms.ToPILImage()(x.squeeze(0))


In [None]:

CONTENT_LAYERS = [21]                         # conv4_2
STYLE_LAYERS   = [0, 5, 10, 19, 28]           # conv1_1..conv5_1
STYLE_WEIGHTS_PER_LAYER = [1, 1, 1, 1, 1]

def forward_features(x, layers_to_capture):
    feats = {}
    out = x
    for i, layer in enumerate(vgg_features):
        out = layer(out)
        if i in layers_to_capture:
            feats[i] = out
        if len(feats) == len(layers_to_capture) and i >= max(layers_to_capture):
            break
    return feats

def gram_matrix(x):
    B, C, H, W = x.shape
    y = x.view(B, C, H*W)
    G = y @ y.transpose(1, 2)
    return G / (C*H*W)

def total_variation(x):
    return (x[:, :, :-1, :] - x[:, :, 1:, :]).abs().mean() + (x[:, :, :, :-1] - x[:, :, :, 1:]).abs().mean()


In [None]:

def run_style_transfer(content_img, style_img, imsize=512, steps=400, lr=0.03,
                       content_weight=1.0, style_weight=1e3, tv_weight=1e-3,
                       init_from="content", preview_every=50, out_path="stylized.png"):
    torch.set_grad_enabled(True)

    with torch.no_grad():
        tgt_c = forward_features(content_img, CONTENT_LAYERS)
        tgt_s = forward_features(style_img,   STYLE_LAYERS)
        tgt_s_grams = {i: gram_matrix(tgt_s[i]) for i in STYLE_LAYERS}

    if init_from == "content":
        generated = content_img.clone().detach()
    elif init_from == "style":
        generated = style_img.clone().detach()
    else:
        generated = torch.randn_like(content_img) * 0.1 + content_img * 0.9

    generated.requires_grad_(True)
    opt = torch.optim.Adam([generated], lr=lr)

    for step in range(1, steps+1):
        opt.zero_grad()
        feats_c = forward_features(generated, CONTENT_LAYERS)
        feats_s = forward_features(generated, STYLE_LAYERS)

        c_loss = sum(torch.nn.functional.mse_loss(feats_c[i], tgt_c[i]) for i in CONTENT_LAYERS)
        s_loss = 0.0
        for w, i in zip(STYLE_WEIGHTS_PER_LAYER, STYLE_LAYERS):
            s_loss = s_loss + w * torch.nn.functional.mse_loss(gram_matrix(feats_s[i]), tgt_s_grams[i])
        tv = total_variation(generated)

        loss = content_weight*c_loss + style_weight*s_loss + tv_weight*tv
        loss.backward(); opt.step()

        if DEVICE.type == "mps":
            torch.mps.synchronize()

        if step % preview_every == 0 or step in (1, steps):
            clear_output(wait=True)
            print(f"Step {step}/{steps} | content {c_loss.item():.4f} | style {s_loss.item():.4f} | tv {tv.item():.4f}")
            display(tensor_to_pil(generated))

    out_pil = tensor_to_pil(generated); out_pil.save(out_path)
    torch.set_grad_enabled(False)
    print("Saved:", out_path)
    return out_pil


In [None]:

# Initial tensors
IMSIZE = 512
content_disp = content_disp
style_disp   = style_disp
content_t = pil_to_tensor(content_disp, IMSIZE)
style_t   = pil_to_tensor(style_disp,   IMSIZE)

# Triptych
fig, ax = plt.subplots(1,3, figsize=(12,4))
ax[0].imshow(content_disp.resize((IMSIZE,IMSIZE))); ax[0].set_title("Content"); ax[0].axis("off")
ax[1].imshow(style_disp.resize((IMSIZE,IMSIZE)));   ax[1].set_title("Style");   ax[1].axis("off")
ax[2].imshow(content_disp.resize((IMSIZE,IMSIZE))); ax[2].set_title("Init (content)"); ax[2].axis("off")
plt.show()

if widgets is not None:
    content_w = widgets.FloatSlider(value=1.0, min=0.1, max=20.0, step=0.1, description="Content")
    style_w   = widgets.FloatLogSlider(value=1e3, base=10, min=2, max=5, step=0.1, description="Style")
    tv_w      = widgets.FloatLogSlider(value=1e-3, base=10, min=-5, max=-1, step=0.1, description="TV")
    imsize_w  = widgets.IntSlider(value=IMSIZE, min=256, max=768, step=64, description="Size")
    steps_w   = widgets.IntSlider(value=400, min=100, max=1000, step=50, description="Steps")
    lr_w      = widgets.FloatSlider(value=0.03, min=0.005, max=0.1, step=0.005, description="LR")
    init_w    = widgets.Dropdown(options=["content","style","noise"], value="content", description="Init")
    out_path  = widgets.Text(value="stylized_output.png", description="Save as")
    run_b     = widgets.Button(description="Run style transfer", button_style="primary")

    ui = widgets.VBox([
        widgets.HBox([content_w, style_w, tv_w]),
        widgets.HBox([imsize_w, steps_w, lr_w, init_w]),
        out_path,
        run_b
    ])
    display(ui)

    def _on_click(_):
        c_t = pil_to_tensor(content_disp, imsize_w.value)
        s_t = pil_to_tensor(style_disp,   imsize_w.value)
        run_style_transfer(
            c_t, s_t,
            imsize=imsize_w.value,
            steps=steps_w.value,
            lr=lr_w.value,
            content_weight=content_w.value,
            style_weight=style_w.value,
            tv_weight=tv_w.value,
            init_from=init_w.value,
            preview_every=50,
            out_path=out_path.value
        )

    run_b.on_click(_on_click)
else:
    print("ipywidgets not available; call run_style_transfer(content_t, style_t, ...) manually.")
