In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from PIL import Image
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from tqdm import tqdm


In [4]:
class ContentLoss(nn.Module):
    def __init__(self, target):
        super(ContentLoss, self).__init__()
        self.target = target.detach()

    def forward(self, input):
        self.loss = nn.functional.mse_loss(input, self.target)
        return input

class StyleLoss(nn.Module):
    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        self.target = self.gram_matrix(target_feature).detach()

    def forward(self, input):
        G = self.gram_matrix(input)
        self.loss = nn.functional.mse_loss(G, self.target)
        return input

    def gram_matrix(self, input):
        batch_size, channels, height, width = input.size()
        features = input.view(batch_size * channels, height * width)
        G = torch.mm(features, features.t())
        return G.div(batch_size * channels * height * width)



In [5]:
def load_img(path, imsize):
    img = Image.open(path).convert('RGB')
    loader = transforms.Compose([
        transforms.Resize(imsize),
        transforms.CenterCrop(imsize),  # Center crop to ensure exact size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img = loader(img).unsqueeze(0)
    return img.to(device, torch.float)

# Device and image size setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
img_size = 512 if torch.cuda.is_available() else 128

# Load and preprocess images
content_img = load_img('/content/lionn.jpg', img_size)
style_img = load_img('/content/firee.jpg', img_size)
target_img = content_img.clone().requires_grad_(True)



In [6]:
# VGG19 model and normalization
cnn = models.vgg19(pretrained=True).features.to(device).eval()
normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)

class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        self.mean = mean.clone().detach().view(-1, 1, 1)
        self.std = std.clone().detach().view(-1, 1, 1)

    def forward(self, img):
        return (img - self.mean) / self.std

# Build model with content and style losses
model = nn.Sequential(Normalization(normalization_mean, normalization_std).to(device))
content_losses = []
style_losses = []
i = 0

for layer in cnn.children():
    if isinstance(layer, nn.Conv2d):
        i += 1
        name = 'conv_{}'.format(i)
    elif isinstance(layer, nn.ReLU):
        name = 'relu_{}'.format(i)
        layer = nn.ReLU(inplace=False)
    elif isinstance(layer, nn.MaxPool2d):
        name = 'pool_{}'.format(i)
    elif isinstance(layer, nn.BatchNorm2d):
        name = 'bn_{}'.format(i)
    else:
        raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

    model.add_module(name, layer)

    if name in ['conv_4']:  # Content layers
        target = model(content_img).detach()
        content_loss = ContentLoss(target)
        model.add_module("content_loss_{}".format(i), content_loss)
        content_losses.append(content_loss)

    if name in ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']:  # Style layers
        target_feature = model(style_img).detach()
        style_loss = StyleLoss(target_feature)
        model.add_module("style_loss_{}".format(i), style_loss)
        style_losses.append(style_loss)

# Remove layers after the last content and style loss
for i in range(len(model) - 1, -1, -1):
    if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
        break

model = model[:i+1]

# Optimizer
optimizer = optim.Adam([target_img], lr=0.001)

# Training parameters
epochs = 2
alpha = 1  # Content weight
beta = 10000  # Style weight



In [8]:
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    print('-' * 10)

    for step in tqdm(range(100)):
        target_feature = model(target_img)
        content_feature = model(content_img)
        style_feature = model(style_img)

        style_loss = 0
        content_loss = 0

        for sl, target, content, style in zip(style_losses, target_feature, content_feature, style_feature): # iterate over the StyleLoss modules
            content_loss += nn.functional.mse_loss(target, content)
            style_loss += sl.loss # use the precomputed loss from the forward pass

        # for target, content, style in zip(target_feature, content_feature, style_feature):
        #     content_loss += nn.functional.mse_loss(target, content)
        #     style_loss += nn.functional.mse_loss(StyleLoss.gram_matrix(None, target), StyleLoss.gram_matrix(None, style)) # this line caused the error

        total_loss = alpha * content_loss + beta * style_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()



    # Save intermediate result after each epoch
    save_image(target_img, f'result_epoch_{epoch + 1}.png')

# Save final result
save_image(target_img, 'final_result.png')


Epoch 1/1
----------


100%|██████████| 100/100 [01:06<00:00,  1.50it/s]
