## 1) Setup & Data Preperation

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# prompt: copy from "/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/Anime2Sketch/netG.pth" to "/content/weights"

import shutil
import os

# Create the destination directory if it doesn't exist
os.makedirs("/content/weights", exist_ok=True)

shutil.copy("/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/Anime2Sketch/netG.pth", "/content/weights")


In [22]:
# !cp "/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/data/ColoredMangaMain.zip" "/content/"
# !unzip -q "/content/ColoredMangaMain.zip" -d "/content/ColoredMangaMain"

import os
import glob
import random
import functools
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from skimage.color import rgb2lab, lab2rgb
import matplotlib.pyplot as plt
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Locate images inside the unzipped folder.
# Adjust this path if your images are in subfolders, etc.
all_image_paths = glob.glob("/content/ColoredMangaMain/**/*.png", recursive=True)
all_image_paths += glob.glob("/content/ColoredMangaMain/**/*.jpg", recursive=True)
all_image_paths += glob.glob("/content/ColoredMangaMain/**/*.jpeg", recursive=True)

print(f"Found {len(all_image_paths)} images.")

Using device: cuda
Found 36285 images.


In [23]:
np.random.seed(42)
np.random.shuffle(all_image_paths)
N = min(len(all_image_paths), 15000)
all_image_paths = all_image_paths[:N]
train_size = int(0.8 * len(all_image_paths))
train_paths = all_image_paths[:train_size]
val_paths   = all_image_paths[train_size:]

print(f"Train set size: {len(train_paths)}")
print(f"Val set size:   {len(val_paths)}")

Train set size: 12000
Val set size:   3000


## 2) Dataset Definition

In [24]:
class MangaColorizationDataset(Dataset):
    """
    Dataset for manga colorization.
    - Reads a colored image from disk
    - Converts to grayscale with CV2 (BGR->GRAY)
    - Threshold the grayscale to produce an "uncolored" L
    - Convert original image to Lab using skimage
    - Output: (l_input, l_diff, a, b)
    """
    def __init__(self, image_paths, threshold=100, img_size=256):
        self.image_paths = image_paths
        self.threshold = threshold
        self.img_size = img_size  # We will resize images to (img_size, img_size)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        # Read image with CV2, which is BGR by default
        bgr_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        if bgr_img is None:
            raise ValueError(f"Error loading image at: {img_path}")

        # Resize if desired
        if self.img_size is not None:
            bgr_img = cv2.resize(bgr_img, (self.img_size, self.img_size))

        # Convert to grayscale with CV2
        gray_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2GRAY)

        # Threshold
        #   "If grayscale >= threshold => 255 else grayscale"
        l_input = np.where(gray_img >= self.threshold, 255, gray_img).astype(np.float32)

        # The difference between the "true grayscale" and the thresholded one
        #   true grayscale in range [0..255]
        #   l_input in range [0..255]
        # => difference in range [-255..+255]
        l_difference = (gray_img.astype(np.float32) - l_input).astype(np.float32)

        # Convert original BGR->RGB->Lab for the ground truth color
        rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
        rgb_img_float = rgb_img.astype(np.float32) / 255.0  # [0..1]

        lab_img = rgb2lab(rgb_img_float)  # returns L:[0..100], a,b:[-128..127]
        L_lab = lab_img[...,0]  # in [0..100]
        A_lab = lab_img[...,1]  # in [-128..127]
        B_lab = lab_img[...,2]  # in [-128..127]

        # Normalize the inputs
        # l_input in [0..255], => /255 => [0..1]
        l_input_norm = l_input / 255.0
        # l_difference in [-255..255], => /255 => [-1..1]
        l_diff_norm  = l_difference / 255.0

        # For A and B, normalize to [-1..+1]
        a_norm = A_lab / 128.0
        b_norm = B_lab / 128.0

        # Convert to tensors, shape: (1, H, W)
        t_l_input = torch.tensor(l_input_norm, dtype=torch.float32).unsqueeze(0)
        t_l_diff  = torch.tensor(l_diff_norm, dtype=torch.float32).unsqueeze(0)
        t_a       = torch.tensor(a_norm,      dtype=torch.float32).unsqueeze(0)
        t_b       = torch.tensor(b_norm,      dtype=torch.float32).unsqueeze(0)

        return t_l_input, t_l_diff, t_a, t_b

#### Create datasets/dataloaders

In [35]:
train_dataset = MangaColorizationDataset(train_paths, threshold=100, img_size=512)
val_dataset   = MangaColorizationDataset(val_paths,   threshold=100, img_size=512)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset,   batch_size=16, shuffle=False, num_workers=4)

print("Train/Val loader ready.")

Train/Val loader ready.


## 3) Model Definition (U-Net)

In [26]:
class UnetSkipConnectionBlock(nn.Module):
    """Defines the U-Net submodule with skip connections."""
    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False,
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if isinstance(norm_layer, functools.partial):
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        if input_nc is None:
            input_nc = outer_nc

        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(
                inner_nc * 2, outer_nc, kernel_size=4,
                stride=2, padding=1
            )
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(
                inner_nc, outer_nc, kernel_size=4,
                stride=2, padding=1, bias=use_bias
            )
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(
                inner_nc * 2, outer_nc, kernel_size=4,
                stride=2, padding=1, bias=use_bias
            )
            down = [downrelu, downconv, downnorm]
            up   = [uprelu, upconv, upnorm]
            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

In [27]:
class UnetGenerator(nn.Module):
    """Original style U-Net with flexible in/out channels."""
    def __init__(self, input_nc, output_nc, num_downs, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetGenerator, self).__init__()
        # construct unet structure from innermost to outermost
        unet_block = UnetSkipConnectionBlock(
            ngf * 8, ngf * 8, input_nc=None,
            submodule=None, norm_layer=norm_layer, innermost=True
        )
        # add intermediate layers with 8 * ngf filters
        for _ in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(
                ngf * 8, ngf * 8, input_nc=None,
                submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout
            )
        # gradually reduce the number of filters from ngf * 8 to ngf
        unet_block = UnetSkipConnectionBlock(
            ngf * 4, ngf * 8, input_nc=None,
            submodule=unet_block, norm_layer=norm_layer
        )
        unet_block = UnetSkipConnectionBlock(
            ngf * 2, ngf * 4, input_nc=None,
            submodule=unet_block, norm_layer=norm_layer
        )
        unet_block = UnetSkipConnectionBlock(
            ngf, ngf * 2, input_nc=None,
            submodule=unet_block, norm_layer=norm_layer
        )
        # outermost
        self.model = UnetSkipConnectionBlock(
            output_nc, ngf, input_nc=input_nc,
            submodule=unet_block, outermost=True,
            norm_layer=norm_layer
        )

    def forward(self, x):
        return self.model(x)

## 4) Partial Weight Loading

In [28]:
def create_pretrained_unet_for_manga(
    anime2sketch_ckpt="weights/netG.pth",
    norm_layer=None,
    num_downs=8,
    ngf=64,
    use_dropout=False
):
    """
    Create a new U-Net generator with input=1, output=3
    and partially load from the anime2sketch checkpoint (which is 3->1).
    """
    if norm_layer is None:
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)

    # 1 input channel -> 3 output channels
    net_manga = UnetGenerator(
        input_nc=1,
        output_nc=3,
        num_downs=num_downs,
        ngf=ngf,
        norm_layer=norm_layer,
        use_dropout=use_dropout
    )

    # Load the checkpoint
    ckpt = torch.load(anime2sketch_ckpt, map_location="cpu")
    # Remove any 'module.' prefixes
    for key in list(ckpt.keys()):
        if "module." in key:
            ckpt[key.replace("module.", "")] = ckpt[key]
            del ckpt[key]

    # Only load matching layers
    model_dict = net_manga.state_dict()
    pretrained_dict = {}
    for k, v in ckpt.items():
        if k in model_dict and model_dict[k].shape == v.shape:
            pretrained_dict[k] = v

    model_dict.update(pretrained_dict)
    net_manga.load_state_dict(model_dict)

    return net_manga

## 5) Create model, loss, optimizer

In [31]:
anime2sketch_ckpt_path = "/content/weights/netG.pth"

net_manga = create_pretrained_unet_for_manga(
    anime2sketch_ckpt=anime2sketch_ckpt_path,
    norm_layer=functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False),
    num_downs=8,
    ngf=64,
    use_dropout=False
)
net_manga.to(device)
net_manga.train()

optimizer = torch.optim.Adam(net_manga.parameters(), lr=1e-4)
criterion = nn.L1Loss()

  ckpt = torch.load(anime2sketch_ckpt, map_location="cpu")


## 6) Training + Validation + Visualization

In [32]:
def denormalize_lab_tensors(l_in, l_diff, a_pred, b_pred):
    """
    Convert network outputs (normalized L-diff, A, B) back to an RGB image.
    Steps:
      1) final L = l_in + l_diff
      2) Scale back to Lab domain
      3) Convert Lab->RGB
    """
    # l_in in [0,1]
    # l_diff in [-1,1]
    # final L in [0,1] (some potential for going out of range if the diff is big, but typically in [-1..1].)
    L_final = (l_in + l_diff).clamp(0,1)  # shape [H,W]

    # A in [-1..1], B in [-1..1]
    # We map them back: A' = A*128, B' = B*128
    A_lab = a_pred * 128.0
    B_lab = b_pred * 128.0

    # L_final is in [0..1], for lab we want [0..100].
    L_lab = L_final * 100.0

    # Combine into [H,W,3]
    lab_img = torch.stack([L_lab, A_lab, B_lab], dim=-1).cpu().numpy()
    # Convert LAB -> RGB with skimage
    rgb_img = lab2rgb(lab_img.astype(np.float64))
    return rgb_img


def visualize_validation_samples(net, val_loader, num_samples=3, device=device):
    """
    Grab a few batches from validation, display predictions vs ground truth.
    """
    net.eval()
    with torch.no_grad():
        # Just take one batch
        for (l_in, l_diff_gt, a_gt, b_gt) in val_loader:
            # Move to GPU if available
            l_in     = l_in.to(device)
            l_diff_gt= l_diff_gt.to(device)
            a_gt     = a_gt.to(device)
            b_gt     = b_gt.to(device)

            # Forward pass
            output = net(l_in)
            # output has shape [B,3,H,W]
            pred_l_diff = output[:,0:1,...]
            pred_a      = output[:,1:2,...]
            pred_b      = output[:,2:3,...]

            # We'll visualize up to num_samples from this batch
            batch_size = l_in.size(0)
            n_show = min(batch_size, num_samples)

            for i in range(n_show):
                # Extract single image
                l_in_i     = l_in[i,0,:,:].detach()     # [H,W]
                l_diff_i   = pred_l_diff[i,0,:,:].detach()
                a_i        = pred_a[i,0,:,:].detach()
                b_i        = pred_b[i,0,:,:].detach()

                # Ground truth Lab
                gt_l_diff_i = l_diff_gt[i,0,:,:].detach()
                gt_a_i      = a_gt[i,0,:,:].detach()
                gt_b_i      = b_gt[i,0,:,:].detach()

                # Convert predictions to RGB
                pred_rgb = denormalize_lab_tensors(l_in_i, l_diff_i, a_i, b_i)
                # Convert ground truth to RGB
                gt_rgb   = denormalize_lab_tensors(l_in_i, gt_l_diff_i, gt_a_i, gt_b_i)

                # Prepare for plotting
                # We'll plot:
                #   (1) the "input" grayscale thresholded (for display),
                #   (2) predicted colorization,
                #   (3) ground truth
                # They are all in numpy HxWx3 or HxW
                inp_vis = l_in_i.cpu().numpy()  # [H,W], in [0..1]

                fig, axes = plt.subplots(1,3, figsize=(12,4))
                # 1) input
                axes[0].imshow(inp_vis, cmap='gray', vmin=0, vmax=1)
                axes[0].set_title("Input L (thresholded)")

                # 2) prediction
                axes[1].imshow(pred_rgb)
                axes[1].set_title("Prediction (RGB)")

                # 3) ground truth
                axes[2].imshow(gt_rgb)
                axes[2].set_title("Ground Truth (RGB)")

                for ax in axes:
                    ax.axis('off')
                plt.tight_layout()
                plt.show()

            break  # only visualize one batch
    net.train()


def run_validation(net, val_loader, criterion, device=device):
    """
    Compute average validation L1 loss over the entire val_loader.
    """
    net.eval()
    val_loss = 0.0
    count = 0
    with torch.no_grad():
        for (l_in, l_diff_gt, a_gt, b_gt) in val_loader:
            l_in     = l_in.to(device)
            l_diff_gt= l_diff_gt.to(device)
            a_gt     = a_gt.to(device)
            b_gt     = b_gt.to(device)

            output = net(l_in)
            pred_l_diff = output[:,0:1,...]
            pred_a      = output[:,1:2,...]
            pred_b      = output[:,2:3,...]

            loss_l_diff = criterion(pred_l_diff, l_diff_gt)
            loss_a      = criterion(pred_a, a_gt)
            loss_b      = criterion(pred_b, b_gt)

            loss = loss_l_diff + loss_a + loss_b
            val_loss += loss.item() * l_in.size(0)
            count    += l_in.size(0)

    avg_val_loss = val_loss / count
    net.train()
    return avg_val_loss

In [33]:
num_epochs = 20
print_every = 3  # visualize & show val loss every X epochs

for epoch in range(num_epochs):
    running_loss = 0.0
    total_samples = 0

    # TQDM progress bar over training
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)
    for (l_in, l_diff_gt, a_gt, b_gt) in pbar:
        l_in     = l_in.to(device)
        l_diff_gt= l_diff_gt.to(device)
        a_gt     = a_gt.to(device)
        b_gt     = b_gt.to(device)

        optimizer.zero_grad()
        output = net_manga(l_in)
        pred_l_diff = output[:,0:1,...]
        pred_a      = output[:,1:2,...]
        pred_b      = output[:,2:3,...]

        loss_l_diff = criterion(pred_l_diff, l_diff_gt)
        loss_a      = criterion(pred_a, a_gt)
        loss_b      = criterion(pred_b, b_gt)

        lambda_l_diff = 1   # weight for the l_diff loss
        loss = lambda_l_diff * loss_l_diff + loss_a + loss_b

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * l_in.size(0)
        total_samples += l_in.size(0)
        pbar.set_postfix({"loss": f"{(running_loss/total_samples):.4f}"})

    # End of epoch
    train_epoch_loss = running_loss / total_samples
    print(f"[Epoch {epoch+1}/{num_epochs}] Train Loss: {train_epoch_loss:.4f}")

    # Every 3 epochs (or as desired), run validation and visualize
    if (epoch+1) % print_every == 0:
        val_loss = run_validation(net_manga, val_loader, criterion, device=device)
        print(f"[Epoch {epoch+1}/{num_epochs}] Validation Loss: {val_loss:.4f}")

        # Visualize some predictions on val set
        visualize_validation_samples(net_manga, val_loader, num_samples=2, device=device)


# Save the fine-tuned model
save_path = "/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/Anime2Sketch_colorization/anime2sketch_finetuned.pth"
model_dir = os.path.dirname(save_path)
os.makedirs(model_dir, exist_ok=True)
torch.save(net_manga, save_path)
print("Finished training! Model saved to:", save_path)

Output hidden; open in https://colab.research.google.com to view.

In [34]:
save_path = "/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/Anime2Sketch_colorization/anime2sketch_finetuned.pth"
model_dir = os.path.dirname(save_path)
os.makedirs(model_dir, exist_ok=True)
torch.save(net_manga, save_path)
print("Finished training! Model saved to:", save_path)

Finished training! Model saved to: /content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/Anime2Sketch_colorization/anime2sketch_finetuned.pth


## Training with higher resolution

In [36]:
anime2sketch_ckpt_path = "/content/weights/netG.pth"

net_manga1 = create_pretrained_unet_for_manga(
    anime2sketch_ckpt=anime2sketch_ckpt_path,
    norm_layer=functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False),
    num_downs=8,
    ngf=64,
    use_dropout=False
)
net_manga1.to(device)
net_manga1.train()

optimizer = torch.optim.Adam(net_manga1.parameters(), lr=1e-4)
criterion = nn.L1Loss()

  ckpt = torch.load(anime2sketch_ckpt, map_location="cpu")


In [37]:
num_epochs = 20
print_every = 3  # visualize & show val loss every X epochs

for epoch in range(num_epochs):
    running_loss = 0.0
    total_samples = 0

    # TQDM progress bar over training
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)
    for (l_in, l_diff_gt, a_gt, b_gt) in pbar:
        l_in     = l_in.to(device)
        l_diff_gt= l_diff_gt.to(device)
        a_gt     = a_gt.to(device)
        b_gt     = b_gt.to(device)

        optimizer.zero_grad()
        output = net_manga1(l_in)
        pred_l_diff = output[:,0:1,...]
        pred_a      = output[:,1:2,...]
        pred_b      = output[:,2:3,...]

        loss_l_diff = criterion(pred_l_diff, l_diff_gt)
        loss_a      = criterion(pred_a, a_gt)
        loss_b      = criterion(pred_b, b_gt)

        lambda_l_diff = 0.5   # weight for the l_diff loss
        loss = lambda_l_diff * loss_l_diff + loss_a + loss_b

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * l_in.size(0)
        total_samples += l_in.size(0)
        pbar.set_postfix({"loss": f"{(running_loss/total_samples):.4f}"})

    # End of epoch
    train_epoch_loss = running_loss / total_samples
    print(f"[Epoch {epoch+1}/{num_epochs}] Train Loss: {train_epoch_loss:.4f}")


    val_loss = run_validation(net_manga1, val_loader, criterion, device=device)
    print(f"[Epoch {epoch+1}/{num_epochs}] Validation Loss: {val_loss:.4f}")

    # Every 3 epochs (or as desired), run validation and visualize
    if (epoch+1) % print_every == 0:
        # Visualize some predictions on val set
        visualize_validation_samples(net_manga1, val_loader, num_samples=2, device=device)


# Save the fine-tuned model
save_path = "/content/drive/MyDrive/Okul/Eğitim/Ders/5. Dönem/YZV 303E - Deep Learning/Project/models/Anime2Sketch_colorization/anime2sketch_finetuned_512p.pth"
model_dir = os.path.dirname(save_path)
os.makedirs(model_dir, exist_ok=True)
torch.save(net_manga1, save_path)
print("Finished training! Model saved to:", save_path)

Output hidden; open in https://colab.research.google.com to view.

## 7) Inference Function

In [40]:
def colorize_manga_page(model, image_path, threshold=100, img_size=512, device=device):
    """
    Let the user input their own manga page (uncolored),
    The model predicts L-diff, A, B -> return colorized RGB.
    """
    model.eval()
    with torch.no_grad():
        bgr_img = cv2.imread(image_path, cv2.IMREAD_COLOR)
        if bgr_img is None:
            raise ValueError(f"Could not load image at {image_path}")

        # Resize if needed
        if img_size is not None:
            bgr_img = cv2.resize(bgr_img, (img_size, img_size))

        gray_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2GRAY)
        l_input = np.where(gray_img >= threshold, 255, gray_img).astype(np.float32)
        l_in_norm = l_input / 255.0  # [0..1]
        l_in_torch = torch.tensor(l_in_norm, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

        # Forward
        output = model(l_in_torch)
        pred_l_diff = output[:,0:1,...]
        pred_a      = output[:,1:2,...]
        pred_b      = output[:,2:3,...]

        # Convert to final RGB
        l_in_i = l_in_torch[0,0,:,:]
        colorized_rgb = denormalize_lab_tensors(l_in_i, pred_l_diff[0,0,:,:], pred_a[0,0,:,:], pred_b[0,0,:,:])

    return colorized_rgb

In [None]:
# Example usage (after training):
colorized_image = colorize_manga_page(net_manga1, "/content/test_images/Bleach v1-063.jpg")
plt.imshow(colorized_image)
plt.axis('off')
plt.show()