# DDColor Image Colorization Training

Training pipeline for DDColor model. See README.md for detailed architecture and theory.

## Setup & Installation

```bash
pip install torch torchvision torchaudio
pip install opencv-python pillow numpy pyyaml scipy scikit-image timm tensorboard
pip install lmdb lpips
```

## Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torch.amp import autocast, GradScaler

import cv2
import numpy as np
from PIL import Image
import yaml
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import OrderedDict
import time
import os
import shutil
import urllib.request

from ddcolor_model import DDColor
from basicsr.losses import PerceptualLoss, GANLoss

## Configuration

In [None]:
class Config:
    # Paths - modify these for your setup
    train_data_dir = './dataset/train'
    val_data_dir = './dataset/test'
    save_dir = './experiments/ddcolor_custom'
    pretrain_dir = './pretrain'
    
    # Model architecture
    encoder_name = 'convnext-l'
    decoder_name = 'MultiScaleColorDecoder'
    num_queries = 100
    num_scales = 3
    dec_layers = 9
    input_size = 512
    
    # Training parameters
    batch_size = 4
    num_workers = 4
    num_epochs = 100
    total_iters = 400000
    
    # Optimizer
    lr = 1e-4
    weight_decay = 0.01
    betas = (0.9, 0.99)
    
    # Loss weights
    lambda_pixel = 0.1
    lambda_perceptual = 5.0
    lambda_gan = 1.0
    lambda_colorfulness = 0.5
    
    # Logging frequencies
    print_freq = 100
    save_freq = 5000
    val_freq = 2000
    
    # Device configuration
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    use_amp = torch.cuda.is_available()
    
    # Multi-GPU settings
    use_distributed = False
    gpu_ids = [0]

config = Config()

# Create necessary directories
for directory in [config.save_dir, config.pretrain_dir, 
                  f'{config.save_dir}/checkpoints', f'{config.save_dir}/samples']:
    os.makedirs(directory, exist_ok=True)

print(f"Configuration loaded")
print(f"Device: {config.device}")
print(f"Mixed precision: {config.use_amp}")
print(f"Training data: {config.train_data_dir}")

## Download Pretrained Weights

In [None]:
def download_pretrained_weights():
    weights = {
        'convnext': {
            'url': 'https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth',
            'path': f'{config.pretrain_dir}/convnext_large_22k_224.pth'
        },
        'inception': {
            'url': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
            'path': f'{config.pretrain_dir}/inception_v3_google-1a9a5a14.pth'
        }
    }
    
    for name, info in weights.items():
        if not os.path.exists(info['path']):
            print(f"Downloading {name} weights...")
            urllib.request.urlretrieve(info['url'], info['path'])
            print(f"Downloaded: {info['path']}")
        else:
            print(f"{name} weights found: {info['path']}")

download_pretrained_weights()

## Dataset

In [None]:
class ColorizeDataset(Dataset):
    def __init__(self, image_dir, input_size=256, is_train=True):
        self.image_dir = Path(image_dir)
        self.input_size = input_size
        self.is_train = is_train
        
        self.image_paths = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
            self.image_paths.extend(list(self.image_dir.glob(ext)))
        
        if len(self.image_paths) == 0:
            raise ValueError(f"No images found in {image_dir}")
        
        print(f"Found {len(self.image_paths)} images in {image_dir}")
    
    def __len__(self):
        return len(self.image_paths)
    
    def rgb_to_lab(self, img):
        img_lab = cv2.cvtColor((img * 255).astype(np.uint8), cv2.COLOR_RGB2LAB)
        return img_lab.astype(np.float32) / 255.0
    
    def lab_to_rgb(self, img_lab):
        img_lab = (img_lab * 255).astype(np.uint8)
        img_rgb = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
        return img_rgb.astype(np.float32) / 255.0
    
    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        img = img.resize((self.input_size, self.input_size), Image.BICUBIC)
        img = np.array(img).astype(np.float32) / 255.0
        
        img_lab = self.rgb_to_lab(img)
        img_l = img_lab[:, :, 0:1]
        img_ab = img_lab[:, :, 1:3]
        
        img_l = torch.from_numpy(img_l.transpose(2, 0, 1)).float()
        img_ab = torch.from_numpy(img_ab.transpose(2, 0, 1)).float()
        
        # Create grayscale RGB for model input
        img_gray_lab = np.concatenate([img_lab[:, :, 0:1], 
                                        np.zeros_like(img_l.numpy().transpose(1, 2, 0)),
                                        np.zeros_like(img_l.numpy().transpose(1, 2, 0))], axis=-1)
        img_gray_rgb = self.lab_to_rgb(img_gray_lab)
        img_gray_rgb = torch.from_numpy(img_gray_rgb.transpose(2, 0, 1)).float()
        
        return {
            'gray_rgb': img_gray_rgb,
            'l': img_l,
            'ab': img_ab,
            'rgb': torch.from_numpy(img.transpose(2, 0, 1)).float()
        }

train_dataset = ColorizeDataset(config.train_data_dir, config.input_size, is_train=True)
val_dataset = ColorizeDataset(config.val_data_dir, config.input_size, is_train=False)

train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"Dataset loaded: {len(train_dataset)} train, {len(val_dataset)} val")

## Loss Functions

In [None]:
class ColorfulnessLoss(nn.Module):
    """Encourages vibrant, saturated colors"""
    def __init__(self):
        super().__init__()
    
    def forward(self, ab_pred):
        std = torch.std(ab_pred, dim=[2, 3])
        mean = torch.mean(torch.abs(ab_pred), dim=[2, 3])
        colorfulness = torch.mean(std + 0.3 * mean)
        return 1.0 / (colorfulness + 1e-8)

class PixelLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss_fn = nn.L1Loss()
    
    def forward(self, pred, target):
        return self.loss_fn(pred, target)

pixel_loss = PixelLoss().to(config.device)
perceptual_loss = PerceptualLoss(
    layer_weights={'conv5_4': 1.0},
    vgg_type='vgg19',
    use_input_norm=True,
    perceptual_weight=1.0,
    style_weight=0,
    criterion='l1'
).to(config.device)
colorfulness_loss = ColorfulnessLoss().to(config.device)

print("Loss functions initialized")

## Discriminator

In [None]:
class PatchDiscriminator(nn.Module):
    """PatchGAN discriminator for adversarial training"""
    def __init__(self, input_nc=3, ndf=64):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=1, padding=1),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=1),
        )
    
    def forward(self, x):
        return self.model(x)

discriminator = PatchDiscriminator().to(config.device)
gan_loss = GANLoss(gan_type='vanilla', real_label_val=1.0, fake_label_val=0.0).to(config.device)

print("Discriminator initialized")

## Model Initialization

In [None]:
model = DDColor(
    encoder_name=config.encoder_name,
    decoder_name=config.decoder_name,
    input_size=[config.input_size, config.input_size],
    num_output_channels=2,
    last_norm='Spectral',
    do_normalize=False,
    num_queries=config.num_queries,
    num_scales=config.num_scales,
    dec_layers=config.dec_layers,
).to(config.device)

# Load pretrained encoder
print("Loading pretrained encoder...")
try:
    pretrained_dict = torch.load(
        f'{config.pretrain_dir}/convnext_large_22k_224.pth',
        map_location=config.device
    )
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict, strict=False)
    print("Pretrained weights loaded")
except Exception as e:
    print(f"Warning: Could not load pretrained weights ({e})")
    print("Training from scratch")

print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

## Optimizers & Schedulers

In [None]:
optimizer_g = optim.AdamW(
    model.parameters(),
    lr=config.lr,
    weight_decay=config.weight_decay,
    betas=config.betas
)

optimizer_d = optim.AdamW(
    discriminator.parameters(),
    lr=config.lr,
    weight_decay=config.weight_decay,
    betas=config.betas
)

scheduler_g = optim.lr_scheduler.StepLR(optimizer_g, step_size=80000, gamma=0.5)
scheduler_d = optim.lr_scheduler.StepLR(optimizer_d, step_size=80000, gamma=0.5)

scaler = GradScaler() if config.use_amp else None

print("Optimizers initialized")

## Training Functions

In [None]:
def lab_to_rgb_tensor(l, ab):
    """Convert LAB tensors to RGB for visualization"""
    lab = torch.cat([l, ab], dim=1)
    lab_np = lab.detach().cpu().numpy().transpose(0, 2, 3, 1)
    
    rgb_list = []
    for i in range(lab_np.shape[0]):
        lab_img = (lab_np[i] * 255).astype(np.uint8)
        rgb_img = cv2.cvtColor(lab_img, cv2.COLOR_LAB2RGB)
        rgb_list.append(rgb_img)
    
    rgb_np = np.stack(rgb_list, axis=0)
    rgb_tensor = torch.from_numpy(rgb_np.transpose(0, 3, 1, 2)).float() / 255.0
    return rgb_tensor.to(l.device)

def train_one_epoch(epoch, iteration):
    model.train()
    discriminator.train()
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
    
    for batch_idx, batch in enumerate(pbar):
        iteration += 1
        
        gray_rgb = batch['gray_rgb'].to(config.device)
        l_gt = batch['l'].to(config.device)
        ab_gt = batch['ab'].to(config.device)
        
        # Train generator
        optimizer_g.zero_grad()
        
        if config.use_amp:
            with autocast(device_type=config.device):
                ab_pred = model(gray_rgb)
                
                loss_pixel = pixel_loss(ab_pred, ab_gt) * config.lambda_pixel
                
                rgb_pred = lab_to_rgb_tensor(l_gt, ab_pred)
                rgb_gt = lab_to_rgb_tensor(l_gt, ab_gt)
                loss_perceptual = perceptual_loss(rgb_pred, rgb_gt)[0] * config.lambda_perceptual
                
                loss_color = colorfulness_loss(ab_pred) * config.lambda_colorfulness
                
                fake_pred = discriminator(rgb_pred)
                loss_gan_g = gan_loss(fake_pred, True) * config.lambda_gan
                
                loss_g = loss_pixel + loss_perceptual + loss_color + loss_gan_g
            
            scaler.scale(loss_g).backward()
            scaler.step(optimizer_g)
            scaler.update()
        else:
            ab_pred = model(gray_rgb)
            loss_pixel = pixel_loss(ab_pred, ab_gt) * config.lambda_pixel
            rgb_pred = lab_to_rgb_tensor(l_gt, ab_pred)
            rgb_gt = lab_to_rgb_tensor(l_gt, ab_gt)
            loss_perceptual = perceptual_loss(rgb_pred, rgb_gt)[0] * config.lambda_perceptual
            loss_color = colorfulness_loss(ab_pred) * config.lambda_colorfulness
            fake_pred = discriminator(rgb_pred)
            loss_gan_g = gan_loss(fake_pred, True) * config.lambda_gan
            loss_g = loss_pixel + loss_perceptual + loss_color + loss_gan_g
            
            loss_g.backward()
            optimizer_g.step()
        
        # Train discriminator
        optimizer_d.zero_grad()
        
        if config.use_amp:
            with autocast(device_type=config.device):
                real_pred = discriminator(rgb_gt.detach())
                loss_d_real = gan_loss(real_pred, True)
                
                fake_pred = discriminator(rgb_pred.detach())
                loss_d_fake = gan_loss(fake_pred, False)
                
                loss_d = (loss_d_real + loss_d_fake) * 0.5
            
            scaler.scale(loss_d).backward()
            scaler.step(optimizer_d)
            scaler.update()
        else:
            real_pred = discriminator(rgb_gt.detach())
            loss_d_real = gan_loss(real_pred, True)
            fake_pred = discriminator(rgb_pred.detach())
            loss_d_fake = gan_loss(fake_pred, False)
            loss_d = (loss_d_real + loss_d_fake) * 0.5
            
            loss_d.backward()
            optimizer_d.step()
        
        scheduler_g.step()
        scheduler_d.step()
        
        if iteration % config.print_freq == 0:
            pbar.set_postfix({
                'iter': iteration,
                'L_pix': f'{loss_pixel.item():.4f}',
                'L_per': f'{loss_perceptual.item():.4f}',
                'L_col': f'{loss_color.item():.4f}',
                'L_gan': f'{loss_gan_g.item():.4f}',
                'L_d': f'{loss_d.item():.4f}',
            })
        
        if iteration % config.val_freq == 0:
            save_samples(epoch, iteration)
        
        if iteration % config.save_freq == 0:
            save_checkpoint(epoch, iteration)
        
        if iteration >= config.total_iters:
            return iteration
    
    return iteration

def save_samples(epoch, iteration):
    model.eval()
    
    with torch.no_grad():
        batch = next(iter(val_loader))
        gray_rgb = batch['gray_rgb'].to(config.device)
        l_gt = batch['l'].to(config.device)
        ab_gt = batch['ab'].to(config.device)
        
        ab_pred = model(gray_rgb)
        rgb_pred = lab_to_rgb_tensor(l_gt, ab_pred)
        rgb_gt = lab_to_rgb_tensor(l_gt, ab_gt)
        
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        axes[0].imshow(gray_rgb[0].cpu().permute(1, 2, 0))
        axes[0].set_title('Input')
        axes[0].axis('off')
        
        axes[1].imshow(rgb_pred[0].cpu().permute(1, 2, 0))
        axes[1].set_title('Predicted')
        axes[1].axis('off')
        
        axes[2].imshow(rgb_gt[0].cpu().permute(1, 2, 0))
        axes[2].set_title('Ground Truth')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.savefig(f'{config.save_dir}/samples/epoch{epoch}_iter{iteration}.png')
        plt.close()
    
    model.train()

def save_checkpoint(epoch, iteration):
    checkpoint = {
        'epoch': epoch,
        'iteration': iteration,
        'model_state_dict': model.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_g_state_dict': optimizer_g.state_dict(),
        'optimizer_d_state_dict': optimizer_d.state_dict(),
        'scheduler_g_state_dict': scheduler_g.state_dict(),
        'scheduler_d_state_dict': scheduler_d.state_dict(),
    }
    
    path = f'{config.save_dir}/checkpoints/checkpoint_iter{iteration}.pth'
    torch.save(checkpoint, path)
    print(f"\nCheckpoint saved: {path}")
    
    latest_path = f'{config.save_dir}/checkpoints/latest.pth'
    shutil.copy(path, latest_path)

print("Training functions defined")

## Main Training Loop

In [None]:
def train():
    print("\n" + "="*80)
    print("STARTING TRAINING")
    print("="*80)
    print(f"Epochs: {config.num_epochs}")
    print(f"Iterations: {config.total_iters}")
    print(f"Batch size: {config.batch_size}")
    print(f"Device: {config.device}")
    print(f"Output: {config.save_dir}")
    print("="*80 + "\n")
    
    iteration = 0
    
    for epoch in range(config.num_epochs):
        print(f"\n{'='*80}")
        print(f"EPOCH {epoch + 1}/{config.num_epochs}")
        print(f"{'='*80}")
        
        iteration = train_one_epoch(epoch, iteration)
        
        if iteration >= config.total_iters:
            print(f"\nReached maximum iterations: {config.total_iters}")
            break
    
    print("\n" + "="*80)
    print("TRAINING COMPLETE")
    print("="*80)
    print(f"Final checkpoint: {config.save_dir}/checkpoints/latest.pth")

train()

## Inference

In [None]:
def inference(image_path, checkpoint_path):
    """Colorize a grayscale image"""
    model = DDColor(
        encoder_name=config.encoder_name,
        decoder_name=config.decoder_name,
        input_size=[config.input_size, config.input_size],
        num_output_channels=2,
        last_norm='Spectral',
        do_normalize=False,
        num_queries=config.num_queries,
        num_scales=config.num_scales,
        dec_layers=config.dec_layers,
    ).to(config.device)
    
    checkpoint = torch.load(checkpoint_path, map_location=config.device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    img = Image.open(image_path).convert('RGB')
    img = img.resize((config.input_size, config.input_size))
    img = np.array(img).astype(np.float32) / 255.0
    
    img_lab = cv2.cvtColor((img * 255).astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32) / 255.0
    img_l = img_lab[:, :, 0:1]
    
    img_gray_lab = np.concatenate([
        img_lab[:, :, 0:1],
        np.zeros_like(img_l),
        np.zeros_like(img_l)
    ], axis=-1)
    img_gray_rgb = cv2.cvtColor((img_gray_lab * 255).astype(np.uint8), cv2.COLOR_LAB2RGB).astype(np.float32) / 255.0
    img_gray_rgb = torch.from_numpy(img_gray_rgb.transpose(2, 0, 1)).float().unsqueeze(0).to(config.device)
    
    with torch.no_grad():
        ab_pred = model(img_gray_rgb)
    
    l_tensor = torch.from_numpy(img_l.transpose(2, 0, 1)).float().unsqueeze(0).to(config.device)
    rgb_pred = lab_to_rgb_tensor(l_tensor, ab_pred)
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    axes[0].imshow(img_gray_rgb[0].cpu().permute(1, 2, 0))
    axes[0].set_title('Input')
    axes[0].axis('off')
    
    axes[1].imshow(rgb_pred[0].cpu().permute(1, 2, 0))
    axes[1].set_title('Colorized')
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return rgb_pred

# Example usage
# result = inference('./test_image.jpg', f'{config.save_dir}/checkpoints/latest.pth')