In [1]:
import os
import time 
from tqdm import tqdm_notebook as tqdm
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
import random  
import numpy as np

import torch 
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

In [2]:
BATCH_SIZE = 1
INIT_LR = .0002

In [3]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        self.ins0 = nn.InstanceNorm2d(3,affine=True)
        self.rpad = nn.ReflectionPad2d(15)
        self.conv1 = nn.Conv2d(3,32,kernel_size=3,stride=1)
        self.ins1 = nn.InstanceNorm2d(32,affine=True)
        self.conv2 = nn.Conv2d(32,32,kernel_size=3,stride=2)
        self.ins2 = nn.InstanceNorm2d(32,affine=True)
        self.conv3 = nn.Conv2d(32,64,kernel_size=3,stride=2)
        self.ins3 = nn.InstanceNorm2d(64,affine=True)
        self.conv4 = nn.Conv2d(64,128,kernel_size=3,stride=2)
        self.ins4 = nn.InstanceNorm2d(128,affine=True)
        self.conv5 = nn.Conv2d(128,256,kernel_size=3,stride=2)
        self.ins5 = nn.InstanceNorm2d(256,affine=True)
        self.relu = torch.nn.ReLU()
        
    def forward(self,x):
        x = self.ins0(x)
        x = self.rpad(x)
        x = self.relu(self.ins1(self.conv1(x)))
        x = self.relu(self.ins2(self.conv2(x)))
        x = self.relu(self.ins3(self.conv3(x)))
        x = self.relu(self.ins4(self.conv4(x)))
        x = self.relu(self.ins5(self.conv5(x)))
        return x

In [4]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.rpad = nn.ReflectionPad2d(1)
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1)
        self.in2 = nn.InstanceNorm2d(channels, affine=True)
        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.rpad(x)
        out = self.relu(self.in1(self.conv1(out)))
        out = self.in2(self.conv2(self.rpad(out)))
        out = out + residual
        return out
    
class TransforBlock(nn.Module): #different from tf implementation but same as paper
    def __init__(self):
        super(TransforBlock,self).__init__()
        self.rpad = nn.ReflectionPad2d(4)
        self.conv = nn.AvgPool2d(kernel_size=10,stride=1) #nn.Conv2d(3,1,kernel_size=10,stride=1)
    
    def forward(self,x):
        return self.rpad(self.conv(x))

In [5]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()
        self.res1 = ResidualBlock(256)
        self.res2 = ResidualBlock(256)
        self.res3 = ResidualBlock(256)
        self.res4 = ResidualBlock(256)
        self.res5 = ResidualBlock(256)
        self.res6 = ResidualBlock(256)
        self.res7 = ResidualBlock(256)
        self.res8 = ResidualBlock(256)
        self.res9 = ResidualBlock(256)
        
        self.upconv1 = nn.Conv2d(256,256,kernel_size=3,stride=1)
        self.ins1 = nn.InstanceNorm2d(256,affine=True)
        self.upconv2 = nn.Conv2d(256,128,kernel_size=3,stride=1)
        self.ins2 = nn.InstanceNorm2d(128,affine=True)
        self.upconv3 = nn.Conv2d(128,64,kernel_size=3,stride=1)
        self.ins3 = nn.InstanceNorm2d(64,affine=True)
        self.upconv4 = nn.Conv2d(64,32,kernel_size=3,stride=1)
        self.ins4 = nn.InstanceNorm2d(32,affine=True)
        self.rpad = nn.ReflectionPad2d(3)
        self.upconv5 = nn.Conv2d(32,3,kernel_size=7,stride=1)
        
        self.relu = nn.ReLU()
        self.sig = nn.Sigmoid()
        self.zpad = nn.ZeroPad2d(1)
        
    def forward(self,x):
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = self.res4(x)
        x = self.res5(x)
        x = self.res6(x)
        x = self.res7(x)
        x = self.res8(x)
        x = self.res9(x)
        
        x = self.relu(self.ins1(self.upconv1(self.zpad(F.interpolate(x,scale_factor=2)))))
        x = self.relu(self.ins2(self.upconv2(self.zpad(F.interpolate(x,scale_factor=2)))))
        x = self.relu(self.ins3(self.upconv3(self.zpad(F.interpolate(x,scale_factor=2)))))
        x = self.relu(self.ins4(self.upconv4(self.zpad(F.interpolate(x,scale_factor=2)))))
        
        x = self.rpad(x)
        x = self.upconv5(x)
        return x

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.conv0 = nn.Conv2d(3,128,kernel_size=5,stride=2)
        self.in0 = nn.InstanceNorm2d(128)
        self.conv1 = nn.Conv2d(128,128,kernel_size=2,stride=2)
        self.in1 = nn.InstanceNorm2d(128)
        self.conv2 = nn.Conv2d(128,256,kernel_size=5,stride=2)
        self.in2 = nn.InstanceNorm2d(256)
        self.conv3 = nn.Conv2d(256,512,kernel_size=5,stride=2)
        self.in3 = nn.InstanceNorm2d(512)
        self.conv4 = nn.Conv2d(512,512,kernel_size=5,stride=2)
        self.in4 = nn.InstanceNorm2d(512)
        self.conv5 = nn.Conv2d(512,1024,kernel_size=5,stride=2)
        self.in5 = nn.InstanceNorm2d(1024)
        self.conv6 = nn.Conv2d(1024,1024,kernel_size=5,stride=2)
        self.in6 = nn.InstanceNorm2d(1024)
        
        self.conv0_pred = nn.Conv2d(128,1,kernel_size=5,stride=1)
        self.conv1_pred = nn.Conv2d(128,1,kernel_size=10,stride=1)
        self.conv3_pred = nn.Conv2d(512,1,kernel_size=10,stride=1)
        self.conv5_pred = nn.Conv2d(1024,1,kernel_size=6,stride=1)
        self.conv6_pred = nn.Conv2d(1024,1,kernel_size=3,stride=1)
        
    def forward(self,x):
        x = F.leaky_relu(self.in0(self.conv0(x)),negative_slope=.2)
        p0 = self.conv0_pred(x)
        x = F.leaky_relu(self.in1(self.conv1(x)),negative_slope=.2)
        p1 = self.conv1_pred(x)
        x = F.leaky_relu(self.in2(self.conv2(x)),negative_slope=.2)
        x = F.leaky_relu(self.in3(self.conv3(x)),negative_slope=.2)
        p3 = self.conv3_pred(x)
        x = F.leaky_relu(self.in4(self.conv4(x)),negative_slope=.2)
        x = F.leaky_relu(self.in5(self.conv5(x)),negative_slope=.2)
        p5 = self.conv5_pred(x)
        x = F.leaky_relu(self.in6(self.conv6(x)),negative_slope=.2)
        p6 = self.conv6_pred(x)
        return {"scale_0": p0,
                "scale_1": p1,
                "scale_3": p3, 
                "scale_5": p5,
                "scale_6": p6}

In [7]:
transformer = transforms.Compose([transforms.Resize(800, Image.BICUBIC),
                transforms.RandomCrop((768,768)),
                transforms.ToTensor()])                
                #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [8]:
def take(path):
    im = Image.open(path).convert("RGB")
    arr = transformer(im)
    return arr.cuda()

In [9]:
CDIR = "d:/COCO17/val/"
PDIR = "d:/A Style Aware Content Loss/data/monet_water-lilies/"

cpaths = [CDIR+f for f in os.listdir(CDIR)]
pool = ThreadPoolExecutor(max_workers=4)
ppaths = [PDIR+f for f in os.listdir(PDIR)][:100]

In [10]:
win_rate = .8
discr_succ_rate = .8
alpha = .05
encoder = Encoder().cuda()
decoder = Decoder().cuda()
discriminator = Discriminator().cuda()
transformerBlock = TransforBlock().cuda()

sceLoss = nn.BCEWithLogitsLoss()
mseLoss = nn.MSELoss()
absLoss = nn.L1Loss()
Gopt = torch.optim.Adam(list(encoder.parameters())+list(decoder.parameters())+list(transformerBlock.parameters()),INIT_LR)
Dopt = torch.optim.Adam(discriminator.parameters(),INIT_LR)

writer = SummaryWriter(log_dir="d:/Visual/")

In [11]:
Glosses,Dlosses = [.5],[.5]
for i in tqdm(range(50)):
    photo = torch.stack(list(pool.map(take,random.choices(cpaths,k=BATCH_SIZE)))).cuda()
    painting = torch.stack(list(pool.map(take,random.choices(ppaths,k=BATCH_SIZE)))).cuda()
    
    #These Tensors are common for both G & D training
    input_photo_features = encoder(photo)
    output_photo = decoder(input_photo_features)
    output_photo_features = encoder(output_photo)
    output_photo_discr_predictions = discriminator(output_photo) #a dict containing predictions at 6 different scales
    
    scale_weight = {"scale_0": 1.,
                    "scale_1": 1.,
                    "scale_3": 1.,
                    "scale_5": 1.,
                    "scale_6": 1.}
    
    if discr_succ_rate>=.8:
        #-----TRAINING GENERATOR--------------------------
        #G Losses
        output_photo_gener_loss = {key: sceLoss(pred, torch.ones_like(pred)) * scale_weight[key]
                                            for key, pred in output_photo_discr_predictions.items()}
        gener_loss = sum(output_photo_gener_loss.values()) 
        img_loss = mseLoss(transformerBlock(output_photo),transformerBlock(photo))
        feature_loss = absLoss(input_photo_features,output_photo_features)
        
        Gloss = gener_loss + 100*img_loss + 200*feature_loss
        
        Gopt.zero_grad()
        Gloss.backward()
        Gopt.step()
    
    else:
        #-----TRAINING DISCRIMINATOR--------------------------
        #Generate D's output for real painting & fake input photo
        input_painting_discr_predictions = discriminator(painting)
        input_photo_discr_predictions = discriminator(photo) 
        
        #Compute D's loss for inp photo (fake),out photo (fake), inp painting (real)
        
        input_painting_discr_loss = {key: sceLoss(pred,torch.ones_like(pred))*scale_weight[key] #Do we need dict, list enough?
                                              for key, pred in input_painting_discr_predictions.items()}
        input_photo_discr_loss = {key: sceLoss(pred,torch.zeros_like(pred))*scale_weight[key]
                                           for key, pred in input_photo_discr_predictions.items()}
        output_photo_discr_loss = {key: sceLoss(pred, torch.zeros_like(pred)) * scale_weight[key]
                                            for key, pred in output_photo_discr_predictions.items()}
        
        Dloss = sum(input_painting_discr_loss.values()) + sum(input_photo_discr_loss.values()) + \
                    sum(output_photo_discr_loss.values())
        
        Dopt.zero_grad()
        Dloss.backward()
        Dopt.step()
        
    if i%100==0:
        with torch.no_grad():
            photo = torch.stack([take("d:/Images/dancing.jpg"),take("d:/Images/amber.jpg")])
            enc = encoder(photo)
            out = decoder(enc)
            save_image(out[0],f"d:/outs/dancing{i}.jpg",normalize=True)
            save_image(out[1],f"d:/outs/amber{i}.jpg",normalize=True)
    if i%10==8:
        writer.add_scalar("Train/Gloss",Gloss.item()/1000,i//2)
        writer.add_scalar("Train/Dloss",Dloss.item(),i//2)
    
        with torch.no_grad():
            inp_paint_discr_acc=sum((pr.detach().cpu().numpy()>.5).mean() for pr in input_painting_discr_predictions.values())\
                                             /len(input_painting_discr_predictions.values())
            inp_photo_discr_acc=sum((pr.detach().cpu().numpy()<.5).mean() for pr in input_photo_discr_predictions.values())\
                                             /len(input_photo_discr_predictions.values())
            out_photo_discr_acc=sum((pr.detach().cpu().numpy()<.5).mean() for pr in output_photo_discr_predictions.values())\
                                             /len(output_photo_discr_predictions.values())
            discr_succ_rate = (inp_paint_discr_acc+inp_photo_discr_acc+out_photo_discr_acc)/3
            print("Discr succ:",discr_succ_rate)
writer.close()

HBox(children=(IntProgress(value=0, max=50), HTML(value='')))




NameError: name 'Dloss' is not defined

In [None]:
photo = torch.stack(list(pool.map(take,random.choices(cpaths,k=BATCH_SIZE)))).cuda()
painting = torch.stack(list(pool.map(take,random.choices(ppaths,k=BATCH_SIZE)))).cuda()

In [None]:
with torch.no_grad():
    photo = torch.stack([take("d:/Images/amber.jpg")])
    enc = encoder(photo)
    out = (decoder(enc)+1)/2
    save_image(out[0],"d:/1.jpg")

In [None]:
input_photo_features = encoder(photo)
output_photo = decoder(input_photo_features)
output_photo_features = encoder(output_photo)

In [None]:
input_painting_discr_predictions = discriminator(painting)
input_photo_discr_predictions = discriminator(photo)  
output_photo_discr_predictions = discriminator(output_photo)

In [None]:
scale_weight = {"scale_0": 1.,
                        "scale_1": 1.,
                        "scale_3": 1.,
                        "scale_5": 1.,
                        "scale_6": 1.}

input_painting_discr_loss = {key: loss(pred,torch.ones_like(pred))*scale_weight[key]
                                              for key, pred in input_painting_discr_predictions.items()}
input_photo_discr_loss = {key: loss(pred,torch.zeros_like(pred))*scale_weight[key]
                                   for key, pred in input_photo_discr_predictions.items()}
output_photo_discr_loss = {key: loss(pred, torch.zeros_like(pred)) * scale_weight[key]
                                    for key, pred in output_photo_discr_predictions.items()}

In [None]:
disc_loss = sum(input_painting_discr_loss.values()) + sum(input_photo_discr_loss.values()) + \
                    sum(output_photo_discr_loss.values())

In [None]:

disc_loss.backward()

In [None]:
input_photo_discr_predictions.shape,input_painting_discr_predictions.shape

In [None]:
torch.cuda.memory_allocated()/(2**20),torch.cuda.memory_cached()/(2**20)

In [None]:
torch.cuda.empty_cache()

In [None]:
for k,pred in input_painting_discr_predictions.items():
    print(pred.shape,(pred.cpu().detach().numpy()>.5).mean())

In [None]:
with torch.no_grad():
    print((pred.cpu().numpy()>.5).mean())

In [None]:
foo = lambda x:(x.mean(),x.var()**.5,x.min(),x.max())


In [None]:
p = take("d:/Images/amber.jpg")

In [None]:
foo(p),foo(out)

In [None]:
a = np.array(img)
b = transforms.ToTensor()(img)
c = transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])(b)
foo(a),foo(b),foo(c)

In [None]:
img = Image.open("d:/Images/amber.jpg")