In [None]:
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader

from fastai.vision.learner import create_body
from fastai.vision.models.unet import DynamicUnet
from torchvision.models import resnet18

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
test_paths = sorted(glob.glob("/home/huuthanhvy.nguyen001/Image-Colorization/Testllama/model_gan_unet/test_100images/*.jpg"))

print("Found test images:", len(test_paths))  # Should print 5

Found test images: 101


In [8]:
def build_res_unet(n_input=1, n_output=2, size=256):

    model = resnet18(pretrained=True)

    body = create_body(model, pretrained=True, n_in=n_input, cut=-2)

    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    
    return net_G

In [11]:
def init_weights(net, init='norm', gain=0.02):
    
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)
            
    net.apply(init_func)
    print(f"model initialized with {init} initialization")
    return net

def init_model(model, device):
    model = model.to(device)
    model = init_weights(model)
    return model

class GANLoss(nn.Module):
    def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))
        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
    
    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)
    
    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

In [2]:
class PatchDiscriminator(nn.Module):
    def __init__(self, input_c, num_filters=64, n_down=3):
        super().__init__()
        model = [self.get_layers(input_c, num_filters, norm=False)]
        model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2) 
                          for i in range(n_down)] # the 'if' statement is taking care of not using
                                                  # stride of 2 for the last block in this loop
        model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] # Make sure to not use normalization or
                                                                                             # activation for the last layer of the model
        self.model = nn.Sequential(*model)                                                   
        
    def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): # when needing to make some repeatitive blocks of layers,
        layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]          # it's always helpful to make a separate method for that purpose
        if norm: layers += [nn.BatchNorm2d(nf)]
        if act: layers += [nn.LeakyReLU(0.2, True)]
        return nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [3]:
class MainModel(nn.Module):
    def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4, 
                 beta1=0.5, beta2=0.999, lambda_L1=100.):

        super().__init__()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1
        
        if net_G is None:
            self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
        else:
            self.net_G = net_G.to(self.device)

        self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
        
        self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
        self.L1criterion = nn.L1Loss()
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
    
    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad
        
    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)
        
    def forward(self):
        self.fake_color = self.net_G(self.L)
    
    def backward_D(self):
        
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image.detach())
        self.loss_D_fake = self.GANcriterion(fake_preds, False)
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()
    
    def backward_G(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()
    
    def optimize(self):
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()
        
        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
from skimage.color import rgb2lab
from torchvision import transforms


from PIL import Image
import numpy as np
from skimage.color import rgb2lab
from torchvision import transforms

# Output directories
train_dir = "/hpcstor6/scratch01/h/huuthanhvy.nguyen001/imagenet_train"
val_dir = "/hpcstor6/scratch01/h/huuthanhvy.nguyen001/imagenet_val"

train_count = len([f for f in os.listdir(train_dir) if f.endswith('.jpg')])
val_count = len([f for f in os.listdir(val_dir) if f.endswith('.jpg')])

print(f"Train images: {train_count}")
print(f"Validation images: {val_count}")


SIZE = 256

class ColorizationAugmentor(Dataset):
    def __init__(self, image_paths, split='train', num_augmentations=1):
        self.paths = image_paths
        self.split = split
        self.num_augmentations = num_augmentations

    def get_transforms(self, random=True):
        if self.split == 'train':
            transform_list = [
                transforms.Resize((SIZE, SIZE), Image.BICUBIC),
            ]
            if random:
                transform_list.append(transforms.RandomRotation(degrees=(0, 180)))
            return transforms.Compose(transform_list)
        elif self.split == 'val':
            return transforms.Resize((SIZE, SIZE), Image.BICUBIC)

    def load_original_image(self, path):
        img = Image.open(path).convert("RGB")
        return img

    def transform_and_convert_to_lab(self, img, transform):
        img_transformed = transform(img)
        img_np = np.array(img_transformed)
        img_lab = rgb2lab(img_np).astype("float32")
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1.
        ab = img_lab[[1, 2], ...] / 110.
        return L, ab

    def get_item_with_original(self, idx):
        img = self.load_original_image(self.paths[idx])
        if self.num_augmentations == 1:
            transform = self.get_transforms(random=True)
            L, ab = self.transform_and_convert_to_lab(img, transform)
            return {'original': img, 'L': L, 'ab': ab}
        else:
            augmented = []
            for _ in range(self.num_augmentations):
                transform = self.get_transforms(random=True)
                L, ab = self.transform_and_convert_to_lab(img, transform)
                augmented.append({'original': img, 'L': L, 'ab': ab})
            return augmented

    def dataset_length(self):
        return len(self.paths) * self.num_augmentations

SIZE = 256

class ColorizationAugmentor(Dataset):
    def __init__(self, paths, split='train', num_augmentations=1):
        self.paths = paths
        self.split = split
        self.num_augmentations = num_augmentations

    def get_transforms(self, random=True):
        if self.split == 'train':
            transform_list = [
                transforms.Resize((SIZE, SIZE), Image.BICUBIC),
            ]
            if random:
                transform_list.append(transforms.RandomRotation(degrees=(0, 180)))
            return transforms.Compose(transform_list)
        elif self.split == 'val':
            return transforms.Resize((SIZE, SIZE), Image.BICUBIC)

    def load_original_image(self, path):
        return Image.open(path).convert("RGB")

    def transform_and_convert_to_lab(self, img, transform):
        img_transformed = transform(img)
        img_np = np.array(img_transformed)
        img_lab = rgb2lab(img_np).astype("float32")
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1.
        ab = img_lab[[1, 2], ...] / 110.
        return L, ab

    def __len__(self):
        return len(self.paths) * self.num_augmentations

    def __getitem__(self, idx):
        image_idx = idx // self.num_augmentations
        img = self.load_original_image(self.paths[image_idx])
        transform = self.get_transforms(random=True)
        L, ab = self.transform_and_convert_to_lab(img, transform)
        return {'L': L, 'ab': ab}

def make_dataloaders(paths, split='train', batch_size=16, num_augmentations=1, num_workers=4):
    dataset = ColorizationAugmentor(paths=paths, split=split, num_augmentations=num_augmentations)
    shuffle = (split == 'train')
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True
    )
    return dataloader

train_paths = sorted(glob.glob("/hpcstor6/scratch01/h/huuthanhvy.nguyen001/imagenet_train/*.jpg"))
val_paths = sorted(glob.glob("/hpcstor6/scratch01/h/huuthanhvy.nguyen001/imagenet_val/*.jpg"))

train_dl = make_dataloaders(paths=train_paths, split='train', num_augmentations=2)
val_dl = make_dataloaders(paths=val_paths, split='val', num_augmentations=1)

data = next(iter(train_dl))
print("L shape:", data['L'].shape)
print("ab shape:", data['ab'].shape)

print("Train loader batch size:", len(train_dl))
print("Val loader batch size:", len(val_dl))

class AverageMeter:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3
    
    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

def create_loss_meters():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()
    
    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G}

def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)

def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """
    
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)
    


In [5]:
import torch.optim as optim
from skimage.color import lab2rgb
from tqdm import tqdm

def visualize(model, data, save=True):
    model.net_G.eval()

    with torch.no_grad():
        model.setup_input(data)
        model.forward()

    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab

    L = model.L

    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)

    fig = plt.figure(figsize=(15, 8))

    for i in range(5):

        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(fake_imgs[i])
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(real_imgs[i])
        ax.axis("off")

    plt.show()

    if save:
        fig.savefig(f"colorization_{time.time()}.png")
        
def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")


def train_model(model, train_dl, epochs, display_every=200):
    for e in range(epochs):
        loss_meter_dict = create_loss_meters()  # Track losses for this epoch
        i = 0
        for data in tqdm(train_dl, desc=f"Epoch {e+1}/{epochs} [Training]"):
            model.setup_input(data)
            model.optimize()
            update_losses(model, loss_meter_dict, count=data['L'].size(0))

            i += 1
            if i % display_every == 0:
                print(f"\nEpoch {e+1}/{epochs} - Iteration {i}/{len(train_dl)}")
                log_results(loss_meter_dict)

                visualize(model, data, save=True)

        print(f"\n✅ End of Epoch {e+1} - Training Losses:")
        
        log_results(loss_meter_dict)


In [None]:
#import torch

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# net_G = build_res_unet(n_input=1, n_output=2, size=256)

# net_G.load_state_dict(torch.load("res18-unet.pt", map_location=device))

# model = MainModel(net_G=net_G)

# train_model(model, train_dl, epochs=50)

# #TRAIN MODELs function here, next time I have to edit the function so it will save the best models with the lowest lost.

# torch.save(model.state_dict(), "final_model_weights.pt")

In [None]:
'''
Generate output on your own set of images

'''
img_path = '/home/huuthanhvy.nguyen001/Image-Colorization/Testllama/model_gan_unet/test'
print(img_path)
paths = glob.glob(img_path + "/*")
idxs = np.arange(len(paths))

test_dl = make_dataloaders(paths=paths, split='val')
for data in tqdm(test_dl):
  model.setup_input(data)
  model.optimize()
  visualize(model, data, save=False)

In [None]:
import torch
import glob
import matplotlib.pyplot as plt
import numpy as np
from skimage.color import lab2rgb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------------------------------
# Load test images
# ------------------------------

test_paths = sorted(glob.glob("/home/huuthanhvy.nguyen001/Image-Colorization/Testllama/model_gan_unet/test_100images/*.jpg"))

print("Found test images:", len(test_paths))  # Should print 5

test_dl = make_dataloaders(paths=test_paths, split='val', batch_size=1, num_augmentations=1)

# ------------------------------
# Load final trained model
# ------------------------------
net_G = build_res_unet(n_input=1, n_output=2, size=256)
model = MainModel(net_G=net_G)
model.load_state_dict(torch.load("/home/huuthanhvy.nguyen001/Image-Colorization/Testllama/model_gan_unet/final_model_weights.pt", map_location=device))
model.to(device)
model.eval()

# ------------------------------
# LAB to RGB conversion
# ------------------------------
def lab_to_rgb(L, ab):
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = [lab2rgb(img) for img in Lab]
    return np.stack(rgb_imgs)



In [None]:
import matplotlib.pyplot as plt

def visualize_all_test_images(model, test_dl, test_paths=None, chunk_size=10):
    model.eval()
    L_list, fake_list, real_list, names = [], [], [], []

    for idx, data in enumerate(test_dl):
        data = {k: v.to(device) for k, v in data.items()}
        with torch.no_grad():
            model.setup_input(data)
            model.forward()

        fake_color = model.fake_color.detach()
        real_color = model.ab
        L = model.L

        fake_imgs = lab_to_rgb(L, fake_color)
        real_imgs = lab_to_rgb(L, real_color)

        L_list.append(L[0][0].cpu())
        fake_list.append(fake_imgs[0])
        real_list.append(real_imgs[0])
        if test_paths:
            names.append(os.path.basename(test_paths[idx]))
        else:
            names.append(f"Image {idx+1}")

    # Split into chunks
    total = len(fake_list)
    for start in range(0, total, chunk_size):
        end = min(start + chunk_size, total)
        fig, axs = plt.subplots(end - start, 3, figsize=(12, (end - start) * 4))

        if end - start == 1:
            axs = [axs]  # wrap

        for i in range(start, end):
            row = i - start
            axs[row][0].imshow(L_list[i], cmap='gray')
            axs[row][0].set_title(f"{names[i]}: Grayscale Input")
            axs[row][0].axis("off")

            axs[row][1].imshow(fake_list[i])
            axs[row][1].set_title("Predicted Color")
            axs[row][1].axis("off")

            axs[row][2].imshow(real_list[i])
            axs[row][2].set_title("Ground Truth")
            axs[row][2].axis("off")

        

        plt.tight_layout()
        plt.show()


# 🔍 Run it
visualize_all_test_images(model, test_dl)