ATTENTION: Ensure that the dataset directories are changed to the appropriate directories if testing.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

In [None]:
class Generator(nn.Module):
  def __init__(self, num_input_channels = 1, num_output_channels = 3, num_features = 64):
    super(Generator, self).__init__()
    # Padding --> p = ((s - 1) * (input_size - 1) + k - input_size) / 2, where k is the kernel size and s is the stride.
    p = ((2 - 1) * (num_input_channels - 1) + 4 - num_input_channels) / 2
  
    self.enc1 = self._encoder(num_input_channels, num_features, kernel_size = 4, stride = 2, padding = 1, first_layer = True)
    self.enc2 = self._encoder(num_features, num_features * 2, kernel_size = 4, stride = 2, padding = 1, first_layer = False)
    self.enc3 = self._encoder(num_features * 2, num_features * 4, kernel_size = 4, stride = 2, padding = 1, first_layer = False)
    self.enc4 = self._encoder(num_features * 4, num_features * 8, kernel_size = 4, stride = 2, padding = 1, first_layer = False)
    self.enc5 = self._encoder(num_features * 8, num_features * 8, kernel_size = 4, stride = 2, padding = 1, first_layer = False)
    self.enc6 = self._encoder(num_features * 8, num_features * 8, kernel_size = 4, stride = 2, padding = 1, first_layer = False)
    self.enc7 = self._encoder(num_features * 8, num_features * 8, kernel_size = 4, stride = 2, padding = 1, first_layer = False)

    # Bottleneck:
    self.bottleneck = nn.Conv2d(num_features * 8, num_features * 8, kernel_size = 4, stride = 2, padding = 1, padding_mode = "reflect")
    self.relu = nn.ReLU()

    self.dec1 = self._decoder(num_features * 8, num_features * 8, kernel_size = 4, stride = 2, padding = 1, use_dropout = True)
    self.dec2 = self._decoder(num_features * 8 * 2, num_features * 8, kernel_size = 4, stride = 2, padding = 1, use_dropout = True)
    self.dec3 = self._decoder(num_features * 8 * 2, num_features * 8, kernel_size = 4, stride = 2, padding = 1, use_dropout = True)
    self.dec4 = self._decoder(num_features * 8 * 2, num_features * 8, kernel_size = 4, stride = 2, padding = 1, use_dropout = False)
    self.dec5 = self._decoder(num_features * 8 * 2, num_features * 4, kernel_size = 4, stride = 2, padding = 1, use_dropout = False)
    self.dec6 = self._decoder(num_features * 4 * 2, num_features * 2, kernel_size = 4, stride = 2, padding = 1, use_dropout = False)
    self.dec7 = self._decoder(num_features * 2 * 2, num_features, kernel_size = 4, stride = 2, padding = 1, use_dropout = False)
    self.dec8 = nn.ConvTranspose2d(num_features * 2, num_output_channels, kernel_size = 4, stride = 2, padding = 1)
    self.tanh = nn.Tanh() # For colorization, sigmoid can be more suitable as the values would be from 0 to 1 for the color channels, rather than Tanh's [-1, 1] range that
    # can lead to a steeper gradient IF THE COLORS ARE NORMALIZED. NOT IN THIS CASE. 
    
  

  def _encoder(self, in_channels, out_channels, kernel_size, stride, padding, first_layer = True):
    if first_layer == True:
      # Batch normalization is not applied on the first layer in the encoder. 
      return nn.Sequential(
          nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias = False, padding_mode = "reflect"),
          nn.LeakyReLU(0.2)
       )
    if first_layer == False:
      return nn.Sequential(
          nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias = False, padding_mode = "reflect"),
          nn.BatchNorm2d(out_channels),
          nn.LeakyReLU(0.2)
       )
  
  def _decoder(self, in_channels, out_channels, kernel_size, stride, padding, use_dropout = True):
    # All ReLUs in the decoder are NOT leaky.
    if use_dropout == True:
      return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.BatchNorm2d(out_channels),
        nn.Dropout2d(0.5),
        nn.ReLU(0.2)
      )
    if use_dropout == False:
      return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(0.2)
      )
  
  def forward(self, x):
    c1 = self.enc1(x)
    c2 = self.enc2(c1)
    c3 = self.enc3(c2)
    c4 = self.enc4(c3)
    c5 = self.enc5(c4)
    c6 = self.enc6(c5)
    c7 = self.enc7(c6)
    bottle = self.relu(self.bottleneck(c7))
    d1 = self.dec1(bottle)
    d2 = self.dec2(torch.cat([d1, c7], dim = 1))
    d3 = self.dec3(torch.cat([d2, c6], dim = 1))
    d4 = self.dec4(torch.cat([d3, c5], dim = 1))
    d5 = self.dec5(torch.cat([d4, c4], dim = 1))
    d6 = self.dec6(torch.cat([d5, c3], dim = 1))
    d7 = self.dec7(torch.cat([d6, c2], dim = 1))
    d8 = self.dec8(torch.cat([d7, c1], dim = 1))
    final = self.tanh(d8)

    return final
    

In [None]:
def test():
  x = torch.randn((1, 3, 256, 256))
  model = Generator(num_input_channels = 3, num_output_channels = 3, num_features = 64)
  preds = model(x)
  print(preds.shape)


test()

In [None]:
# 286 x 286 discriminator architecture:
class Discriminator(nn.Module):
  def __init__(self, input_channels = 3, features = [64, 128, 256, 512]):
    super(Discriminator, self).__init__()
    p = ((2 - 1) * (input_channels - 1) + 4 - input_channels) / 2
    self.l1 = self._block(input_channels * 2, features[0], kernel_size = 4, stride = 2, padding = 1, first_layer = True)
    layers = []
    in_channels = features[0]
    for feature in features[1:]:
      layers.append(
          self._block(in_channels, feature, kernel_size = 4, stride = 1 if feature == features[-1] else 2, padding = 1, first_layer = False)
      )
      in_channels = feature
    layers.append(self._block(512, 512, kernel_size = 4, stride = 1, padding = 1, first_layer = False))
    layers.append(self._block(512, 512, kernel_size = 4, stride = 1, padding = 1, first_layer = False))
    self.model = nn.Sequential(*layers)
    self.final = nn.Conv2d(features[-1], 1, kernel_size = 1, stride = 1, padding = 1)
    self.sigmoid = nn.Sigmoid()

# Standard convolutional layer.
  def _block(self, in_channels, out_channels, kernel_size, stride, padding, first_layer = True):
    if first_layer == True:
      return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, padding_mode = "reflect"),
        nn.LeakyReLU(0.2)
        )
    if first_layer == False:
      return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, padding_mode = "reflect"),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2)
      )  
  
  def forward(self, x, y):
    x = torch.cat([x, y], dim = 1)
    x = self.l1(x)
    x = self.model(x)
    x = self.final(x)
    x = self.sigmoid(x)

    return x



In [None]:
def test():
  x = torch.randn((1, 3, 286, 286))
  y = torch.randn((1, 3, 286, 286))
  model = Discriminator()
  preds = model(x, y)
  print(preds.shape)

In [None]:
if __name__ == "__main__":
  test()

In [None]:
import os
os.environ['KAGGLE_USERNAME'] = 'atasgeld'
os.environ['KAGGLE_KEY'] = '531dce420032bafdfc6a47e0ce5e1fcf'

!kaggle datasets download -d ashwingupta3012/human-faces

In [None]:
from PIL import Image
from torch.utils.data import Dataset
import zipfile

dataset_path = "/content/drive/MyDrive/FINALDATA/higher.zip"

with zipfile.ZipFile(dataset_path, 'r') as zip_file:
  zip_file.extractall('.')

In [None]:
transform = [
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ), (0.5,))
]

In [None]:
from torchvision.utils import save_image
import torchvision.utils as vutils

def save_some_examples(gen, val_loader, epoch, folder):
    gen.eval()
    with torch.no_grad():
        inputs = []
        outputs = []
        for i in range(8):
            x, y = val_loader.dataset[np.random.randint(len(val_loader.dataset))]
            x, y = x.to(device), y.to(device)
            y_fake = gen(x.unsqueeze(0))
            y_fake = y_fake.squeeze(0) * 0.5 + 0.5  # remove normalization
            inputs.append(x * 0.5 + 0.5)
            outputs.append(y_fake)
        inputs_grid = vutils.make_grid(inputs, nrow=4, normalize=True, scale_each=True)
        outputs_grid = vutils.make_grid(outputs, nrow=4, normalize=True, scale_each=True)
        grid = torch.cat([inputs_grid, outputs_grid], dim=1)
        vutils.save_image(grid, folder + f"/examples_{epoch}.png")
    gen.train()


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location= device)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [None]:
import glob
import cv2
from PIL import Image

class ImageDataset_color(Dataset):
    def __init__(self, root, transforms_=None, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.files = sorted(glob.glob(os.path.join(root) + "/*.*"))
        print(f"Number of files found: {len(self.files)}")
        print(self.files[:10])
    
    def __getitem__(self, index):
        img_A = cv2.imread(self.files[index % len(self.files)])
        img_A = cv2.cvtColor(img_A, cv2.COLOR_BGR2RGB)
        img_B = cv2.cvtColor(cv2.cvtColor(img_A, cv2.COLOR_RGB2GRAY), cv2.COLOR_GRAY2RGB)
        img_A = Image.fromarray(np.array(img_A), "RGB")
        img_B = Image.fromarray(np.array(img_B), "RGB")
        if np.random.random() < 0.5:
            img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], "RGB")
            img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], "RGB")

        img_A = self.transform(img_A)
        img_B = self.transform(img_B)

        return img_B, img_A

    def __len__(self):
        return len(self.files)

In [None]:
# Hyperparameter setup:
import albumentations as A
from albumentations.pytorch import ToTensorV2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
learning_rate = 2e-4 # A smaller learning rate should be used when training on a new dataset because we want to make small adjustments
# for the new dataset. 
batch_size = 4
num_workers = 2
image_size = 256
channels_img = 3
l1_lambda = 100
num_epochs = 100
load_model = True
save_model = True
checkpoint_disc = "/content/drive/MyDrive/FINALDATA/disc.pth.tar"
checkpoint_gen = "/content/drive/MyDrive/FINALDATA/gen.pth.tar"

both_transform = A.Compose([
    A.Resize(width=256, height=256),
    A.HorizontalFlip(p=0.5)
], additional_targets={"image0":"image"})

transform_only_input = A.Compose(
    [
     A.ColorJitter(p=0.2),
     A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
     A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
     ToTensorV2()
    ]
)

transform_only_mask = A.Compose(
    [
     A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
     A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
     ToTensorV2()
    ]
)

In [None]:
train_dataset = ImageDataset_color(root = '/content/framesito', transforms_ = transform)
val_dataset = ImageDataset_color(root = '/content/higher', transforms_ = transform)

# Show some images from the train dataset
for i in range(3):
    input_image, target_image = train_dataset[i]
    input_image = np.transpose(input_image, (1, 2, 0))
    target_image = np.transpose(target_image, (1, 2, 0))
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(input_image)
    axes[0].set_title('Input Image')
    axes[1].imshow(target_image)
    axes[1].set_title('Target Image')
    plt.show()

In [None]:
# Training:

disc = Discriminator(input_channels = 3).to(device)
gen = Generator(num_input_channels = 3, num_output_channels = 3).to(device)
opt_disc = optim.Adam(disc.parameters(), lr = learning_rate, betas = (0.5, 0.999))
opt_gen = optim.Adam(gen.parameters(), lr = learning_rate, betas = (0.5, 0.999))
bce_loss = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

if load_model == True:
  load_checkpoint(checkpoint_gen, gen, opt_gen, learning_rate)
  load_checkpoint(checkpoint_disc, disc, opt_disc, learning_rate)


#train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False)

In [None]:
def train_fn(disc, gen, loader, opt_disc, opt_gen, l1, bce_loss, g_scaler, d_scaler):
  loop = tqdm(loader, leave = True)

  for index, (x, y) in enumerate(loop):
    x = x.to(device)
    y = y.to(device)

    # Train discriminator:
    with torch.cuda.amp.autocast(): # Float 16.
      y_fake = gen(x)
      d_real = disc(x, y)
      d_fake = disc(x, y_fake) # 
      d_real_loss = bce_loss(d_real, torch.ones_like(d_real))
      d_fake_loss = bce_loss(d_fake, torch.zeros_like(d_fake))
      d_loss = (d_real_loss + d_fake_loss) / 2 # Divide by two to make the discriminator train 'slower' than the generator. 
    
    disc.zero_grad()
    d_scaler.scale(d_loss).backward(retain_graph = True)
    d_scaler.step(opt_disc)
    d_scaler.update()

    # Train generator:
    with torch.cuda.amp.autocast():
      d_fake = disc(x, y_fake)
      g_fake_loss = bce_loss(d_fake, torch.ones_like(d_fake)) # Trick the discriminator into thinking that these are real images. 
      l1 = l1_loss(y_fake, y) * l1_lambda
      g_loss = g_fake_loss + l1
    
    opt_gen.zero_grad() # Clear weights.
    g_scaler.scale(g_loss).backward()
    g_scaler.step(opt_gen)
    g_scaler.update()


In [None]:
for epoch in range(num_epochs):
  train_fn(disc, gen, train_loader, opt_disc, opt_gen, l1_loss, bce_loss, g_scaler, d_scaler)

  if save_model and epoch % 1 == 0:
    save_checkpoint(gen, opt_gen, filename = checkpoint_gen)
    save_checkpoint(disc, opt_disc, filename = checkpoint_disc)
  
  save_some_examples(gen, val_loader, epoch, folder = "./evaluation")


In [None]:
from PIL import Image

def test_image(gen, img_dir, folder):
  """
  Allows for testing of a single image by passing it in directly via its directory
  and converting it to a PIL-type image. 
  The image is transformed and resized to the appropriate size.
  Then, the input tensor is passed into the generator, which generates the 'fake'
  colorized image and saves it in the testing folder. 
  """
  input_image = Image.open(img_dir).convert('RGB')
  transform = transforms.Compose([
      transforms.Resize((1024, 1024)),
      transforms.ToTensor(), 
      transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
  ])
  input_tensor = transform(input_image).unsqueeze(0).to(device) # unsqueeze(0) 
  # adds a new dimension to the tensor at position 0. 
  # Since the transform function returns a tensor of shape (channels, height, width), we
  # want the resulting shape to be (1, height, width) so that only one image is fed (as a batch size of 1).
  gen.eval()
  with torch.no_grad():
    y_fake = gen(input_tensor)
    y_fake = y_fake * 0.5 + 0.5 # Remove normalization.
    save_image(y_fake, folder + f"/colorized_image.png")
    save_image(input_tensor.squeeze() * 0.5 + 0.5, folder + f"/original_image.png")
  gen.train()

In [None]:
test_image(gen, img_dir = "/content/5d4d43d585600a0462410713.jpg", folder = "testing")

In [None]:
save_some_examples(gen, val_loader, 1, "./evaluation")

In [None]:
load_checkpoint()