In [None]:
# Imports
import torch
from torch import nn
from torchvision import models
#from torchsummary import summary
import torch.nn.functional as F


import pandas
import numpy
from sklearn.model_selection import train_test_split

In [None]:
####################################################### RAE Encoder Definitions #######################################################

# reference to code: https://towardsdev.com/implement-resnet-with-pytorch-a9fb40a77448
# https://github.com/Alvinhech/resnet-autoencoder/blob/cdcaab6c6c9792f76f46190c2b6407a28702f7af/autoencoder1.py#L142

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        self.relu = nn.ReLU()
        self.downsample = downsample

    def forward(self, x):
        output = []shortcut = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

                output.append(out)
out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

      
        output.append(out)  out = self.conv3(out)
        out = self.bn3(out)
        out = self.relu(out)

      
        output.append(out)  if self.downsample is not None:
            shortcut = self.downsample(x)

        out += shortcut
        out = self.relu(out)

      
        output.append(out)  return out
     output

In [None]:
####################################################### RAE Decoder Definitions #######################################################

class DeconvBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, expansion=2, stride=1, upsample=None):
        super(DeconvBottleneck, self).__init__()
        self.expansion = expansion
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        if stride == 1:
            self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                                   stride=stride, bias=False, padding=1)
        else:
            self.conv2 = nn.ConvTranspose2d(out_channels, out_channels,
                                            kernel_size=3,
                                            stride=stride, bias=False,
                                            padding=1,
                                            output_padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
                               kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        self.relu = nn.ReLU()
        self.upsample = upsample

    def forward(self, x):shortcut = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        out = self.relu(out)

        if self.upsample is not None:
            shortcut = self.upsample(x)

        out += shortcut
        out = self.relu(out)

        return out
     out

In [None]:
####################################################### ResNet Encoder Definition #######################################################

class ResNet_encoder(nn.Module):
    def __init__(self, downblock, num_layers):
        super(ResNet_encoder, self).__init__()

        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_downlayer(downblock, 64, num_layers[0])
        self.layer2 = self._make_downlayer(downblock, 128, num_layers[1],
                                           stride=2)
        self.layer3 = self._make_downlayer(downblock, 256, num_layers[2],
                                           stride=2)
        self.layer4 = self._make_downlayer(downblock, 512, num_layers[3],
                                           stride=2)


    def _make_downlayer(self, block, init_channels, num_layer, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != init_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, init_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(init_channels * block.expansion),
            )
        layers = []
        layers.append(block(self.in_channels, init_channels, stride, downsample))
        self.in_channels = init_channels * block.expansion
        for i in range(1, num_layer):
            layers.append(block(self.in_channels, init_channels))

        return nn.Sequential(*layers)


    def forward(self, x):
        output = {}
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        output['re-l0-1'] = x
        x = self.maxpool(x)

        output_layer1 = self.layer1(x)
        x = output_layer1[3]
        output['re-l1-1'] = output_layer1[0]
        output['re-l1-2'] = output_layer1[1]
        output['re-l1-2'] = output_layer1[2]
        output['re-l1-3'] = output_layer1[3]

        output_layer2 = self.layer2(x)
        x = output_layer2[3]
        output['re-l2-1'] = output_layer2[0]
        output['re-l2-2'] = output_layer2[1]
        output['re-l2-2'] = output_layer2[2]
        output['re-l2-3'] = output_layer2[3]

        output_layer3 = self.layer3(x)
        x = output_layer3[3]
        output['re-l3-1'] = output_layer3[0]
        output['re-l3-2'] = output_layer3[1]
        output['re-l3-2'] = output_layer3[2]
        output['re-l3-3'] = output_layer3[3]

        x = self.layer4(x)

        output_layer4 = self.layer4(x)
        x = output_layer4[3]
        output['re-l4-1'] = output_layer4[0]
        output['re-l4-2'] = output_layer4[1]
        output['re-l4-2'] = output_layer4[2]
        output['re-l4-3'] = output_layer4[3]

        return output

In [None]:
####################################################### ResNet Decoder Definition #######################################################

class ResNet_decoder(nn.Module):
    def __init__(self, upblock, num_layers, n_classes):
        super(ResNet_decoder, self).__init__()

        self.in_channels = 64

        self.uplayer1 = self._make_up_block(
            upblock, 512,  num_layers[3], stride=2)
        self.uplayer2 = self._make_up_block(
            upblock, 256, num_layers[2], stride=2)
        self.uplayer3 = self._make_up_block(
            upblock, 128, num_layers[1], stride=2)
        self.uplayer4 = self._make_up_block(
            upblock, 64,  num_layers[0], stride=2)

        upsample = nn.Sequential(
            nn.ConvTranspose2d(self.in_channels,  # 256
                               64,
                               kernel_size=1, stride=2,
                               bias=False, output_padding=1),
            nn.BatchNorm2d(64),
        )
        self.uplayer_top = DeconvBottleneck(
            self.in_channels, 64, 1, 2, upsample)

        self.conv1_1 = nn.ConvTranspose2d(64, n_classes, kernel_size=1, stride=1,
                                          bias=False)

    def _make_up_block(self, block, init_channels, num_layer, stride=1):
        upsample = None
        # expansion = block.expansion
        if stride != 1 or self.in_channels != init_channels * 2:
            upsample = nn.Sequential(
                nn.ConvTranspose2d(self.in_channels, init_channels * 2,
                                   kernel_size=1, stride=stride,
                                   bias=False, output_padding=1),
                nn.BatchNorm2d(init_channels * 2),
            )
        layers = []
        for i in range(1, num_layer):
            layers.append(block(self.in_channels, init_channels, 4))
        layers.append(
            block(self.in_channels, init_channels, 2, stride, upsample))
        self.in_channels = init_channels * 2
        return nn.Sequential(*layers)

    def forward(self, x, image_size):
        x = self.uplayer1(x)
        x = self.uplayer2(x)
        x = self.uplayer3(x)
        x = self.uplayer4(x)

        x = self.conv1_1(x, output_size=image_size)

        return x


In [None]:
### We do not need the forward method  

class ResNet_autoencoder(nn.Module):
    def __init__(self, downblock, upblock, num_layers, n_classes):
        super(ResNet_autoencoder, self).__init__()

        self.encoder = ResNet_encoder(downblock, num_layers)
        self.decoder = ResNet_decoder(upblock, num_layers, n_classes)

        

    def encoder(self, x):
        return self.encoder.forward(x)

    def decoder(self, x, image_size):
        return self.decoder.forward(x, image_size)

    def forward(self, x):
        img = x
        dict_encoder = self.encoder(x)
        # final value of the encoder is stored in entry 're-l4-3'
        x = self.decoder(dict_encoder['re-l4-3'], img.size())

        return x

In [None]:
class CNN(nn.Module):
    def __init__(self, matrixSize=32):
        super(CNN,self).__init__()

        self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(256,128,3,1,1),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(128,matrixSize,3,1,1))

        # 32x8x8
        self.fc = nn.Linear(matrixSize*matrixSize,matrixSize*matrixSize)
        #self.fc = nn.Linear(32*64,256*256)

    def forward(self,x):
        out = self.convs(x)
        # 32x8x8
        b,c,h,w = out.size()
        out = out.view(b,c,-1)
        # 32x64
        out = torch.bmm(out,out.transpose(1,2)).div(h*w)
        # 32x32
        out = out.view(out.size(0),-1)
        return self.fc(out)

class VAE(nn.Module):
    def __init__(self, z_dim):
        super(VAE,self).__init__()

        # 32x8x8
        self.encode = nn.Sequential(nn.Linear(512, 2*z_dim),
                                    )
        self.bn = nn.BatchNorm1d(z_dim)
        self.decode = nn.Sequential(nn.Linear(z_dim, 512),
                                    nn.BatchNorm1d(512),
                                    nn.ReLU(),
                                    nn.Linear(512, 512),
                                    )

    def reparameterize(self, mu, logvar):
        mu = self.bn(mu)
        std = torch.exp(logvar)
        eps = torch.randn_like(std)

        return mu + std

    def forward(self,x):
        # 32x8x8
        b,c,h = x.size()
        x = x.view(b,-1)

        z_q_mu, z_q_logvar = self.encode(x).chunk(2, dim=1)
        # reparameterize
        z_q = self.reparameterize(z_q_mu, z_q_logvar)
        out = self.decode(z_q)
        out = out.view(b,c,h)

        KL = torch.sum(0.5 * (z_q_mu.pow(2) + z_q_logvar.exp().pow(2) - 1) - z_q_logvar)

        return out, KL

class MulLayer(nn.Module):
    def __init__(self, z_dim, matrixSize=32):
        super(MulLayer,self).__init__()
        # self.snet = CNN_VAE(layer, z_dim, matrixSize)
        self.snet = CNN(matrixSize)
        self.cnet = CNN(matrixSize)
        self.VAE = VAE(z_dim=z_dim)
        self.matrixSize = matrixSize

        self.compress = nn.Conv2d(512,matrixSize,1,1,0)
        self.unzip = nn.Conv2d(matrixSize,512,1,1,0)

        self.transmatrix = None

    def forward(self,cF,sF,trans=True):
        cb,cc,ch,cw = cF.size()
        cFF = cF.view(cb,cc,-1)
        cMean = torch.mean(cFF,dim=2,keepdim=True)
        cMean = cMean.unsqueeze(3)
        cMean = cMean.expand_as(cF)
        cF = cF - cMean

        sb,sc,sh,sw = sF.size()
        sFF = sF.view(sb,sc,-1)
        sMean = torch.mean(sFF,dim=2,keepdim=True)
        sMean, KL = self.VAE(sMean)
        sMean = sMean.unsqueeze(3)
        sMeanC = sMean.expand_as(cF)
        sMeanS = sMean.expand_as(sF)
        sF = sF - sMeanS


        compress_content = self.compress(cF)
        b,c,h,w = compress_content.size()
        compress_content = compress_content.view(b,c,-1)

        if(trans):
            cMatrix = self.cnet(cF)
            sMatrix = self.snet(sF)

            sMatrix = sMatrix.view(sMatrix.size(0),self.matrixSize,self.matrixSize)
            cMatrix = cMatrix.view(cMatrix.size(0),self.matrixSize,self.matrixSize)
            transmatrix = torch.bmm(sMatrix,cMatrix)
            transfeature = torch.bmm(transmatrix,compress_content).view(b,c,h,w)
            out = self.unzip(transfeature.view(b,c,h,w))
            out = out + sMeanC
            return out, transmatrix, KL
        else:
            out = self.unzip(compress_content.view(b,c,h,w))
            out = out + cMean
            return out

In [None]:
####################################################################################### Training Loop #######################################################################################


In [None]:
# Create the model - The array is the number of blocks -  and the 3 is (we think) RGB 
model = ResNet_autoencoder(Bottleneck, DeconvBottleneck, [3, 4, 6, 3], 3).cuda()


latent_dim = 33 # dummy value
# chose the latent space
matrix = MulLayer(z_dim=latent_dim)

# Load Dataset

# Load a pretrained ResNet 50 
pretrained_dict = torch.load('./resnet50-19c8e357.pth')
print("load pretrained model success")


# Create model dictionary
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

# Set all parameters to untrainable
for param in model.parameters():
    param.requires_grad = False

################# LOSS & OPTIMIZER #################
# Possible entries 
criterion = LossCriterion(opt.style_layers, opt.content_layers, opt.style_weight, opt.content_weight)
optimizer = optim.Adam(matrix.parameters(), opt.lr)


def train(epoch):
    epoch_loss = 0
    
    

    for iteration, batch in enumerate(training_data_loader, 1):
        content, target, style = Variable(batch[0]), Variable(batch[1]), Variable(batch[2])
        content = content.cuda()
        target = target.cuda()
        style = style.cuda()

        optimizer.zero_grad()

        # forward
        ### To access encoder and decoder intermediate values 
        ## x is the input 
        sF = model.encoder(style)
        cF = model.encoder(style)

        # possible values in the dictionary of the encoder per layer: 
        # first relu after the first conolution: 're-l0-1'
        # layer 1: 're-l1-1' -> 're-l1-2' -> 're-l1-3' -> 're-l1-4'
        # layer 2: 're-l2-1' -> 're-l2-2' -> 're-l2-3' -> 're-l2-4'
        # layer 3: 're-l3-1' -> 're-l3-2' -> 're-l3-3' -> 're-l3-4'
        # layer 4: 're-l4-1' -> 're-l4-2' -> 're-l4-3' -> 're-l4-4'
        # this needs to be carefully chosen 
        cF_intermediate = cF['re-l1-1']
        sF_intermediate = sF['re-l1-1', 're-l1-2']
        feature, transmatrix, KL = matrix(cF_intermediate, sF_intermediate)


        transfer = model.decoder(feature)


        ## Need to find a suitable loss network 
        #sF_loss = vgg5(style)
        #cF_loss = vgg5(content)
        #tF = vgg5(transfer)
        #loss, styleLoss, contentLoss, KL_loss = criterion(tF, sF_loss, cF_loss, KL)

         # backward & optimization
        #loss.backward()
        #optimizer.step()


        #print("===> Epoch[{}]({}/{}): loss: {:.4f} || content: {:.4f} || style: {:.4f} KL: {:.4f}.".format(epoch, iteration, len(training_data_loader), loss, contentLoss, styleLoss, KL_loss,))

    #print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader)))

    #return content, style, transfer


