In [2]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("duttadebadri/image-classification")
path1 = kagglehub.dataset_download("steubk/wikiart") #this is 31.4GB, BIG PROBLEM
#print("Path to dataset files:", path, path1)
print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/image-classification


In [3]:
import os
import shutil
import random
from pathlib import Path
from PIL import Image
import numpy as np

WHEN WE ADD PATH1

In [6]:
def split_dataset(content_path, style_path, max_total_images=10000, content_ratio=0.8, output_dir="dataset_splits"):
    content_path = Path(content_path)
    style_path = Path(style_path)

    # Collect image files
    content_images = list(content_path.rglob("*.jpg")) + list(content_path.rglob("*.png"))
    style_images = list(style_path.rglob("*.jpg")) + list(style_path.rglob("*.png"))

    # Shuffle
    random.shuffle(content_images)
    random.shuffle(style_images)

    # Sample according to the max_total_images and ratio
    content_count = int(max_total_images * content_ratio)
    style_count = max_total_images - content_count
    content_images = content_images[:content_count]
    style_images = style_images[:style_count]

    # Combine and shuffle again
    all_images = content_images + style_images
    random.shuffle(all_images)

    # Split: 70% train, 20% val, 10% test
    train_split = int(len(all_images) * 0.7)
    val_split = int(len(all_images) * 0.2)

    train_images = all_images[:train_split]
    val_images = all_images[train_split:train_split + val_split]
    test_images = all_images[train_split + val_split:]
    return train_images, val_images, test_images



In [12]:
def copy_files(images, split_name, output_dir):
    split_dir = Path(output_dir) / split_name
    split_dir.mkdir(parents=True, exist_ok=True)
    for img_path in images:
        dest = split_dir / img_path.name
        shutil.copy(img_path, dest)

train_images, val_images, test_images = split_dataset(path, path1)

copy_files(train_images, "train", "dataset_splits")
copy_files(val_images, "val", "dataset_splits")
copy_files(test_images, "test", "dataset_splits")

print(f"Dataset split completed: {len(train_images)} train, {len(val_images)} val, {len(test_images)} test.")


print(f"Dataset split completed: {len(train_images)} train, {len(val_images)} val, {len(test_images)} test.")

# Use the function on the downloaded paths
split_dataset(path, path1)

In [11]:
def preprocess_image(img_path, image_size=(256, 256)):
    img = Image.open(img_path).convert("RGB")
    img = img.resize(image_size)
    img_arr = np.array(img) / 255.0
    return img_arr.astype(np.float32)

In [None]:
def preprocess_split_dataset(split_dir="dataset_splits", image_size=(256, 256), output_dir="processed_dataset"):
    split_dir = Path(split_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    for split in ['train', 'val', 'test']:
        split_path = split_dir / split
        images = list(split_path.glob("*.jpg")) + list(split_path.glob("*.png"))
        random.shuffle(images)

        half = len(images) // 2
        content_imgs = images[:half]
        style_imgs = images[half:]

        triplets = []
        for content_img in content_imgs:
            style_img = random.choice(style_imgs)
            try:
                c_img = preprocess_image(content_img, image_size)
                s_img = preprocess_image(style_img, image_size)
                threshold = np.random.uniform(0.2, 1.0)
                triplets.append((c_img, s_img, threshold))
            except Exception as e:
                print(f"[{split}] Skipped image due to error: {e}")

        split_out = output_dir / f"{split}_triplets.npz"
        np.savez_compressed(split_out, data=triplets)
        print(f"[{split}] Saved {len(triplets)} triplets to {split_out}")


In [None]:
# Then preprocess the split images into triplets
preprocess_split_dataset()


Adding Model

In [15]:
import torch
import torch.nn as nn
from torchvision import models

class VGGFeatureExtractor(nn.Module):
    def __init__(self, layers=None):
        super(VGGFeatureExtractor, self).__init__()
        if layers is None:
            # Layers to extract: relu1_1, relu2_1, relu3_1, relu4_1, relu5_1
            self.layers = ['0', '5', '10', '19', '28']

        vgg = models.vgg19(pretrained=True).features.eval()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg
        self.selected_layers = self.layers

    def forward(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.selected_layers:
                features.append(x)
        return features


In [16]:
def gram_matrix(tensor):
    b, c, h, w = tensor.size()
    tensor = tensor.view(b, c, h * w)
    gram = torch.bmm(tensor, tensor.transpose(1, 2))
    return gram / (c * h * w)


In [17]:
class StyleTransferModel(nn.Module):
    def __init__(self, content_weight=1e5, style_weight=1e10):
        super(StyleTransferModel, self).__init__()
        self.vgg = VGGFeatureExtractor()
        self.content_weight = content_weight
        self.style_weight = style_weight

    def forward(self, content_img, style_img, generated_img):
        content_features = self.vgg(content_img)
        style_features = self.vgg(style_img)
        gen_features = self.vgg(generated_img)

        # Compute content loss (e.g., relu4_1)
        content_loss = torch.mean((gen_features[3] - content_features[3]) ** 2)

        # Compute style loss (all selected layers)
        style_loss = 0
        for gf, sf in zip(gen_features, style_features):
            gram_g = gram_matrix(gf)
            gram_s = gram_matrix(sf)
            style_loss += torch.mean((gram_g - gram_s) ** 2)

        total_loss = self.content_weight * content_loss + self.style_weight * style_loss
        return total_loss


In [22]:
def closure(model, optimizer, content_img, style_img, generated_img):
    optimizer.zero_grad()
    loss = model(content_img, style_img, generated_img)
    loss.backward()
    return loss

def stylize_image(content_img, style_img, model, num_steps=300, style_threshold=1.0):
    generated = content_img.clone().requires_grad_(True)
    optimizer = torch.optim.LBFGS([generated])

    for step in range(num_steps):
        optimizer.step(lambda: closure(model, optimizer, content_img, style_img, generated))

    # Apply style threshold blending
    with torch.no_grad():
        final_img = style_threshold * generated + (1 - style_threshold) * content_img

    return final_img.detach()


In [21]:
#Want to avoid functions inside functions
# def stylize_image(content_img, style_img, model, num_steps=300, style_threshold=1.0):
#     generated = content_img.clone().requires_grad_(True)
#     optimizer = torch.optim.LBFGS([generated])

#     for step in range(num_steps):
#         def closure():
#             optimizer.zero_grad()
#             loss = model(content_img, style_img, generated)
#             loss.backward()
#             return loss
#         optimizer.step(closure)

#     return generated.detach()


In [23]:
#Dataset loader
class StyleTransferDataset(torch.utils.data.Dataset):
    def __init__(self, npz_path):
        data = np.load(npz_path, allow_pickle=True)['data']
        self.triplets = data.tolist()

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        c, s, threshold = self.triplets[idx]
        return (
            torch.tensor(c).permute(2, 0, 1),   # (3,H,W)
            torch.tensor(s).permute(2, 0, 1),
            torch.tensor(threshold, dtype=torch.float32)
        )


In [24]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, upsample=False):
        super().__init__()
        if upsample:
            self.block = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                nn.InstanceNorm2d(out_channels, affine=True),
                nn.ReLU(inplace=True)
            )
        else:
            self.block = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                nn.InstanceNorm2d(out_channels, affine=True),
                nn.ReLU(inplace=True)
            )

    def forward(self, x):
        return self.block(x)


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels, affine=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels, affine=True)
        )

    def forward(self, x):
        return x + self.block(x)


In [None]:
import torch
import torch.nn as nn

class TransformerNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Initial convolution layers
        self.down = nn.Sequential(
            ConvBlock(3, 32, 9, 1, 4),    # 256 -> 256
            ConvBlock(32, 64, 3, 2, 1),   # 256 -> 128
            ConvBlock(64, 128, 3, 2, 1),  # 128 -> 64
        )
        # 5 Residual blocks
        self.residuals = nn.Sequential(
            *[ResidualBlock(128) for _ in range(5)]
        )
        # Upsampling
        self.up = nn.Sequential(
            ConvBlock(128, 64, 3, 1, 1, upsample=True),  # 64 -> 128
            ConvBlock(64, 32, 3, 1, 1, upsample=True),   # 128 -> 256
            nn.Conv2d(32, 3, 9, 1, 4),                   # final conv
            nn.Tanh()  # optional, for keeping outputs in range [-1, 1]
        )

    def forward(self, x, style_threshold):
        y = self.down(x)
        y = self.residuals(y)
        y = self.up(y)
        return style_threshold * y + (1 - style_threshold) * x


In [None]:
class NSTLoss(nn.Module):
    def __init__(self, content_weight=1e5, style_weight=1e10):
        super().__init__()
        self.vgg = VGGFeatureExtractor()
        self.content_weight = content_weight
        self.style_weight = style_weight

    def forward(self, content_img, style_img, stylized_img):
        c_feats = self.vgg(content_img)
        s_feats = self.vgg(style_img)
        g_feats = self.vgg(stylized_img)

        content_loss = torch.mean((g_feats[3] - c_feats[3]) ** 2)
        style_loss = 0
        for gf, sf in zip(g_feats, s_feats):
            gram_g = gram_matrix(gf)
            gram_s = gram_matrix(sf)
            style_loss += torch.mean((gram_g - gram_s) ** 2)

        return self.content_weight * content_loss + self.style_weight * style_loss


In [None]:
def train(model, dataloader, optimizer, loss_fn, device, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for i, (content, style, threshold) in enumerate(dataloader):
            content = content.to(device)
            style = style.to(device)
            threshold = threshold.to(device).view(-1, 1, 1, 1)

            optimizer.zero_grad()
            generated = model(content, threshold)
            loss = loss_fn(content, style, generated)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(dataloader)
        print(f"[Epoch {epoch+1}/{num_epochs}] Train Loss: {avg_train_loss:.4f}")

        # Evaluate on validation set every epoch
        val_loss = evaluate(model, val_loader, loss_fn, device)
        print(f"[Epoch {epoch+1}] Validation Loss: {val_loss:.4f}")

In [None]:
def evaluate(model, dataloader, loss_fn, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for content, style, threshold in dataloader:
            content = content.to(device)
            style = style.to(device)
            threshold = threshold.to(device).view(-1, 1, 1, 1)

            output = model(content, threshold)
            loss = loss_fn(content, style, output)
            total_loss += loss.item()

    return total_loss / len(dataloader)