In [1]:
import dataset
from visualize import *

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import os
from zipfile import ZipFile
from PIL import Image
from skimage import color
from datasets import load_dataset
from dotenv import load_dotenv
import os
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from torch import nn
import pandas as pd
from fastprogress.fastprogress import master_bar, progress_bar
from torch import nn

import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available()
                else "mps" if torch.backends.mps.is_built() else "cpu")
print(f"device: {device}")

device: mps


In [2]:
import lpips
perceptual_loss_fn = lpips.LPIPS(net='alex').to(device)
# d = loss_fn.forward(im0,im1)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /Users/matteom/miniconda3/envs/torch/lib/python3.9/site-packages/lpips/weights/v0.1/alex.pth


### Architecture

In [3]:
mse_loss_fn = nn.MSELoss()
def perceptual_and_MSE_loss(reproduced_image, original_image):
    '''
    reproduced_image: output of the model
    original_image: ground truth
    '''
    # Perceptual loss
    perceptual_loss = perceptual_loss_fn(reproduced_image, original_image).mean()
    
    # MSE loss
    mse_loss = mse_loss_fn(reproduced_image, original_image)
    
    return perceptual_loss + mse_loss

In [4]:
class Encoder(nn.Module):
    def __init__(self, in_C = 3):
        super(Encoder, self).__init__()

   
        self.conv1_3 = nn.Sequential(
            nn.Conv2d(in_C, 64, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
        )
        self.maxpool_1to2 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
        )
        self.maxpool_2to3 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
        )
        self.maxpool_3to4 = nn.MaxPool2d(2, 2)
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
        )
        self.maxpool_4to5 = nn.MaxPool2d(2, 2)
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
        )




    def forward(self,content_image):
        content_image=content_image.to(device)

        # Encoder
        conv1 = self.conv1(content_image)

        maxpooled_1to2 = self.maxpool_1to2(conv1)
        conv2 = self.conv2(maxpooled_1to2)

        maxpooled_2to3 = self.maxpool_2to3(conv2)
        conv3 = self.conv3(maxpooled_2to3)

        maxpooled_3to4 = self.maxpool_3to4(conv3)
        conv4 = self.conv4(maxpooled_3to4)

        maxpooled_4to5 = self.maxpool_4to5(conv4)
        conv5 = self.conv5(maxpooled_4to5)

        return conv1, conv2, conv3, conv4, conv5

In [5]:
class Decoder(nn.Module):
    def __init__(self, out_C=3):
        super(Decoder, self).__init__()


        # Decoder
        self.conv_transpose_5to6 = nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1)
        self.conv6 = nn.Sequential(
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
        )
        self.conv_transpose_6to7 = nn.ConvTranspose2d(256, 256, 4, stride=2, padding=1)
        self.conv7 = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
        )
        self.conv_transpose_7to8 = nn.ConvTranspose2d(128, 128, 4, stride=2, padding=1)
        self.conv8 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
        )
        self.conv_transpose8to9 = nn.ConvTranspose2d(64, 64, 4, stride=2, padding=1)
        self.conv9 = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),      # Simmetry broken here: keeps being 64 (from paper)
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
        )
        self.conv10 = nn.Conv2d(64, out_C, 1)



    def forward(self,latent_features):
        # Decoder
        conv_transpose_5to6=self.conv_transpose_5to6(latent_features)
        conv6 = self.conv6(conv_transpose_5to6)

        conv_transpose_6to7 = self.conv_transpose_6to7(conv6)
        conv7 = self.conv7(conv_transpose_6to7)

        conv_transpose_7to8 = self.conv_transpose_7to8(conv7)
        conv8 = self.conv8(conv_transpose_7to8)

        concatenation_8_to9 = self.conv_transpose8to9(conv8)
        conv9 = self.conv9(concatenation_8_to9)

        output = self.conv10(conv9)
        return output


In [6]:
def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std

def adain(content_feat, style_feat):
    assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)

def adjust_learning_rate(opts, iteration_count, lr,lr_decay):
    """Imitating the original implementation"""
    lr = lr / (1.0 + lr_decay * iteration_count)
    for opt in opts:
        for param_group in opt.param_groups:
            param_group['lr'] = lr

In [7]:
class Net(nn.Module):
    def __init__(self, checkpoint_path=None):
        super(Net, self).__init__()

        encoder = Encoder(in_C=3)
        if checkpoint_path is not None:
            checkpoint = torch.load(checkpoint_path)
            encoder.load_state_dict(checkpoint, strict=False)
        self.encoder = encoder.to(device)

        self.decoder_L = Decoder(out_C=1).to(device)
        self.decoder_ab = Decoder(out_C=2).to(device)
    
        self.mse_loss = nn.MSELoss()

    def encode_with_intermediate(self, input):
        conv1, conv2, conv3, conv4, conv5 = self.encoder(input)
        return [conv1, conv2, conv3, conv4, conv5]

    # extract relu4_1 from input image
    def encode(self, input):
        _, _, _, _, conv5 = self.encoder(input)
        return conv5

    def calc_content_loss(self, input, target):
        assert (input.size() == target.size())
        #assert (target.requires_grad is False)
        return self.mse_loss(input, target)

    def calc_texture_loss(self, input, target):
        assert (input.size() == target.size())
        #assert (target.requires_grad is False)
        input_mean, input_std = calc_mean_std(input)
        target_mean, target_std = calc_mean_std(target)
        return self.mse_loss(input_mean, target_mean) + \
               self.mse_loss(input_std, target_std)

    def ct_t_loss(self, pred_l, content_l, texture_l):
        input_feats = self.encode_with_intermediate(pred_l)
        target_ct = self.encode(content_l)
        target_t = self.encode_with_intermediate(texture_l)

        loss_ct = self.calc_content_loss(input_feats[-1], target_ct)
        loss_t = self.calc_texture_loss(input_feats[0], target_t[0])
        for i in range(1, len(input_feats) - 1):
            loss_t += self.calc_texture_loss(input_feats[i], target_t[i])

        return loss_ct, loss_t

    def cr_loss(self, pred_ab, color_ab):
        zero = torch.zeros(pred_ab.shape[0], 1, pred_ab.shape[2], pred_ab.shape[3]).to(device)
        pred_ab = torch.cat([zero, pred_ab], dim=1)

        input_cr = self.encode_with_intermediate(pred_ab)
        target_cr = self.encode_with_intermediate(color_ab)

        loss_cr = self.calc_texture_loss(input_cr[0], target_cr[0])
        for i in range(1, len(input_cr) - 1):
            loss_cr += self.calc_texture_loss(input_cr[i], target_cr[i])

        return loss_cr

    def run_L_path(self, content_l, texture_l, alpha = 1.0):
        ct_l_feat = self.encode(content_l)
        t_l_feat = self.encode(texture_l)
        o_l_feat = adain(ct_l_feat, t_l_feat)
        o_l_feat = alpha *  o_l_feat + (1.0 - alpha) * ct_l_feat
        l_pred = self.L_path(o_l_feat)

        return l_pred

    def run_AB_path(self, content_ab, color_ab, alpha = 1.0):
        ct_ab_feat = self.encode(content_ab)
        cr_ab_feat = self.encode(color_ab)
        o_ab_feat = adain(ct_ab_feat, cr_ab_feat)
        o_ab_feat = alpha * o_ab_feat + (1.0 - alpha) * ct_ab_feat
        ab_pred = self.AB_path(o_ab_feat)

        return ab_pred

    def forward(self, content_ab, color_ab, alpha_ab=1.0):
       

        ct_ab_feat = self.encode(content_ab)
        cr_ab_feat = self.encode(color_ab)

        o_ab_feat = adain(ct_ab_feat, cr_ab_feat)
        o_ab_feat = alpha_ab * o_ab_feat + (1.0 - alpha_ab) * ct_ab_feat

        ab_pred = self.decoder_ab(o_ab_feat)

        return ab_pred
    

    def train_model(self, 
              train_loader, val_loader,
              lr=1e-4, lr_decay=1e-5,
              epochs=100):

        self.train()

        # Training options

        opt_AB = torch.optim.Adam(self.decoder_ab.parameters(), lr=lr)
        opts = [opt_AB]

        # For stats purposes
        best_val_loss = 99999
        loss_archive = {"training": [], "validation": []}

        mb = master_bar(range(epochs))
        for epoch in mb:
            train_loss = 0

            adjust_learning_rate(opts, iteration_count=epoch, lr=lr, lr_decay=lr_decay)

            total=len(train_loader)
            for i, (context_batch_data, style_batch_data) in progress_bar(enumerate(train_loader), total=total, parent=mb):
                content_ab= context_batch_data[:,1:,:,:].to(device)
                color_ab= style_batch_data[:,1:,:,:].to(device)

                # S2: Forward

                ab_pred = self(content_ab, color_ab)

                # S3: Calculate loss
                loss_cr = self.cr_loss(ab_pred, color_ab)
                train_loss+=loss_cr
                loss = loss_cr

                # S4: Backward

                for opt in opts:
                    opt.zero_grad()
                loss.backward()
                for opt in opts:
                    opt.step()

            # Stats:
            train_loss /= len(train_loader)
            loss_archive["training"].append(train_loss)      
            

In [8]:
# Kept apart as it must be the SAME for images and styles!
BATCH_SIZE = 10
RESOLUTION = (128,128)
COLORSPACE = 'RGB'
TRAIN_SIZE = 60
VAL_SIZE = 10

In [9]:
""" # Actual images:
TRAIN_SIZE = 60
VAL_SIZE = 10

train_data, validation_data = dataset.prepare_dataset(train_size=TRAIN_SIZE, test_size=VAL_SIZE, batch_size=BATCH_SIZE,colorspace=COLORSPACE,resolution=RESOLUTION)
train_loader, validation_loader = dataset.prepare_dataloader(train_data, validation_data, batch_size=BATCH_SIZE)
 """

' # Actual images:\nTRAIN_SIZE = 60\nVAL_SIZE = 10\n\ntrain_data, validation_data = dataset.prepare_dataset(train_size=TRAIN_SIZE, test_size=VAL_SIZE, batch_size=BATCH_SIZE,colorspace=COLORSPACE,resolution=RESOLUTION)\ntrain_loader, validation_loader = dataset.prepare_dataloader(train_data, validation_data, batch_size=BATCH_SIZE)\n '

In [10]:
""" # Styles:
S_TRAIN_SIZE = 60
S_VAL_SIZE = 10

style_train_data, style_validation_data = dataset.prepare_styles_dataset(train_size=S_TRAIN_SIZE, test_size=S_VAL_SIZE, batch_size=BATCH_SIZE,colorspace=COLORSPACE,resolution=RESOLUTION)
style_train_loader, style_validation_loader = dataset.prepare_styles_dataloader(style_train_data, style_validation_data, batch_size=BATCH_SIZE) """

' # Styles:\nS_TRAIN_SIZE = 60\nS_VAL_SIZE = 10\n\nstyle_train_data, style_validation_data = dataset.prepare_styles_dataset(train_size=S_TRAIN_SIZE, test_size=S_VAL_SIZE, batch_size=BATCH_SIZE,colorspace=COLORSPACE,resolution=RESOLUTION)\nstyle_train_loader, style_validation_loader = dataset.prepare_styles_dataloader(style_train_data, style_validation_data, batch_size=BATCH_SIZE) '

In [11]:
""" class ContextDataset(Dataset):
    def __init__(self, data):
        super(ContextDataset, self).__init__()
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]['image']
        return image

class StylesDataset(Dataset):
    def __init__(self, data):
        super(StylesDataset, self).__init__()
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]['image']
        return image
    
class StyleTransferTensorDataset(torch.utils.data.TensorDataset):
    def __init__(self, dataset1, dataset2, *args, **kwargs):
        super(StylesDataset, self).__init__(*args, **kwargs)
        self.dataset1 = dataset1
        self.dataset2 = dataset2

    def __len__(self):
        len1 = len(self.dataset1)
        len2 = len(self.dataset2)
        assert len1 == len2
        return len1

    def __getitem__(self, idx):
        image1 = self.dataset1[idx]
        image2 = self.dataset2[idx]
        return (image1,image2) """

class StyleTransferDataset(torch.utils.data.Dataset):
    def __init__(self, data1, data2):
        super(StyleTransferDataset, self).__init__()
        self.data1 = data1
        self.data2 = data2

    def __len__(self):
        len1 = len(self.data1)
        len2 = len(self.data2)
        if len1<len2:
            return len1
        else:
            return len2

    def __getitem__(self, idx):
        image1 = self.data1[idx]['image']
        image2 = self.data2[idx]['image']
        return (image1,image2)


def prepare_joint_dataloaders(train_size=10,test_size=10,batch_size=4,colorspace='RGB',resolution=(128,128)):
    context_train, context_test = dataset.prepare_dataset(train_size,test_size,batch_size,colorspace,resolution)
    styles_train, styles_test = dataset.prepare_styles_dataset(train_size,test_size,batch_size,colorspace,resolution)
    # Filter context images
    filtered_context_train_data = []
    filtered_context_test_data = []
    for entry in list(context_train):
        if entry['image'].shape[0] == 3:
            filtered_context_train_data.append(entry)
    for entry in list(context_test):
        if entry['image'].shape[0] == 3:
            filtered_context_test_data.append(entry)
    # Filter styles images
    filtered_styles_train_data = []
    filtered_styles_test_data = []
    for entry in styles_train:
        if entry['image'].shape[0] == 3:
            filtered_styles_train_data.append(entry)
    for entry in styles_test:
        if entry['image'].shape[0] == 3:
            filtered_styles_test_data.append(entry)
    # Datasets
    train_dataset = StyleTransferDataset(filtered_context_train_data, filtered_styles_train_data)
    test_dataset = StyleTransferDataset(filtered_context_test_data, filtered_styles_test_data)
    # Dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_dataloader, test_dataloader

train_dataloader, test_dataloader = prepare_joint_dataloaders(TRAIN_SIZE,VAL_SIZE,BATCH_SIZE,COLORSPACE,RESOLUTION)

Dataset loaded successfully


Resolving data files:   0%|          | 0/72 [00:00<?, ?it/s]

Dataset loaded successfully


In [12]:
print("Device is: ", device)
model = Net(checkpoint_path="./colorisation/UNetREG_trsize1000_valsize100_best_model.pth.tar").to(device)

#model = load_model_from_checkpoint("./colorisation/UNetREG_trsize50_valsize5nope_best_model.pth.tar")
model = model.to(device)

model.train_model(train_loader=train_dataloader, val_loader=test_dataloader, 
                        epochs=300, lr=1e-4, 
                        #save_path= 'style-transfer', 
                        #save_name_prefix='/AB_trsize60-60',
                        #val_check_every=5,
                        #plot_every=15,
                        #plotting_samples=list(validation_data)[:3]
                        )

   

Device is:  mps


NameError: name 'train_loader' is not defined