In [None]:
from torch_snippets import *
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torch.optim import Adam, lr_scheduler
from tqdm import tqdm
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class Colorize(torchvision.datasets.CIFAR10):
  def __init__(self, root, train):
    super().__init__(root, train)
      
  def __getitem__(self, ix):
    im, _ = super().__getitem__(ix)
    bw = im.convert('L').convert('RGB')
    bw, im = np.array(bw)/255., np.array(im)/255.
    bw, im = [torch.tensor(i).permute(2,0,1).to(device).float() for i in [bw,im]]
    return bw, im

In [None]:
class Identity(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x

class DownConv(nn.Module):
  def __init__(self, in_channels, out_channels, maxPool = True):
    super().__init__()
    self.model = nn.Sequential(
        nn.MaxPool2d(2) if maxPool else Identity(),
        nn.Conv2d(in_channels, out_channels, 3, padding = 1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, 3, padding = 1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )
  
  def forward(self, x):
    return self.model(x)

class UpConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.convtranspose = nn.ConvTranspose2d(in_channels, out_channels, 2, stride = 2)
    self.convLayer = nn.Sequential(
        nn.Conv2d(out_channels + out_channels, out_channels, 3, padding = 1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, 3, padding = 1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )
  
  def forward(self, x, y):
    x = self.convtranspose(x)
    x = torch.cat([x,y], axis = 1)
    x = self.convLayer(x)
    return x

class UNET(nn.Module):
  def __init__(self):
    super().__init__()
    self.d1 = DownConv(3, 64, False)
    self.d2 = DownConv(64, 128)
    self.d3 = DownConv(128, 256)
    self.d4 = DownConv(256, 512)
    self.d5 = DownConv(512, 1024)

    self.u5 = UpConv(1024, 512)
    self.u4 = UpConv(512, 256)
    self.u3 = UpConv(256, 128)
    self.u2 = UpConv(128, 64)
    self.u1 = nn.Conv2d(64, 3, kernel_size=1, stride = 1)
  
  def forward(self, x):
    x0 = self.d1( x) # 32
    x1 = self.d2(x0) # 16
    x2 = self.d3(x1) # 8
    x3 = self.d4(x2) # 4
    x4 = self.d5(x3) # 2
    X4 = self.u5(x4, x3)# 4
    X3 = self.u4(X4, x2)# 8
    X2 = self.u3(X3, x1)# 16
    X1 = self.u2(X2, x0)# 32
    X0 = self.u1(X1) # 3

    return X0

In [None]:
def get_model():
  model = UNET().to(device)
  optimizer = Adam(model.parameters(), lr = 1e-6)
  loss_fn = nn.MSELoss()
  return model, optimizer, loss_fn

def get_data_loaders(data_folder):
  datasets.CIFAR10(data_folder, download = True, train = True)
  tr_set = Colorize(data_folder, train = True)
  vl_set = Colorize(data_folder, train = False)
  tr_dl = DataLoader(tr_set, batch_size = 256, shuffle = True, drop_last = True)
  vl_dl = DataLoader(vl_set, batch_size = 256, shuffle = True, drop_last = True)
  return tr_dl, vl_dl

def train_batch(model, optimizer, loss_function, data):
  model.train()
  x, y = data
  _y = model(x)
  optimizer.zero_grad()
  loss_value = loss_function(_y, y)
  loss_value.backward()
  optimizer.step()
  return loss_value.item()

def valid_batch(model, loss_function, data):
  model.eval() 
  x, y = data
  _y = model(x)
  loss_value = loss_function(_y, y)
  return loss_value.item()

def save_plot(tr_list, val_list, title):
  plt.plot(tr_list, label = "Training")
  plt.plot(val_list, label = "Validation")
  plt.xlabel("Epoch")
  plt.ylabel("Accuracy")
  plt.legend()
  plt.title(title)
  plt.show()
  plt.savefig(f"{title}.png")
  print(f"Saved {title}.png")

def run(epoch_range, data_folder, model_dict = None):
  model, optimizer, loss_fn = get_model()
  tr_dl, vl_dl = get_data_loaders(data_folder)
  scheduler = lr_scheduler.StepLR(optimizer, step_size = 10, gamma = 0.1)
  log = Report(epoch_range)
  min_loss = 10
  for epoch in range(epoch_range):
    N = len(tr_dl);    
    for bx, data in enumerate(tr_dl):
      loss = train_batch(model, optimizer, loss_fn, data)
      log.reocrd(epoch + (1+bx)/N, loss = loss, end = "/r")
      del loss

    N = len(vl_dl)
    val_loss_avg = 0
    for bx, data in enumerate(vl_dl):
      val_loss = valid_batch(model, loss_fn, data)
      val_loss_avg+=val_loss
      log.record(epoch + (bx + 1)/N, val_loss = val_loss, end = "/r")
      del loss
    log.report_avgs(epoch+1)
    val_loss = val_loss/N
    if min_loss > val_loss:
      min_loss = val_loss
      torch.save(f"Epoch {epoch+1}")