In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms

In [14]:
styleImage = Image.open('../data/style/cubism.jpg')
contentImage = Image.open('../data/content/nalax.jpg')

In [12]:
transform = transforms.Compose([
    transforms.Resize((512,512)),  # Resize the images to a suitable size
    transforms.ToTensor(),  # Convert the images to PyTorch tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])  # Normalize the images
])

In [2]:
class GramMatrix(nn.Module):
    def forward(self, input):
        b, c, h, w = input.size()
        F = input.view(b, c, h * w)
        G = torch.bmm(F, F.transpose(1, 2))
        return G.div(h * w)

In [4]:
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        
        vgg = models.vgg19(pretrained=True).features
        self.slice1 = nn.Sequential(*list(vgg.children())[:2])
        self.slice2 = nn.Sequential(*list(vgg.children())[2:7])
        self.slice3 = nn.Sequential(*list(vgg.children())[7:12])
        self.slice4 = nn.Sequential(*list(vgg.children())[12:21])
        
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        loss = nn.MSELoss()
        input_slices = [self.slice1(input), self.slice2(input), self.slice3(input), self.slice4(input)]
        target_slices = [self.slice1(target), self.slice2(target), self.slice3(target), self.slice4(target)]
        losses = [loss(input_slice, target_slice) for input_slice, target_slice in zip(input_slices, target_slices)]
        return sum(losses)

In [5]:
class StyleTransfer(nn.Module):
    def __init__(self, style, content, alpha, beta, lr=0.01):
        super(StyleTransfer, self).__init__()
        
        self.style = style
        self.content = content
        
        self.content_layers = ['slice3']
        self.style_layers = ['slice1', 'slice2', 'slice3', 'slice4']
        self.content_weight = alpha
        self.style_weight = beta
        
        self.loss_network = nn.Sequential()
        self.gram = GramMatrix()
        self.perceptual_loss = PerceptualLoss()
        self.optimizer = optim.Adam(self.parameters(), lr=lr)

        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.loss_network.cuda()
            self.gram.cuda()
            self.perceptual_loss.cuda()
            self.content = self.content.cuda()
            self.style = self.style.cuda()
        else:
            self.loss_network.cpu()
            self.gram.cpu()
            self.perceptual_loss.cpu()
            self.content = self.content.cpu()
            self.style = self.style.cpu()

    def train(self, steps):
        for i in range(steps):
            self.optimizer.zero_grad()

            pastiche = self.content.clone().requires_grad_()
            
            content_features = {}
            style_features = {}
            output_features = {}
            for name, module in self.loss_network._modules.items():
                pastiche = module(pastiche)
                
                if name in self.content_layers:
                    content_features[name] = pastiche
                
                if name in self.style_layers:
                    style_features[name] = self.gram(pastiche)
                    
                output_features[name] = pastiche
            
            content_loss = self.perceptual_loss(content_features['slice3'], self.content)
            style_loss = sum([self.perceptual_loss(self.gram(output_features[name]), style_features[name]) for name in style_features])
            
            total_loss = self.content_weight * content_loss + self.style_weight * style_loss
            total_loss.backward()
            
            self.optimizer.step()
            
        return pastiche

In [13]:
alpha = 1 # Content weight
beta = 1000 # Style weight
num_steps = 2000 # Number of iterations to run the optimization
learning_rate = 0.01

In [17]:
style_tensor = transform(styleImage).unsqueeze(0)
content_tensor = transform(contentImage).unsqueeze(0)

In [18]:
model = StyleTransfer(style_tensor, content_tensor, alpha, beta)



In [19]:
canva_tensor = model.train(num_steps)

KeyError: 'slice3'