In [None]:
!pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git
!pip install optax
!git clone https://github.com/MarcoForte/closed-form-matting.git
import os
os.chdir('closed-form-matting')
!pip3 install .
os.chdir('..')
!pip install jaxopt

In [30]:
from PIL import Image
import jax
import jax.numpy as jnp
import flaxmodels as fm
import matplotlib.pyplot as plt
from jax import jit, random, grad
import numpy as np
import optax
from functools import partial
from tqdm import trange
from jax.example_libraries import optimizers
from closed_form_matting import compute_laplacian
from jax.experimental import sparse
import jaxopt

In [31]:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
!wget https://pytorch.org/tutorials/_static/img/neural-style/picasso.jpg
!wget https://pytorch.org/tutorials/_static/img/neural-style/dancing.jpg

In [None]:
# Load image
img_con = Image.open('dancing.jpg').resize((256,256))
display(img_con)

In [None]:
img_sty = Image.open('picasso.jpg').resize((256,256))
display(img_sty)

In [36]:
# Image should be in range [0, 1]
image_content = jnp.array(img_con, dtype=jnp.float32) / 255.0

# Add batch dimension
img_content = jnp.expand_dims(image_content, axis=0)


In [37]:
image_style = jnp.array(img_sty, dtype=jnp.float32) / 255

img_style = jnp.expand_dims(image_style, axis=0)


In [38]:
class StyleTransfer:
  def __init__(self, input_content, input_style, content_layers, style_layers):
    # Original style and content input
    self.origin_content = input_content
    self.origin_style = input_style

    # Initialize Pretrained model
    self.vgg19 = fm.VGG19(output='activations', pretrained='imagenet', include_head=False) #flaxmodel on github
    self.init_rngs = {'params': jax.random.PRNGKey(0)}#, 'dropout': jax.random.PRNGKey(1)}
    self.vggparams = self.vgg19.init(self.init_rngs, img_content)
    self.fn_out = jit(self.vgg19.apply)

    # Get VGG activation of original style and content input
    self.activation_style_origin = self.fn_out(self.vggparams, self.origin_style, train=False)
    self.activation_content_origin = self.fn_out(self.vggparams, self.origin_content, train=False)

    # Initialize style layer and content layer
    self.layer_style = style_layers
    self.layer_content = content_layers

    #Initialize style and content weights
    self.style_weight = 1000000
    self.content_weight = 1
    self.photo_weight = 100
    self.tv_weight = 0.01

    # Initialize Optimizer
    self.lr = 1e-2
    self.optimizer = optax.adam(learning_rate = self.lr) #optax google it 

    # Initialize generated image
    self.generate_img = self.origin_content.copy()
    self.opt_state = self.optimizer.init(self.generate_img)

    # Matting Laplacian of content image
    self.mat_laplacian = sparse.BCOO.from_scipy_sparse(compute_laplacian(self.origin_content[0]))


  @partial(jit, static_argnums=(0,))
  def gram_matrix(self, input):
    input = jnp.transpose(input, axes=(0, 3, 1, 2)) #N C H W/ N H W C
    a, b, c, d = input.shape #a=1, batchsize, b=3 number of feature maps, (c,d) size of feature map
    features = input.reshape(a * b, c * d)
    G = jnp.matmul(features, features.T)
    return G / (a * b * c * d)

  @partial(jit, static_argnums=(0,))
  def content_loss(self, input_content, img_generated):
    return jnp.mean((input_content.flatten()-img_generated.flatten()) ** 2)
  
  @partial(jit, static_argnums=(0,))
  def style_loss(self, input_style, img_generated):
    return jnp.mean((input_style - img_generated) ** 2)
  
  @partial(jit, static_argnums=(0,))
  def photo_regularization(self, img_generated):
    _, h, w, c = img_generated.shape
    V_c = img_generated.copy().reshape((h*w, c))
    regularization = jnp.trace(V_c.T @ self.mat_laplacian @ V_c)
    return regularization

  @partial(jit, static_argnums=(0,))
  def tv_loss(self, img_generated):
    tv_h = jnp.mean(jnp.abs(img_generated[:, :, 1:, :]-img_generated[:, :, :-1, :]))
    tv_w = jnp.mean(jnp.abs(img_generated[:, :, :, 1:]-img_generated[:, :, :, :-1]))
    return tv_h+tv_w

  @partial(jit, static_argnums=(0,))
  def loss(self, img_generated):
    out_generated = self.fn_out(self.vggparams, img_generated, train=False)
    
    style_score = 0
    content_score = 0
    photo_score = 0
    tv_score = 0

    for cont_layer in self.layer_content:
      content_score += self.content_loss(self.activation_content_origin[cont_layer], out_generated[cont_layer])
    
    for sty_layer in self.layer_style:
      gram_sty = self.gram_matrix(self.activation_style_origin[sty_layer])
      gram_gen = self.gram_matrix(out_generated[sty_layer])
      style_score += self.style_loss(gram_sty, gram_gen)
    
    photo_score = self.photo_regularization(img_generated)

    tv_score = self.tv_loss(img_generated)
    
    loss = self.style_weight * style_score + self.content_weight * content_score\
            + self.photo_weight * photo_score + self.tv_weight * tv_score
    return loss
  
  @partial(jit, static_argnums=(0,))
  def step(self, optimizer_state, img_generated):
    grads = grad(self.loss)(img_generated)
    updates, opt_state = self.optimizer.update(grads, optimizer_state, img_generated)
    return optax.apply_updates(img_generated, updates), opt_state

  def train(self, iter = 8000):
    for iter in trange(iter):
      self.generate_img, self.opt_state = self.step(self.opt_state, self.generate_img)
      self.generate_img = jnp.clip(self.generate_img, 0, 1)
    
    return self.generate_img




In [39]:
content_layers_default = ['conv4_2']
style_layers_default = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']
sty_trans = StyleTransfer(img_content, img_style, content_layers_default, style_layers_default)

In [None]:
output = sty_trans.train()

In [None]:
img_output = np.array(output[0]*255).astype('uint8')
img_out = Image.fromarray(img_output)
display(img_out)
Image.Image.save(img_out, fp='Right1.jpg')