In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/Generative\ Artisan/

In [None]:
!pip install git+https://github.com/openai/CLIP.git

In [None]:
import os
import sys 
import matplotlib.pyplot as plt 

from PIL import Image
import numpy as np
import sys
import torch
import torch.nn
import torch.optim as optim
from torchvision import transforms, models

import StyleNet
import utils
import clip
import torch.nn.functional as F
from template import imagenet_templates

from PIL import Image 
import PIL 
from torchvision import utils as vutils
import argparse
from torchvision.transforms.functional import adjust_contrast
import random
import copy

In [None]:
def img_denormalize(image):
    mean=torch.tensor([0.485, 0.456, 0.406]).to(device)
    std=torch.tensor([0.229, 0.224, 0.225]).to(device)
    mean = mean.view(1,-1,1,1)
    std = std.view(1,-1,1,1)
    image = image * std + mean
    return image

def img_normalize(image):
    mean=torch.tensor([0.485, 0.456, 0.406]).to(device)
    std=torch.tensor([0.229, 0.224, 0.225]).to(device)
    mean = mean.view(1,-1,1,1)
    std = std.view(1,-1,1,1)
    image = (image - mean) / std
    return image

def clip_normalize(image,device):
    image = F.interpolate(image,size=224, mode='bicubic', align_corners=False)
    mean=torch.tensor([0.48145466, 0.4578275, 0.40821073]).to(device)
    std=torch.tensor([0.26862954, 0.26130258, 0.27577711]).to(device)
    mean = mean.view(1,-1,1,1)
    std = std.view(1,-1,1,1)
    image = (image-mean)/std
    return image

def get_image_prior_losses(inputs_jit):
    diff1 = inputs_jit[:, :, :, :-1] - inputs_jit[:, :, :, 1:]
    diff2 = inputs_jit[:, :, :-1, :] - inputs_jit[:, :, 1:, :]
    diff3 = inputs_jit[:, :, 1:, :-1] - inputs_jit[:, :, :-1, 1:]
    diff4 = inputs_jit[:, :, :-1, :-1] - inputs_jit[:, :, 1:, 1:]
    loss_var_l2 = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4)
    return loss_var_l2

In [None]:
def decode_segmap(image, nc=21):
  label_colors = np.array([(0, 0, 0),  # 0=background
               # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
               (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
               # 6=bus, 7=car, 8=cat, 9=chair, 10=cow
               (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
               # 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
               (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
               # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
               (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])
  r = np.zeros_like(image).astype(np.uint8)
  g = np.zeros_like(image).astype(np.uint8)
  b = np.zeros_like(image).astype(np.uint8)
  for l in range(0, nc):
    idx = image == l
    r[idx] = label_colors[l, 0]
    g[idx] = label_colors[l, 1]
    b[idx] = label_colors[l, 2] 
  rgb = np.stack([r, g, b], axis=2)
  return rgb

In [None]:
def compose_text_with_templates(text, templates=imagenet_templates):
    return [template.format(text) for template in templates]

In [None]:
fcn = models.segmentation.fcn_resnet101(pretrained=True).eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fcn = fcn.to(device)

In [None]:
image = utils.load_image2('./test_set/lena.png', img_size=512)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img = img_normalize(image.to(device))
seg = torch.argmax(fcn(img)['out'].squeeze(), dim=0).detach().cpu().numpy()
rgb = decode_segmap(seg)

In [None]:
torch.cuda.empty_cache()

In [None]:
plt.figure(figsize=(15, 15))
plt.imshow(rgb)
plt.show()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
VGG = models.vgg19(pretrained=True).features
VGG.to(device)
for parameter in VGG.parameters():
    parameter.requires_grad_(False)

In [None]:
clip_model, preprocess = clip.load('ViT-B/32', device, jit=False)

In [None]:
from IPython.display import display
from argparse import Namespace

source = "a Photo"
text = "Starry Night by Vincent van gogh"
#text = "The great wave off Wanagawa by Hokusai"
#text = "The scream by edvard munch"

crop_size = 128
image_dir = "./test_set/lena.png"

training_iterations = 400

training_args = {
    "lambda_tv": 2e-3,
    "lambda_patch": 9000,
    "lambda_dir": 500,
    "lambda_c": 150,
    "crop_size": 128,
    "num_crops": 64,
    "img_size": 512,
    "max_step": training_iterations,
    "lr": 5e-4,
    "thresh": 0.7,
    "content_path": image_dir,
    "text": text
}

args = Namespace(**training_args)

In [None]:
content_path = args.content_path
content_image = utils.load_image2(content_path, img_size=args.img_size)
content_image = content_image.to(device)
content_features = utils.get_features(img_normalize(content_image), VGG)
target = content_image.clone().requires_grad_(True).to(device)

In [None]:
style_net = StyleNet.UNet()
style_net.to(device)

content_weight = args.lambda_c
crop_size = args.crop_size

optimizer = optim.Adam(style_net.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
steps = args.max_step

content_loss_epoch = []
style_loss_epoch = []
total_loss_epoch = []

cropper = transforms.RandomCrop(args.crop_size)

augment = transforms.Compose([
    transforms.RandomPerspective(fill=0, p=1,distortion_scale=0.5),
    transforms.Resize(224)
])

source = "a Photo"
prompt = args.text

In [None]:
# change optimize to False to run baseline
# change optimize to True to run optimized version
# run 1024 x 1024, please change peo_num to 0.1

In [None]:
optimize = True # whether use optimized loss
people_scale = 0.3 # penalty of potrait
back_scale = 1.0 # penalty of back
window_width = 0.2 # portion of patch size to determine area
back_thres = 0.7 # thres rejection of potrait
people_thres = 0.7 # thres rejection of back
peo_num = 0.2 # portion of patch in potrait

In [None]:
mask = torch.tensor(np.repeat((seg.reshape(1, 1, 512, 512) == 15), 3, axis=1)).to(device)

In [None]:
with torch.no_grad():

    template_text = compose_text_with_templates(prompt, imagenet_templates)
    tokens = clip.tokenize(template_text).to(device)
    text_features = clip_model.encode_text(tokens).detach()
    text_features = text_features.mean(axis=0, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    
    template_source = compose_text_with_templates(source, imagenet_templates)
    tokens_source = clip.tokenize(template_source).to(device)
    text_source = clip_model.encode_text(tokens_source).detach()
    text_source = text_source.mean(axis=0, keepdim=True)
    text_source /= text_source.norm(dim=-1, keepdim=True)

    source_features1 = clip_model.encode_image(clip_normalize(content_image.masked_fill(mask, 0), device))
    source_features1 /= (source_features1.clone().norm(dim=-1, keepdim=True))
    source_features = clip_model.encode_image(clip_normalize(content_image, device))
    source_features /= (source_features.clone().norm(dim=-1, keepdim=True))

num_crops = args.num_crops
img_size = args.img_size

for epoch in range(0, steps+1):
    
    target = style_net(content_image,use_sigmoid=True).to(device)
    target.requires_grad_(True)
    
    target_features = utils.get_features(img_normalize(target), VGG)
    content_loss = 0

    content_loss += torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)
    content_loss += torch.mean((target_features['conv5_2'] - content_features['conv5_2']) ** 2)

    if optimize:
        back_proc, peo_proc, thres, scales =[], [], [], []
        while (len(back_proc) + len(peo_proc)) != args.num_crops:
            (i, j, h, w) = cropper.get_params(target, (crop_size, crop_size))
            target_crop = transforms.functional.crop(target, i, j, h, w)
            target_crop = augment(target_crop) 
            if 15 in seg[i+h-int(h * window_width):i+h, j:j+w]: # potrait
                if len(peo_proc) < int(args.num_crops * peo_num):
                    peo_proc.append(target_crop)
                    scales.append(people_scale)
                    thres.append(people_thres)
            else: # background
                back_proc.append(target_crop)
                scales.append(back_scale)
                thres.append(back_thres)
        img_proc = back_proc + peo_proc
    else:
        img_proc = []
        for i in range(args.num_crops):
            (i, j, h, w) = cropper.get_params(target, (crop_size, crop_size))
            target_crop = transforms.functional.crop(target, i, j, h, w)
            target_crop = augment(target_crop)
            img_proc.append(target_crop)

    img_proc = torch.cat(img_proc,dim=0)
    img_aug = img_proc

    image_features = clip_model.encode_image(clip_normalize(img_aug,device))
    image_features /= (image_features.clone().norm(dim=-1, keepdim=True))
    
    img_direction = (image_features-source_features)
    img_direction /= img_direction.clone().norm(dim=-1, keepdim=True)
    
    text_direction = (text_features-text_source).repeat(image_features.size(0),1)
    text_direction /= text_direction.norm(dim=-1, keepdim=True)
    loss_temp = (1 - torch.cosine_similarity(img_direction, text_direction, dim=1))
    
    loss_patch = 0.0
    if optimize:
        for index, loss in enumerate(loss_temp):
            if loss >= thres[index]: 
                loss_patch += loss * scales[index]
        loss_patch /= num_crops
    else:
        for index, loss in enumerate(loss_temp):
            if loss >= args.thresh:
                loss_patch += loss
        loss_patch /= num_crops
    
    if optimize:
        glob_features = clip_model.encode_image(clip_normalize(target.masked_fill(mask, 0),device))
        glob_features /= (glob_features.clone().norm(dim=-1, keepdim=True))
        glob_direction = (glob_features-source_features1)
        glob_direction /= glob_direction.clone().norm(dim=-1, keepdim=True)
    else:
        glob_features = clip_model.encode_image(clip_normalize(target,device))
        glob_features /= (glob_features.clone().norm(dim=-1, keepdim=True))
        glob_direction = (glob_features-source_features)
        glob_direction /= glob_direction.clone().norm(dim=-1, keepdim=True)
    
    loss_glob = (1 - torch.cosine_similarity(glob_direction, text_direction, dim=1)).mean()
    
    reg_tv = args.lambda_tv * get_image_prior_losses(target)
    total_loss = args.lambda_patch * loss_patch + content_weight * content_loss + reg_tv + args.lambda_dir * loss_glob
    total_loss_epoch.append(total_loss)

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    scheduler.step()

    if epoch % 20 == 0:
        print("After %d iters:" % epoch)
        print('Total loss: ', total_loss.item())
        print('Content loss: ', content_loss.item())
        print('patch loss: ', loss_patch.item())
        print('dir loss: ', loss_glob.item())
        print('TV loss: ', reg_tv.item())
    
    if epoch % 20 ==0:
        output_image = target.clone()
        output_image = torch.clamp(output_image,0,1)
        output_image = adjust_contrast(output_image,1.5)
        plt.figure(figsize=(15,15))
        plt.imshow(utils.im_convert2(output_image))
        plt.show()
    
        

In [None]:
torch.cuda.empty_cache()