In [None]:
# !pip install torch torchvision diffusers transformers matplotlib pillow

In [None]:
import torch
from torch import nn, optim
from torch.autograd import Variable
from diffusers import StableDiffusionPipeline
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
import os

# Load and preprocess images
def loadImage(img_path, max_size=400, shape=None):
    image = Image.open(img_path).convert('RGB')

    if max(image.size) > max_size:
        size = max_size
    else:
        size = max(image.size)

    if shape is not None:
        size = shape

    if isinstance(size, tuple):
        in_transform = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))])
    else:
        in_transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))])

    image = in_transform(image)[:3, :, :].unsqueeze(0)

    return image

# Save a tensor as an image file to a specified path
def saveTensorAsImage(tensor, path):
    image = tensorToImage(tensor)
    Image.fromarray((image * 255).astype('uint8')).save(path)

# Convert a tensor to an image
def tensorToImage(tensor):
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1, 2, 0)
    image = image * (0.229, 0.224, 0.225) + (0.485, 0.456, 0.406)
    image = image.clip(0, 1)

    return image

# Define content and style losses
def getFeatures(image, model, layers=None):
    if layers is None:
        layers = {'0': 'conv1_1',
                  '5': 'conv2_1',
                  '10': 'conv3_1',
                  '19': 'conv4_1',
                  '21': 'conv4_2',  ## content representation
                  '28': 'conv5_1'}

    features = {}
    x = image
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x

    return features

# Calculate the Gram feature matrix
def gramMatrix(tensor):
    _, d, h, w = tensor.size()
    tensor = tensor.view(d, h * w)
    gram = torch.mm(tensor, tensor.t())
    return gram

# Display images
def showImages(*images, titles=None):
    num_images = len(images)
    fig, axes = plt.subplots(1, num_images, figsize=(15, 5))
    for i, img in enumerate(images):
        if num_images == 1:
            ax = axes
        else:
            ax = axes[i]
        ax.imshow(img)
        if titles and len(titles) == num_images:
            ax.set_title(titles[i])
        ax.axis('off')
    plt.show()

# Generate an image using Stable Diffusion
def generateImage(prompt, save_path=None):
    model_id = "CompVis/stable-diffusion-v1-4"
    pipe = StableDiffusionPipeline.from_pretrained(model_id)
    pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
    image = pipe(prompt).images[0]

    if save_path:
        image.save(save_path)

    return image

# Style transfer from the generated image to the content image
def styleTransfer(content_img, style_img, device, generated_image, content_image_path):
    # Define the model for style transfer
    vgg = models.vgg19(pretrained=True).features

    for param in vgg.parameters():
        param.requires_grad_(False)

    vgg.to(device)

    # Get content and style features
    content_features = getFeatures(content_img, vgg)
    style_features = getFeatures(style_img, vgg)

    # Calculate gram matrices for style features
    style_grams = {layer: gramMatrix(style_features[layer]) for layer in style_features}

    # Create a target image
    target = content_img.clone().requires_grad_(True).to(device)

    # Define weights for style and content
    style_weights = {'conv1_1': 1.0,
                    'conv2_1': 0.75,
                    'conv3_1': 0.2,
                    'conv4_1': 0.2,
                    'conv5_1': 0.2}

    # Define the weights for content and style
    content_weight = 1e0  # alpha
    style_weight = 1e8  # beta

    # Define the optimizer and the learning rate
    optimizer = optim.Adam([target], lr=0.003)
    steps = 4000  # number of iterations to update your image

    for ii in range(1, steps+1):
        targetFeatures = getFeatures(target, vgg)
        content_loss = torch.mean((targetFeatures['conv4_2'] - content_features['conv4_2'])**2)

        style_loss = 0
        for layer in style_weights:
            target_feature = targetFeatures[layer]
            target_gram = gramMatrix(target_feature)
            style_gram = style_grams[layer]
            layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
            b, d, h, w = target_feature.shape
            style_loss += layer_style_loss / (d * h * w)

        total_loss = content_weight * content_loss + style_weight * style_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # Display the generated image every 100 steps
        if  ii % 100 == 0:
            print('Total loss: ', total_loss.item())
            plt.imshow(tensorToImage(target))
            plt.show()

    # Save and display the final styled image
    saveTensorAsImage(target, "styled_image.png")
    styled_image = Image.open("styled_image.png")
    showImages(generated_image, Image.open(content_image_path), styled_image, titles=["Generated Image", "Original Image", "Styled Image"])


def main(prompt, content_image_path, save_generated_image_path="generated_image.png"):
    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Check if the generated image already exists
    if not os.path.exists(save_generated_image_path):
      # Generate an image using Stable Diffusion
      generated_image = generateImage(prompt, save_path=save_generated_image_path)
      print(f"Saved generated image to {save_generated_image_path}")
    else:
      generated_image = Image.open(save_generated_image_path)

    # Load content and style images
    content = loadImage(content_image_path).to(device)
    style = loadImage(save_generated_image_path, shape=content.shape[-2:]).to(device)

    # Display the content and style images
    styleTransfer(content, style, device, generated_image, content_image_path)


# Define the prompt and content image path
prompt = "Line sketch of a business person"
content_image_path = "profile.jpg"
main(prompt, content_image_path)