<a href="https://colab.research.google.com/github/Oleg007003/ColorizationPictures/blob/main/Colorization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
from time import sleep
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
!wget http://sereja.me/f/universum_compressed.tar
!tar xf universum_compressed.tar
!wget http://vis-www.cs.umass.edu/lfw/lfw.tgz
!tar xf lfw.tgz

!mv lfw/* universum-photos/


In [None]:
!pip install -q kaggle
!kaggle datasets download -d olgabelitskaya/flower-color-images
!unzip flower-color-images.zip -d '/content/colors/'
!mv colors/* universum-photos/

In [None]:
import os
from PIL import Image

class ColorizationDataset(Dataset):
    def __init__(self, path, transform_x, transform_y):
        self.transform_x = transform_x
        self.transform_y = transform_y
      
        filenames = []
        for root, dirs, files in os.walk(path):
            for file in files:
                if file.endswith('.jpg') or file.endswith('.JPG'):
                    filenames.append(os.path.join(root, file))

        self.images = []
        for filename in tqdm(filenames):
            try:
                with Image.open(filename) as image:
                    self.images.append(image.copy())
            except:
                pass
                #print('Could not load image:', filename)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        Y = self.transform_y(img)
        X = self.transform_x(Y)
        return X, Y

In [None]:
transform_all = transforms.Compose([
    transforms.RandomResizedCrop(128),
    transforms.RandomRotation(degrees=(-45, 45)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

def to_grayscale(x):
    return (x[0] * 0.299 + x[1] * 0.587 + x[2] * 0.114).view(1, 128, 128)

In [None]:
dataset = ColorizationDataset('universum-photos/', to_grayscale, transform_all)
loader = DataLoader(dataset, batch_size=50, shuffle=True)

In [None]:
class Colorizer(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.preconcat = nn.Sequential(
            nn.Conv2d(1, 8, (3, 3), padding=1), 
            nn.BatchNorm2d(8),
            #nn.MaxPool2d((2,2), stride=(2,2)),
            nn.ReLU(),
            nn.Conv2d(8, 8, (3,3), padding=1),
            nn.BatchNorm2d(8),
            #nn.MaxPool2d((2,2), stride=(2,2)),
            nn.ReLU(),
            nn.Conv2d(8,32, (3,3), padding=1),
            nn.BatchNorm2d(32),
            nn.MaxPool2d((2, 2), stride=(2, 2)),
            nn.ReLU(),
            nn.Conv2d(32, 64, (3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.MaxPool2d((2, 2), stride=(2, 2)),
            nn.ReLU(),
            nn.Conv2d(64, 128, (3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, (3, 3), padding=1),
            nn.BatchNorm2d(256),
            nn.MaxPool2d((2, 2), stride=(2, 2)),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, (3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Upsample(scale_factor=2),   
            nn.Conv2d(128, 64, (3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),

            nn.Upsample(scale_factor=2),    
            nn.Conv2d(64, 64, (3, 3), padding=1),
            nn.LeakyReLU(),
        )
         
        self.postconcat = nn.Sequential(
            nn.Conv2d(65, 32, (3, 3), padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 8, (3, 3), padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(8),
            nn.Conv2d(8, 8, (3, 3), padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(8, 3, (3, 3), padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        h = self.preconcat(x)
        h = torch.cat((h, x), 1)
        h = self.postconcat(h)
        return h

In [None]:
def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    torch.nn.init.normal(m.weight, 0.0, 0.02)
  elif classname.find('BatchNorm') != -1:
    torch.nn.init.normal_(m.weight, 1.0, 0.02)
    torch.nn.init.zeros_(m.bias)

In [None]:
num_epochs = 14
lr = 5e-3
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = Colorizer().to(device)
model.apply(weights_init)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.SmoothL1Loss()  # тут можно поиграться с лоссами
criterion2 = nn.MSELoss()  
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.8)

In [None]:
history = []
i = 0
#losses = utils.AverageMeter()
for epoch in range(num_epochs):
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = criterion2(output, y)
        loss.backward() 
        oX = x[0]
        oY = y[0]
        oR = output[0]
        history.append(loss.item())
        optimizer.step()
      #  losses.update(loss.cpu().item(), x.size(0))
        del x
        del y
        del output
        torch.cuda.empty_cache()
        # теперь сами:
        # 0. распакавать данные на нужное устройство+
        # 1. сбросить градиент+
        # 2. прогнать данные через сеть+
        # 3. посчитать loss+
        # 4. залоггировать его куда-нибудь+
        # 5. сделать .backward()+
        # 6. optimizer.step()+
        # (7. вывести пример колоризации -- см код ниже)
        #print(loss.item())
    print(i)
    i += 1
model.eval()
tm = model(predicting)[0].detach().cpu()
model.train()
npimg = tm.numpy()
npimg = np.clip(npimg, 0.,1)
mping_color = np.transpose(npimg,(1,2,0))
#Image.fromarray((npimg_color*255).astype(np.uint8)).save(f"gif/{epoch}.png")
show_img((oX.detach().cpu(), oY.detach().cpu(), oR.detach().cpu()))

In [None]:
def show_img(sample):
  img = sample[1]
  img = img / 2 + 0.5
  npimg = img.numpy()
  npimg = np.clip(npimg, 0., 1.)
  npimg_color = np.transpose(npimg, (1, 2, 0))
  X = sample[0]
  _, H, W = X.shape
  img = np.zeros((H, W, 3))
  img[:,:,1] = img[:,:,2] = img[:,:,0] = X
  f = plt.figure(figsize=(20,20))
  f.add_subplot(1,3,1)
  plt.imshow(img)
  f.add_subplot(1,3,2)
  plt.imshow(npimg_color)
  if len(sample) > 2:
    rest = sample[2]
    npimg = rest.numpy()
    npimg = np.clip(npimg, 0., 1.)
    npimg_color = np.transpose(npimg, (1,2,0))
    f.add_subplot(1,3,3)
    plt.imshow(npimg_color)
  plt.show(block=True)

In [None]:
def to_numpy_image(img):
    return img.detach().cpu().view(3, 128, 128).transpose(0, 1).transpose(1, 2).numpy()

In [None]:
for t in range(10):
    img_gray, img_true = dataset[t]
    img_pred = model(img_gray.to(device).view(1, 1, 128, 128))
    img_pred = to_numpy_image(img_pred)
    # теперь это numpy-евский ndarray размера (128, 128, 3)
    plt.figure(figsize=(10,10))
    
    plt.subplot(141)
    plt.axis('off')
    plt.set_cmap('Greys')
    plt.imshow(img_gray.reshape((128, 128)))

    plt.subplot(142)
    plt.axis('off')
    plt.imshow(img_pred.reshape((128, 128, 3)))

    plt.subplot(143)
    plt.axis('off')
    plt.imshow(to_numpy_image(img_true))
    
    plt.show()