In [1]:
from fastai.vision import *
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch import optim
# from torchsummary import summary
import gc
from fastai.callback import *
from fastai.utils.mem import *

In [2]:
from fastai.callbacks.hooks import hook_outputs
import numpy as np
import loader
import fastai

In [3]:
# pre-trained VGG16 (with  batch norm) for feature loss
from torchvision.models import vgg16_bn

In [4]:
# use the ranger optimizer
# https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d
# code: https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
from ranger import Ranger

In [5]:
bs=8 #batch size
size=128 # image size (size x size)

# image paths
path_out = 'data/JackieChan2A'
path_inp = 'data/A2JackieChan'

In [6]:
# use gpu if possible
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

# Original dataloading (doesn't work)

In [7]:
'''# create image 2 image databunch
# code from: ???
src = ImageImageList.from_folder(path_inp).split_by_rand_pct(0.05, seed=42)
def get_data(bs,size):
    data = (src.label_from_func(lambda x: path_out+'/'+x.name)
           .transform(get_transforms(), size=size, tfm_y=True)
           .databunch(bs=bs).normalize(imagenet_stats, do_y=True))

    data.c = 3
    return data
data = get_data(bs,size)
data.show_batch(ds_type=DatasetType.Valid, rows=2, figsize=(9,9))
data'''

"# create image 2 image databunch\n# code from: ???\nsrc = ImageImageList.from_folder(path_inp).split_by_rand_pct(0.05, seed=42)\ndef get_data(bs,size):\n    data = (src.label_from_func(lambda x: path_out+'/'+x.name)\n           .transform(get_transforms(), size=size, tfm_y=True)\n           .databunch(bs=bs).normalize(imagenet_stats, do_y=True))\n\n    data.c = 3\n    return data\ndata = get_data(bs,size)\ndata.show_batch(ds_type=DatasetType.Valid, rows=2, figsize=(9,9))\ndata"

Looks like some errors made it into the data set. The first target image isn't Jackie Chan, but his son (I think).

# Loading data/dataloaders (MD/PS work)

In [8]:
class Img_Dataset(Dataset):
    def __init__(self, data_set, patch_size, width, height, seed=1234):
        """
        Parameters:
        -----------
        data: np.ndarray
            Array that contains image/label pairs ie. corrupted image/clean image.
            Shape = (P, N, C, H, W):
                P = corrupted/uncorrupted image pair 
                N = number of samples
                C = number of channels
                H = image height
                W = image width
        patch_size: int
            Size of randomly chosen image patch the model uses for training
        width: int
            Width of the chosen sample.
            NOTE: It's a parameter because you can input a larger image and choose
                  to look at only portions of said image for more training samples.
        height: int
            Height of the chosen sample.
        seed: int 
            Randomized seed used for the random slicing used to create the image patch.
        """
        self.data_set = data_set
        self.patch_size = patch_size
        self.width = width
        self.height = height
        self.seed = seed

    def __len__(self):
        return len(self.data_set[0])

    def __getitem__(self, idx):
        """
        Function that returns the PyTorch Dataloader compatible dataset.
        
        Parameters:
        -----------
        idx: var
            Variable used in PyTorch Dataloader to be able to sample from the dataset
            to create minibatches of the data for us automatically.
        """
        # Loading the dataset and then slicing the image/label pairs 
        # ie. corrupted/uncorrupted images. 
        # Note the use of the idx in the image/label variables. This allows the
        # PyTorch Dataloader to get all the important data info eg. (N, C, H, W)
        data = self.data_set
        image = data[0, idx]
        label = data[1, idx]
        
        # Setting the patch size and the randomized seed for the image patch
        patch_size = self.patch_size
        seed = self.seed
        rng = np.random.RandomState(seed)

        img_width = self.width
        img_height = self.height
        
        #randomly crop patch from training set
        x1 = rng.randint(img_width - patch_size)
        y1 = rng.randint(img_height - patch_size)
        S = (slice(y1, y1 + patch_size), slice(x1, x1 + patch_size))
        
        # create new arrays for training patchs
        image_patch = image[0][S]
        label_patch = label[0][S]
        

        image_patch = image_patch[np.newaxis, :, :]
        label_patch = label_patch[np.newaxis, :, :]
        
        image_patch = np.concatenate((image_patch,)*3, axis=0)
        label_patch = np.concatenate((label_patch,)*3, axis=0)

        
        # Turning our image/label to a PyTorch Tensor with dtype = float 
        # and then putting it onto the GPU for faster training/inference
        
        image = torch.from_numpy(image_patch).float().to(device)
        label = torch.from_numpy(label_patch).float().to(device)
        # image = torch.from_numpy(image_patch).float()
        # label = torch.from_numpy(label_patch).float()
            
        return image, label

In [9]:
#Load the actual data that we're working on & print the shape of this data
train_data = loader.load('training_data610-2000.npy')
test_data = loader.load('test_data200-2000.npy')
print('Shape of train set=', train_data.shape)

Shape of train set= (2, 610, 1, 2000, 2000)


In [10]:
train_ds = Img_Dataset(data_set=train_data,
                       patch_size=64,
                       height=2000,
                       width=2000)

test_ds = Img_Dataset(data_set=test_data,
                       patch_size=64,
                       height=2000,
                       width=2000)

train_dataloader = DataLoader(train_ds, batch_size=56, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=56, shuffle=True)

data = fastai.basic_data.DataBunch(train_dataloader, test_dataloader)

# Back to original work

In [11]:
# predictive filter flow layer
# https://arxiv.org/abs/1811.11482
# Kong, S., & Fowlkes, C. (2018). Image reconstruction with predictive filter flow. arXiv preprint arXiv:1811.11482.
# - learn and apply individual filters (ksize x ksize) for each spatial position in the input
# - i.e. when using softmax activation it is basically image warping, but instead of offsets, we learn filters
class pFF(nn.Module):
    def __init__(self,ni, ksize=3,stride=1,softmax = True,upsample=1):
        super(pFF, self).__init__()
        # size of the learned filter: ksize x ksize
        self.ksize=ksize
        # use softmax or tanh
        self.softmax = softmax
        # upsampling of the learned filters (gives smoother result)
        self.upsample = upsample
        # train conv layer to output filter flow and use reflection padding
        self.get_filter = nn.Conv2d(ni,ksize**2,3,padding=1,stride=upsample,padding_mode='reflect')
        self.pad = nn.ReflectionPad2d(padding=int((ksize-1)/2)*stride)
        # apply learned filters
        self.uf1 = nn.Unfold(ksize, dilation=stride, padding=0, stride=1)
        self.uf2 = nn.Unfold(1, dilation=1, padding=0, stride=1)
        if upsample>1:
            self.us = nn.UpsamplingBilinear2d(scale_factor=upsample)
        
    def forward(self, features,inpt):
        # features: features learned by CNN, inpt: input that filters should be applied to
        # 1: get filter
        ff = self.get_filter(features)
        # 2: apply activation function
        if self.softmax:
            ff = F.softmax(ff,dim=1)
        else:
            ff = torch.tanh(ff)
        if self.upsample>1:
            ff = self.us(ff)
            
        # apply learned filters
        inp_pad = self.pad(inpt)
        
        # use filter on each channel/feature of the input
        ff = torch.cat([ff]*inpt.shape[1],dim=1)
        out = self.uf1(inp_pad) * self.uf2(ff)
        out = out.view(-1,inpt.shape[1],self.ksize**2,inpt.shape[2],inpt.shape[3])
        return out.sum(dim=2)
   

In [12]:
# Build U-Net for face2face translation usining predictive filter flow
# U-Net adapted from https://github.com/milesial/Pytorch-UNet
# - use residual blocks instead of double convolution
# - replaced transpose conv with PixelShuffle for efficiency
# - replaced conv2d with depthwise separable conv for efficiency
# - Use Mish activation function: code by https://github.com/lessw2020/mish
# - multiple residual blocks in the middle
# - predictive filter flow


class Mish(nn.Module):
    # source: https://github.com/lessw2020/mish
    def __init__(self):
        super().__init__()
    def forward(self, x):
        #inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
        return x *( torch.tanh(F.softplus(x)))
    

class depthwise_separable_conv(nn.Module):
    # source: "shicai": https://discuss.pytorch.org/t/how-to-modify-a-conv2d-to-depthwise-separable-convolution/15843)
    def __init__(self, nin, nout,stride=1,ksize=3):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size=ksize, 
                                   padding=int((ksize-1)/2), groups=nin,
                                   stride=stride,padding_mode='reflect')
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)
    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

class block1(nn.Module):
    def __init__(self,ni,no,stride=2,last_act=True):
        super(block1, self).__init__()
        self.bottleneck = ni != no
        self.last_act = last_act
        self.stride=stride
        self.sconv1 = depthwise_separable_conv(ni,no,stride)
        self.sconv2 = depthwise_separable_conv(no,no)
        self.sconv3 = depthwise_separable_conv(no,no)
        self.normact0 = nn.Sequential(nn.BatchNorm2d(no),Mish())
        self.normact1 = nn.Sequential(nn.BatchNorm2d(no),Mish())
        self.normact2 = nn.Sequential(nn.BatchNorm2d(no),Mish())
    def forward(self,x):
        if self.stride>1 or self.bottleneck:
            x = self.normact0(self.sconv1(x))
        residual = x
        out = self.normact1(self.sconv2(x))
        out = self.sconv3(out)
        out += residual
        if self.last_act:
            out = self.normact2(out)
        return out


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = nn.Sequential(
            depthwise_separable_conv(in_ch, out_ch,1,7),
            Mish()
        )
    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.conv = block1(in_ch,out_ch,2)
    def forward(self, x):
        x = self.conv(x)
        return x

    
class res_block(nn.Module):
    def __init__(self, in_ch):
        super(res_block, self).__init__()
        self.resconv = nn.Sequential(
            block1(in_ch,in_ch,1,False),
            block1(in_ch,in_ch,1,False),
            block1(in_ch,in_ch,1,False),
            block1(in_ch,in_ch,1,False)
        )
    def forward(self, x):
        x = self.resconv(x)
        return x
    

class up(nn.Module):
    def __init__(self, in_ch,mid_ch, out_ch, bilinear=True):
        super(up, self).__init__()
        # theoretically it makes sense to add another conv layer before pixel
        # shuffle since we need a positional encoding, but here I did not find
        # it to be necessary
#         self.conv1 = nn.Conv2d(in_ch, in_ch, kernel_size=1)
        self.ups = nn.modules.PixelShuffle(2)
        self.norm = nn.BatchNorm2d(in_ch//4+mid_ch)
        self.conv2 = block1(in_ch//4+mid_ch, out_ch,1)
    def forward(self, x1, x2):
        # double spatial resolution via pixel shuffle
#         x1 = self.conv1(x1) 
        x1 = self.ups(x1)
        # combine information of high- and low-level features
        x = torch.cat([x2, x1], dim=1)
        x = self.norm(x)
        x = self.conv2(x)
        return x  
    
    

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.inc = inconv(3, 64)
        self.down1 = down(64, 64)
        self.down2 = down(64, 128)
        self.down3 = down(128,256)
        self.down4 = down(256,256)
        self.res1 = res_block(256)
        
        self.up1 = up(256,256, 256)
        self.up2 = up(256,128, 128)
        self.up3 = up(128,64,64)
        self.up4 = up(64,64, 64)
        
        self.filter0 = pFF(64, ksize=9)
        self.filter1 = pFF(64, ksize=9,stride=4)
        self.filter2 = pFF(64, ksize=9,stride=8)
        self.filter3 = pFF(64, ksize=9)
        
    def forward(self, inp):
        x1 = self.inc(inp)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.res1(x5)
        
        x = self.up1(x, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        filtered = self.filter0(x,inp)
        filtered = self.filter1(x,filtered)
        filtered = self.filter2(x,filtered)
        filtered = self.filter3(x,filtered)
    
        return filtered#torch.sigmoid(x)*4-1.6 #x

model = UNet()
model = model.to(device)
# summary(model, input_size=(3, size,size))


In [13]:
data

DataBunch;

Train: <__main__.Img_Dataset object at 0x7fd7d46cb2b0>;

Valid: <__main__.Img_Dataset object at 0x7fd7d46cb310>;

Test: None

In [14]:
x = model.state_dict()
x['inc.conv.0.depthwise.weight']

tensor([[[[ 0.0735,  0.0923,  0.1194,  0.0260, -0.0513,  0.1205, -0.0798],
          [ 0.0316,  0.1201,  0.1090,  0.0218,  0.1378, -0.0743, -0.0557],
          [ 0.0671, -0.0775,  0.0115,  0.1103, -0.0321,  0.0138,  0.1135],
          [ 0.0686,  0.1268,  0.0813, -0.1070, -0.1342, -0.0828,  0.0453],
          [-0.1066, -0.0013,  0.1090,  0.1390, -0.0726,  0.0512,  0.0838],
          [-0.0030, -0.0791, -0.0056,  0.0673,  0.1043,  0.0451, -0.0156],
          [ 0.0361, -0.0817,  0.1336,  0.1220, -0.0308,  0.1421, -0.0596]]],


        [[[ 0.1381,  0.0836, -0.0523,  0.1165, -0.1349,  0.0113, -0.0289],
          [-0.0737, -0.0852,  0.0037,  0.0063, -0.0229,  0.0264, -0.0277],
          [ 0.0249, -0.1395,  0.0210, -0.1040, -0.1241, -0.1128,  0.0267],
          [-0.0699, -0.1338, -0.1163, -0.0858, -0.0710, -0.0193, -0.0326],
          [-0.0585, -0.0615,  0.0898, -0.0370, -0.0671, -0.0285, -0.1272],
          [ 0.1026, -0.0805,  0.0980, -0.0636,  0.1009,  0.1081,  0.1229],
          [-0.0817, -

In [15]:
def charbonnier(y_pred, y_true):
    epsilon = 1e-3
    error = y_true - y_pred
    p = torch.sqrt(error**2 + epsilon**2)
    return torch.mean(p)


# Perceptual Loss:
# original code from: fast.ai lesson ???
# modifications:
# - Instance Normalization of low-level features to remove influence of "style"
def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)
#vgg_m = vgg16_bn(True).features.cuda().eval()
vgg_m = vgg16_bn(True).features.to(device).eval()
requires_grad(vgg_m, False)
blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]

base_loss = F.mse_loss

class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts,without_instancenorm=1):
        super().__init__()
        self.m_feat = m_feat
        # how many layers are not subjected to instance norm (starting from high-level, i.e. later layers)
        self.without_instancenorm = without_instancenorm
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts # This is a list. How to get this onto GPU
        self.metric_names = [f'feat_{i}' for i in range(len(layer_ids))
                                         ]+ [f'gram_{i}' for i in range(len(layer_ids))]
              

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        try:
            # instance normalization for all but last layer
            for l in range(len(in_feat)-self.without_instancenorm):
                in_feat[l] = nn.InstanceNorm2d(in_feat[l][1],momentum=0)(in_feat[l])
                out_feat[l] = nn.InstanceNorm2d(out_feat[l][1],momentum=0)(out_feat[l])
                
            self.feat_losses = [base_loss(f_in, f_out)*w
                                 for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
            self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                                 for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        except IndexError: # No idea why this tends to happen (only during validation)
            self.feat_losses = [torch.tensor(1).float().to(device)]
            for k in range(6):
                self.feat_losses += [torch.tensor(1).float().to(device)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()
        
        
feat_loss = FeatureLoss(vgg_m, blocks[2:5], [2,2,3])

In [16]:
G = Learner(data,model,loss_func=feat_loss, opt_func=Ranger)
gc.collect()
# show output before training
# G.show_results(rows=2, imgsize=5)

0

In [17]:
# find good learning rate
G.lr_find()
G.recorder.plot()

Ranger optimizer loaded. 
Gradient Centralization usage = True
GC applied to both conv and fc layers


epoch,train_loss,valid_loss,time


set state called
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.


RuntimeError: cannot pin 'torch.cuda.FloatTensor' only dense CPU tensors can be pinned

In [None]:
gc.collect()
G.fit_fc(15,4e-2)
# show intermediate results
G.show_results(rows=2, imgsize=5)

In [None]:
gc.collect()
G.fit_fc(15,4e-3)
# show intermediate result
G.show_results(rows=2, imgsize=5)

In [None]:
gc.collect()
G.fit_fc(15,4e-4)
# show final result
G.show_results(rows=2, imgsize=5)

In [None]:
# save model
G.save(path_inp.split('/')[-1])
G.export()

In [None]:
# from Show import live_swap
# model2 = load_learner(path_inp)
# live_swap(model2,size)