In [1]:
import torch.nn as nn
import torch
import torchvision.transforms as transforms
import torchvision.utils as vutils

from PIL import Image
from time import clock

In [4]:
model_loc = "d:/UniversalStyle/models/"
class encoder1(nn.Module):
    def __init__(self):
        super(encoder1,self).__init__()
        vgg1 = torch.load(model_loc+"vgg_normalised_conv1_1.pth")
        self.conv1 = nn.Conv2d(3,3,1,1,0)
        self.conv1.weight = torch.nn.Parameter(vgg1.get('0.weight').float())
        self.conv1.bias = torch.nn.Parameter(vgg1.get('0.bias').float())
        self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
        self.conv2 = nn.Conv2d(3,64,3,1,0)
        self.conv2.weight = torch.nn.Parameter(vgg1.get('2.weight').float())
        self.conv2.bias = torch.nn.Parameter(vgg1.get('2.bias').float())
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,x):
        out = self.conv1(x)
        out = self.reflecPad1(out)
        out = self.conv2(out)
        out = self.relu(out)
        return out
    
class decoder1(nn.Module):
    def __init__(self):
        super(decoder1,self).__init__()
        d = torch.load(model_loc+"feature_invertor_conv1_1.pth")
        self.reflecPad2 = nn.ReflectionPad2d((1,1,1,1))
        self.conv3 = nn.Conv2d(64,3,3,1,0)
        self.conv3.weight = torch.nn.Parameter(d.get('1.weight').float())
        self.conv3.bias = torch.nn.Parameter(d.get('1.bias').float())

    def forward(self,x):
        out = self.reflecPad2(x)
        out = self.conv3(out)
        return out


class encoder2(nn.Module):
    def __init__(self):
        super(encoder2,self).__init__()
        vgg = torch.load(model_loc+"vgg_normalised_conv2_1.pth")
        self.conv1 = nn.Conv2d(3,3,1,1,0)
        self.conv1.weight = torch.nn.Parameter(vgg.get('0.weight').float())
        self.conv1.bias = torch.nn.Parameter(vgg.get('0.bias').float())
        self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
        # 226 x 226

        self.conv2 = nn.Conv2d(3,64,3,1,0)
        self.conv2.weight = torch.nn.Parameter(vgg.get('2.weight').float())
        self.conv2.bias = torch.nn.Parameter(vgg.get('2.bias').float())
        self.relu2 = nn.ReLU(inplace=True)
        # 224 x 224

        self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
        self.conv3 = nn.Conv2d(64,64,3,1,0)
        self.conv3.weight = torch.nn.Parameter(vgg.get('5.weight').float())
        self.conv3.bias = torch.nn.Parameter(vgg.get('5.bias').float())
        self.relu3 = nn.ReLU(inplace=True)
        # 224 x 224

        self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
        # 112 x 112

        self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
        self.conv4 = nn.Conv2d(64,128,3,1,0)
        self.conv4.weight = torch.nn.Parameter(vgg.get('9.weight').float())
        self.conv4.bias = torch.nn.Parameter(vgg.get('9.bias').float())
        self.relu4 = nn.ReLU(inplace=True)
        # 112 x 112

    def forward(self,x):
        out = self.conv1(x)
        out = self.reflecPad1(out)
        out = self.conv2(out)
        out = self.relu2(out)
        out = self.reflecPad3(out)
        out = self.conv3(out)
        pool = self.relu3(out)
        out,pool_idx = self.maxPool(pool)
        out = self.reflecPad4(out)
        out = self.conv4(out)
        out = self.relu4(out)
        return out

class decoder2(nn.Module):
    def __init__(self):
        super(decoder2,self).__init__()
        d = torch.load(model_loc+"feature_invertor_conv2_1.pth")
        self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
        self.conv5 = nn.Conv2d(128,64,3,1,0)
        self.conv5.weight = torch.nn.Parameter(d.get('1.weight').float())
        self.conv5.bias = torch.nn.Parameter(d.get('1.bias').float())
        self.relu5 = nn.ReLU(inplace=True)
        # 112 x 112

        self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
        # 224 x 224

        self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
        self.conv6 = nn.Conv2d(64,64,3,1,0)
        self.conv6.weight = torch.nn.Parameter(d.get('5.weight').float())
        self.conv6.bias = torch.nn.Parameter(d.get('5.bias').float())
        self.relu6 = nn.ReLU(inplace=True)
        # 224 x 224

        self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
        self.conv7 = nn.Conv2d(64,3,3,1,0)
        self.conv7.weight = torch.nn.Parameter(d.get('8.weight').float())
        self.conv7.bias = torch.nn.Parameter(d.get('8.bias').float())

    def forward(self,x):
        out = self.reflecPad5(x)
        out = self.conv5(out)
        out = self.relu5(out)
        out = self.unpool(out)
        out = self.reflecPad6(out)
        out = self.conv6(out)
        out = self.relu6(out)
        out = self.reflecPad7(out)
        out = self.conv7(out)
        return out

class encoder3(nn.Module):
    def __init__(self):
        super(encoder3,self).__init__()
        vgg = torch.load(model_loc+"vgg_normalised_conv3_1.pth")
        self.conv1 = nn.Conv2d(3,3,1,1,0)
        self.conv1.weight = torch.nn.Parameter(vgg.get('0.weight').float())
        self.conv1.bias = torch.nn.Parameter(vgg.get('0.bias').float())
        self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
        # 226 x 226

        self.conv2 = nn.Conv2d(3,64,3,1,0)
        self.conv2.weight = torch.nn.Parameter(vgg.get('2.weight').float())
        self.conv2.bias = torch.nn.Parameter(vgg.get('2.bias').float())
        self.relu2 = nn.ReLU(inplace=True)
        # 224 x 224

        self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
        self.conv3 = nn.Conv2d(64,64,3,1,0)
        self.conv3.weight = torch.nn.Parameter(vgg.get('5.weight').float())
        self.conv3.bias = torch.nn.Parameter(vgg.get('5.bias').float())
        self.relu3 = nn.ReLU(inplace=True)
        # 224 x 224

        self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
        # 112 x 112

        self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
        self.conv4 = nn.Conv2d(64,128,3,1,0)
        self.conv4.weight = torch.nn.Parameter(vgg.get('9.weight').float())
        self.conv4.bias = torch.nn.Parameter(vgg.get('9.bias').float())
        self.relu4 = nn.ReLU(inplace=True)
        # 112 x 112

        self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
        self.conv5 = nn.Conv2d(128,128,3,1,0)
        self.conv5.weight = torch.nn.Parameter(vgg.get('12.weight').float())
        self.conv5.bias = torch.nn.Parameter(vgg.get('12.bias').float())
        self.relu5 = nn.ReLU(inplace=True)
        # 112 x 112

        self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
        # 56 x 56

        self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
        self.conv6 = nn.Conv2d(128,256,3,1,0)
        self.conv6.weight = torch.nn.Parameter(vgg.get('16.weight').float())
        self.conv6.bias = torch.nn.Parameter(vgg.get('16.bias').float())
        self.relu6 = nn.ReLU(inplace=True)
        # 56 x 56
    def forward(self,x):
        out = self.conv1(x)
        out = self.reflecPad1(out)
        out = self.conv2(out)
        out = self.relu2(out)
        out = self.reflecPad3(out)
        out = self.conv3(out)
        pool1 = self.relu3(out)
        out,pool_idx = self.maxPool(pool1)
        out = self.reflecPad4(out)
        out = self.conv4(out)
        out = self.relu4(out)
        out = self.reflecPad5(out)
        out = self.conv5(out)
        pool2 = self.relu5(out)
        out,pool_idx2 = self.maxPool2(pool2)
        out = self.reflecPad6(out)
        out = self.conv6(out)
        out = self.relu6(out)
        return out

class decoder3(nn.Module):
    def __init__(self):
        super(decoder3,self).__init__()
        d = torch.load(model_loc+"feature_invertor_conv3_1.pth")
        self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
        self.conv7 = nn.Conv2d(256,128,3,1,0)
        self.conv7.weight = torch.nn.Parameter(d.get('1.weight').float())
        self.conv7.bias = torch.nn.Parameter(d.get('1.bias').float())
        self.relu7 = nn.ReLU(inplace=True)
        # 56 x 56

        self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
        # 112 x 112

        self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
        self.conv8 = nn.Conv2d(128,128,3,1,0)
        self.conv8.weight = torch.nn.Parameter(d.get('5.weight').float())
        self.conv8.bias = torch.nn.Parameter(d.get('5.bias').float())
        self.relu8 = nn.ReLU(inplace=True)
        # 112 x 112

        self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
        self.conv9 = nn.Conv2d(128,64,3,1,0)
        self.conv9.weight = torch.nn.Parameter(d.get('8.weight').float())
        self.conv9.bias = torch.nn.Parameter(d.get('8.bias').float())
        self.relu9 = nn.ReLU(inplace=True)

        self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
        # 224 x 224

        self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
        self.conv10 = nn.Conv2d(64,64,3,1,0)
        self.conv10.weight = torch.nn.Parameter(d.get('12.weight').float())
        self.conv10.bias = torch.nn.Parameter(d.get('12.bias').float())
        self.relu10 = nn.ReLU(inplace=True)

        self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
        self.conv11 = nn.Conv2d(64,3,3,1,0)
        self.conv11.weight = torch.nn.Parameter(d.get('15.weight').float())
        self.conv11.bias = torch.nn.Parameter(d.get('15.bias').float())

    def forward(self,x):
        out = self.reflecPad7(x)
        out = self.conv7(out)
        out = self.relu7(out)
        out = self.unpool(out)
        out = self.reflecPad8(out)
        out = self.conv8(out)
        out = self.relu8(out)
        out = self.reflecPad9(out)
        out = self.conv9(out)
        out = self.relu9(out)
        out = self.unpool2(out)
        out = self.reflecPad10(out)
        out = self.conv10(out)
        out = self.relu10(out)
        out = self.reflecPad11(out)
        out = self.conv11(out)
        return out

class encoder4(nn.Module):
    def __init__(self):
        super(encoder4,self).__init__()
        vgg = torch.load(model_loc+"vgg_normalised_conv4_1.pth")
        self.conv1 = nn.Conv2d(3,3,1,1,0)
        self.conv1.weight = torch.nn.Parameter(vgg.get('0.weight').float())
        self.conv1.bias = torch.nn.Parameter(vgg.get('0.bias').float())
        self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
        # 226 x 226

        self.conv2 = nn.Conv2d(3,64,3,1,0)
        self.conv2.weight = torch.nn.Parameter(vgg.get('2.weight').float())
        self.conv2.bias = torch.nn.Parameter(vgg.get('2.bias').float())
        self.relu2 = nn.ReLU(inplace=True)
        # 224 x 224

        self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
        self.conv3 = nn.Conv2d(64,64,3,1,0)
        self.conv3.weight = torch.nn.Parameter(vgg.get('5.weight').float())
        self.conv3.bias = torch.nn.Parameter(vgg.get('5.bias').float())
        self.relu3 = nn.ReLU(inplace=True)
        # 224 x 224

        self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
        # 112 x 112

        self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
        self.conv4 = nn.Conv2d(64,128,3,1,0)
        self.conv4.weight = torch.nn.Parameter(vgg.get('9.weight').float())
        self.conv4.bias = torch.nn.Parameter(vgg.get('9.bias').float())
        self.relu4 = nn.ReLU(inplace=True)
        # 112 x 112

        self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
        self.conv5 = nn.Conv2d(128,128,3,1,0)
        self.conv5.weight = torch.nn.Parameter(vgg.get('12.weight').float())
        self.conv5.bias = torch.nn.Parameter(vgg.get('12.bias').float())
        self.relu5 = nn.ReLU(inplace=True)
        # 112 x 112

        self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
        # 56 x 56

        self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
        self.conv6 = nn.Conv2d(128,256,3,1,0)
        self.conv6.weight = torch.nn.Parameter(vgg.get('16.weight').float())
        self.conv6.bias = torch.nn.Parameter(vgg.get('16.bias').float())
        self.relu6 = nn.ReLU(inplace=True)
        # 56 x 56

        self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
        self.conv7 = nn.Conv2d(256,256,3,1,0)
        self.conv7.weight = torch.nn.Parameter(vgg.get('19.weight').float())
        self.conv7.bias = torch.nn.Parameter(vgg.get('19.bias').float())
        self.relu7 = nn.ReLU(inplace=True)
        # 56 x 56

        self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
        self.conv8 = nn.Conv2d(256,256,3,1,0)
        self.conv8.weight = torch.nn.Parameter(vgg.get('22.weight').float())
        self.conv8.bias = torch.nn.Parameter(vgg.get('22.bias').float())
        self.relu8 = nn.ReLU(inplace=True)
        # 56 x 56

        self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
        self.conv9 = nn.Conv2d(256,256,3,1,0)
        self.conv9.weight = torch.nn.Parameter(vgg.get('25.weight').float())
        self.conv9.bias = torch.nn.Parameter(vgg.get('25.bias').float())
        self.relu9 = nn.ReLU(inplace=True)
        # 56 x 56

        self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
        # 28 x 28

        self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
        self.conv10 = nn.Conv2d(256,512,3,1,0)
        self.conv10.weight = torch.nn.Parameter(vgg.get('29.weight').float())
        self.conv10.bias = torch.nn.Parameter(vgg.get('29.bias').float())
        self.relu10 = nn.ReLU(inplace=True)

    def forward(self,x):
        out = self.conv1(x)
        out = self.reflecPad1(out)
        out = self.conv2(out)
        out = self.relu2(out)
        out = self.reflecPad3(out)
        out = self.conv3(out)
        pool1 = self.relu3(out)
        out,pool_idx = self.maxPool(pool1)
        out = self.reflecPad4(out)
        out = self.conv4(out)
        out = self.relu4(out)
        out = self.reflecPad5(out)
        out = self.conv5(out)
        pool2 = self.relu5(out)
        out,pool_idx2 = self.maxPool2(pool2)
        out = self.reflecPad6(out)
        out = self.conv6(out)
        out = self.relu6(out)
        out = self.reflecPad7(out)
        out = self.conv7(out)
        out = self.relu7(out)
        out = self.reflecPad8(out)
        out = self.conv8(out)
        out = self.relu8(out)
        out = self.reflecPad9(out)
        out = self.conv9(out)
        pool3 = self.relu9(out)
        out,pool_idx3 = self.maxPool3(pool3)
        out = self.reflecPad10(out)
        out = self.conv10(out)
        out = self.relu10(out)
        return out

class decoder4(nn.Module):
    def __init__(self):
        super(decoder4,self).__init__()
        d = torch.load(model_loc+"feature_invertor_conv4_1.pth")
        self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
        self.conv11 = nn.Conv2d(512,256,3,1,0)
        self.conv11.weight = torch.nn.Parameter(d.get('1.weight').float())
        self.conv11.bias = torch.nn.Parameter(d.get('1.bias').float())
        self.relu11 = nn.ReLU(inplace=True)
        # 28 x 28

        self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
        # 56 x 56

        self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
        self.conv12 = nn.Conv2d(256,256,3,1,0)
        self.conv12.weight = torch.nn.Parameter(d.get('5.weight').float())
        self.conv12.bias = torch.nn.Parameter(d.get('5.bias').float())
        self.relu12 = nn.ReLU(inplace=True)
        # 56 x 56

        self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
        self.conv13 = nn.Conv2d(256,256,3,1,0)
        self.conv13.weight = torch.nn.Parameter(d.get('8.weight').float())
        self.conv13.bias = torch.nn.Parameter(d.get('8.bias').float())
        self.relu13 = nn.ReLU(inplace=True)
        # 56 x 56

        self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
        self.conv14 = nn.Conv2d(256,256,3,1,0)
        self.conv14.weight = torch.nn.Parameter(d.get('11.weight').float())
        self.conv14.bias = torch.nn.Parameter(d.get('11.bias').float())
        self.relu14 = nn.ReLU(inplace=True)
        # 56 x 56

        self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
        self.conv15 = nn.Conv2d(256,128,3,1,0)
        self.conv15.weight = torch.nn.Parameter(d.get('14.weight').float())
        self.conv15.bias = torch.nn.Parameter(d.get('14.bias').float())
        self.relu15 = nn.ReLU(inplace=True)
        # 56 x 56

        self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
        # 112 x 112

        self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
        self.conv16 = nn.Conv2d(128,128,3,1,0)
        self.conv16.weight = torch.nn.Parameter(d.get('18.weight').float())
        self.conv16.bias = torch.nn.Parameter(d.get('18.bias').float())
        self.relu16 = nn.ReLU(inplace=True)
        # 112 x 112

        self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
        self.conv17 = nn.Conv2d(128,64,3,1,0)
        self.conv17.weight = torch.nn.Parameter(d.get('21.weight').float())
        self.conv17.bias = torch.nn.Parameter(d.get('21.bias').float())
        self.relu17 = nn.ReLU(inplace=True)
        # 112 x 112

        self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
        # 224 x 224

        self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
        self.conv18 = nn.Conv2d(64,64,3,1,0)
        self.conv18.weight = torch.nn.Parameter(d.get('25.weight').float())
        self.conv18.bias = torch.nn.Parameter(d.get('25.bias').float())
        self.relu18 = nn.ReLU(inplace=True)
        # 224 x 224

        self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
        self.conv19 = nn.Conv2d(64,3,3,1,0)
        self.conv19.weight = torch.nn.Parameter(d.get('28.weight').float())
        self.conv19.bias = torch.nn.Parameter(d.get('28.bias').float())



    def forward(self,x):
        # decoder
        out = self.reflecPad11(x)
        out = self.conv11(out)
        out = self.relu11(out)
        out = self.unpool(out)
        out = self.reflecPad12(out)
        out = self.conv12(out)

        out = self.relu12(out)
        out = self.reflecPad13(out)
        out = self.conv13(out)
        out = self.relu13(out)
        out = self.reflecPad14(out)
        out = self.conv14(out)
        out = self.relu14(out)
        out = self.reflecPad15(out)
        out = self.conv15(out)
        out = self.relu15(out)
        out = self.unpool2(out)
        out = self.reflecPad16(out)
        out = self.conv16(out)
        out = self.relu16(out)
        out = self.reflecPad17(out)
        out = self.conv17(out)
        out = self.relu17(out)
        out = self.unpool3(out)
        out = self.reflecPad18(out)
        out = self.conv18(out)
        out = self.relu18(out)
        out = self.reflecPad19(out)
        out = self.conv19(out)
        return out
class encoder5(nn.Module):
    def __init__(self):
        super(encoder5,self).__init__()
        vgg = torch.load(model_loc+"vgg_normalised_conv5_1.pth")
        self.conv1 = nn.Conv2d(3,3,1,1,0)
        self.conv1.weight = torch.nn.Parameter(vgg.get('0.weight').float())
        self.conv1.bias = torch.nn.Parameter(vgg.get('0.bias').float())
        self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
        # 226 x 226

        self.conv2 = nn.Conv2d(3,64,3,1,0)
        self.conv2.weight = torch.nn.Parameter(vgg.get('2.weight').float())
        self.conv2.bias = torch.nn.Parameter(vgg.get('2.bias').float())
        self.relu2 = nn.ReLU(inplace=True)
        # 224 x 224

        self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
        self.conv3 = nn.Conv2d(64,64,3,1,0)
        self.conv3.weight = torch.nn.Parameter(vgg.get('5.weight').float())
        self.conv3.bias = torch.nn.Parameter(vgg.get('5.bias').float())
        self.relu3 = nn.ReLU(inplace=True)
        # 224 x 224

        self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
        # 112 x 112

        self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
        self.conv4 = nn.Conv2d(64,128,3,1,0)
        self.conv4.weight = torch.nn.Parameter(vgg.get('9.weight').float())
        self.conv4.bias = torch.nn.Parameter(vgg.get('9.bias').float())
        self.relu4 = nn.ReLU(inplace=True)
        # 112 x 112

        self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
        self.conv5 = nn.Conv2d(128,128,3,1,0)
        self.conv5.weight = torch.nn.Parameter(vgg.get('12.weight').float())
        self.conv5.bias = torch.nn.Parameter(vgg.get('12.bias').float())
        self.relu5 = nn.ReLU(inplace=True)
        # 112 x 112

        self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
        # 56 x 56

        self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
        self.conv6 = nn.Conv2d(128,256,3,1,0)
        self.conv6.weight = torch.nn.Parameter(vgg.get('16.weight').float())
        self.conv6.bias = torch.nn.Parameter(vgg.get('16.bias').float())
        self.relu6 = nn.ReLU(inplace=True)
        # 56 x 56

        self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
        self.conv7 = nn.Conv2d(256,256,3,1,0)
        self.conv7.weight = torch.nn.Parameter(vgg.get('19.weight').float())
        self.conv7.bias = torch.nn.Parameter(vgg.get('19.bias').float())
        self.relu7 = nn.ReLU(inplace=True)
        # 56 x 56

        self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
        self.conv8 = nn.Conv2d(256,256,3,1,0)
        self.conv8.weight = torch.nn.Parameter(vgg.get('22.weight').float())
        self.conv8.bias = torch.nn.Parameter(vgg.get('22.bias').float())
        self.relu8 = nn.ReLU(inplace=True)
        # 56 x 56

        self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
        self.conv9 = nn.Conv2d(256,256,3,1,0)
        self.conv9.weight = torch.nn.Parameter(vgg.get('25.weight').float())
        self.conv9.bias = torch.nn.Parameter(vgg.get('25.bias').float())
        self.relu9 = nn.ReLU(inplace=True)
        # 56 x 56

        self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
        # 28 x 28

        self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
        self.conv10 = nn.Conv2d(256,512,3,1,0)
        self.conv10.weight = torch.nn.Parameter(vgg.get('29.weight').float())
        self.conv10.bias = torch.nn.Parameter(vgg.get('29.bias').float())
        self.relu10 = nn.ReLU(inplace=True)
        # 28 x 28

        self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
        self.conv11 = nn.Conv2d(512,512,3,1,0)
        self.conv11.weight = torch.nn.Parameter(vgg.get('32.weight').float())
        self.conv11.bias = torch.nn.Parameter(vgg.get('32.bias').float())
        self.relu11 = nn.ReLU(inplace=True)
        # 28 x 28

        self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
        self.conv12 = nn.Conv2d(512,512,3,1,0)
        self.conv12.weight = torch.nn.Parameter(vgg.get('35.weight').float())
        self.conv12.bias = torch.nn.Parameter(vgg.get('35.bias').float())
        self.relu12 = nn.ReLU(inplace=True)
        # 28 x 28

        self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
        self.conv13 = nn.Conv2d(512,512,3,1,0)
        self.conv13.weight = torch.nn.Parameter(vgg.get('38.weight').float())
        self.conv13.bias = torch.nn.Parameter(vgg.get('38.bias').float())
        self.relu13 = nn.ReLU(inplace=True)
        # 28 x 28

        self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
        # 14 x 14

        self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
        self.conv14 = nn.Conv2d(512,512,3,1,0)
        self.conv14.weight = torch.nn.Parameter(vgg.get('42.weight').float())
        self.conv14.bias = torch.nn.Parameter(vgg.get('42.bias').float())
        self.relu14 = nn.ReLU(inplace=True)
        # 14 x 14
    def forward(self,x):
        out = self.conv1(x)
        out = self.reflecPad1(out)
        out = self.conv2(out)
        out = self.relu2(out)
        out = self.reflecPad3(out)
        out = self.conv3(out)
        out = self.relu3(out)
        out,pool_idx = self.maxPool(out)
        out = self.reflecPad4(out)
        out = self.conv4(out)
        out = self.relu4(out)
        out = self.reflecPad5(out)
        out = self.conv5(out)
        out = self.relu5(out)
        out,pool_idx2 = self.maxPool2(out)
        out = self.reflecPad6(out)
        out = self.conv6(out)
        out = self.relu6(out)
        out = self.reflecPad7(out)
        out = self.conv7(out)
        out = self.relu7(out)
        out = self.reflecPad8(out)
        out = self.conv8(out)
        out = self.relu8(out)
        out = self.reflecPad9(out)
        out = self.conv9(out)
        out = self.relu9(out)
        out,pool_idx3 = self.maxPool3(out)
        out = self.reflecPad10(out)
        out = self.conv10(out)
        out = self.relu10(out)
        out = self.reflecPad11(out)
        out = self.conv11(out)
        out = self.relu11(out)
        out = self.reflecPad12(out)
        out = self.conv12(out)
        out = self.relu12(out)
        out = self.reflecPad13(out)
        out = self.conv13(out)
        out = self.relu13(out)
        out,pool_idx4 = self.maxPool4(out)
        out = self.reflecPad14(out)
        out = self.conv14(out)
        out = self.relu14(out)
        return out


class decoder5(nn.Module):
    def __init__(self):
        super(decoder5,self).__init__()
        d = torch.load(model_loc+"feature_invertor_conv5_1.pth")
        self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
        self.conv15 = nn.Conv2d(512,512,3,1,0)
        self.conv15.weight = torch.nn.Parameter(d.get('1.weight').float())
        self.conv15.bias = torch.nn.Parameter(d.get('1.bias').float())
        self.relu15 = nn.ReLU(inplace=True)

        self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
        # 28 x 28

        self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
        self.conv16 = nn.Conv2d(512,512,3,1,0)
        self.conv16.weight = torch.nn.Parameter(d.get('5.weight').float())
        self.conv16.bias = torch.nn.Parameter(d.get('5.bias').float())
        self.relu16 = nn.ReLU(inplace=True)
        # 28 x 28

        self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
        self.conv17 = nn.Conv2d(512,512,3,1,0)
        self.conv17.weight = torch.nn.Parameter(d.get('8.weight').float())
        self.conv17.bias = torch.nn.Parameter(d.get('8.bias').float())
        self.relu17 = nn.ReLU(inplace=True)
        # 28 x 28

        self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
        self.conv18 = nn.Conv2d(512,512,3,1,0)
        self.conv18.weight = torch.nn.Parameter(d.get('11.weight').float())
        self.conv18.bias = torch.nn.Parameter(d.get('11.bias').float())
        self.relu18 = nn.ReLU(inplace=True)
        # 28 x 28

        self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
        self.conv19 = nn.Conv2d(512,256,3,1,0)
        self.conv19.weight = torch.nn.Parameter(d.get('14.weight').float())
        self.conv19.bias = torch.nn.Parameter(d.get('14.bias').float())
        self.relu19 = nn.ReLU(inplace=True)
        # 28 x 28

        self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
        # 56 x 56

        self.reflecPad20 = nn.ReflectionPad2d((1,1,1,1))
        self.conv20 = nn.Conv2d(256,256,3,1,0)
        self.conv20.weight = torch.nn.Parameter(d.get('18.weight').float())
        self.conv20.bias = torch.nn.Parameter(d.get('18.bias').float())
        self.relu20 = nn.ReLU(inplace=True)
        # 56 x 56

        self.reflecPad21 = nn.ReflectionPad2d((1,1,1,1))
        self.conv21 = nn.Conv2d(256,256,3,1,0)
        self.conv21.weight = torch.nn.Parameter(d.get('21.weight').float())
        self.conv21.bias = torch.nn.Parameter(d.get('21.bias').float())
        self.relu21 = nn.ReLU(inplace=True)

        self.reflecPad22 = nn.ReflectionPad2d((1,1,1,1))
        self.conv22 = nn.Conv2d(256,256,3,1,0)
        self.conv22.weight = torch.nn.Parameter(d.get('24.weight').float())
        self.conv22.bias = torch.nn.Parameter(d.get('24.bias').float())
        self.relu22 = nn.ReLU(inplace=True)

        self.reflecPad23 = nn.ReflectionPad2d((1,1,1,1))
        self.conv23 = nn.Conv2d(256,128,3,1,0)
        self.conv23.weight = torch.nn.Parameter(d.get('27.weight').float())
        self.conv23.bias = torch.nn.Parameter(d.get('27.bias').float())
        self.relu23 = nn.ReLU(inplace=True)

        self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
        # 112 X 112

        self.reflecPad24 = nn.ReflectionPad2d((1,1,1,1))
        self.conv24 = nn.Conv2d(128,128,3,1,0)
        self.conv24.weight = torch.nn.Parameter(d.get('31.weight').float())
        self.conv24.bias = torch.nn.Parameter(d.get('31.bias').float())
        self.relu24 = nn.ReLU(inplace=True)

        self.reflecPad25 = nn.ReflectionPad2d((1,1,1,1))
        self.conv25 = nn.Conv2d(128,64,3,1,0)
        self.conv25.weight = torch.nn.Parameter(d.get('34.weight').float())
        self.conv25.bias = torch.nn.Parameter(d.get('34.bias').float())
        self.relu25 = nn.ReLU(inplace=True)

        self.unpool4 = nn.UpsamplingNearest2d(scale_factor=2)

        self.reflecPad26 = nn.ReflectionPad2d((1,1,1,1))
        self.conv26 = nn.Conv2d(64,64,3,1,0)
        self.conv26.weight = torch.nn.Parameter(d.get('38.weight').float())
        self.conv26.bias = torch.nn.Parameter(d.get('38.bias').float())
        self.relu26 = nn.ReLU(inplace=True)

        self.reflecPad27 = nn.ReflectionPad2d((1,1,1,1))
        self.conv27 = nn.Conv2d(64,3,3,1,0)
        self.conv27.weight = torch.nn.Parameter(d.get('41.weight').float())
        self.conv27.bias = torch.nn.Parameter(d.get('41.bias').float())

    def forward(self,x):
        # decoder
        out = self.reflecPad15(x)
        out = self.conv15(out)
        out = self.relu15(out)
        out = self.unpool(out)
        out = self.reflecPad16(out)
        out = self.conv16(out)
        out = self.relu16(out)
        out = self.reflecPad17(out)
        out = self.conv17(out)
        out = self.relu17(out)
        out = self.reflecPad18(out)
        out = self.conv18(out)
        out = self.relu18(out)
        out = self.reflecPad19(out)
        out = self.conv19(out)
        out = self.relu19(out)
        out = self.unpool2(out)
        out = self.reflecPad20(out)
        out = self.conv20(out)
        out = self.relu20(out)
        out = self.reflecPad21(out)
        out = self.conv21(out)
        out = self.relu21(out)
        out = self.reflecPad22(out)
        out = self.conv22(out)
        out = self.relu22(out)
        out = self.reflecPad23(out)
        out = self.conv23(out)
        out = self.relu23(out)
        out = self.unpool3(out)
        out = self.reflecPad24(out)
        out = self.conv24(out)
        out = self.relu24(out)
        out = self.reflecPad25(out)
        out = self.conv25(out)
        out = self.relu25(out)
        out = self.unpool4(out)
        out = self.reflecPad26(out)
        out = self.conv26(out)
        out = self.relu26(out)
        out = self.reflecPad27(out)
        out = self.conv27(out)
        return out

In [5]:
class WCT(nn.Module):
    def __init__(self,):
        super(WCT, self).__init__()
        self.e1 = encoder1()
        self.d1 = decoder1()
        self.e2 = encoder2()
        self.d2 = decoder2()
        self.e3 = encoder3()
        self.d3 = decoder3()
        self.e4 = encoder4()
        self.d4 = decoder4()
        self.e5 = encoder5()
        self.d5 = decoder5()
        
wct = WCT().to('cuda')

In [3]:
def transform(cF,sF,alpha = .6):
    cfSize = cF.size()
    sF = sF.squeeze(0).cpu().double()
    cF = cF.squeeze(0).cpu().double()

    C,W,H = cF.size(0),cF.size(1),cF.size(2)
    C1,W1,H1 = sF.size(0),sF.size(1),sF.size(2)

    cF = cF.view(C,-1)
    sF = sF.view(C,-1)
    
    c_mean = torch.mean(cF,1) # c x (h x w)
    c_mean = c_mean.unsqueeze(1).expand_as(cF)
    cF = cF - c_mean
    
    contentConv = torch.mm(cF,cF.t()).div(cF.size(1)-1) + torch.eye(cF.size(0)).double()
    c_u,c_e,c_v = torch.svd(contentConv,some=False)
    
    k_c = cF.size(0)
    for i in range(cF.size(0)):
        if c_e[i] < 0.00001:
            k_c = i
            break
            
    sFSize = sF.size()
    s_mean = torch.mean(sF,1)
    sF = sF - s_mean.unsqueeze(1).expand_as(sF)
    styleConv = torch.mm(sF,sF.t()).div(sFSize[1]-1)
    s_u,s_e,s_v = torch.svd(styleConv,some=False)

    k_s = sFSize[0]
    for i in range(sFSize[0]):
        if s_e[i] < 0.00001:
            k_s = i
            break
            
    c_d = (c_e[0:k_c]).pow(-0.5)
    cM1 = torch.mm(c_v[:,0:k_c],torch.diag(c_d))
    cM2 = torch.mm(cM1,(c_v[:,0:k_c].t()))
    whiten_cF = torch.mm(cM2,cF)
    
    s_d = (s_e[0:k_s]).pow(0.5)
    targetFeature = torch.mm(torch.mm(torch.mm(s_v[:,0:k_s],torch.diag(s_d)),(s_v[:,0:k_s].t())),whiten_cF)
    targetFeature = targetFeature + s_mean.unsqueeze(1).expand_as(targetFeature)
    
    ccsF = alpha * targetFeature + (1-alpha) * cF
    
    out = ccsF.detach()
    out.resize_(cfSize)
    return out.float().cuda()

In [2]:
Ic = transforms.ToTensor()(Image.open("d:/Images/amber.jpg")).to('cuda').unsqueeze(0)
Is = transforms.ToTensor()(Image.open("d:/Images/mosaic.jpg")).to('cuda').unsqueeze(0)

In [6]:
for i in range(5,0,-1):
    start = clock()
    encoder = getattr(wct,f'e{i}')
    decoder = getattr(wct,f'd{i}')
    cF = encoder(Ic)
    sF = encoder(Is)
    cF = transform(cF,sF)
    Ic = decoder(cF)
    print(cF.size(),Ic.size(),clock()-start)
    vutils.save_image(Ic.data.cpu().float(),f'd:/out{i}.jpg')



torch.Size([1, 512, 67, 67]) torch.Size([1, 3, 1072, 1072]) 1.2457705943525246
torch.Size([1, 512, 134, 134]) torch.Size([1, 3, 1072, 1072]) 0.8940290543015976
torch.Size([1, 256, 268, 268]) torch.Size([1, 3, 1072, 1072]) 0.7731224954659544
torch.Size([1, 128, 536, 536]) torch.Size([1, 3, 1072, 1072]) 1.1859291551087376
torch.Size([1, 64, 1072, 1072]) torch.Size([1, 3, 1072, 1072]) 2.0976500899982478


In [None]:
torch.cuda.memory_allocated(device=None)/(2**20)

In [None]:
type(Ic)

In [None]:
q = w.float()