In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision
torchvision.disable_beta_transforms_warning()
import torchvision.transforms.v2 as v2
from torchvision import transforms
from torchvision.datasets import Food101

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import wandb

from FoodData import FoodDataset, FoodColorizationDataset
from utils import collect_train_val, get_classes_map, split_in_out_domain, split_train
from Autoencoder import Autoencoder

ROOT_DIR = '../'
CHECKPOINT_DIR = './checkpoint_small'

In [2]:
train_dataset = Food101(ROOT_DIR, split='train', transform=transforms.ToTensor(), download=True)
test_dataset = Food101(ROOT_DIR, split='test', transform=transforms.ToTensor(), download=True)

In [3]:
np.random.seed(42)

class_to_id, id_to_class = get_classes_map(ROOT_DIR)
classes = np.array(list(class_to_id.keys()))
out_id = np.random.choice(len(classes), 20)
out_classes = classes[out_id]
out_classes_id = list([class_to_id[x] for x in out_classes])
out_classes

array(['guacamole', 'spring_rolls', 'carrot_cake', 'paella',
       'lobster_bisque', 'chicken_wings', 'ravioli', 'sashimi',
       'peking_duck', 'peking_duck', 'scallops', 'tuna_tartare',
       'churros', 'baklava', 'chocolate_cake', 'gyoza', 'baby_back_ribs',
       'scallops', 'cup_cakes', 'filet_mignon'], dtype='<U23')

In [4]:
train_files, val_files, train_target, val_target =\
    collect_train_val(ROOT_DIR)

train_in_files, train_in_target, train_out_files, train_out_target =\
    split_in_out_domain(train_files, train_target, out_classes_id)

_, train_class_files, _, train_class_target =\
    split_train(train_in_files, train_in_target)

_, train_class_files_out, _, train_class_target_out =\
    split_train(train_out_files, train_out_target)

val_in_files, val_in_target, val_out_files, val_out_target =\
    split_in_out_domain(val_files, val_target, out_classes_id)

In [5]:
transform_colorization_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((128, 128), antialias=True),
            transforms.RandomHorizontalFlip(),
            transforms.Grayscale()
        ])
transform_colorization_valid = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((128, 128), antialias=True),
            transforms.Grayscale()
        ])

transform_classic_train = transforms.Compose([
            transforms.TrivialAugmentWide(),
            transforms.ToTensor(),
            transforms.Resize((128, 128), antialias=True),
            transforms.RandomHorizontalFlip()
        ])
transform_classic_valid = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((128, 128), antialias=True),
        ])

In [6]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, hidden, out_channels):
        super().__init__()
        self.basic_block = nn.Sequential(
            nn.Conv2d(in_channels, hidden, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(hidden, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
    
    def forward(self, x):
        return self.basic_block(x)

class ResNetFoodClassifierSmall(nn.Module):
    def __init__(self, prefix, in_channels, classes):
        super().__init__()
        
        self.back_bone_prefix = prefix
        
        for parametr in self.back_bone_prefix.parameters():
            parametr.requires_grad = False
            
        
        self.basic_block = BasicBlock(in_channels, 128, 256)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.out = nn.Linear(256, classes)
        
    def forward(self, x):
        x = self.back_bone_prefix(x)
        x = self.basic_block(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), x.size(1))
        x = self.out(x)
        
        return x

In [7]:
@torch.no_grad()
def accuracy(logits, target):
    pred = torch.argmax(logits, dim=1)
    return torch.sum(pred == target).item() / pred.size(0)


@torch.no_grad()
def validation_epoch(model, valid_loader, loss_fn, metric_fn, step):
    model.eval()
    
    loss, metric = 0, 0
    count = 0
    for X_batch, target in valid_loader:
        X_batch, target = X_batch.cuda(), target.cuda()

        out = model(X_batch)
        loss += loss_fn(out, target).item() * out.size(0)
        metric += metric_fn(out, target) * out.size(0)
        count += out.size(0)
        
    loss /= count
    metric /= count
    wandb.log({"eval/loss": loss, "eval/metric": metric}, step=step)
    
    return loss, metric
    

def train_epoch(model, optimizer, train_loader, loss_fn, metric_fn, step):
    model.train()
    for X_batch, target in train_loader:
        X_batch, target = X_batch.cuda(), target.cuda()

        optimizer.zero_grad()
        out = model(X_batch)
        loss = loss_fn(out, target)
        loss.backward()
        optimizer.step()
        
        metric = metric_fn(out, target)

        wandb.log({"train/loss": loss.item(), "train/metric": metric}, step=step)
        
        step += 1
    
    return step

def train(model, optimizer, train_loader, valid_loader, loss_fn, metric_fn, checkpoint_path, epoch_num=10):
    torch.backends.cudnn.benchmark = True
    
    best_metric = 0
    step = 0
    for epoch in tqdm(range(epoch_num)):
        
        step = train_epoch(model, optimizer, train_loader, loss_fn, metric_fn, step)
        loss, metric = validation_epoch(model, valid_loader, loss_fn, metric_fn, step)
        
        if best_metric < metric:
            best_metric = metric
            
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                'metric': metric
                }, checkpoint_path)
        step += 1
        

# Без использования предобученной модели для раскраски изображений

In [8]:
class_to_id, id_to_class = get_classes_map(ROOT_DIR)


train_dataset = FoodDataset(ROOT_DIR, train_class_files, train_class_target, \
                            transform_classic_train, class_to_id, id_to_class)
train_loader = DataLoader(train_dataset, batch_size=32, num_workers=0, shuffle=True)

valid_dataset = FoodDataset(ROOT_DIR, val_in_files, val_in_target, \
                            transform_classic_valid, class_to_id, id_to_class)
valid_loader = DataLoader(valid_dataset, batch_size=32, num_workers=0, shuffle=False)


train_out_dataset = FoodDataset(ROOT_DIR, train_class_files_out, train_class_target_out, \
                                transform_classic_train, class_to_id, id_to_class)
train_out_loader = DataLoader(train_dataset, batch_size=32, num_workers=0, shuffle=True)

valid_out_dataset = FoodDataset(ROOT_DIR, val_out_files, val_out_target, \
                                transform_classic_valid, class_to_id, id_to_class)
valid_out_loader = DataLoader(valid_dataset, batch_size=32, num_workers=0, shuffle=False)

In [9]:
model = ResNetFoodClassifierSmall(nn.Identity(), 3, 101)
model = model.cuda()

optimizer = optim.AdamW(model.parameters(), lr=0.0003, weight_decay=2e-05)
loss_fn = nn.CrossEntropyLoss()

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params

325093

In [10]:
# wandb.init(project="colorization-classifier_small_in", name="classifier_origin_in")
# train(model, optimizer, train_loader, valid_loader, loss_fn, accuracy, checkpoint_path=f"{CHECKPOINT_DIR}/origin_in.pth")
# wandb.finish()

In [11]:
model = ResNetFoodClassifierSmall(nn.Identity(), 3, 101)
model = model.cuda()

optimizer = optim.AdamW(model.parameters(), lr=0.0005, weight_decay=2e-05)
loss_fn = nn.CrossEntropyLoss()

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params

325093

In [12]:
# wandb.init(project="colorization-classifier_small_out", name="classifier_origin_out")
# train(model, optimizer, train_out_loader, valid_out_loader, loss_fn, accuracy, checkpoint_path=f"{CHECKPOINT_DIR}/origin_out.pth")
# wandb.finish()

# С использованием предобученной модели для раскраски изображений

In [13]:
class_to_id, id_to_class = get_classes_map(ROOT_DIR)

train_dataset = FoodDataset(ROOT_DIR, train_class_files, train_class_target, \
                            transform_colorization_train, class_to_id, id_to_class)
train_loader = DataLoader(train_dataset, batch_size=32, num_workers=0, shuffle=True)

valid_dataset = FoodDataset(ROOT_DIR, val_in_files, val_in_target, \
                            transform_colorization_valid, class_to_id, id_to_class)
valid_loader = DataLoader(valid_dataset, batch_size=32, num_workers=0, shuffle=False)


train_out_dataset = FoodDataset(ROOT_DIR, train_class_files_out, train_class_target_out, \
                                transform_colorization_train, class_to_id, id_to_class)
train_out_loader = DataLoader(train_dataset, batch_size=32, num_workers=0, shuffle=True)

valid_out_dataset = FoodDataset(ROOT_DIR, val_out_files, val_out_target, \
                                transform_colorization_valid, class_to_id, id_to_class)
valid_out_loader = DataLoader(valid_dataset, batch_size=32, num_workers=0, shuffle=False)

In [14]:
checkpoint = torch.load("./checkpoint_ae/ae.pth")
checkpoint.keys(), checkpoint["epoch"], checkpoint["loss"]

(dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss']),
 14,
 0.19743407670991966)

In [15]:
AE = Autoencoder()
AE.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [16]:
prefix_backbone = nn.Sequential(AE.encoder, AE.quant_conv)

In [17]:
model = ResNetFoodClassifierSmall(prefix_backbone, 4, 101)
model = model.cuda()

optimizer = optim.AdamW(model.parameters(), lr=0.0003, weight_decay=2e-05)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params

326245

In [18]:
# wandb.init(project="colorization-classifier_small_in", name="classifier_color_in")
# train(model, optimizer, train_loader, valid_loader, loss_fn, accuracy, checkpoint_path=f"{CHECKPOINT_DIR}/color_in.pth")
# wandb.finish()

In [19]:
model = ResNetFoodClassifierSmall(prefix_backbone, 4, 101)
model = model.cuda()

optimizer = optim.AdamW(model.parameters(), lr=0.0003, weight_decay=2e-05)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params

326245

In [20]:
# wandb.init(project="colorization-classifier_small_out", name="classifier_color_out")
# train(model, optimizer, train_out_loader, valid_out_loader, loss_fn, accuracy, checkpoint_path=f"{CHECKPOINT_DIR}/color_out.pth")
# wandb.finish()