In [None]:
# Import necessary libraries
import torch
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import copy

# Load a pre-trained VGG19 model
vgg = models.vgg19(pretrained=True).features

# Define the device for computation (GPU/CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Send the model to the device
vgg.to(device).eval()

# Define image preprocessing and post-processing transformations
def load_image(img_path, max_size=400, shape=None):
    image = Image.open(img_path)
    
    # Resize the image
    size = max_size if max(image.size) > max_size else max(image.size)
    if shape:
        size = shape
    
    img_transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
    ])
    
    image = img_transform(image)[:3, :, :].unsqueeze(0)
    return image.to(device)

def im_convert(tensor):
    """ Convert a tensor to a valid image (PIL format) """
    image = tensor.cpu().clone().detach()
    image = image.squeeze(0)
    image = transforms.ToPILImage()(image)
    return image

# Load content and style images
content_image = load_image("path_to_content_image.jpg")
style_image = load_image("path_to_style_image.jpg", shape=[content_image.size(2), content_image.size(3)])

# Display the images
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.imshow(im_convert(content_image))
ax1.set_title("Content Image")
ax2.imshow(im_convert(style_image))
ax2.set_title("Style Image")
plt.show()

# Define the content and style loss
class StyleTransferLoss(torch.nn.Module):
    def __init__(self, target_feature):
        super(StyleTransferLoss, self).__init__()
        self.target = target_feature.detach()
        self.loss = None
    
    def forward(self, input):
        self.loss = torch.nn.functional.mse_loss(input, self.target)
        return input

# Function to perform the style transfer
def style_transfer(content_image, style_image, vgg, steps=500, style_weight=1e6, content_weight=1):
    # Clone the content image to preserve it
    target = content_image.clone().requires_grad_(True).to(device)
    optimizer = torch.optim.Adam([target], lr=0.003)

    for step in range(steps):
        target_feature = vgg(target)
        content_loss = torch.nn.functional.mse_loss(target_feature, vgg(content_image))

        style_loss = 0
        for style_layer in vgg(style_image):
            style_loss += torch.nn.functional.mse_loss(style_layer, vgg(target))
        
        total_loss = content_weight * content_loss + style_weight * style_loss
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if step % 100 == 0:
            print(f"Step {step}/{steps}, Loss: {total_loss.item()}")

    return target

# Perform the style transfer
output_image = style_transfer(content_image, style_image, vgg)

# Show the final output image
plt.imshow(im_convert(output_image))
plt.title("Stylized Image")
plt.show()
