In [1]:

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as tf
from PIL import Image
from typing import Optional
from torch.nn.modules.instancenorm import InstanceNorm2d
import cv2 as cv
from google.colab.patches import cv2_imshow
#from torchsummary import  summary
#!pip install torchviz
#from torchviz import make_dot

In [2]:
#functions and model structure is used from another repo

In [3]:
def per_channel_normalize(x):
    if x.ndim == 3:
        mu = torch.mean(x, dim=[1, 2])
        std = torch.std(x, dim=[1, 2])
        return 0.5 + (x - mu[:, None, None]) / (2 * std[:, None, None])
    elif x.ndim == 4:
        mu = torch.mean(x, dim=[2, 3])
        std = torch.std(x, dim=[2, 3])
        return 0.5 + (x - mu[:, :, None, None]) / (2 * std[:, :, None, None])
    raise Exception("not implemented")

In [4]:
def get_pillow_transform(image_size: Optional[int]):
    if image_size is None:
        transforms = []
    else:
        transforms = [
            tf.Resize(image_size),
            tf.CenterCrop(image_size),
        ]
    transforms.append(tf.ToTensor())
    return tf.Compose(transforms)

In [5]:
def load_image(file_name: Optional[str], image=None) -> torch.Tensor:
    """Load an image so that its shape is (B, C, H, W) and it's normalized to
    the range [0, 1].
    """
    transform = get_pillow_transform(None)
    if file_name is not None:
        assert image is None
        image = Image.open(file_name)
    return transform(image).unsqueeze(0)


In [6]:
"""class residualBlock(nn.Module):
  def __init__(self,inchannels):
    super(residualBlock,self).__init__()
    self.conv1=nn.Conv2d(inchannels,inchannels,3,1,1)
    self.conv2=nn bv 56
"""
"""class style_net(nn.Module):
  def __init__(self,norm=True):
    super(style_net,self).__init__()
    self.downsample= nn.Sequential(
            nn.Conv2d(3, 32, 9, 1,4),nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2,1),nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2,1),nn.ReLU(),
        )
    self.residual=nn.Sequential(*[ residualBlock(128) for i in range(5)])

    self.up=nn.Sequential(
            nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, 1,1),
            nn.Upsample(scale_factor=2),nn.Conv2d(64, 32, 3, 1,1)
        )
    self.conv=nn.Conv2d(32,3,9,1,4)
    self.img_norm=norm
  def forward(self,x):
    if self.img_norm==True:
      x=per_channel_normalize(x)
    y1=self.downsample(x)
    y2=self.residual(y1)
    y3=self.up(y2)
    y=self.conv(y3)

    return torch.tanh(y) * 0.5 + 0.5"""

'class style_net(nn.Module):\n  def __init__(self,norm=True):\n    super(style_net,self).__init__()\n    self.downsample= nn.Sequential(\n            nn.Conv2d(3, 32, 9, 1,4),nn.ReLU(),\n            nn.Conv2d(32, 64, 3, 2,1),nn.ReLU(),\n            nn.Conv2d(64, 128, 3, 2,1),nn.ReLU(),\n        )\n    self.residual=nn.Sequential(*[ residualBlock(128) for i in range(5)])\n\n    self.up=nn.Sequential(\n            nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, 1,1),\n            nn.Upsample(scale_factor=2),nn.Conv2d(64, 32, 3, 1,1)\n        )\n    self.conv=nn.Conv2d(32,3,9,1,4)\n    self.img_norm=norm\n  def forward(self,x):\n    if self.img_norm==True:\n      x=per_channel_normalize(x)\n    y1=self.downsample(x)\n    y2=self.residual(y1)\n    y3=self.up(y2)\n    y=self.conv(y3)\n\n    return torch.tanh(y) * 0.5 + 0.5'

In [7]:
class ConvolutionBlock(nn.Module):
    """This combines convolution with instance normalization."""

    reflection_pad: nn.ReflectionPad2d
    conv: nn.Conv2d
    instance_norm: InstanceNorm2d
    no_norm: bool

    def __init__(self, in_channels, out_channels, kernel_size, stride, no_norm=False):
        super(ConvolutionBlock, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
        self.instance_norm = nn.InstanceNorm2d(out_channels, affine=True)
        self.no_norm = no_norm

    def forward(self, image: torch.Tensor):
        pad = self.reflection_pad(image)
        conv = self.conv(pad)
        if self.no_norm:
            return conv
        return self.instance_norm(conv)

In [8]:
class ResidualBlock(nn.Module):
    """This is a residual block as defined in "Perceptual Losses for Real-Time
    Style Transfer and Super-Resolution: Supplementary Material" by Johnson et
    al. See https://cs.stanford.edu/people/jcjohns/papers/fast-style/fast-style-supp.pdf
    """

    conv1: ConvolutionBlock
    relu: nn.ReLU
    conv2: ConvolutionBlock

    def __init__(self, num_channels: int) -> None:
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvolutionBlock(
            num_channels, num_channels, kernel_size=3, stride=1
        )
        self.relu = nn.ReLU()
        self.conv2 = ConvolutionBlock(
            num_channels, num_channels, kernel_size=3, stride=1
        )

    def forward(self, image: torch.Tensor):
        return self.conv2(self.relu(self.conv1(image))) + image

In [9]:
class UpsampleBlock(nn.Module):
    """This increases resolution by upsampling and then convolving."""

    upsample: nn.Upsample
    conv: ConvolutionBlock

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        scale_factor: float,
    ) -> None:

        super(UpsampleBlock, self).__init__()
        self.upsample = nn.Upsample(scale_factor=scale_factor)
        self.conv = ConvolutionBlock(in_channels, out_channels, kernel_size, 1)

    def forward(self, image: torch.Tensor):
        return self.conv(self.upsample(image))

In [10]:

class StylizationModel(nn.Module):
    """This is the stylization network described here:
    https://cs.stanford.edu/people/jcjohns/papers/fast-style/fast-style-supp.pdf
    """

    def __init__(self, normalize) -> None:
        super(StylizationModel, self).__init__()

        self.down_convolution = nn.Sequential(
            ConvolutionBlock(3, 32, kernel_size=9, stride=1),
            nn.ReLU(),
            ConvolutionBlock(32, 64, kernel_size=3, stride=2),
            nn.ReLU(),
            ConvolutionBlock(64, 128, kernel_size=3, stride=2),
            nn.ReLU(),
        )
        self.residual = nn.Sequential(*[ResidualBlock(128) for _ in range(5)])
        self.up_convolution = nn.Sequential(
            UpsampleBlock(128, 64, 3, 2),
            nn.ReLU(),
            UpsampleBlock(64, 32, 3, 2),
            nn.ReLU(),
            ConvolutionBlock(32, 3, 9, 1),
        )
        self.normalize = normalize

    def forward(self, image: torch.Tensor):
        if self.normalize:
            image = per_channel_normalize(image)
        x = self.down_convolution(image)
        x = self.residual(x)
        x = self.up_convolution(x)
        return torch.tanh(x) * 0.5 + 0.5


In [11]:
resize = tf.Compose(
    [tf.Resize((256, 256)),]
)

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")

In [12]:
device

device(type='cuda')

In [13]:
def processimg(image):
  image_3=np.squeeze(image,axis=0) #removing dimension
  image=np.transpose(image_3,(1,2,0)) # (c,h,w)->(h,w,c)
  image=np.clip(image,0,1)
  con_image=np.uint8(image*255.0)
  con_image=cv.cvtColor(con_image, cv.COLOR_RGB2BGR) # rgb->bgr , since cv2_imshow assumes image is in bgr format
  return con_image

In [14]:
def perimage(content_img_path,style_img_path,model_output):
    #target image(content)
    img=cv.imread(content_img_path)
    img=cv.resize(img,(256,256))
    #source image(style)
    style=cv.imread(style_img_path)
    style=cv.resize(style,(256,256))
    #style applied image
    mix_img=processimg(model_output)
    #creating a stack of image : imagec:style:style_applied_output
    stack=np.zeros((256,768,3))
    stack[:,:256,:]=img
    stack[:,256:512,:]=style
    stack[:,512:,:]=mix_img

    return stack

In [15]:

def test(contentimg_path_list,styleimg_path_list,pretrained_wets_pathlist):
    #loading model and pretrained weights

    assert len(styleimg_path_list)==len(pretrained_wets_pathlist), "for each style image , there must be one pretrained weights"
    model = StylizationModel(True)
    model.eval()

    #model.load_state_dict(torch.load(pretrained_weights_pth))
    #model = model.to(device)

    output_stacks=[] # len(contentimg_path_list)*len(styleimg_path_list)
    for i,weights_path in enumerate(pretrained_wets_pathlist):
      style_img_path=styleimg_path_list[i]

      model.load_state_dict(torch.load(weights_path))
      model = model.to(device)

      for j in contentimg_path_list:

        content_img_path=j
        print(content_img_path)
        #loading image and applying stylisation to it.
        img = load_image(content_img_path)
        image = resize(img)
        output_test = model(image.to(device)).cpu().detach().numpy()

        # now we have output and other images path list

        get=perimage(content_img_path,style_img_path,output_test)

        output_stacks.append(get)

    return output_stacks

In [27]:
pretrained_wets_pathlist=['/content/haring_model_0_09000.pth',
                           '/content/stairs_final_model.pth'
                       ]

styleimg_path_list=[ '/content/haring.jpg',
                      '/content/stairs.jpg'
                    ]

contentimg_path_list=[ '/content/pic10.jpg',
                   '/content/pic11.jpg',
                       '/content/pic5.jpg'
                      ]

In [None]:
stack_img=test(contentimg_path_list,styleimg_path_list,pretrained_wets_pathlist)

In [36]:
for i,j in enumerate(stack_img):
  cv.imwrite(str(i+6)+".jpg",j)

In [None]:
cv2_imshow(stack_img[0])

In [None]:
#so yep we got it that the issue was parameters keys in my custom model were not matching with the one used while training .
#after using original model(basically original parameter keys it worked fine).

In [None]:
#but when i am  displaying the images i am not getting pleasing output
#does pytorch or opencv uses different method to adjust , the color irradiance .
#now color issue is also resolved , PIL was reading image as RGB , but opencv_inshow displays image as bgr ,so need to convert color space before showing the image