<a href="https://colab.research.google.com/github/LoPA607/IE643/blob/main/neural_style_transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision pillow tqdm




In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from PIL import Image
import copy
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [None]:
def load_image(path, max_size=512, shape=None):
    image = Image.open(path).convert('RGB')
    size = max(image.size)
    if size > max_size:
        scale = max_size / size
        new_size = (int(image.width * scale), int(image.height * scale))
        image = image.resize(new_size, Image.LANCZOS)
    if shape is not None:
        image = image.resize(shape, Image.LANCZOS)
    transform = transforms.ToTensor()
    tensor = transform(image).unsqueeze(0).to(device)
    return tensor

def save_image(tensor, path):
    tensor = tensor.clone().detach().cpu().squeeze(0)
    tensor = torch.clamp(tensor, 0, 1)
    image = transforms.ToPILImage()(tensor)
    image.save(path)

In [None]:
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)

class Normalization(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = mean.reshape(-1,1,1)
        self.std = std.reshape(-1,1,1)
    def forward(self, img):
        return (img - self.mean) / self.std

def gram_matrix(tensor):
    b, c, h, w = tensor.size()
    features = tensor.view(b * c, h * w)
    G = torch.mm(features, features.t())
    return G.div(b * c * h * w)

class ContentLoss(nn.Module):
    def __init__(self, target):
        super().__init__()
        self.target = target.detach()
        self.loss = 0
    def forward(self, input):
        self.loss = nn.functional.mse_loss(input, self.target)
        return input

class StyleLoss(nn.Module):
    def __init__(self, target_feature):
        super().__init__()
        self.target = gram_matrix(target_feature).detach()
        self.loss = 0
    def forward(self, input):
        G = gram_matrix(input)
        self.loss = nn.functional.mse_loss(G, self.target)
        return input

In [None]:
def get_style_model_and_losses(cnn, norm_mean, norm_std, style_img, content_img,
                               content_layers=['conv_4'],
                               style_layers=['conv_1','conv_2','conv_3','conv_4','conv_5']):
    cnn = copy.deepcopy(cnn)
    normalization = Normalization(norm_mean, norm_std).to(device)
    model = nn.Sequential(normalization)

    content_losses = []
    style_losses = []

    i = 0
    for layer in cnn.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = f'conv_{i}'
        elif isinstance(layer, nn.ReLU):
            name = f'relu_{i}'
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = f'pool_{i}'
        else:
            name = f'layer_{i}'

        model.add_module(name, layer)

        if name in content_layers:
            target = model(content_img).detach()
            print(f"Content layer {name} output shape: {target.shape}")
            content_loss = ContentLoss(target)
            model.add_module("content_loss_" + str(i), content_loss)
            content_losses.append(content_loss)

        if name in style_layers:
            target_feature = model(style_img).detach()
            print(f"Style layer {name} output shape: {target_feature.shape}")
            style_loss = StyleLoss(target_feature)
            model.add_module("style_loss_" + str(i), style_loss)
            style_losses.append(style_loss)

    # trim off unused layers
    for j in range(len(model) - 1, -1, -1):
        if isinstance(model[j], ContentLoss) or isinstance(model[j], StyleLoss):
            break
    model = model[:(j+1)]

    return model, style_losses, content_losses

In [None]:
def run_style_transfer(cnn, norm_mean, norm_std,
                       content_img, style_img, input_img,
                       num_steps=300, style_weight=1e6, content_weight=1):
    print("Building model...")
    model, style_losses, content_losses = get_style_model_and_losses(
        cnn, norm_mean, norm_std, style_img, content_img)

    optimizer = optim.LBFGS([input_img.requires_grad_()])

    print("Optimizing...")
    run = [0]
    pbar = tqdm(total=num_steps)
    while run[0] <= num_steps:
        def closure():
            input_img.data.clamp_(0, 1)
            optimizer.zero_grad()
            model(input_img)
            style_score = sum(sl.loss for sl in style_losses)
            content_score = sum(cl.loss for cl in content_losses)
            loss = style_weight * style_score + content_weight * content_score
            loss.backward()
            run[0] += 1
            pbar.update(1)
            return loss
        optimizer.step(closure)
        if run[0] > num_steps:
            break
    pbar.close()
    input_img.data.clamp_(0, 1)
    return input_img

In [None]:
content_path = "/content/hoovertowernight.jpg"
style_path = "/content/starrynight.jpg"

content_img = load_image(content_path, shape=[512, 512])
style_img = load_image(style_path, shape=[512, 512])

input_img = content_img.clone()

cnn = models.vgg19(pretrained=True).features.to(device).eval()

output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
                            content_img, style_img, input_img,
                            num_steps=300)

save_image(output, "/content/starrynight.jpg")
print("Done! Saved as stylized3.jpg")

Building model...
Style layer conv_1 output shape: torch.Size([1, 64, 512, 512])
Style layer conv_2 output shape: torch.Size([1, 64, 512, 512])
Style layer conv_3 output shape: torch.Size([1, 128, 256, 256])
Content layer conv_4 output shape: torch.Size([1, 128, 256, 256])
Style layer conv_4 output shape: torch.Size([1, 128, 256, 256])
Style layer conv_5 output shape: torch.Size([1, 256, 128, 128])
Optimizing...


320it [00:22, 13.97it/s]

Done! Saved as stylized3.jpg



