In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
from PIL import Image
from pprint import pprint
import inspect

In [7]:
class FastStyleTransfer:
    def __init__(self, content_image_path, style_image_path, output_image_path, content_weight=1, style_weight=1000,
                 num_iterations=1000, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
        self.content_image_path = content_image_path
        self.style_image_path = style_image_path
        self.output_image_path = output_image_path
        self.content_weight = content_weight
        self.style_weight = style_weight
        self.num_iterations = num_iterations
        self.device = device

        # Load the VGG19 model and set it to evaluation mode
        self.vgg = models.vgg19(pretrained=True).features.to(self.device)
        for param in self.vgg.parameters():
            param.requires_grad_(False)

    def load_image(self, image_path):
        image = Image.open(image_path)
        image = transforms.ToTensor()(image)
        image = image.unsqueeze(0).to(self.device)
        return image

    def save_image(self, image, image_path):
        image = image.cpu().clone()
        image = image.squeeze(0)
        image = transforms.ToPILImage()(image)
        image.save(image_path)

    def gram_matrix(self, input):
        b, c, h, w = input.size(0), input.size(1), input.size(2), input.size(3)  # Update here
        features = input.view(b * c, h * w)
        gram_matrix = torch.mm(features, features.t())
        return gram_matrix


    def compute_content_loss(self, content_features, output_features):
        content_loss = F.mse_loss(output_features, content_features)
        return content_loss

    def compute_style_loss(self, style_features, output_features):
        style_loss = 0.0
        for ft_y, gm_s in zip(output_features, style_features):
            gm_y = self.gram_matrix(ft_y)
            style_loss += F.mse_loss(gm_y, gm_s)
        return style_loss

    def build_model(self, content, style):
        target = content.clone().requires_grad_(True)
        optimizer = optim.LBFGS([target])
        content_features = self.vgg(content)
        style_features = self.vgg(style)
        style_grams = [self.gram_matrix(y) for y in style_features]
        return target, optimizer, content_features, style_features, style_grams

    def run(self):
    # Load content and style images
        content = self.load_image(self.content_image_path)
        style = self.load_image(self.style_image_path)

        # Build the model
        target, optimizer, content_features, style_features, style_grams = self.build_model(content, style)

        # Run the optimization loop
        print("Running style transfer...")
        step = 0
        while step <= self.num_iterations:
            def closure():
                nonlocal step
                optimizer.zero_grad()

                output = self.vgg(target)
                content_loss = self.compute_content_loss(content_features, output)
                style_loss = self.compute_style_loss(style_features, output)

                total_loss = self.content_weight * content_loss + self.style_weight * style_loss
                total_loss.backward()

                step += 1
                if step % 100 == 0:
                    print(f'Step [{step}/{self.num_iterations}], Loss: {total_loss.item()}')

                return total_loss

            optimizer.step(closure)

        # Save the output image
        self.save_image(target, self.output_image_path)
        print(f'Style transfer complete. Output image saved at {self.output_image_path}')

    

In [8]:
content_image_path = '../data/content/adri.jpeg'
style_image_path = '../data/style/rain_princess.jpg'
output_image_path = '../data/content/adri.jpeg'


fast_style_transfer = FastStyleTransfer(content_image_path, style_image_path, output_image_path)
fast_style_transfer.run()



ValueError: not enough values to unpack (expected 4, got 3)