Подготовка:

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
from tqdm import tqdm
from time import sleep

for i in tqdm(range(10)):
    sleep(0.5)

100%|██████████| 10/10 [00:05<00:00,  1.99it/s]


In [13]:
import os
from PIL import Image

class ColorizationDataset(Dataset):
    def __init__(self, path, transform_x, transform_y):
        self.transform_x = transform_x
        self.transform_y = transform_y
      
        filenames = []
        for root, dirs, files in os.walk(path):
            for file in files:
                if file.endswith('.jpg') or file.endswith('.JPG'):
                    filenames.append(os.path.join(root, file))
                    filenames.append(os.path.join(root, file))
                    filenames.append(os.path.join(root, file))
                    filenames.append(os.path.join(root, file))

        self.images = []
        for filename in tqdm(filenames):
            try:
                with Image.open(filename) as image:
                    self.images.append(image.copy())
            except:
                pass
                #print('Could not load image:', filename)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        Y = self.transform_y(img)
        X = self.transform_x(Y)
        return X, Y

In [7]:
!wget http://sereja.me/f/universum_compressed.tar
!tar xf universum_compressed.tar

--2020-11-11 09:14:45--  http://sereja.me/f/universum_compressed.tar
Resolving sereja.me (sereja.me)... 213.159.215.132
Connecting to sereja.me (sereja.me)|213.159.215.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 72028160 (69M) [application/x-tar]
Saving to: ‘universum_compressed.tar’


2020-11-11 09:15:03 (3.88 MB/s) - ‘universum_compressed.tar’ saved [72028160/72028160]



In [8]:
device = torch.device('cuda:0')

In [9]:
transform_all = transforms.Compose([
    transforms.RandomResizedCrop(128),
    transforms.RandomRotation(degrees=(-45, 45)),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
])

def to_grayscale(x):
    return (x[0] * 0.299 + x[1] * 0.587 + x[2] * 0.114).view(1, 128, 128)

In [14]:
dataset = ColorizationDataset('universum-photos', to_grayscale, transform_all)
loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

100%|██████████| 5016/5016 [00:05<00:00, 841.50it/s] 


In [15]:
len(dataset)

4060

In [16]:
class Colorizer(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.preconcat = nn.Sequential(
            nn.Conv2d(1, 32, (3, 3), padding=1), # по дефолту stride = 1, а следовательно размерность не меняется
            nn.BatchNorm2d(32),                                     # меняем мы только количество каналов
            nn.MaxPool2d((2, 2), stride=(2, 2)), # а вот тут мы уменьшаем и высоту, и ширину, в два раза
            nn.LeakyReLU(),

            nn.Conv2d(32, 64, (3, 3), padding=1), 
            nn.BatchNorm2d(64),                                     
            nn.MaxPool2d((2, 2), stride=(2, 2)), 
            nn.LeakyReLU(),

            nn.Conv2d(64, 128, (3, 3), padding=1), 
            nn.BatchNorm2d(128),                                     
            nn.MaxPool2d((2, 2), stride=(2, 2)), 
            nn.LeakyReLU(),

            nn.Conv2d(128, 256, (3, 3), padding=1), 
            nn.BatchNorm2d(256),                                     
            nn.LeakyReLU(),
           
            nn.Upsample(scale_factor=2),         # увеличиваем высоту и ширину в два раза
            nn.Conv2d(256, 128, (3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),

            nn.Upsample(scale_factor=2),         
            nn.Conv2d(128, 64, (3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),

            nn.Upsample(scale_factor=2),         
            nn.Conv2d(64, 64, (3, 3), padding=1),
            nn.LeakyReLU(),
        )
         
        self.postconcat = nn.Sequential(         # эту сетку можно особо не увеличивать - она не должна быть очень умной
            nn.Conv2d(65, 32, (3, 3), padding=1),# подумайте, откуда у автора тут 65
            nn.ReLU(),
            nn.Conv2d(32, 3, (3, 3), padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        h = self.preconcat(x)
        # исходное чб изображение -- просто дополнительным слоем
        h = torch.cat((h, x), 1)
        h = self.postconcat(h)
        return h

In [17]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

In [18]:
class AverageMeter(object):
  def __init__(self):
    self.reset()
  def reset(self):
    self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    self.history = []
  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count
    self.history.append(self.avg)

  def plot(self):
    plt.plot(self.history)
    plt.show()
  
  def truncate(self, n ):
    self.history = self.history[:-(n % len(self.history))]

In [19]:
def show_img(sample):
    img = sample[1]
    img = img / 2 + 0.5  
    npimg = img.numpy()
    npimg = np.clip(npimg, 0., 1.)
    npimg_color = np.transpose(npimg, (1, 2, 0))
    X = sample[0]
    _, H, W = X.shape
    img = np.zeros((H, W, 3))
    img[:,:,1] = img[:,:,2] = img[:,:,0] = X
    f = plt.figure(figsize=(20,20))
    f.add_subplot(1,3, 1)
    plt.imshow(img)
    f.add_subplot(1,3, 2)
    plt.imshow(npimg_color)
    if len(sample) > 2:
      rest = sample[2]
      npimg = rest.numpy()
      npimg = np.clip(npimg, 0., 1.)
      npimg_color = np.transpose(npimg, (1, 2, 0))
      f.add_subplot(1,3, 3)
      plt.imshow(npimg_color)


    
    plt.show(block=True)

In [20]:
num_epochs = 350
lr = 4e-3

model = Colorizer().to(device)
model.apply(weights_init)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.5, 0.999))
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)
criterion = nn.L1Loss()  
predicting = dataset[0][0].reshape((1,1,128,128)).to(device)

In [21]:
mkdir gif

In [22]:
pip install colorama

Collecting colorama
  Downloading https://files.pythonhosted.org/packages/44/98/5b86278fbbf250d239ae0ecb724f8572af1c91f4a11edf4d36a206189440/colorama-0.4.4-py2.py3-none-any.whl
Installing collected packages: colorama
Successfully installed colorama-0.4.4


In [23]:
from colorama import Fore, Back, Style
y_ = Fore.YELLOW
r_ = Fore.RED
g_ = Fore.GREEN
b_ = Fore.BLUE
m_ = Fore.MAGENTA
c_ = Fore.CYAN
sr_ = Style.RESET_ALL

In [None]:
losses = AverageMeter()
for epoch in range(0,num_epochs):
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)  
        optimizer.zero_grad()
        pred = model(x)
        loss = criterion(pred,y)
        loss.backward()
        optimizer.step()
        losses.update(loss.cpu().item(), x.size(0))
        xx = x[0]
        yy = y[0]
        ppred = pred[0]
        del pred 
        del x 
        del y 
        torch.cuda.empty_cache()
    scheduler.step()
    model.eval()
    tm = model(predicting)[0].detach().cpu()       
    model.train()
    npimg = tm.numpy()
    npimg = np.clip(npimg, 0., 1.)
    npimg_color = np.transpose(npimg, (1, 2, 0))
    Image.fromarray((npimg_color*255).astype(np.uint8)).save(f"gif/{epoch}.png")
    print(f"{b_}Epoch {epoch} \t LR: {scheduler.get_last_lr()}{sr_}")
    if epoch % 5 == 0:
      losses.plot()
      print(f'{g_}Epoch {epoch} of {num_epochs} \n Loss: {losses.avg} \n Iters all: {losses.count}{sr_}')
      show_img((xx.detach().cpu(), yy.detach().cpu(), ppred.detach().cpu()))