In [None]:
# Cell 1: Install required packages
!pip install mediapipe
!pip install numpy --upgrade
!pip install torchvision --upgrade


# Cell 2: Import necessary libraries
import torch
import torchvision.transforms as transforms
from torchvision.models.detection import maskrcnn_resnet50_fpn
from PIL import Image
import numpy as np
import cv2
import mediapipe as mp
from google.colab import files
import matplotlib.pyplot as plt

# Cell 3: Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Cell 4: Upload and display images
content_file = files.upload()
style_file = files.upload()

content_image = Image.open(list(content_file.keys())[0]).convert('RGB')
style_image = Image.open(list(style_file.keys())[0]).convert('RGB')

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(content_image)
plt.title('Content Image')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(style_image)
plt.title('Style Image')
plt.axis('off')
plt.show()

# Cell 5: Load Mask R-CNN model
mask_rcnn = maskrcnn_resnet50_fpn(pretrained=True).to(device)
mask_rcnn.eval()

# Cell 6: Define segmentation function using Mask R-CNN
def segment_foreground_maskrcnn(image):
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    input_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        prediction = mask_rcnn(input_tensor)[0]

    person_mask = prediction['masks'][prediction['labels'] == 1]

    if len(person_mask) > 0:
        binary_mask = (person_mask[0].squeeze() > 0.5).float() * 255
    else:
        binary_mask = torch.zeros(image.size[::-1])

    return Image.fromarray(binary_mask.cpu().numpy().astype(np.uint8))

# Cell 7: Perform segmentation and display result
foreground_mask = segment_foreground_maskrcnn(content_image)

plt.imshow(foreground_mask, cmap='gray')
plt.title('Segmentation Mask')
plt.axis('off')
plt.show()

# Cell 8: Define preprocessing functions
def preprocess_image(image, size=512, is_mask=False):
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
    ])
    if not is_mask:
        transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    return transform(image).unsqueeze(0)

# Cell 9: Preprocess images and mask
content_tensor = preprocess_image(content_image)
style_tensor = preprocess_image(style_image)
mask_tensor = preprocess_image(foreground_mask, is_mask=True).to(device)
inverted_mask = 1 - mask_tensor

# Cell 10: Load VGG19 model
vgg = vgg19(pretrained=True).features
vgg = vgg.to(device)
vgg.eval()

for param in vgg.parameters():
    param.requires_grad_(False)

# Cell 11: Define helper functions for style transfer
def get_features(image, model, layers):
    features = {}
    x = image
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[name] = x
    return features

content_layers = ['21']
style_layers = ['0', '5', '10', '19', '28']

def gram_matrix(tensor):
    _, c, h, w = tensor.size()
    tensor = tensor.view(c, h * w)
    return torch.mm(tensor, tensor.t())

def content_loss(content_features, target_features):
    return torch.mean((content_features - target_features)**2)

def style_loss(style_features, target_features, weight=1.0):
    gram_style = gram_matrix(style_features)
    gram_target = gram_matrix(target_features)
    return weight * torch.mean((gram_style - gram_target)**2)

def total_variation_loss(image):
    h_tv = torch.mean((image[:,:,1:,:] - image[:,:,:-1,:])**2)
    w_tv = torch.mean((image[:,:,:,1:] - image[:,:,:,:-1])**2)
    return h_tv + w_tv

# Cell 12: Define style transfer function
def style_transfer(content_img, style_img, mask_img, num_steps=700):
    input_img = content_img.clone().requires_grad_(True)
    optimizer = torch.optim.Adam([input_img], lr=0.02)

    style_weights = {'0': 1.0, '5': 0.8, '10': 0.5, '19': 0.3, '28': 0.1}

    content_features = get_features(content_img, vgg, content_layers + list(style_weights.keys()))
    style_features = get_features(style_img, vgg, list(style_weights.keys()))

    for step in range(num_steps):
        features = get_features(input_img, vgg, content_layers + list(style_weights.keys()))

        content_l = content_loss(features[content_layers[0]], content_features[content_layers[0]])
        style_l = 0
        for layer, weight in style_weights.items():
            layer_style_loss = style_loss(features[layer], style_features[layer], weight)
            style_l += layer_style_loss

        tv_l = total_variation_loss(input_img)

        total_loss = content_l + 1e6 * style_l + 1e-5 * tv_l

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if (step + 1) % 50 == 0:
            print(f"Step {step + 1}: Content Loss: {content_l.item():.4f}, Style Loss: {style_l.item():.4f}")

    with torch.no_grad():
        output = input_img * mask_imginverted_mask + content_img * inverted_mask

    return output

# Cell 13: Perform style transfer
output = style_transfer(content_tensor.to(device), style_tensor.to(device), mask_tensor.to(device))

# Cell 14: Define denormalization function and display results
def denormalize_image(tensor):
    tensor = tensor.cpu()
    mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    tensor = tensor * std + mean
    tensor = torch.clamp(tensor, 0, 1)
    tensor = tensor.detach() * 255
    tensor = tensor.numpy().transpose(1, 2, 0).astype(np.uint8)
    return Image.fromarray(tensor)

plt.figure(figsize=(15, 5))
plt.subplot(1, 4, 1)
content_display = denormalize_image(content_tensor.squeeze(0))
plt.imshow(content_display)
plt.title('Content Image')
plt.axis('off')

plt.subplot(1, 4, 2)
style_display = denormalize_image(style_tensor.squeeze(0))
plt.imshow(style_display)
plt.title('Style Image')
plt.axis('off')

plt.subplot(1, 4, 3)
plt.imshow(foreground_mask, cmap='gray')
plt.title('Segmentation Mask')
plt.axis('off')

plt.subplot(1, 4, 4)
output_display = denormalize_image(output.squeeze(0))
plt.imshow(output_display)
plt.title('Output Image')
plt.axis('off')

plt.tight_layout()
plt.show()

# Cell 15: Save the output image
output_display.save('style_transfer_output.png')


Collecting numpy<2 (from mediapipe)
  Using cached numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Using cached numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.3.3
    Uninstalling numpy-2.3.3:
      Successfully uninstalled numpy-2.3.3
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ydf 0.13.0 requires protobuf<7.0.0,>=5.29.1, but you have protobuf 4.25.8 which is incompatible.
opencv-python-headless 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 1.26.4 which is incompatible.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.26.4 which is incompatible.
opencv-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but 

Collecting numpy
  Using cached numpy-2.3.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (62 kB)
Using cached numpy-2.3.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.6 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalling numpy-1.26.4:
      Successfully uninstalled numpy-1.26.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
mediapipe 0.10.21 requires numpy<2, but you have numpy 2.3.3 which is incompatible.
ydf 0.13.0 requires protobuf<7.0.0,>=5.29.1, but you have protobuf 4.25.8 which is incompatible.
opencv-python-headless 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 2.3.3 which is incompatible.
tensorflow 2.19.0 requires numpy<2.2.0,>=1.26.0, but you have numpy 2.3.3 which is incompatible.
cupy-c

