In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, Subset

import torchvision
from torchvision import transforms
from torchvision.datasets import Food101

torchvision.disable_beta_transforms_warning()
import torchvision.transforms.v2 as v2


from torchmetrics.image import lpip

from diffusers.models.vae import Encoder, Decoder

import matplotlib.pyplot as plt
import numpy as np
import wandb

from Autoencoder import Autoencoder
from FoodData import FoodColorizationDataset
from utils import collect_train_val, get_classes_map, split_in_out_domain, split_train
from tqdm import tqdm

ROOT_DIR = '../'
CHECKPOINT_DIR = './checkpoint_ae'

train_dataset = Food101(ROOT_DIR, split='train', transform=transforms.ToTensor(), download=True)
test_dataset = Food101(ROOT_DIR, split='test', transform=transforms.ToTensor(), download=True)

In [2]:
class_to_id, id_to_class = get_classes_map(ROOT_DIR)
classes = np.array(list(class_to_id.keys()))


In [3]:
np.random.seed(42)
out_id = np.random.choice(len(classes), 20)
out_classes = classes[out_id]
out_classes_id = list([class_to_id[x] for x in out_classes])
out_classes

array(['guacamole', 'spring_rolls', 'carrot_cake', 'paella',
       'lobster_bisque', 'chicken_wings', 'ravioli', 'sashimi',
       'peking_duck', 'peking_duck', 'scallops', 'tuna_tartare',
       'churros', 'baklava', 'chocolate_cake', 'gyoza', 'baby_back_ribs',
       'scallops', 'cup_cakes', 'filet_mignon'], dtype='<U23')

In [4]:
train_files, val_files, train_target, val_target =\
    collect_train_val(ROOT_DIR)

train_in_files, train_in_target, train_out_files, train_out_target =\
    split_in_out_domain(train_files, train_target, out_classes_id)

val_in_files, val_in_target, val_out_files, val_out_target =\
    split_in_out_domain(val_files, val_target, out_classes_id)

In [5]:
train_color_files, train_class_files, train_color_target, train_class_target =\
    split_train(train_in_files, train_in_target)

In [6]:
class_to_id, id_to_class = get_classes_map(ROOT_DIR)

train_dataset = FoodColorizationDataset(ROOT_DIR, train_color_files, train_color_target, transforms.Compose(
        [transforms.ToTensor(),
         transforms.Resize((128, 128), antialias=True)
        ]), class_to_id, id_to_class, transforms.Compose([
            v2.RandomPhotometricDistort(),
            transforms.RandomAdjustSharpness(2),
            transforms.RandomInvert(),
            transforms.Grayscale()
        ])
        )

train_loader = DataLoader(train_dataset, batch_size=32, num_workers=0, shuffle=True)

valid_dataset = FoodColorizationDataset(ROOT_DIR, val_in_files, val_in_target, transforms.Compose(
        [transforms.ToTensor(),
         transforms.Resize((128, 128), antialias=True)
        ]), class_to_id, id_to_class, transforms.Compose([
            transforms.Grayscale()
        ]))

valid_loader = DataLoader(valid_dataset, batch_size=64, num_workers=0, shuffle=False)

valid_out_dataset = FoodColorizationDataset(ROOT_DIR, val_out_files, val_out_target, transforms.Compose(
        [transforms.ToTensor(),
         transforms.Resize((128, 128), antialias=True)
        ]), class_to_id, id_to_class, transforms.Compose([
            transforms.Grayscale()
        ]))

valid_out_loader = DataLoader(valid_dataset, batch_size=64, num_workers=0, shuffle=False)

In [7]:
np.random.seed(42)
vis_id = np.random.choice(len(train_dataset), 64)
vis_train = DataLoader(Subset(train_dataset, vis_id), batch_size=64, shuffle=False)

vis_id = np.random.choice(len(valid_dataset), 64)
vis_valid = DataLoader(Subset(valid_dataset, vis_id), batch_size=64, shuffle=False)

vis_id = np.random.choice(len(valid_out_dataset), 64)
vis_out_valid = DataLoader(Subset(valid_out_dataset, vis_id), batch_size=64, shuffle=False)

In [8]:
w1 = 1
w2 = 1

pixel_loss = nn.SmoothL1Loss()
perc_loss = lpip.LearnedPerceptualImagePatchSimilarity().cuda()

def loss_fn(x, y):
    l1 = w1 * pixel_loss(x, y)
    l2 = w2 * perc_loss(x, y)
    return l1 + l2, l1.item() / w1, l2.item() / w2



In [9]:
model = Autoencoder()
model = model.cuda()

optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=2e-05)

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params

1287631

In [10]:
@torch.no_grad()
def validation_epoch(model, valid_loader, loss_fn, step):
    model.eval()
    
    loss = 0
    l1_total, l2_total = 0, 0
    count = 0
    for X_batch, target in valid_loader:
        X_batch, target = X_batch.cuda(), target.cuda()

        out = model(X_batch)
        l, l1, l2 = loss_fn(out, target)
        loss += l.item() * out.size(0)
        l1_total += l1 * out.size(0)
        l2_total += l2 * out.size(0)
        count += out.size(0)
        
    loss /= count
    l1 /= count
    l2 /= count
    wandb.log({"eval/loss": loss, 'eval/pixel_loss': l1, 'eval/lpips': l2}, step=step)
    
    return loss

@torch.no_grad()
def visualization(model, vis_train, vis_valid, vis_out_valid, step):
    model.eval()
    
    def get_img(loader):
        for X_batch, target in loader:
            X_batch, target = X_batch.cuda(), target.cuda()
            out = model(X_batch)
            loss, _, _ = loss_fn(out, target)

            img = torchvision.utils.make_grid(out, normalize=True).cpu()
            return img, loss.item()
        
    img, loss = get_img(vis_train)
    wandb.log({"vis/train_vis": wandb.Image(img, caption=f"mean loss = {loss}")}, step=step)
    
    img, loss = get_img(vis_valid)
    wandb.log({"vis/valid_vis": wandb.Image(img, caption=f"mean loss = {loss}")}, step=step)
    
    img, loss = get_img(vis_out_valid)
    wandb.log({"vis/valid_out_vis": wandb.Image(img, caption=f"mean loss = {loss}")}, step=step)
    
    model.train()
    
def visualization_init(vis_train, vis_valid, vis_out_valid, step):
    
    for _, target in vis_train:
        
        img = torchvision.utils.make_grid(target, normalize=True)
        wandb.log({"vis/train": wandb.Image(img)}, step=step)
        
    for _, target in vis_valid:
        img = torchvision.utils.make_grid(target, normalize=True)
        wandb.log({"vis/valid": wandb.Image(img)}, step=step)
    
    for _, target in vis_out_valid:
        img = torchvision.utils.make_grid(target, normalize=True)
        wandb.log({"vis/valid_out": wandb.Image(img)}, step=step)

def train_epoch(model, optimizer, train_loader, loss_fn,  vis_train, vis_valid,  vis_out_valid, step):
    model.train()
    for X_batch, target in train_loader:
        X_batch, target = X_batch.cuda(), target.cuda()

        optimizer.zero_grad()
        out = model(X_batch)
        loss, l1, l2 = loss_fn(out, target)
        loss.backward()
        optimizer.step()
        
        if step % 100 == 0:
            visualization(model, vis_train, vis_valid,  vis_out_valid, step)
        
        wandb.log({"train/loss": loss.item(), 'train/pixel_loss': l1, 'train/lpips': l2}, step=step)
        
        step += 1
    
    return step

def train(model, optimizer, train_loader, valid_loader, loss_fn, checkpoint_path, 
            vis_train, vis_valid, vis_out_valid, epoch_num=25):
    
    torch.backends.cudnn.benchmark = True
    
    best_loss = 10000
    step = 0
    
    visualization_init(vis_train, vis_valid,  vis_out_valid, step)
    for epoch in tqdm(range(epoch_num)):
        
        step = train_epoch(model, optimizer, train_loader, loss_fn, vis_train, vis_valid, vis_out_valid, step)
        loss = validation_epoch(model, valid_loader, loss_fn, step)
        
        if best_loss > loss:
            best_loss = loss
            
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, checkpoint_path)
        step += 1
        

In [11]:
# wandb.init(project="colorization-ae", name="2-ae")
# train(model, optimizer, train_loader, valid_loader, loss_fn,
#           checkpoint_path=f"{CHECKPOINT_DIR}/2ae.pth",
#           vis_train=vis_train, vis_valid=vis_valid, vis_out_valid=vis_out_valid)
# wandb.finish()