In [1]:
import warnings
import numpy as np
import random
import torch

warnings.filterwarnings("ignore")
np.random.seed(42)
random.seed(42) 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [2]:
!pip install -qU fastai

[K     |████████████████████████████████| 189 kB 14.5 MB/s 
[K     |████████████████████████████████| 55 kB 4.8 MB/s 
[?25h

In [3]:
from fastai.data.external import untar_data, URLs
import glob

coco_path = untar_data(URLs.COCO_SAMPLE)
coco_path = str(coco_path) + "/train_sample"

paths = glob.glob(coco_path + "/*.jpg")
train_paths = paths[:12000]
pretrain_paths = paths[12000:20000]
val_paths = paths[20000:]

In [4]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from skimage.color import rgb2lab

class ColorizationDataset(Dataset):
    def __init__(self, img_path1, transform=None):
        self.img_paths = list(img_path1)
        if transform is not None:
            self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert("RGB")
        img = self.transform(img)
        img = np.array(img)
        img = np.moveaxis(img, 0, 2)
        img_lab = rgb2lab(img).astype("float32")
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1.
        ab = img_lab[[1, 2], ...] / 110.
        return L, ab

from torchvision import transforms

transform = transforms.Compose(
    [
        transforms.Resize((256, 256), Image.BICUBIC),
        transforms.ToTensor(),
        #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)

train_ds = ColorizationDataset(train_paths, transform=transform)
pretrain_ds = ColorizationDataset(pretrain_paths, transform=transform)
val_ds = ColorizationDataset(val_paths, transform=transform)

batch_size = 16
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
pretrain_loader = DataLoader(pretrain_ds, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)

In [5]:
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()

        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x):
        return self.block(x)

In [6]:
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

class Generator(nn.Module):
    def __init__(self, in_channels=1, out_channels=2):
        super().__init__()

        body = create_body(resnet18, pretrained=True, n_in=in_channels, cut=-2)
        self.model = DynamicUnet(body, out_channels, (256, 256))

    def forward(self, x):
        return torch.tanh(self.model(x))

gen = Generator()
gen = gen.to(device)
print(gen)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

Generator(
  (model): DynamicUnet(
    (layers): ModuleList(
      (0): Sequential(
        (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (4): Sequential(
          (0): BasicBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): BasicBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), 

In [7]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=2, out_channels=1):
        super().__init__()

        filters = [in_channels, 64, 128, 256, 512]
        self.model = nn.Sequential(*[ConvBlock(filters[i], filters[i + 1], 4, 2) for i in range(len(filters) - 1)])
        self.model.add_module("final_conv", nn.Conv2d(filters[-1], out_channels, 4, padding=1))

    def forward(self, x):
        return torch.sigmoid(self.model(x))

dis = Discriminator()
dis = dis.to(device)

In [8]:
class AverageMeter:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.count, self.avg, self.sum = 0, 0, 0
    
    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

In [9]:
from skimage.color import lab2rgb
import matplotlib.pyplot as plt
%matplotlib inline

def show_result(gen, val_loader):
    gen.eval()
    inputs, predictions = [], []
    with torch.no_grad():
        for L, ab in val_loader:
            L = L.to(device)
            ab = ab.to(device)
            pred = gen(L)
            L = (L + 1.) * 50.
            ab = pred * 110.
            Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
            for i, img in enumerate(Lab):
                img_rgb = lab2rgb(img)
                inputs.append(L[i].cpu().squeeze())
                predictions.append(img_rgb)
        n_rows, n_col = 8, 8
        _, axs = plt.subplots(n_rows, n_col, figsize=(25, 25))
        idx = 0
        for i in range(n_rows):
            for j in range(0, n_col, 2):
                axs[i][j].imshow(inputs[idx], cmap="gray")
                axs[i][j].axis("off")
                axs[i][j + 1].imshow(predictions[idx])
                axs[i][j + 1].axis("off")
                idx += 1
        plt.show()

In [10]:
from google.colab import drive
drive.mount("/content/drive/", force_remount=True)
import os
os.chdir("/content/drive/My Drive/Colab Notebooks/lessons/project")

KeyboardInterrupt: ignored

In [None]:
from tqdm.notebook import tqdm
import os

os.makedirs("models", exist_ok=True)

def pretrain_generator(gen, train_loader, optimizer, criterion, epochs):
    for epoch in range(1, epochs + 1):
        loss_meter = AverageMeter()
        for L, ab in tqdm(train_loader):
            L = L.to(device)
            ab = ab.to(device)

            pred = gen(L)

            loss = criterion(pred, ab)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_meter.update(loss.item(), L.size(0))
            
        print(f"Epoch: {epoch}/{epochs}")
        print(f"L1 Loss: {loss_meter.avg:.5f}")

# optimizer = torch.optim.Adam(gen.parameters(), lr=1e-4)
# criterion = nn.L1Loss()
# pretrain_generator(gen, pretrain_loader, optimizer, criterion, 20)
# show_result(gen, val_loader)
# torch.save(gen.state_dict(), "models/pretrain_gen.pt")

In [None]:
gen.load_state_dict(torch.load("models/pretrain_gen.pt"))

In [None]:
import torch.optim as optim

beta1 = 0.5
lr = 2e-4
optimizer_gen = optim.Adam(gen.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_dis = optim.Adam(dis.parameters(), lr=lr, betas=(beta1, 0.999))

epochs = 100
decay_epoch = 20
lambda_func = lambda epoch: 1 - max(0, epoch - decay_epoch) / (epochs - decay_epoch)
lr_scheduler_gen = optim.lr_scheduler.LambdaLR(optimizer_gen, lr_lambda=lambda_func)
lr_scheduler_dis = optim.lr_scheduler.LambdaLR(optimizer_dis, lr_lambda=lambda_func)

criterion_gen = nn.L1Loss()
criterion_gen = criterion_gen.to(device)
criterion_dis = nn.BCELoss()
criterion_dis = criterion_dis.to(device)

In [None]:
for epoch in range(1, epochs + 1):
    gen.train()
    dis.train()
    for L, ab in train_loader:
        L = L.to(device)
        ab = ab.to(device)

        optimizer_dis.zero_grad()
        
        real_out = dis(ab)
        real_label = torch.ones_like(real_out, device=device)
        real_loss = criterion_dis(real_out, real_label)
        
        fake = gen(L)
        fake_out = dis(fake.detach())
        fake_label = torch.zeros_like(fake_out, device=device)
        fake_loss = criterion_dis(fake_out, fake_label)

        dis_loss = (real_loss + fake_loss) * 0.5

        dis_loss.backward()
        optimizer_dis.step()

        optimizer_gen.zero_grad()

        out = dis(fake)
        gen_loss1 = criterion_dis(out, real_label)
        gen_loss2 = criterion_gen(fake, ab)

        gen_loss = gen_loss1 + gen_loss2 * 10
        
        gen_loss.backward()
        optimizer_gen.step()
    
        del L
        del ab
        del real_out
        del real_label
        del real_loss
        del fake
        del fake_out
        del fake_label
        del fake_loss
        del dis_loss
        del out
        del gen_loss1
        del gen_loss2
        del gen_loss
        torch.cuda.empty_cache()
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch}/{epochs}")
        show_result(gen, val_loader)
        torch.save(gen.state_dict(), "models/generator{}.pt".format(epoch))
        torch.save(dis.state_dict(), "models/discriminator{}.pt".format(epoch))

    lr_scheduler_gen.step()
    lr_scheduler_dis.step()