In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# %pip install -q torchvision pandas
# %matplotlib inline

In [3]:
from charts.common.dataset import LabeledImage
from charts.common.timer import Timer
import charts.pytorch.color_regression as cr
from charts.pytorch.utils import Experiment, num_trainable_parameters, is_google_colab

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter

from torchvision.utils import make_grid
from torchvision.io import read_image
from torchvision.transforms import ToTensor, ToPILImage
from torchvision import transforms

import torch_lr_finder
import timm
from ptflops import get_model_complexity_info

import pandas as pd
import matplotlib.pyplot as plt

import numpy as np

from PIL import Image

from icecream import ic
from tqdm.notebook import tqdm

import os
from pathlib import Path
import time

In [4]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
display(f"Use CUDA: {use_cuda}")

'Use CUDA: True'

In [12]:
preprocessor = cr.ImagePreprocessor(device)

dataset_path = Path("/content/datasets/drawings") if is_google_colab() else Path('../../generated/drawings')
dataset = cr.ColorRegressionImageDataset(dataset_path, preprocessor)
n_train = max(int(len(dataset) * 0.75), 1)
# n_val = len(dataset) - n_train
# train_dataset, val_dataset = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(42))
generator = torch.Generator().manual_seed(42)

train_indices = range(0, n_train)
val_indices = range(n_train, len(dataset))
train_sampler = SubsetRandomSampler(train_indices, generator=generator)
val_sampler = SubsetRandomSampler(val_indices, generator=generator)

BATCH_SIZE=16 if is_google_colab() else 4

train_dataloader = DataLoader(dataset, sampler=train_sampler, batch_size=BATCH_SIZE)
val_dataloader = DataLoader(dataset, sampler=val_sampler, batch_size=BATCH_SIZE)

monitored_sample = dataset[0]
monitored_sample_inputs = torch.unsqueeze(monitored_sample[0], dim=0)
monitored_sample_json = monitored_sample[2]

In [13]:
xp = Experiment("2022-Jan31-CR1", clear_previous_results=True)

net = cr.RegressionNet_Unet1()

# get_model_complexity_info(net, (3, 256, 256), as_strings=True, print_per_layer_stat=True, verbose=False)

ic(num_trainable_parameters(net))
net.to(device)

criterion = nn.MSELoss()

# print (f"Initial training loss: {samecolors.compute_average_loss (train_dataloader, net, criterion)}")
# print (f"Initial validation loss: {samecolors.compute_average_loss (val_dataloader, net, criterion)}")

# optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9)
optimizer = optim.Adam([
    {'params': net.encoder.parameters(), 'lr': 1e-3 },
    {'params': net.decoder.parameters(), 'lr': 1e-5 }
])

xp.prepare (net, optimizer, device, monitored_sample_inputs)

def train (first_epoch, end_epoch, optimizer, max_lr):
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, steps_per_epoch=len(train_dataloader), epochs=(end_epoch-first_epoch))
    pbar = tqdm(range(first_epoch, end_epoch))
    for epoch in pbar:  # loop over the dataset multiple times

        net.train()
        cumulated_training_loss = 0.0
        tstart = time.time()
        
        # batch_bar = tqdm(train_dataloader, leave=False)
        for i, data in enumerate(train_dataloader):
            inputs, labels, json_files = data
            outputs = net(inputs)
            
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            batch_loss = loss.item()
            xp.writer.add_scalar("Single Batch Loss", batch_loss, epoch)

            cumulated_training_loss += batch_loss

            if scheduler:
                scheduler.step()

            try:
                idx = json_files.index(monitored_sample_json)
                # t.stop()
                
                xp.writer.add_image("Sample output", preprocessor.denormalize_and_clip_as_tensor(outputs[idx]), epoch)
                xp.writer.add_image("Target output", preprocessor.denormalize_and_clip_as_tensor(labels[idx]), epoch)
                xp.writer.add_image("Sample input",  preprocessor.denormalize_and_clip_as_tensor(inputs[idx]), epoch)
            except ValueError: # monitored_json not in the batch
                pass
        

        # Very important for batch norm layers.
        net.eval()

        training_loss = cumulated_training_loss / len(train_dataloader)
        xp.writer.add_scalar("Training Loss", training_loss, epoch)
        
        val_loss = cr.compute_average_loss (val_dataloader, net, criterion)
        xp.writer.add_scalar("Validation Loss", val_loss, epoch)

        val_accuracy = cr.compute_accuracy (val_dataloader, net, criterion)
        xp.writer.add_scalar("Validation Accuracy", val_accuracy, epoch)

        elapsedSecs = (time.time() - tstart)
        xp.writer.add_scalar("Elapsed Time (s)", elapsedSecs, epoch)
        # print(f"[{epoch}] [TRAIN_LOSS={training_loss:.4f}] [VAL_LOSS={val_loss:.4f}] [{elapsedSecs:.1f}s]")
        
        xp.writer.add_histogram("enc0", net.decoder.enc0.block[3].weight, global_step=epoch)
        xp.writer.add_histogram("dec0", net.decoder.dec0.block[3].weight, global_step=epoch)

        pbar.set_postfix({'train_loss': training_loss, 'val_loss': val_loss, 'val_accuracy': val_accuracy})

        if epoch % 5 == 1:
            xp.save_checkpoint(epoch)

FROZEN_EPOCHS=100
TOTAL_EPOCHS=200

net.freeze_encoder()
ic(num_trainable_parameters(net))
train(xp.first_epoch, FROZEN_EPOCHS, optimizer, max_lr=(1e-5, 1e-3))

net.unfreeze_encoder()
ic(num_trainable_parameters(net))
train(FROZEN_EPOCHS, TOTAL_EPOCHS, optimizer, max_lr=(1e-5, 3e-4))

print('Finished Training!')

Will store the experiment data to /content/drive/MyDrive/DaltonLens-Colab/DaltonLensPrivate/charts/pytorch/experiments/2022-Jan31-CR1


ic| num_trainable_parameters(net): 29508035




ic| num_trainable_parameters(net): 18331523


  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
def load_specific_checkpoint (name):
    checkpoint = torch.load(xp.log_path / name, map_location=device)
    net.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# load_specific_checkpoint ("checkpoint-00701.pt")
# torch.save (net, "regression_unet_v1.pt")

In [None]:
with torch.no_grad():
    input, labels, _ = next(iter(train_dataloader))
    output = net(input)
    #clear_output(wait=True)
    plt.figure()
    plt.imshow (preprocessor.denormalize_and_clip_as_numpy(input[0]))
    plt.figure()
    plt.imshow (preprocessor.denormalize_and_clip_as_numpy(output[0]))
    plt.figure()
    plt.imshow (preprocessor.denormalize_and_clip_as_numpy(labels[0]))