In [None]:
import torch
from ultralytics import YOLO
import numpy as np

device: str = "mps" if torch.backends.mps.is_available() else "cpu"

model = YOLO('./models/custom-trained-yolov8-seg.pt')
vgg = model.model.model

In [None]:
for parameters in vgg:
    vgg.requires_grad_(False)

In [None]:
i = 0
model_layers = {}
for name, layer in model._modules.items():
    for name_l, layer_l in layer._modules.items():
        for name_ll, layer_ll in layer_l._modules.items():
            model_layers[str(i)] = layer_ll
            i += 1

In [None]:
device = 'mps'
vgg.to(device)

In [None]:
from PIL import Image
from torchvision import transforms as T

In [None]:
def preprocess(img_path, max_size = 640):

  image = Image.open(img_path).convert('RGB')

  # if max(image.size) > max_size:
  #   size = max_size

  # else:
  #   size = max(image.size)

  img_transforms = T.Compose([
      # T.Resize(size),
      T.ToTensor(),  # (224, 224, 3) -> (3, 224, 224)
      T.Normalize(mean = [0.485, 0.456, 0.406],
                  std = [0.229, 0.224, 0.225])
  ])

  image = img_transforms(image)

  image = image.unsqueeze(0) # (3, 224, 224) -> (1, 3, 224, 224)

  return image

In [None]:
content_p = preprocess('../images/af00bd10-d7ef-4076-90c8-28f2d6ff6aa2___RS_HL 8188.JPG')
style_p = preprocess('../images/4dadb9f1-27b1-4d3c-8111-d1602febd585___JR_FrgE.S 8632.JPG')

content_p = content_p.to(device)
style_p = style_p.to(device)

print("Content shape", content_p.shape)
print("Style shape", style_p.shape)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def deprocess(tensor):

  image = tensor.to('cpu').clone()
  image = image.numpy()
  image = image.squeeze(0)
  image = image.transpose(1, 2, 0)

  # denormalizing the image
  image = image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])

  image = image.clip(0, 1)

  return image

In [None]:
content_d = deprocess(content_p)
style_d = deprocess(style_p)

print("Deprocess content:", content_d.shape)
print("Deprocess style:", style_d.shape)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

ax1.imshow(content_d)
ax2.imshow(style_d)

In [None]:
def get_features(image, model):

  layers = {

            '0' : 'conv1_1',
            '1' : 'conv2_1',
            '2' : 'conv3_1',
            '4' : 'conv4_1',
            '7' : 'conv4_2',
            '8' : 'conv5_1'
  }

  x = image

  Features = {}
  i = 0

  for name in model_layers.keys():

    x = model_layers[name](x)

    if name in layers:
      Features[layers[name]] = x
    i += 1
  
    if (i > 8):
      break
    

  return Features

In [None]:
content_f = get_features(content_p, vgg)
style_f = get_features(style_p, vgg)

In [None]:
def gram_matrix(tensor):

  b, c, h, w = tensor.size()
  tensor = tensor.view(c, h*w)

  gram = torch.mm(tensor, tensor.t())

  return gram

In [None]:
style_grams = { layer : gram_matrix(style_f[layer]) for layer in style_f }

In [None]:
def content_loss(target_conv4_2, content_conv4_2):

  loss = torch.mean((target_conv4_2 - content_conv4_2)**2)
  return loss

In [None]:
style_weights = {

    'conv1_1' : 0.2,
    'conv2_1' : 0.2,
    'conv3_1' : 0.5,
    'conv4_1' : 1.0,
    'conv5_1' : 0.2
}

In [None]:
def style_loss(style_weights, target_features, style_grams):
  loss = 0

  for layer in style_weights:
    target_f = target_features[layer]
    target_gram = gram_matrix(target_f)
    style_gram = style_grams[layer]
    b, c, h, w = target_f.shape
    layer_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
    loss += layer_loss/(c*h*w)

  return loss

In [None]:
target = content_p.clone().requires_grad_(True).to(device)
target_f = get_features(target, vgg)
print("Content Loss: ", content_loss(target_f['conv4_2'], content_f['conv4_2']))
print("Style Loss: ", style_loss(style_weights, target_f, style_grams))

In [None]:
from torch import optim

optimizer = optim.Adam([target], lr = 0.008)

alpha = 1
beta = 1e5

epochs = 5000
show_every = 500

In [None]:
def total_loss(c_loss, s_loss, alpha, beta):
  loss = alpha * c_loss + beta * s_loss
  return loss

In [None]:
results = []
for i in range(epochs):
  target_f = get_features(target, vgg)

  c_loss = content_loss(target_f['conv4_2'], content_f['conv4_2'])
  s_loss = style_loss(style_weights, target_f, style_grams)

  t_loss = total_loss(c_loss, s_loss, alpha, beta)

  optimizer.zero_grad()
  t_loss.backward()
  optimizer.step()

  if i % show_every == 0:
    print("Total loss at epoch {}: {}".format(i, t_loss))
    results.append(deprocess(target.detach()))

In [None]:
target_copy = deprocess(target.detach())
content_copy = deprocess(content_p)
print(target_copy.shape)

# fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (10, 5))
# ax1.imshow(target_copy)
# ax2.imshow(content_copy)

In [None]:
# print(target_copy.type())
target_copy_1 = target_copy.astype(np.float)
print(target_copy.shape)
img = np.dstack((target_copy_1,target_copy_1,target_copy_1))
print(img.shape)
img = target_copy_1[:,:,:3]
print(img.shape)

In [None]:
import matplotlib.pyplot as plt

In [None]:
import cv2
import numpy as np
from PIL import Image
def imread(path):
    img = cv2.imread(path).astype(np.float)
    if len(img.shape) == 2:
        # grayscale
        img = np.dstack((img,img,img))
    elif img.shape[2] == 4:
        # PNG with alpha channel
        img = img[:,:,:3]
    return img

def imsave(path, img):
    img = np.clip(img, 0, 255).astype(np.uint8)
    Image.fromarray(img).save(path, quality=95)

In [None]:
import torch
from ultralytics import YOLO
import numpy as np

device: str = "mps" if torch.backends.mps.is_available() else "cpu"

model = YOLO('./models/custom-trained-yolov8-seg.pt')  # load a pretrained YOLOv8n segmentation model

In [None]:
import cv2
imgcon = cv2.imread('content.png')
imgcon = cv2.cvtColor(imgcon, cv2.COLOR_BGR2RGB)
H, W, _ = imgcon.shape
results = model(imgcon)

In [None]:
i = 0
for result in results:
    for j, mask in enumerate(result.masks.data):
        mask = mask.cpu().numpy() * 255
        mask  =cv2.resize(mask, (W, H))
        cv2.imwrite('./mask.png', mask)
        if i == 0:
            break
        i += 1

In [None]:
imgtar = cv2.imread("target.png")
imgtar = cv2.cvtColor(imgtar, cv2.COLOR_BGR2RGB)
imgtar = imgtar.astype(np.uint8)
imgtar = cv2.resize(imgtar, (W, H))
mask = cv2.imread('./mask.png', cv2.IMREAD_GRAYSCALE)
_, binary_mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
masked_overlay = cv2.bitwise_or(imgtar, imgtar, mask=binary_mask)
masked_overlay = masked_overlay.astype(np.uint8)
inverted_mask = cv2.bitwise_not(binary_mask.astype(np.uint8))
roi = cv2.bitwise_and(imgcon, imgcon, mask=inverted_mask)
result_image = cv2.add(roi, masked_overlay)
imsave('./final_output.png', result_image)