<a href="https://colab.research.google.com/github/ElenaBianchini/ColoringGrayscaleImages/blob/main/ProgettoLabIA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Progetto di Laboratorio di Intelligenza Artificiale e Grafica Interattiva**

# Import

In [2]:
%matplotlib inline

In [17]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision 
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from skimage import color
from PIL import Image

# Impostazione dei parametri

In [54]:
num_epochs = 2
batch_size = 32
learning_rate = 1e-2
use_gpu = torch.cuda.is_available()

In [6]:
if use_gpu:
  device = torch.device("cuda:0")
else:
  device = torch.device("cpu")

In [7]:
device

device(type='cuda', index=0)

# Paths

In [8]:
root_path = "/content/drive/MyDrive/COCO"

In [9]:
train_folder = root_path+"/train2014/"
val_folder = root_path+"/val2014/"
test_folder = root_path+"/test2014/"

# Dataset

In [10]:
def loadImagesName(dir_path, num):
    images_list = []
    count = 1
    for image_name in os.listdir(dir_path):
      if (count>num):
        break
      filename = os.path.join(dir_path, image_name)
      images_list.append(filename)
      print("\rImage num: {}".format(count), end='')
      count = count+1
    
    return images_list


In [43]:
train_list = loadImagesName(train_folder, 5000)

Image num: 5000

In [44]:
val_list = loadImagesName(val_folder, 2500)

Image num: 1Image num: 2Image num: 3Image num: 4Image num: 5Image num: 6Image num: 7Image num: 8Image num: 9Image num: 10Image num: 11Image num: 12Image num: 13Image num: 14Image num: 15Image num: 16Image num: 17Image num: 18Image num: 19Image num: 20Image num: 21Image num: 22Image num: 23Image num: 24Image num: 25Image num: 26Image num: 27Image num: 28Image num: 29Image num: 30Image num: 31Image num: 32Image num: 33Image num: 34Image num: 35Image num: 36Image num: 37Image num: 38Image num: 39Image num: 40Image num: 41Image num: 42Image num: 43Image num: 44Image num: 45Image num: 46Image num: 47Image num: 48Image num: 49Image num: 50Image num: 51Image num: 52Image num: 53Image num: 54Image num: 55Image num: 56Image num: 57Image num: 58Image num: 59Image num: 60Image num: 61Image num: 62Image num: 63Image num: 64Image num: 65Image num: 66Image num: 67Image num: 68Image num: 69Image num: 70Image num: 71Image num: 72

In [45]:
test_list = loadImagesName(test_folder, 1875)

Image num: 1Image num: 2Image num: 3Image num: 4Image num: 5Image num: 6Image num: 7Image num: 8Image num: 9Image num: 10Image num: 11Image num: 12Image num: 13Image num: 14Image num: 15Image num: 16Image num: 17Image num: 18Image num: 19Image num: 20Image num: 21Image num: 22Image num: 23Image num: 24Image num: 25Image num: 26Image num: 27Image num: 28Image num: 29Image num: 30Image num: 31Image num: 32Image num: 33Image num: 34Image num: 35Image num: 36Image num: 37Image num: 38Image num: 39Image num: 40Image num: 41Image num: 42Image num: 43Image num: 44Image num: 45Image num: 46Image num: 47Image num: 48Image num: 49Image num: 50Image num: 51Image num: 52Image num: 53Image num: 54Image num: 55Image num: 56Image num: 57Image num: 58Image num: 59Image num: 60Image num: 61Image num: 62Image num: 63Image num: 64Image num: 65Image num: 66Image num: 67Image num: 68Image num: 69Image num: 70Image num: 71Image num: 72

In [19]:
class ImageDataset(torch.utils.data.Dataset):
  def __init__(self, images_list):
    self.images_list = images_list
    self.img_transform = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
    ])

  def __len__(self):
    return len(self.images_list)
  
  def __getitem__(self, idx):
    img = Image.open(self.images_list[idx]).convert('RGB')
    img = self.img_transform(img)
    img = np.asarray(img)
    img_lab = color.rgb2lab(img) # restituisce un numpy
    img_lab = (img_lab + 128) / 255    # perché i valori dei canali ab del formato Lab vanno da -128 a 127 e li vogliamo tra [0,1]
    img_ab = img_lab[:, :, 1:3] # forma: WxHxC = [224, 224, 2]
    img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
    img_gray = color.rgb2gray(img)
    img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
    img = torch.from_numpy(img.transpose((2, 0, 1))).float()
    return img, img_ab, img_gray


In [48]:
train_dataset = ImageDataset(train_list)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [49]:
val_dataset = ImageDataset(val_list)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [50]:
test_dataset = ImageDataset(test_list)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# Regressione

## Modello

In [23]:
class ColorizationRNet(nn.Module):
  def __init__(self, input_size = 128):
    super(ColorizationRNet, self).__init__()

    # Importo ResNet che userò per estrarre le features dalle immagini
    resnet = torchvision.models.resnet18()
    # Cambio il primo livello di convoluzione di ResNet per accetta input con un solo canale
    resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1)) 
    # Estraggo le feature dalle immagini
    self.midlevel_resnet = nn.Sequential(*list(resnet.children())[0:6])

    # Livelli di deconvoluzione:
    self.deconv = nn.Sequential(
        nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        nn.Upsample(scale_factor=2),
        nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.Upsample(scale_factor=2),
        nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1),
        nn.Upsample(scale_factor=2)
    )

  def forward(self, x):
    x = self.midlevel_resnet(x)
    output = self.deconv(x)
    return output

In [24]:
reg_net = ColorizationRNet()
reg_net = reg_net.to(device)

## Funzione di costo e di ottimizzazione

In [25]:
optimizer = torch.optim.Adam(params=reg_net.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
criterion = criterion.to(device)

## Train

In [46]:
def train(epoch, loss_avg):
  # Set model to training model
  reg_net.train()

  print('\nStarting training epoch {}\n'.format(epoch))

  loss_avg.append(0)

  for batch_idx, (img, img_ab, img_gray) in enumerate(train_dataloader):
    img = img.to(device)
    img_ab = img_ab.to(device)
    img_gray = img_gray.to(device)

    # Predizione dell'immagine ab da grayscale
    predicted = reg_net(img_gray)
  
    # Calcolo l'errore L2 tra i colori ottenuti e quelli veri:
    loss = criterion(predicted, img_ab)

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()

    # Aggiorno i pesi:
    optimizer.step()

    loss_avg[-1]+=loss.item()

    if batch_idx % 12 == 0:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]  Loss: {:.6f}'.format(
            epoch, batch_idx*len(img), len(train_dataloader.dataset), 100. * batch_idx / len(train_dataloader), loss.item()))

  loss_avg[-1]/= batch_size
  print('\nFinished training epoch {}\n'.format(epoch))


## Validation

In [63]:
def validation(epoch, val_loss_avg):
  # Set model to validation model
  reg_net.eval()

  val_loss_avg.append(0)

  for batch_idx, (img, img_ab, img_gray)in enumerate(val_dataloader):
    img = img.to(device)
    img_ab = img_ab.to(device)
    img_gray = img_gray.to(device)

    # Predizione dell'immagine ab da grayscale
    predicted = reg_net(img_gray)

    # Calcolo l'errore L2 tra i colori ottenuti e quelli veri:
    loss = criterion(predicted, img_ab)
    val_loss_avg[-1]+=loss.item()

    # Salvo la prima foto di ogni epoca
    if batch_idx==0: 
      plt.clf() #clear matplotlib
      input_gray = img_gray[0].cpu().squeeze().numpy()
      output_color = torch.cat((img_gray[0].cpu(), predicted[0].detach().cpu()), 0).numpy()
      output_color = output_color.transpose((1,2,0))
      output_color[:, :, 0:1] = output_color[:, :, 0:1] * 100
      output_color[:, :, 1:3] = output_color[:, :, 1:3] * 255 - 128   
      output_color = color.lab2rgb(output_color.astype(np.float64))
      real_color = img[0].cpu().numpy().transpose((1,2,0)).astype(np.float64)
      plt.imsave(arr=input_gray, fname = '/content/drive/MyDrive/ProgettoLab/immagini/grayscale_{}.jpg'.format(epoch))
      plt.imsave(arr=output_color, fname = '/content/drive/MyDrive/ProgettoLab/immagini/recolored_{}.jpg'.format(epoch))
      plt.imsave(arr=real_color, fname = '/content/drive/MyDrive/ProgettoLab/immagini/real_{}.jpg'.format(epoch))

  

  val_loss_avg[-1]/= batch_size
  print('\nValidation set: Average loss: {:.4f}\n'.format(val_loss_avg[-1]))
  return val_loss_avg[-1]


## Allenamento

In [None]:
best_losses = 1.0
train_loss_avg = []
val_loss_avg = []

for epoch in range(1, num_epochs+1):
  train(epoch, train_loss_avg)
  losses = validation(epoch, val_loss_avg)

  if losses < best_losses:
    best_losses = losses
    torch.save(reg_net.state_dict(), '/content/drive/MyDrive/ProgettoLab/checkpoints/model-epoch-{}-losses-{:.3f}.pth'.format(epoch+1,losses))



Starting training epoch 1



## Grafico della curva di apprendimento

In [None]:
plt.ion()
fig = plt.figure(figsize=(10, 5))
plt.plot(train_loss_avg)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training loss')
plt.show()

In [None]:
plt.ion()
fig = plt.figure(figsize=(10, 5))
plt.plot(val_loss_avg)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Validation loss')
plt.show()

## Risultato su un'immagine del Test Set

# Classificazione 

## Modello

In [None]:
class ColorizationCNet(nn.Module):
  def __init__(self):
    super(ColorizationCNet, self).__init__()

    self.network = nn.Sequential(
        nn.Conv2d(1,64,kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(64,64,kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(64),

        nn.Conv2d(64,128,kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(128,128,kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(128),

        nn.Conv2d(128,256,kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(256,256,kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(256),

        nn.Conv2d(256,512,kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(512,512,kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(512),

        nn.Conv2d(512,512,kernel_size=3, stride=1, padding=2, dilatation=2),
        nn.ReLU(),
        nn.Conv2d(512,512,kernel_size=3, stride=1, padding=2, dilatation=2),
        nn.ReLU(),
        nn.BatchNorm2d(512),

        nn.Conv2d(512,512,kernel_size=3, stride=1, padding=2, dilatation=2),
        nn.ReLU(),
        nn.Conv2d(512,512,kernel_size=3, stride=1, padding=2, dilatation=2),
        nn.ReLU(),
        nn.BatchNorm2d(512),

        nn.Conv2d(512,512,kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(512,512,kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(512),

        nn.ConvTransposed2d(512, 256, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(256,256, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),

        nn.Conv2d(256, 313, kenel_size=1, stride=1, padding=0),
        nn.Softmax(dim=1),
        nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1),
        nn.Upsample(scale_factor=4)
    )


  def forward(self, x):
    output = self.network(x)
    return output

In [None]:
clas_net = ColorizationCNet()

## Funzione di costo e di ottimizzazione

## Funzione per ottenere le label

In [None]:
def get_labels(batch_ab):
  labels = []
  for img in batch_ab:
    label = np.zeros((224*224,313))
    img = np.reshape(img, (-1,2))
    distances, indices = knn.kneighbors(img, 5)
    weights = np.exp(-distances ** 2 / (2 * SIGMA ** 2))
    weights = weights / np.sum(weights, axis=1)[:, np.newaxis]
    label[self.pixel_idx, indices] = weights
    label = np.reshape(label, (IMAGE_HEIGHT, IMAGE_WIDTH, 313))
    labels.append(label)
  return labels

## Train

## Grafico della curva di apprendimento

## Valutazione sul Test Set

## Risultato su un'immagine del Test Set