In [12]:
# IMPORTS AND CONFIGURATION
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
import shutil
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')
from torchvision.transforms import functional as TF
import random
from sklearn.model_selection import train_test_split


# For reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
print("✓ Imports loaded")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

✓ Imports loaded
PyTorch version: 2.5.1+cu121
CUDA available: True


In [13]:
# CONFIGURATION
class EnhancedConfig:
    def __init__(self, test_mode=False):
    
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        

        self.BATCH_SIZE = 8 if torch.cuda.is_available() else 4
        self.EPOCHS = 50
        self.LR = 0.0002
        self.BETA1 = 0.5
        self.BETA2 = 0.999
        
        # Loss weights
        self.LAMBDA_L1 = 100
        self.LAMBDA_PERCEPTUAL = 10  # Perceptual loss weight
        self.LAMBDA_GP = 10  # Gradient penalty
        
        # Model parameters
        self.IN_CHANNELS = 1
        self.OUT_CHANNELS = 3
        self.USE_ATTENTION = True  # Enable attention mechanisms
        self.USE_SE_BLOCKS = True  # Enable Squeeze-and-Excitation
        
        # Data parameters
        self.IMG_SIZE = 256
        self.DATA_DIR = './pix2pix_dataset'
        
        # Multi-scale discriminator
        self.NUM_D_SCALES = 3  # NEW: 3 discriminators at different scales
        
        # Training settings
        self.SAVE_INTERVAL = 5
        self.VAL_INTERVAL = 1
        self.NUM_WORKERS = 4
        self.PIN_MEMORY = True
        self.USE_AMP = True
        self.GRAD_CLIP = 1.0
        self.WEIGHT_DECAY = 0.0001
        
        # Stability features
        self.USE_LABEL_SMOOTHING = 0.1
        self.ADD_INPUT_NOISE = 0.02
        self.D_TRAIN_RATIO = 1
        
        # Learning rate scheduler (NEW)
        self.USE_COSINE_ANNEALING = True
        self.T_0 = 10  # Initial restart period
        self.T_MULT = 2  # Period multiplier after restart
        
        # Paths
        self.PROJECT_ROOT = './enhanced_pix2pix_project'
        
        if test_mode:
            self.EXPERIMENT_NAME = f"ENHANCED_TEST_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        else:
            self.EXPERIMENT_NAME = f"ENHANCED_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    
    def to_dict(self):
        config_dict = {}
        for key in dir(self):
            if not key.startswith('__') and not callable(getattr(self, key)):
                value = getattr(self, key)
                if isinstance(value, torch.device):
                    config_dict[key] = str(value)
                else:
                    config_dict[key] = value
        return config_dict

# Setup directories
def setup_directories(config):
    dirs = [
        config.PROJECT_ROOT,
        os.path.join(config.PROJECT_ROOT, 'checkpoints'),
        os.path.join(config.PROJECT_ROOT, 'losses'),
        os.path.join(config.PROJECT_ROOT, 'samples'),
        os.path.join(config.PROJECT_ROOT, 'logs')
    ]
    
    for dir_path in dirs:
        os.makedirs(dir_path, exist_ok=True)
    
    print(f"Directories created in: {config.PROJECT_ROOT}")
    print(f"Device: {config.device}")
    return config

config = EnhancedConfig(test_mode=False)
config = setup_directories(config)
print(f"\n{'='*60}")
print("ENHANCED PIX2PIX CONFIGURATION")
print(f"{'='*60}")
print(f"Batch Size: {config.BATCH_SIZE}")
print(f"Epochs: {config.EPOCHS}")
print(f"Image Size: {config.IMG_SIZE}x{config.IMG_SIZE}")
print(f"Attention Enabled: {config.USE_ATTENTION}")
print(f"SE Blocks Enabled: {config.USE_SE_BLOCKS}")
print(f"Multi-Scale Discriminator: {config.NUM_D_SCALES} scales")
print(f"Perceptual Loss Weight: {config.LAMBDA_PERCEPTUAL}")
print(f"{'='*60}")

Directories created in: ./enhanced_pix2pix_project
Device: cuda

ENHANCED PIX2PIX CONFIGURATION
Batch Size: 8
Epochs: 50
Image Size: 256x256
Attention Enabled: True
SE Blocks Enabled: True
Multi-Scale Discriminator: 3 scales
Perceptual Loss Weight: 10


In [14]:
# ATTENTION MECHANISMS

class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels
        
        # Use 1x1 convolutions to reduce computation
        self.query = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.value = nn.Conv2d(in_channels, in_channels, 1)
        
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x):
        batch_size, C, width, height = x.size()
        
        # Query, Key, Value projections
        proj_query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key(x).view(batch_size, -1, width * height)
        proj_value = self.value(x).view(batch_size, -1, width * height)
        
        # Attention map
        attention = torch.bmm(proj_query, proj_key)
        attention = self.softmax(attention)
        
        # Apply attention to values
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        
        # Residual connection with learnable weight
        out = self.gamma * out + x
        return out


class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

In [15]:
# GENERATOR WITH ATTENTION

class EnhancedUNetDown(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0, use_se=False):
        super().__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=4, 
                     stride=2, padding=1, bias=not normalize)
        ]
        
        if normalize:
            layers.append(nn.BatchNorm2d(out_channels))
        
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        
        if dropout > 0:
            layers.append(nn.Dropout(dropout))
        
        self.model = nn.Sequential(*layers)
        
        # Add SE block if enabled
        self.se_block = SEBlock(out_channels) if use_se else None
    
    def forward(self, x):
        x = self.model(x)
        if self.se_block is not None:
            x = self.se_block(x)
        return x


class EnhancedUNetUp(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0, use_se=False):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 
                              kernel_size=4, stride=2, 
                              padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        
        if dropout > 0:
            layers.append(nn.Dropout(dropout))
        
        self.model = nn.Sequential(*layers)
        
        # Add SE block if enabled
        self.se_block = SEBlock(out_channels) if use_se else None
    
    def forward(self, x, skip_input=None):
        x = self.model(x)
        
        if self.se_block is not None:
            x = self.se_block(x)
        
        if skip_input is not None:
            x = torch.cat((x, skip_input), 1)
        
        return x


class EnhancedGeneratorUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=3, use_attention=True, use_se=True):
        super().__init__()
        
        # Encoder
        self.down1 = EnhancedUNetDown(in_channels, 64, normalize=False, use_se=use_se)
        self.down2 = EnhancedUNetDown(64, 128, use_se=use_se)
        self.down3 = EnhancedUNetDown(128, 256, use_se=use_se)
        self.down4 = EnhancedUNetDown(256, 512, use_se=use_se)
        self.down5 = EnhancedUNetDown(512, 512, use_se=use_se)
        self.down6 = EnhancedUNetDown(512, 512, use_se=use_se)
        self.down7 = EnhancedUNetDown(512, 512, use_se=use_se)
        self.down8 = EnhancedUNetDown(512, 512, normalize=False, use_se=use_se)
        
        # Self-attention at bottleneck
        self.attention = SelfAttention(512) if use_attention else nn.Identity()
        
        # Decoder
        self.up1 = EnhancedUNetUp(512, 512, dropout=0.5, use_se=use_se)
        self.up2 = EnhancedUNetUp(1024, 512, dropout=0.5, use_se=use_se)
        self.up3 = EnhancedUNetUp(1024, 512, dropout=0.5, use_se=use_se)
        self.up4 = EnhancedUNetUp(1024, 512, dropout=0.0, use_se=use_se)
        self.up5 = EnhancedUNetUp(1024, 256, dropout=0.0, use_se=use_se)
        self.up6 = EnhancedUNetUp(512, 128, dropout=0.0, use_se=use_se)
        self.up7 = EnhancedUNetUp(256, 64, dropout=0.0, use_se=use_se)
        
        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        # Encoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        
        # Apply attention at bottleneck
        d8 = self.attention(d8)
        
        # Decoder with skip connections
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)
        
        return self.final(u7)


In [16]:
# MULTI-SCALE DISCRIMINATOR

class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=4):  # 1 (gray) + 3 (RGB)
        super().__init__()
        
        self.model = nn.Sequential(
            # C64
            nn.Conv2d(in_channels, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # C128
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # C256
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # C512
            nn.Conv2d(256, 512, 4, 1, 1, bias=False),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Output
            nn.Conv2d(512, 1, 4, 1, 1)
        )
    
    def forward(self, gray, rgb):
        x = torch.cat([gray, rgb], dim=1)
        return self.model(x)


class MultiScaleDiscriminator(nn.Module):
    def __init__(self, in_channels=1, out_channels=3, num_scales=3):
        super().__init__()
        self.num_scales = num_scales
        
        # Create discriminators at different scales
        self.discriminators = nn.ModuleList([
            PatchDiscriminator(in_channels + out_channels)
            for _ in range(num_scales)
        ])
        
        # Downsampling layer
        self.downsample = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
    
    def forward(self, gray, rgb):
        outputs = []
        
        # Run through each discriminator at different scales
        for i, discriminator in enumerate(self.discriminators):
            outputs.append(discriminator(gray, rgb))
            
            # Downsample for next scale (except for last scale)
            if i < self.num_scales - 1:
                gray = self.downsample(gray)
                rgb = self.downsample(rgb)
        
        return outputs

In [17]:
#PERCEPTUAL LOSS

class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # Load pretrained VGG16
        vgg = models.vgg16(pretrained=True).features
        
        # Use features from multiple layers
        self.slice1 = nn.Sequential(*list(vgg[:4]))   # relu1_2
        self.slice2 = nn.Sequential(*list(vgg[4:9]))  # relu2_2
        self.slice3 = nn.Sequential(*list(vgg[9:16])) # relu3_3
        self.slice4 = nn.Sequential(*list(vgg[16:23]))# relu4_3
        
        # Freeze VGG parameters
        for param in self.parameters():
            param.requires_grad = False
        
        # Normalization for ImageNet
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
    
    def normalize(self, x):
        # Denormalize from [-1, 1] to [0, 1]
        x = (x + 1) / 2
        # Normalize for VGG
        return (x - self.mean) / self.std
    
    def forward(self, fake, real):
        # Normalize inputs
        fake = self.normalize(fake)
        real = self.normalize(real)
        
        # Extract features
        fake_features = []
        real_features = []
        
        x_fake = fake
        x_real = real
        
        for slice_layer in [self.slice1, self.slice2, self.slice3, self.slice4]:
            x_fake = slice_layer(x_fake)
            x_real = slice_layer(x_real)
            fake_features.append(x_fake)
            real_features.append(x_real)
        
        # Calculate loss across all layers
        loss = 0
        for fake_feat, real_feat in zip(fake_features, real_features):
            loss += F.l1_loss(fake_feat, real_feat)
        
        return loss / len(fake_features)


In [18]:
# def organize_dataset(grayscale_dir='dataset/grayscale', 
#                      colored_dir='dataset/colored',
#                      output_root='pix2pix_dataset',
#                      train_ratio=0.80,
#                      val_ratio=0.15,
#                      test_ratio=0.05):

#     print("=" * 60)
#     print("ORGANIZING DATASET")
#     print("=" * 60)
#     print(f"Source directories:")
#     print(f"  Grayscale: {grayscale_dir}")
#     print(f"  Colored:   {colored_dir}")
#     print(f"\nOutput root: {output_root}")
#     print(f"Split ratios - Train: {train_ratio:.0%}, Val: {val_ratio:.0%}, Test: {test_ratio:.0%}")
#     print("=" * 60 + "\n")
    
#     # Create output directory structure
#     splits = ['train', 'val', 'test']
#     for split in splits:
#         os.makedirs(os.path.join(output_root, split, 'gray'), exist_ok=True)
#         os.makedirs(os.path.join(output_root, split, 'rgb'), exist_ok=True)
    
#     # Get all image files from grayscale directory
#     image_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tiff')
#     gray_files = [f for f in os.listdir(grayscale_dir) 
#                   if f.lower().endswith(image_extensions)]
    
#     # Filter to only include files that have matching colored versions
#     # Handle naming convention: chop_1_grey.png <-> chop_1_color.png
#     valid_pairs = []
#     missing_colored = []
    
#     print("Scanning for valid image pairs...")
#     for gray_file in tqdm(gray_files, desc="Processing"):
#         gray_path = os.path.join(grayscale_dir, gray_file)
        
#         # Convert grey filename to color filename
#         # chop_1_grey.png -> chop_1_color.png
#         if '_grey.' in gray_file:
#             colored_file = gray_file.replace('_grey.', '_color.')
#         elif '_gray.' in gray_file:
#             colored_file = gray_file.replace('_gray.', '_color.')
#         else:
#             # If no _grey or _gray in filename, assume same filename
#             colored_file = gray_file
        
#         colored_path = os.path.join(colored_dir, colored_file)
        
#         if os.path.exists(colored_path):
#             valid_pairs.append({
#                 'gray_file': gray_file,
#                 'colored_file': colored_file
#             })
#         else:
#             missing_colored.append(gray_file)
    
#     print(f"\nFound {len(gray_files)} grayscale images")
#     print(f"Valid pairs (with matching colored): {len(valid_pairs)}")
    
#     if missing_colored:
#         print(f"⚠ Missing colored versions: {len(missing_colored)}")
#         print(f"  (These will be skipped)")
    
#     if len(valid_pairs) == 0:
#         raise ValueError("No valid image pairs found! Check your folder structure.")
    
#     # Split data
#     # First split: train vs (val+test)
#     train_files, temp_files = train_test_split(
#         valid_pairs, 
#         test_size=(val_ratio + test_ratio),
#         random_state=42
#     )
    
#     # Second split: val vs test
#     val_size = val_ratio / (val_ratio + test_ratio)
#     val_files, test_files = train_test_split(
#         temp_files,
#         test_size=(1 - val_size),
#         random_state=42
#     )
    
#     print(f"\n{'='*60}")
#     print("SPLIT SUMMARY")
#     print(f"{'='*60}")
#     print(f"Train: {len(train_files):4d} images ({len(train_files)/len(valid_pairs)*100:.1f}%)")
#     print(f"Val:   {len(val_files):4d} images ({len(val_files)/len(valid_pairs)*100:.1f}%)")
#     print(f"Test:  {len(test_files):4d} images ({len(test_files)/len(valid_pairs)*100:.1f}%)")
#     print(f"Total: {len(valid_pairs):4d} images")
#     print(f"{'='*60}\n")
    
#     # Copy files to splits
#     def copy_split_files(file_list, split_name):
#         gray_out = os.path.join(output_root, split_name, 'gray')
#         rgb_out = os.path.join(output_root, split_name, 'rgb')
        
#         for pair in tqdm(file_list, desc=f"Copying {split_name:5s} files"):
#             gray_file = pair['gray_file']
#             colored_file = pair['colored_file']
            
#             # Create a common base name for both files (remove _grey and _color suffixes)
#             # chop_1_grey.png -> chop_1.png
#             if '_grey.' in gray_file:
#                 base_name = gray_file.replace('_grey.', '.')
#             elif '_gray.' in gray_file:
#                 base_name = gray_file.replace('_gray.', '.')
#             else:
#                 base_name = gray_file
            
#             # Copy grayscale
#             shutil.copy2(
#                 os.path.join(grayscale_dir, gray_file),
#                 os.path.join(gray_out, base_name)
#             )
#             # Copy colored
#             shutil.copy2(
#                 os.path.join(colored_dir, colored_file),
#                 os.path.join(rgb_out, base_name)
#             )
    
#     copy_split_files(train_files, 'train')
#     copy_split_files(val_files, 'val')
#     copy_split_files(test_files, 'test')
    
#     # Save dataset statistics
#     stats = {
#         'total_pairs': len(valid_pairs),
#         'train': len(train_files),
#         'val': len(val_files),
#         'test': len(test_files),
#         'missing_colored': len(missing_colored),
#         'split_ratios': {
#             'train': train_ratio,
#             'val': val_ratio,
#             'test': test_ratio
#         },
#         'source_dirs': {
#             'grayscale': grayscale_dir,
#             'colored': colored_dir
#         },
#         'created_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
#     }
    
#     stats_path = os.path.join(output_root, 'dataset_stats.json')
#     with open(stats_path, 'w') as f:
#         json.dump(stats, f, indent=2)
    
#     print(f"\n{'='*60}")
#     print("✓ DATASET ORGANIZATION COMPLETE")
  
    
#     return output_root, stats



# dataset_path, stats = organize_dataset(
#         grayscale_dir='dataset/grayscale',
#         colored_dir='dataset/colored',
#         output_root='pix2pix_dataset',
#         train_ratio=0.80,
#         val_ratio=0.15,
#         test_ratio=0.05
# )

In [19]:
class EnhancedPairedDataset(Dataset):
    def __init__(self, root_dir, mode='train', config=None, use_percentage=1.0):
        self.root_dir = root_dir
        self.mode = mode
        self.config = config
        self.use_percentage = use_percentage
        
        # Get all image pairs
        self.pairs = []
        gray_dir = os.path.join(root_dir, mode, 'gray')
        rgb_dir = os.path.join(root_dir, mode, 'rgb')
        
        # Collect all valid pairs
        all_pairs = []
        for img_name in os.listdir(gray_dir):
            if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                gray_path = os.path.join(gray_dir, img_name)
                rgb_path = os.path.join(rgb_dir, img_name)
                
                if os.path.exists(rgb_path):
                    all_pairs.append((gray_path, rgb_path))
        
        # Shuffle and select percentage
        np.random.shuffle(all_pairs)
        n_total = len(all_pairs)
        n_select = int(n_total * use_percentage)
        self.pairs = all_pairs[:n_select]
        
        print(f"{mode.capitalize()} dataset: {len(self.pairs)}/{n_total} pairs ({use_percentage*100:.1f}%)")
        
        self.transform = self._get_transforms()
    
    def _get_transforms(self):
        img_size = self.config.IMG_SIZE if self.config else 256
        
        if self.mode == 'train':
            return {
                'gray': transforms.Compose([
                    transforms.Lambda(lambda img: self.center_crop_to_square(img)),  
                    transforms.Resize((img_size, img_size)),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomVerticalFlip(p=0.2),
                    transforms.RandomRotation(degrees=10),
                    transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,))
                ]),
                'rgb': transforms.Compose([
                    transforms.Lambda(lambda img: self.center_crop_to_square(img)),  
                    transforms.Resize((img_size, img_size)),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomVerticalFlip(p=0.2),
                    transforms.RandomRotation(degrees=10),
                    transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
                    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])
            }
        else:
            return {
                'gray': transforms.Compose([
                    transforms.Lambda(lambda img: self.center_crop_to_square(img)),  
                    transforms.Resize((img_size, img_size)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,))
                ]),
                'rgb': transforms.Compose([
                    transforms.Lambda(lambda img: self.center_crop_to_square(img)),  
                    transforms.Resize((img_size, img_size)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])
            }
    
    def center_crop_to_square(self, img):
        width, height = img.size
        
        if width == height:
            return img
        
        # Crop to the smaller dimension
        crop_size = min(width, height)
        
        # Calculate crop coordinates (center crop)
        left = (width - crop_size) // 2
        top = (height - crop_size) // 2
        right = left + crop_size
        bottom = top + crop_size
        
        return img.crop((left, top, right, bottom))
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        gray_path, rgb_path = self.pairs[idx]
        
        # Apply same random seed for paired augmentation
        seed = np.random.randint(2147483647)
        
        gray_img = Image.open(gray_path).convert('L')
        rgb_img = Image.open(rgb_path).convert('RGB')
        
        # Set random seed for reproducible augmentation
        torch.manual_seed(seed)
        gray_tensor = self.transform['gray'](gray_img)
        
        torch.manual_seed(seed)
        rgb_tensor = self.transform['rgb'](rgb_img)
        
        return gray_tensor, rgb_tensor

In [20]:
# ENHANCED TRAINER

class EnhancedPix2PixTrainer:
    def __init__(self, generator, discriminator, config):
        self.config = config
        self.device = config.device
        
        # Models
        self.generator = generator.to(self.device)
        self.discriminator = discriminator.to(self.device)
        
        # Losses
        self.criterion_gan = nn.BCEWithLogitsLoss()
        self.criterion_l1 = nn.L1Loss()
        self.perceptual_loss = PerceptualLoss().to(self.device)
        
        # Optimizers
        self.optimizer_G = optim.Adam(
            self.generator.parameters(), 
            lr=config.LR, 
            betas=(config.BETA1, config.BETA2),
            weight_decay=config.WEIGHT_DECAY
        )
        self.optimizer_D = optim.Adam(
            self.discriminator.parameters(), 
            lr=config.LR, 
            betas=(config.BETA1, config.BETA2),
            weight_decay=config.WEIGHT_DECAY
        )
        
        # AMP GradScalers
        self.scaler_G = GradScaler(enabled=config.USE_AMP)
        self.scaler_D = GradScaler(enabled=config.USE_AMP)
        
        # Cosine Annealing with Warm Restarts
        if config.USE_COSINE_ANNEALING:
            self.scheduler_G = optim.lr_scheduler.CosineAnnealingWarmRestarts(
                self.optimizer_G, T_0=config.T_0, T_mult=config.T_MULT
            )
            self.scheduler_D = optim.lr_scheduler.CosineAnnealingWarmRestarts(
                self.optimizer_D, T_0=config.T_0, T_mult=config.T_MULT
            )
        else:
            self.scheduler_G = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer_G, mode='min', factor=0.5, patience=10
            )
            self.scheduler_D = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer_D, mode='min', factor=0.5, patience=10
            )
        
        # Tracking
        self.current_epoch = 0
        self.best_val_loss = float('inf')
        self.losses = {
            'train': {'D': [], 'G': [], 'G_GAN': [], 'G_L1': [], 'G_Perceptual': []},
            'val': {'D': [], 'G': [], 'G_GAN': [], 'G_L1': [], 'G_Perceptual': []}
        }
        
        # Early stopping
        self.patience = 50
        self.best_epoch = 0
        
        # Create experiment directory
        self.exp_dir = os.path.join(config.PROJECT_ROOT, config.EXPERIMENT_NAME)
        self.checkpoint_dir = os.path.join(self.exp_dir, 'checkpoints')
        self.samples_dir = os.path.join(self.exp_dir, 'samples')
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        os.makedirs(self.samples_dir, exist_ok=True)
        
        # Save config
        self._save_config()
        
        print(f"\n{'='*60}")
        print(f"Experiment directory: {self.exp_dir}")
        print(f"Using AMP: {config.USE_AMP}")
        print(f"Using Cosine Annealing: {config.USE_COSINE_ANNEALING}")
        print(f"Batch size: {config.BATCH_SIZE}")
        print(f"{'='*60}\n")
    
    def _save_config(self):
        config_dict = self.config.to_dict()
        with open(os.path.join(self.exp_dir, 'config.json'), 'w') as f:
            json.dump(config_dict, f, indent=2)
    
    def set_requires_grad(self, model, requires_grad):
        for param in model.parameters():
            param.requires_grad = requires_grad
    
    def gradient_penalty(self, real_gray, real_rgb, fake_rgb):
        batch_size = real_gray.size(0)
        alpha = torch.rand(batch_size, 1, 1, 1, device=self.device)
        interpolated = (alpha * real_rgb + (1 - alpha) * fake_rgb).requires_grad_(True)
        
        # Get discriminator outputs at all scales
        d_interpolated_list = self.discriminator(real_gray, interpolated)
        
        gp_loss = 0
        for d_interpolated in d_interpolated_list:
            gradients = torch.autograd.grad(
                outputs=d_interpolated,
                inputs=interpolated,
                grad_outputs=torch.ones_like(d_interpolated),
                create_graph=True,
                retain_graph=True,
                only_inputs=True
            )[0]
            
            gradients = gradients.view(gradients.size(0), -1)
            gp_loss += ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        
        return gp_loss / len(d_interpolated_list)
    
    def train_step(self, real_gray, real_rgb):
        real_gray, real_rgb = real_gray.to(self.device), real_rgb.to(self.device)
        batch_size = real_gray.size(0)
        
        # Label smoothing
        smooth = self.config.USE_LABEL_SMOOTHING
        
        # Add input noise
        if self.config.ADD_INPUT_NOISE > 0:
            noise_level = self.config.ADD_INPUT_NOISE
            real_rgb_noisy = real_rgb + torch.randn_like(real_rgb) * noise_level
        else:
            real_rgb_noisy = real_rgb
        
        # ===== Train Discriminator =====
        self.set_requires_grad(self.discriminator, True)
        self.optimizer_D.zero_grad()
        
        with autocast(enabled=self.config.USE_AMP):
            fake_rgb = self.generator(real_gray)
            
            # Multi-scale discriminator outputs
            pred_real_list = self.discriminator(real_gray, real_rgb_noisy)
            pred_fake_list = self.discriminator(real_gray, fake_rgb.detach())
            
            # Calculate GAN loss for each scale
            loss_D_real = 0
            loss_D_fake = 0
            
            for pred_real, pred_fake in zip(pred_real_list, pred_fake_list):
                real_labels = torch.ones_like(pred_real) * (1 - smooth)
                fake_labels = torch.zeros_like(pred_fake) + smooth
                
                loss_D_real += self.criterion_gan(pred_real, real_labels)
                loss_D_fake += self.criterion_gan(pred_fake, fake_labels)
            
            loss_D_real = loss_D_real / len(pred_real_list)
            loss_D_fake = loss_D_fake / len(pred_fake_list)
            
            # Gradient penalty
            gp = self.gradient_penalty(real_gray, real_rgb, fake_rgb)
            
            loss_D = loss_D_real + loss_D_fake + self.config.LAMBDA_GP * gp
        
        self.scaler_D.scale(loss_D).backward()
        
        if self.config.GRAD_CLIP > 0:
            self.scaler_D.unscale_(self.optimizer_D)
            torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.config.GRAD_CLIP)
        
        self.scaler_D.step(self.optimizer_D)
        self.scaler_D.update()
        
        # ===== Train Generator =====
        self.set_requires_grad(self.discriminator, False)
        self.optimizer_G.zero_grad()
        
        with autocast(enabled=self.config.USE_AMP):
            fake_rgb = self.generator(real_gray)
            
            # Multi-scale GAN loss
            pred_fake_list = self.discriminator(real_gray, fake_rgb)
            
            loss_G_gan = 0
            for pred_fake in pred_fake_list:
                real_labels = torch.ones_like(pred_fake)
                loss_G_gan += self.criterion_gan(pred_fake, real_labels)
            loss_G_gan = loss_G_gan / len(pred_fake_list)
            
            # L1 loss
            loss_G_l1 = self.criterion_l1(fake_rgb, real_rgb)
            
            # Perceptual loss
            loss_G_perceptual = self.perceptual_loss(fake_rgb, real_rgb)
            
            # Total generator loss
            loss_G = loss_G_gan + \
                     self.config.LAMBDA_L1 * loss_G_l1 + \
                     self.config.LAMBDA_PERCEPTUAL * loss_G_perceptual
        
        self.scaler_G.scale(loss_G).backward()
        
        if self.config.GRAD_CLIP > 0:
            self.scaler_G.unscale_(self.optimizer_G)
            torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.config.GRAD_CLIP)
        
        self.scaler_G.step(self.optimizer_G)
        self.scaler_G.update()
        
        return {
            'D': loss_D.item(),
            'G': loss_G.item(),
            'G_GAN': loss_G_gan.item(),
            'G_L1': loss_G_l1.item(),
            'G_Perceptual': loss_G_perceptual.item()
        }
    
    @torch.no_grad()
    def validate(self, val_loader):
        self.generator.eval()
        self.discriminator.eval()
        
        val_losses = {'D': [], 'G': [], 'G_GAN': [], 'G_L1': [], 'G_Perceptual': []}
        
        for real_gray, real_rgb in tqdm(val_loader, desc="Validating", leave=False):
            real_gray, real_rgb = real_gray.to(self.device), real_rgb.to(self.device)
            
            fake_rgb = self.generator(real_gray)
            
            # Multi-scale discriminator
            pred_real_list = self.discriminator(real_gray, real_rgb)
            pred_fake_list = self.discriminator(real_gray, fake_rgb)
            
            # D loss
            loss_D_real = sum([self.criterion_gan(pred, torch.ones_like(pred)) 
                              for pred in pred_real_list]) / len(pred_real_list)
            loss_D_fake = sum([self.criterion_gan(pred, torch.zeros_like(pred)) 
                              for pred in pred_fake_list]) / len(pred_fake_list)
            loss_D = loss_D_real + loss_D_fake
            
            # G losses
            loss_G_gan = sum([self.criterion_gan(pred, torch.ones_like(pred)) 
                             for pred in pred_fake_list]) / len(pred_fake_list)
            loss_G_l1 = self.criterion_l1(fake_rgb, real_rgb)
            loss_G_perceptual = self.perceptual_loss(fake_rgb, real_rgb)
            loss_G = loss_G_gan + self.config.LAMBDA_L1 * loss_G_l1 + \
                     self.config.LAMBDA_PERCEPTUAL * loss_G_perceptual
            
            val_losses['D'].append(loss_D.item())
            val_losses['G'].append(loss_G.item())
            val_losses['G_GAN'].append(loss_G_gan.item())
            val_losses['G_L1'].append(loss_G_l1.item())
            val_losses['G_Perceptual'].append(loss_G_perceptual.item())
        
        self.generator.train()
        self.discriminator.train()
        
        return {k: np.mean(v) for k, v in val_losses.items()}
    
    def train(self, train_loader, val_loader=None):
        print(f"\n{'='*60}")
        print("STARTING ENHANCED TRAINING")
        print(f"{'='*60}\n")
        
        for epoch in range(self.current_epoch, self.config.EPOCHS):
            self.current_epoch = epoch + 1
            
            # Training
            self.generator.train()
            self.discriminator.train()
            
            train_losses = {'D': [], 'G': [], 'G_GAN': [], 'G_L1': [], 'G_Perceptual': []}
            
            pbar = tqdm(train_loader, desc=f"Epoch {self.current_epoch}/{self.config.EPOCHS}")
            for batch_idx, (real_gray, real_rgb) in enumerate(pbar):
                losses = self.train_step(real_gray, real_rgb)
                
                for k, v in losses.items():
                    train_losses[k].append(v)
                
                # Update progress bar
                pbar.set_postfix({
                    'D': f"{losses['D']:.4f}",
                    'G': f"{losses['G']:.4f}",
                    'L1': f"{losses['G_L1']:.4f}",
                    'Perceptual': f"{losses['G_Perceptual']:.4f}"
                })
            
            # Calculate epoch averages
            avg_train_losses = {k: np.mean(v) for k, v in train_losses.items()}
            for k, v in avg_train_losses.items():
                self.losses['train'][k].append(v)
            
            # Validation
            if val_loader and (self.current_epoch % self.config.VAL_INTERVAL == 0):
                avg_val_losses = self.validate(val_loader)
                for k, v in avg_val_losses.items():
                    self.losses['val'][k].append(v)
                
                print(f"\nEpoch {self.current_epoch} - Val Loss: {avg_val_losses['G']:.4f} "
                      f"(L1: {avg_val_losses['G_L1']:.4f}, Perceptual: {avg_val_losses['G_Perceptual']:.4f})")
                
                # Save best model
                if avg_val_losses['G'] < self.best_val_loss:
                    self.best_val_loss = avg_val_losses['G']
                    self.best_epoch = self.current_epoch
                    self.save_checkpoint('best')
                    print(f"✓ New best model saved! (Val Loss: {self.best_val_loss:.4f})")
                
                # Early stopping
                if self.current_epoch - self.best_epoch > self.patience:
                    print(f"\nEarly stopping triggered. No improvement for {self.patience} epochs.")
                    break
            
            # Learning rate scheduling
            if self.config.USE_COSINE_ANNEALING:
                self.scheduler_G.step()
                self.scheduler_D.step()
            else:
                if val_loader:
                    self.scheduler_G.step(avg_val_losses['G'])
                    self.scheduler_D.step(avg_val_losses['D'])
            
            # Periodic checkpoint
            if self.current_epoch % self.config.SAVE_INTERVAL == 0:
                self.save_checkpoint()
                
                # Generate samples
                if batch_idx > 0:
                    self.generate_samples(real_gray, real_rgb, self.current_epoch)
            
            # Save losses
            self.save_losses()
        
        print("\n" + "="*60)
        print("TRAINING COMPLETE!")
        print(f"Best epoch: {self.best_epoch}")
        print(f"Best validation loss: {self.best_val_loss:.4f}")
        print("="*60)
        
        self.save_checkpoint('final')
        self.plot_losses()
    
    def save_checkpoint(self, name=None):
        if name is None:
            name = f'epoch_{self.current_epoch:04d}'
        
        checkpoint = {
            'epoch': self.current_epoch,
            'generator_state_dict': self.generator.state_dict(),
            'discriminator_state_dict': self.discriminator.state_dict(),
            'optimizer_G_state_dict': self.optimizer_G.state_dict(),
            'optimizer_D_state_dict': self.optimizer_D.state_dict(),
            'scheduler_G_state_dict': self.scheduler_G.state_dict(),
            'scheduler_D_state_dict': self.scheduler_D.state_dict(),
            'scaler_G_state_dict': self.scaler_G.state_dict(),
            'scaler_D_state_dict': self.scaler_D.state_dict(),
            'losses': self.losses,
            'best_val_loss': self.best_val_loss,
            'config': self.config.to_dict()
        }
        
        path = os.path.join(self.checkpoint_dir, f'{name}.pth')
        torch.save(checkpoint, path)
        print(f"Checkpoint saved: {path}")
    
    def save_losses(self):
        losses_path = os.path.join(self.exp_dir, 'losses.json')
        with open(losses_path, 'w') as f:
            json.dump(self.losses, f, indent=2)
    
    def generate_samples(self, real_gray, real_rgb, epoch, num_samples=4):
        self.generator.eval()
        
        with torch.no_grad():
            real_gray = real_gray[:num_samples].to(self.device)
            real_rgb = real_rgb[:num_samples].to(self.device)
            fake_rgb = self.generator(real_gray)
            
            real_gray_np = real_gray.cpu().numpy()
            real_rgb_np = real_rgb.cpu().numpy()
            fake_rgb_np = fake_rgb.cpu().numpy()
        
        def denormalize(img):
            return (img.transpose(0, 2, 3, 1) * 0.5 + 0.5).clip(0, 1)
        
        fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
        
        for i in range(num_samples):
            # Input
            axes[i, 0].imshow(real_gray_np[i, 0], cmap='gray')
            axes[i, 0].set_title('Input (Grayscale)')
            axes[i, 0].axis('off')
            
            # Generated
            axes[i, 1].imshow(denormalize(fake_rgb_np)[i])
            axes[i, 1].set_title('Generated (RGB)')
            axes[i, 1].axis('off')
            
            # Ground Truth
            axes[i, 2].imshow(denormalize(real_rgb_np)[i])
            axes[i, 2].set_title('Ground Truth (RGB)')
            axes[i, 2].axis('off')
        
        plt.tight_layout()
        sample_path = os.path.join(self.samples_dir, f'samples_epoch_{epoch:04d}.png')
        plt.savefig(sample_path, dpi=150, bbox_inches='tight')
        plt.close()
        
        self.generator.train()
    
    def plot_losses(self):
        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
        
        epochs = range(1, len(self.losses['train']['G']) + 1)
        
        # Generator vs Discriminator (Train)
        axes[0, 0].plot(epochs, self.losses['train']['G'], label='Generator', linewidth=2)
        axes[0, 0].plot(epochs, self.losses['train']['D'], label='Discriminator', linewidth=2)
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title('Generator vs Discriminator (Train)')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Generator Components (Train)
        axes[0, 1].plot(epochs, self.losses['train']['G_GAN'], label='GAN Loss', linewidth=2)
        axes[0, 1].plot(epochs, self.losses['train']['G_L1'], label='L1 Loss', linewidth=2)
        axes[0, 1].plot(epochs, self.losses['train']['G_Perceptual'], label='Perceptual Loss', linewidth=2)
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Loss')
        axes[0, 1].set_title('Generator Loss Components (Train)')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # L1 Loss Detail
        axes[0, 2].plot(epochs, self.losses['train']['G_L1'], linewidth=2, color='green')
        axes[0, 2].set_xlabel('Epoch')
        axes[0, 2].set_ylabel('L1 Loss')
        axes[0, 2].set_title('L1 Loss (Train)')
        axes[0, 2].grid(True, alpha=0.3)
        
        if self.losses['val']['G']:
            val_epochs = range(1, len(self.losses['val']['G']) + 1)
            
            # Validation losses
            axes[1, 0].plot(val_epochs, self.losses['val']['G'], label='Generator', linewidth=2)
            axes[1, 0].plot(val_epochs, self.losses['val']['D'], label='Discriminator', linewidth=2)
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].set_ylabel('Loss')
            axes[1, 0].set_title('Generator vs Discriminator (Validation)')
            axes[1, 0].legend()
            axes[1, 0].grid(True, alpha=0.3)
            
            axes[1, 1].plot(val_epochs, self.losses['val']['G_GAN'], label='GAN Loss', linewidth=2)
            axes[1, 1].plot(val_epochs, self.losses['val']['G_L1'], label='L1 Loss', linewidth=2)
            axes[1, 1].plot(val_epochs, self.losses['val']['G_Perceptual'], label='Perceptual Loss', linewidth=2)
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].set_ylabel('Loss')
            axes[1, 1].set_title('Generator Loss Components (Validation)')
            axes[1, 1].legend()
            axes[1, 1].grid(True, alpha=0.3)
            
            axes[1, 2].plot(val_epochs, self.losses['val']['G_Perceptual'], linewidth=2, color='red')
            axes[1, 2].set_xlabel('Epoch')
            axes[1, 2].set_ylabel('Perceptual Loss')
            axes[1, 2].set_title('Perceptual Loss (Validation)')
            axes[1, 2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        loss_plot_path = os.path.join(self.exp_dir, 'loss_curves.png')
        plt.savefig(loss_plot_path, dpi=150, bbox_inches='tight')
        plt.close()
        print(f"\nLoss curves saved: {loss_plot_path}")

print("✓ Enhanced Trainer defined")

✓ Enhanced Trainer defined


In [21]:
#INITIALIZE MODELS AND DATALOADERS

print("\n" + "="*60)
print("INITIALIZING MODELS")
print("="*60)
num_workers = 0 if os.name == 'nt' else min(config.NUM_WORKERS, 2)

# Initialize models
generator = EnhancedGeneratorUNet(
    in_channels=config.IN_CHANNELS,
    out_channels=config.OUT_CHANNELS,
    use_attention=config.USE_ATTENTION,
    use_se=config.USE_SE_BLOCKS
)

discriminator = MultiScaleDiscriminator(
    in_channels=config.IN_CHANNELS,
    out_channels=config.OUT_CHANNELS,
    num_scales=config.NUM_D_SCALES
)

# Count parameters
g_params = sum(p.numel() for p in generator.parameters())
d_params = sum(p.numel() for p in discriminator.parameters())

print(f"Generator parameters: {g_params:,}")
print(f"Discriminator parameters: {d_params:,}")
print(f"Total parameters: {g_params + d_params:,}")

# Create datasets
print("\n" + "="*60)
print("LOADING DATASETS")
print("="*60)

train_dataset = EnhancedPairedDataset(
    root_dir=config.DATA_DIR,
    mode='train',
    config=config,
    use_percentage=0.2
)

val_dataset = EnhancedPairedDataset(
    root_dir=config.DATA_DIR,
    mode='val',
    config=config,
    use_percentage=1.0
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    pin_memory=config.PIN_MEMORY,
    num_workers=num_workers,  
    persistent_workers=False, 
    drop_last=True  # (train only)
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=False,
    num_workers=num_workers,  
    persistent_workers=False, 
    pin_memory=config.PIN_MEMORY
)

print(f"✓ Training batches: {len(train_loader)}")
print(f"✓ Validation batches: {len(val_loader)}")
print("="*60)


INITIALIZING MODELS
Generator parameters: 55,057,220
Discriminator parameters: 8,294,595
Total parameters: 63,351,815

LOADING DATASETS
Train dataset: 1369/6847 pairs (20.0%)
Val dataset: 1283/1283 pairs (100.0%)
✓ Training batches: 171
✓ Validation batches: 161


In [22]:
# Cell 10: TRAIN THE MODEL
# ===========================================

# Initialize trainer
trainer = EnhancedPix2PixTrainer(generator, discriminator, config)

# Start training
trainer.train(train_loader, val_loader)

print("\n✓ Training completed successfully!")


Experiment directory: ./enhanced_pix2pix_project\ENHANCED_20251225_203059
Using AMP: True
Using Cosine Annealing: True
Batch size: 8


STARTING ENHANCED TRAINING



Epoch 1/50: 100%|██████████| 171/171 [03:59<00:00,  1.40s/it, D=2312.3518, G=16.5275, L1=0.1055, Perceptual=0.5287]
                                                             


Epoch 1 - Val Loss: 19.4316 (L1: 0.1243, Perceptual: 0.6305)
Checkpoint saved: ./enhanced_pix2pix_project\ENHANCED_20251225_203059\checkpoints\best.pth
✓ New best model saved! (Val Loss: 19.4316)


Epoch 2/50: 100%|██████████| 171/171 [06:31<00:00,  2.29s/it, D=2853.8835, G=13.8212, L1=0.0871, Perceptual=0.4419]
                                                             


Epoch 2 - Val Loss: 12.6130 (L1: 0.0750, Perceptual: 0.4418)
Checkpoint saved: ./enhanced_pix2pix_project\ENHANCED_20251225_203059\checkpoints\best.pth
✓ New best model saved! (Val Loss: 12.6130)


Epoch 3/50: 100%|██████████| 171/171 [06:25<00:00,  2.26s/it, D=5468.8423, G=12.4780, L1=0.0829, Perceptual=0.3497]
                                                             


Epoch 3 - Val Loss: 11.9520 (L1: 0.0723, Perceptual: 0.4028)
Checkpoint saved: ./enhanced_pix2pix_project\ENHANCED_20251225_203059\checkpoints\best.pth
✓ New best model saved! (Val Loss: 11.9520)


Epoch 4/50: 100%|██████████| 171/171 [05:44<00:00,  2.02s/it, D=2265.5813, G=13.0931, L1=0.0869, Perceptual=0.3710]
                                                             


Epoch 4 - Val Loss: 11.3484 (L1: 0.0680, Perceptual: 0.3854)
Checkpoint saved: ./enhanced_pix2pix_project\ENHANCED_20251225_203059\checkpoints\best.pth
✓ New best model saved! (Val Loss: 11.3484)


Epoch 5/50: 100%|██████████| 171/171 [03:24<00:00,  1.20s/it, D=nan, G=nan, L1=0.0836, Perceptual=0.3592]         
                                                             


Epoch 5 - Val Loss: nan (L1: 0.0668, Perceptual: 0.3700)
Checkpoint saved: ./enhanced_pix2pix_project\ENHANCED_20251225_203059\checkpoints\epoch_0005.pth


Epoch 6/50: 100%|██████████| 171/171 [02:31<00:00,  1.13it/s, D=nan, G=nan, L1=0.1120, Perceptual=0.3943]
                                                             


Epoch 6 - Val Loss: nan (L1: 0.0678, Perceptual: 0.3690)


Epoch 7/50: 100%|██████████| 171/171 [02:35<00:00,  1.10it/s, D=nan, G=nan, L1=0.0950, Perceptual=0.3991]
                                                             


Epoch 7 - Val Loss: nan (L1: 0.0680, Perceptual: 0.3707)


Epoch 8/50: 100%|██████████| 171/171 [06:06<00:00,  2.14s/it, D=nan, G=nan, L1=0.0894, Perceptual=0.3284]
                                                             


Epoch 8 - Val Loss: nan (L1: 0.0674, Perceptual: 0.3692)


Epoch 9/50: 100%|██████████| 171/171 [07:50<00:00,  2.75s/it, D=nan, G=nan, L1=0.0971, Perceptual=0.3760]
                                                             


Epoch 9 - Val Loss: nan (L1: 0.0674, Perceptual: 0.3741)


Epoch 10/50: 100%|██████████| 171/171 [07:16<00:00,  2.55s/it, D=nan, G=nan, L1=0.0986, Perceptual=0.3808]
                                                             

KeyboardInterrupt: 

In [None]:
# Cell 11: INFERENCE FUNCTION
# ===========================================

def run_enhanced_inference(checkpoint_path, input_image_path, output_dir='./enhanced_inference_results'):

    os.makedirs(output_dir, exist_ok=True)
    
    # Load checkpoint
    print(f"Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    config_dict = checkpoint.get('config', {})
    
    # Create config
    class InferenceConfig:
        def __init__(self, config_dict):
            for key, value in config_dict.items():
                setattr(self, key, value)
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    inf_config = InferenceConfig(config_dict)
    
    # Initialize generator
    generator = EnhancedGeneratorUNet(
        in_channels=inf_config.IN_CHANNELS,
        out_channels=inf_config.OUT_CHANNELS,
        use_attention=getattr(inf_config, 'USE_ATTENTION', True),
        use_se=getattr(inf_config, 'USE_SE_BLOCKS', True)
    ).to(inf_config.device)
    
    generator.load_state_dict(checkpoint['generator_state_dict'])
    generator.eval()
    
    print(f"✓ Model loaded on {inf_config.device}")
    
    # Prepare input
    transform = transforms.Compose([
        transforms.Resize((inf_config.IMG_SIZE, inf_config.IMG_SIZE)),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    input_img = Image.open(input_image_path).convert('L')
    input_tensor = transform(input_img).unsqueeze(0).to(inf_config.device)
    
    # Generate output
    print("Generating colorized image...")
    with torch.no_grad():
        output_tensor = generator(input_tensor)
    
    # Save output
    output_img = output_tensor.squeeze(0).cpu()
    output_img = (output_img * 0.5 + 0.5).clamp(0, 1)
    output_img = transforms.ToPILImage()(output_img)
    
    input_name = os.path.splitext(os.path.basename(input_image_path))[0]
    output_path = os.path.join(output_dir, f'{input_name}_colorized.png')
    output_img.save(output_path)
    
    # Save input for comparison
    input_img.save(os.path.join(output_dir, f'{input_name}_input.png'))
    
    # Create side-by-side comparison
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    axes[0].imshow(input_img, cmap='gray')
    axes[0].set_title('Input (Grayscale)')
    axes[0].axis('off')
    
    axes[1].imshow(output_img)
    axes[1].set_title('Generated (RGB)')
    axes[1].axis('off')
    
    plt.tight_layout()
    comparison_path = os.path.join(output_dir, f'{input_name}_comparison.png')
    plt.savefig(comparison_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"\n✓ Results saved:")
    print(f"  - Input: {os.path.join(output_dir, f'{input_name}_input.png')}")
    print(f"  - Output: {output_path}")
    print(f"  - Comparison: {comparison_path}")
    
    return output_img


In [None]:
# BATCH INFERENCE FOR EVENT HALL IMAGES

def batch_inference(checkpoint_path, input_dir, output_dir='./batch_colorized_results', max_images=None):

    os.makedirs(output_dir, exist_ok=True)
    
    # Get all image files
    image_extensions = ('.png', '.jpg', '.jpeg', '.bmp')
    image_files = [f for f in os.listdir(input_dir) 
                   if f.lower().endswith(image_extensions)]
    
    if max_images:
        image_files = image_files[:max_images]
    
    print(f"\nProcessing {len(image_files)} images from {input_dir}")
    print("="*60)
    
    # Load model once
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    config_dict = checkpoint.get('config', {})
    
    class InferenceConfig:
        def __init__(self, config_dict):
            for key, value in config_dict.items():
                setattr(self, key, value)
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    inf_config = InferenceConfig(config_dict)
    
    generator = EnhancedGeneratorUNet(
        in_channels=inf_config.IN_CHANNELS,
        out_channels=inf_config.OUT_CHANNELS,
        use_attention=getattr(inf_config, 'USE_ATTENTION', True),
        use_se=getattr(inf_config, 'USE_SE_BLOCKS', True)
    ).to(inf_config.device)
    
    generator.load_state_dict(checkpoint['generator_state_dict'])
    generator.eval()
    
    transform = transforms.Compose([
        transforms.Resize((inf_config.IMG_SIZE, inf_config.IMG_SIZE)),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    # Process images
    for img_file in tqdm(image_files, desc="Colorizing images"):
        try:
            input_path = os.path.join(input_dir, img_file)
            input_img = Image.open(input_path).convert('L')
            input_tensor = transform(input_img).unsqueeze(0).to(inf_config.device)
            
            with torch.no_grad():
                output_tensor = generator(input_tensor)
            
            output_img = output_tensor.squeeze(0).cpu()
            output_img = (output_img * 0.5 + 0.5).clamp(0, 1)
            output_img = transforms.ToPILImage()(output_img)
            
            # Save output
            output_name = os.path.splitext(img_file)[0] + '_colorized.png'
            output_path = os.path.join(output_dir, output_name)
            output_img.save(output_path)
            
        except Exception as e:
            print(f"Error processing {img_file}: {str(e)}")
            continue
    
    print(f"\n✓ Batch inference complete!")
    print(f"✓ Colorized images saved to: {output_dir}")


In [None]:
# Cell 13: EXAMPLE USAGE
# ===========================================

# AFTER TRAINING IS COMPLETE, use these functions for inference:

# Example 1: Single image inference
# run_enhanced_inference(
#     checkpoint_path='./enhanced_pix2pix_project/ENHANCED_YYYYMMDD_HHMMSS/checkpoints/best.pth',
#     input_image_path='path/to/your/grayscale/image.jpg',
#     output_dir='./colorized_results'
# )

# Example 2: Batch inference on all 5000 event hall images
# batch_inference(
#     checkpoint_path='./enhanced_pix2pix_project/ENHANCED_YYYYMMDD_HHMMSS/checkpoints/best.pth',
#     input_dir='pix2pix_dataset/test/gray',
#     output_dir='./all_colorized_results',
#     max_images=None  # Process all images
# )
