In [None]:
import os
import glob
import random
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths
content_dir = "dataset/content"
style_dir = "dataset/style"
save_decoder_path = "decoder.pth"

# Hyperparameters
batch_size = 8
lr = 1e-4
num_epochs = 10
content_weight = 1.0
style_weight = 10.0

image_size = 256


In [None]:
vgg_mean = torch.tensor([0.485, 0.456, 0.406]).to(device).view(1, 3, 1, 1)
vgg_std = torch.tensor([0.229, 0.224, 0.225]).to(device).view(1, 3, 1, 1)

def vgg_preprocess(img):
    # img is a PIL image or a tensor in [0,1]
    # transform to tensor and normalize
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
    return transform(img)

def deprocess(tensor):
    # Reverse of the preprocessing for display
    tensor = tensor * vgg_std + vgg_mean
    tensor = torch.clamp(tensor, 0, 1)
    return tensor


In [None]:
class ImageDataset(Dataset):
    def __init__(self, content_dir, style_dir, transform=None):
        self.content_images = glob.glob(os.path.join(content_dir, "*"))
        self.style_images = glob.glob(os.path.join(style_dir, "*"))
        self.transform = transform

    def __len__(self):
        return min(len(self.content_images), len(self.style_images))

    def __getitem__(self, idx):
        content_path = self.content_images[idx % len(self.content_images)]
        style_path = self.style_images[idx % len(self.style_images)]

        content_img = Image.open(content_path).convert("RGB")
        style_img = Image.open(style_path).convert("RGB")

        if self.transform:
            content = self.transform(content_img)
            style = self.transform(style_img)
        else:
            content = content_img
            style = style_img

        return content, style

dataset = ImageDataset(content_dir, style_dir, transform=lambda img: vgg_preprocess(img))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)


In [None]:
def mean_variance_norm(feat, eps=1e-5):
    mean = feat.mean(dim=[2, 3], keepdim=True)
    var = feat.var(dim=[2, 3], keepdim=True) + eps
    std = var.sqrt()
    return mean, std

def adain(content_feat, style_feat, eps=1e-5):
    c_mean, c_std = mean_variance_norm(content_feat, eps)
    s_mean, s_std = mean_variance_norm(style_feat, eps)
    normalized = (content_feat - c_mean) / c_std
    return normalized * s_std + s_mean

# Encoder: VGG19 up to relu4_1
def vgg_encoder():
    vgg = models.vgg19(pretrained=True).features
    # Freeze weights
    for param in vgg.parameters():
        param.requires_grad = False

    # Extract up to relu4_1 (which is index 21: conv4_1 at 21 and relu4_1 at 22)
    encoder = nn.Sequential()
    for i in range(23):
        encoder.add_module(str(i), vgg[i])
    return encoder

# Decoder (mirroring VGG up to relu4_1)
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        # This is one possible decoder architecture (from AdaIN paper)
        self.layers = nn.Sequential(
            # Input is relu4_1 feature space: C=512
            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(512, 512, 3),
            nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(512, 256, 3),
            nn.ReLU(inplace=True),

            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(256, 256, 3),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(256, 256, 3),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(256, 256, 3),
            nn.ReLU(inplace=True),

            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(256, 128, 3),
            nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(128, 128, 3),
            nn.ReLU(inplace=True),

            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(128, 64, 3),
            nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(64, 64, 3),
            nn.ReLU(inplace=True),

            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(64, 3, 3),
        )

    def forward(self, x):
        return self.layers(x)


In [None]:
def mean_variance_norm(feat, eps=1e-5):
    mean = feat.mean(dim=[2, 3], keepdim=True)
    var = feat.var(dim=[2, 3], keepdim=True) + eps
    std = var.sqrt()
    return mean, std

def adain(content_feat, style_feat, eps=1e-5):
    c_mean, c_std = mean_variance_norm(content_feat, eps)
    s_mean, s_std = mean_variance_norm(style_feat, eps)
    normalized = (content_feat - c_mean) / c_std
    return normalized * s_std + s_mean

# Encoder: VGG19 up to relu4_1
def vgg_encoder():
    vgg = models.vgg19(pretrained=True).features
    # Freeze weights
    for param in vgg.parameters():
        param.requires_grad = False

    # Extract up to relu4_1 (which is index 21: conv4_1 at 21 and relu4_1 at 22)
    encoder = nn.Sequential()
    for i in range(23):
        encoder.add_module(str(i), vgg[i])
    return encoder

# Decoder (mirroring VGG up to relu4_1)
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        # This is one possible decoder architecture (from AdaIN paper)
        self.layers = nn.Sequential(
            # Input is relu4_1 feature space: C=512
            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(512, 512, 3),
            nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(512, 256, 3),
            nn.ReLU(inplace=True),

            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(256, 256, 3),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(256, 256, 3),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(256, 256, 3),
            nn.ReLU(inplace=True),

            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(256, 128, 3),
            nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(128, 128, 3),
            nn.ReLU(inplace=True),

            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(128, 64, 3),
            nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(64, 64, 3),
            nn.ReLU(inplace=True),

            nn.ReflectionPad2d((1,1,1,1)),
            nn.Conv2d(64, 3, 3),
        )

    def forward(self, x):
        return self.layers(x)


In [None]:
def extract_features(x, model):
    # We want multiple layers: relu1_1, relu2_1, relu3_1, relu4_1
    # According to VGG19:
    # relu1_1: index 1
    # relu2_1: index 6
    # relu3_1: index 11
    # relu4_1: index 20 (conv4_1) or 21 (relu4_1) depending on indexing
    # Actually for ease, let's just run step by step and store needed layers.
    relu1_1 = model[:2](x)    # up to relu1_1
    relu2_1 = model[:7](x)    # up to relu2_1
    relu3_1 = model[:12](x)   # up to relu3_1
    relu4_1 = model[:23](x)   # up to relu4_1
    return [relu1_1, relu2_1, relu3_1, relu4_1]

def calc_content_loss(out_feat, target_feat):
    return torch.mean((out_feat - target_feat)**2)

def calc_style_loss(out_feats, target_feats):
    # style loss is computed by matching mean and variance at each layer
    loss = 0.0
    for out_f, tgt_f in zip(out_feats, target_feats):
        out_mean, out_std = mean_variance_norm(out_f)
        tgt_mean, tgt_std = mean_variance_norm(tgt_f)
        loss += torch.mean((out_mean - tgt_mean)**2) + torch.mean((out_std - tgt_std)**2)
    return loss

############################################################
# Training Setup
############################################################

decoder = Decoder().to(device)
optimizer = optim.Adam(decoder.parameters(), lr=lr)



In [None]:
for epoch in range(num_epochs):
    decoder.train()
    for i, (content_imgs, style_imgs) in enumerate(dataloader):
        content_imgs = content_imgs.to(device)
        style_imgs = style_imgs.to(device)

        # Extract content and style features
        with torch.no_grad():
            content_feats = encoder(content_imgs)  # relu4_1 features
            style_feats = encoder(style_imgs)

        # Apply AdaIN
        t = adain(content_feats, style_feats)
        # Blend factor alpha if needed, here alpha=1 for training
        t = t

        # Decode
        out = decoder(t)

        # Compute losses
        out_feats = encoder(out)

        # Content loss (compare out_feats relu4_1 to content_feats relu4_1)
        c_loss = calc_content_loss(out_feats, content_feats)

        # Style loss (compare multiple layers)
        out_feats_multi = extract_features(out, encoder)
        style_feats_multi = extract_features(style_imgs, encoder)
        s_loss = calc_style_loss(out_feats_multi, style_feats_multi)

        total_loss = content_weight * c_loss + style_weight * s_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if (i+1) % 50 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}], "
                  f"Content Loss: {c_loss.item():.4f}, Style Loss: {s_loss.item():.4f}, "
                  f"Total Loss: {total_loss.item():.4f}")

# Save trained decoder
torch.save(decoder.state_dict(), save_decoder_path)



In [None]:
def style_transfer(encoder, decoder, content, style, alpha=1.0):
    with torch.no_grad():
        content_feat = encoder(content)
        style_feat = encoder(style)
        t = adain(content_feat, style_feat)
        t = alpha * t + (1 - alpha) * content_feat
        out = decoder(t)
    return out

# Example inference usage:
# Load a content and style image (just as an example)
content_image = Image.open("path_to_a_content_image.jpg").convert("RGB")
style_image = Image.open("path_to_a_style_image.jpg").convert("RGB")

content_tensor = vgg_preprocess(content_image).unsqueeze(0).to(device)
style_tensor = vgg_preprocess(style_image).unsqueeze(0).to(device)

decoder.eval()
# Load trained decoder weights if needed
# decoder.load_state_dict(torch.load(save_decoder_path))

with torch.no_grad():
    output = style_transfer(encoder, decoder, content_tensor, style_tensor, alpha=1.0)

# Convert output to PIL Image for saving
out_img = deprocess(output[0].cpu())
out_img_pil = transforms.ToPILImage()(out_img)
out_img_pil.save("stylized_output.jpg")

print("Style transfer complete! Check stylized_output.jpg.")
