This code efficiently applies the artistic style of starry night by van gogh to a random selection of 100 images from the CIFAR-10 dataset, leveraging the capabilities of the VGG19 model for feature extraction and optimization. The resulting styled images are saved for further analysis or use.

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import os
from PIL import Image
import torchvision.models as models
from torchvision.utils import save_image
from tqdm import tqdm
import torch.optim as optim

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

cifar10_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

subset_100_indices = torch.randperm(len(cifar10_dataset))[:100]  # Randomly select 100 images
subset_9900_indices = torch.randperm(len(cifar10_dataset))[100:]

subset_100 = Subset(cifar10_dataset, subset_100_indices)
subset_9900 = Subset(cifar10_dataset, subset_9900_indices)

subset_100_loader = DataLoader(subset_100, batch_size=1, shuffle=False)
subset_9900_loader = DataLoader(subset_9900, batch_size=1, shuffle=False)

vgg = models.vgg19(pretrained=True).features
for param in vgg.parameters():
    param.requires_grad = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)


style_image_path = "./van.jpg"  # starry night by van gogh styles introduced
style_image = Image.open(style_image_path).convert("RGB")
style_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
style_image = style_transform(style_image).unsqueeze(0).to(device)

def gram_matrix(tensor):
    _, d, h, w = tensor.size()
    tensor = tensor.view(d, h * w)
    gram = torch.mm(tensor, tensor.t())
    return gram


def style_transfer(content_image, style_image, vgg, content_weight=1e5, style_weight=1e10, num_steps=300):
    content_image = content_image.clone().requires_grad_(True).to(device)
    optimizer = optim.Adam([content_image], lr=0.003)

    for step in range(num_steps):
        optimizer.zero_grad()

        content_features = vgg(content_image)
        style_features = vgg(style_image)
        content_loss = torch.mean((content_features - style_features)**2)

        style_loss = torch.mean((gram_matrix(content_features) - gram_matrix(style_features))**2)

        loss = content_weight * content_loss + style_weight * style_loss

        loss.backward()
        optimizer.step()

        if step % 50 == 0:
            print(f"Step [{step}/{num_steps}], Loss: {loss.item()}")

    return content_image

output_dir = './cifar10_styled_100/'
os.makedirs(output_dir, exist_ok=True)

for i, (content_image, label) in enumerate(tqdm(subset_100_loader, desc="Applying Style Transfer")):

    content_image = content_image.to(device)
    styled_image = style_transfer(content_image, style_image, vgg)
    save_image(styled_image, os.path.join(output_dir, f"styled_image_{label.item()}_{i}.png"))

print("Style transfer completed on 100 images.")
