In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms as T
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt 


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

imsize = 256
beta = 1e5

style_layers_names = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']
style_weights = {'conv1_1': 1.0, 'conv2_1': 0.75, 'conv3_1': 0.2, 'conv4_1': 0.2, 'conv5_1': 0.2}

layer_name_to_index = {
    'conv1_1': '0', 'conv2_1': '5', 'conv3_1': '10', 'conv4_1': '19', 'conv4_2': '21', 'conv5_1': '28'
}

style_layers_indices = {layer_name_to_index[name] for name in style_layers_names}

layers_for_inference = {idx: name for name, idx in layer_name_to_index.items() if idx in style_layers_indices}



model = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.to(device).eval()
for param in model.parameters():
    param.requires_grad_(False) 
    
try:
    GRAMS_FILE_PATH = 'style_target_grams.pt' # Adjust this path
    loaded_target_grams = torch.load(GRAMS_FILE_PATH, map_location=device)
    print(f"Style target grams loaded successfully from {GRAMS_FILE_PATH}.")
except FileNotFoundError:
    print(f"Error: {GRAMS_FILE_PATH} not found. Please ensure it's in the correct path.")
   
    raise SystemExit(f"Required file {GRAMS_FILE_PATH} not found.")
except Exception as e:
    print(f"Error loading style target grams: {e}")
    raise SystemExit(f"Error loading style target grams: {e}")



def image_loader(image: Image.Image, size=256, device=torch.device("cpu")):
    # VGG19 mean and std
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    loader = T.Compose([
        T.Resize(size),
        T.CenterCrop(size), 
        T.ToTensor(),
        normalize,
    ])

    image = image.convert('RGB') 
    image = loader(image).unsqueeze(0) # Add batch dimension
    return image.to(device, torch.float)

def im_convert(tensor):
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze(0) 
    image = image.transpose(1, 2, 0) 


    image = np.clip(image, -2.5, 2.5) 
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))

    image = image.clip(0, 1)
    return image

def gram_matrix(tensor):
    b, c, h, w = tensor.size()
    features = tensor.view(c, h * w) # Reshape features: (c, h*w)
    gram = features.mm(features.t()) # Calculate gram matrix: features * features^T
    return gram.div(c * h * w) # Normalize

def get_features(image, model, layers):
    # Extracts features from specified layers of the model.
    features = {}
    x = image
    i = 0
    for module in model.children():
        name = str(i)
        x = module(x)
        if name in layers:
            features[layers[name]] = x
        i += 1
    return features



def stylize_image(content_image: Image.Image):
    print("Starting style transfer inference...")

    try:
        # 1. Load and preprocess the new content image
        new_content_img = image_loader(content_image, size=imsize, device=device)

        # 2. Initialize the generated image (clone of content)
        generated_img = new_content_img.clone().requires_grad_(True).to(device)

        # 3. Setup optimizer for the generated image
        lr = 0.002
        optimizer = optim.Adam([generated_img], lr=lr)

        # 4. Run optimization loop
        inference_steps = 100 # Number of optimization steps for inference
        print(f"Running {inference_steps} optimization steps...")

        for step in range(1, inference_steps + 1):
            # Get features for the generated image
            generated_features = get_features(generated_img, model, layers=layers_for_inference)

            # Calculate style loss
            current_style_loss = torch.tensor(0.0, device=device)
            for layer_name in style_layers_names:
                target_gram = loaded_target_grams[layer_name].to(device)
                input_feature = generated_features[layer_name]
                input_gram = gram_matrix(input_feature)
                loss = nn.functional.mse_loss(input_gram, target_gram)
                current_style_loss = current_style_loss + style_weights[layer_name] * loss

            # Total loss (only style loss in inference mode)
            total_loss = beta * current_style_loss

            # Optimization step
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            if step % 20 == 0: # Print every 20 steps
                 print(f"Step {step}/{inference_steps}, Loss: {total_loss.item():.4f}")


        print("Inference finished.")

        # 5. Convert the final tensor to a displayable image format
        stylized_np_img = im_convert(generated_img)

        return stylized_np_img

    except Exception as e:
        print(f"An error occurred during style transfer: {e}")
        return None


if __name__ == "__main__":
    CONTENT_IMAGE_PATH = 'content.jpg' # Change to content image path

    if not os.path.exists(CONTENT_IMAGE_PATH):
        print(f"Error: Content image not found at {CONTENT_IMAGE_PATH}")
        print("Please update CONTENT_IMAGE_PATH to point to a valid image file.")
    else:
        try:
            content_image = Image.open(CONTENT_IMAGE_PATH)
            print(f"Content image loaded successfully from {CONTENT_IMAGE_PATH}")

            # --- Run the style transfer ---
            stylized_image_np = stylize_image(content_image)

            # --- Display the result ---
            if stylized_image_np is not None:
                print("Displaying the stylized image:")
                plt.imshow(stylized_image_np)
                plt.axis('off') # Hide axes
                plt.title('Stylized Image')
                plt.show()
            else:
                print("Style transfer failed.")

        except Exception as e:
            print(f"An error occurred during image loading or display: {e}")