In [1]:
!pip install pycocotools

Collecting pycocotools
  Downloading pycocotools-2.0.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (455 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m455.0/455.0 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m00:01[0m
Installing collected packages: pycocotools
Successfully installed pycocotools-2.0.10
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import os 
os.listdir('/kaggle/input/')
base_path = "/kaggle/input/coco2017"
print("Annotations:", os.listdir(os.path.join(base_path, "annotations")))
print("Training images:", len(os.listdir(os.path.join(base_path, "train2017"))))
print("Validation images:", len(os.listdir(os.path.join(base_path, "val2017"))))
print("Test images:", len(os.listdir(os.path.join(base_path, "test2017"))))

# Data visualization

Annotations: ['person_keypoints_train2017.json', 'instances_val2017.json', 'instances_train2017.json', 'person_keypoints_val2017.json', 'captions_train2017.json', 'captions_val2017.json']
Training images: 118287
Validation images: 5000
Test images: 40670


In [3]:
import os
import cv2
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import Dataset

IMAGE_DIRS = [
    f"{base_path}/train2017",
    f"{base_path}/test2017", 
    f"{base_path}/val2017"
]

# Memory-efficient approach: Just collect image paths, don't load images yet!
image_paths = []
for dir_path in IMAGE_DIRS:
    for fname in os.listdir(dir_path):
        if fname.endswith(".jpg"):
            image_paths.append(os.path.join(dir_path, fname))

print(f"Found {len(image_paths)} images total")

# Limit to a reasonable number for memory constraints
max_images = 10000  # Reduced for memory efficiency
image_paths = image_paths[:max_images]
print(f"Using {len(image_paths)} images for training")


Found 163957 images total
Using 10000 images for training


In [4]:
# Memory-efficient Dataset class - loads images on-the-fly!
class ColorizationDataset(Dataset):
    def __init__(self, image_paths, color_bins):
        self.image_paths = image_paths
        self.color_bins = color_bins
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        
        try:
            # Load and process image on-demand
            img_bgr = cv2.imread(img_path)
            if img_bgr is None:
                # Return a dummy image if loading fails
                L = np.zeros((256, 256), dtype=np.float32)
                ab_class = np.zeros((256, 256), dtype=np.int64)
                return torch.tensor(L).unsqueeze(0), torch.tensor(ab_class)
            
            # Resize to 256x256
            img_bgr = cv2.resize(img_bgr, (256, 256), interpolation=cv2.INTER_AREA)
            
            # Convert to LAB and split
            img_lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2Lab)
            L, a, b = cv2.split(img_lab)
            ab = np.stack([a, b], axis=-1)
            
            # Convert AB to class indices
            ab_tensor = torch.tensor(ab, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)  # (1, 2, H, W)
            ab_classes = self.ab_to_class_indices(ab_tensor, self.color_bins)
            
            # Normalize L channel to [0, 1]
            L = L.astype(np.float32) / 255.0
            
            return torch.tensor(L).unsqueeze(0), ab_classes.squeeze(0)
            
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Return dummy data on error
            L = np.zeros((256, 256), dtype=np.float32)
            ab_class = np.zeros((256, 256), dtype=np.int64)
            return torch.tensor(L).unsqueeze(0), torch.tensor(ab_class)
    
    def ab_to_class_indices(self, ab_batch, color_bins):
        """Convert AB values to class indices"""
        if isinstance(ab_batch, np.ndarray):
            ab_batch = torch.from_numpy(ab_batch)
        
        N, _, H, W = ab_batch.shape
        ab_pixels = ab_batch.permute(0, 2, 3, 1).reshape(-1, 2)  # (N*H*W, 2)
        ab_np = ab_pixels.numpy()
        
        # Broadcasting for distance calculation
        diff = ab_np[:, None, :] - color_bins[None, :, :]  # (N*H*W, 313, 2)
        distances = np.sum(diff**2, axis=2)  # (N*H*W, 313)
        class_indices = np.argmin(distances, axis=1)  # (N*H*W,)
        
        return torch.tensor(class_indices.reshape(N, H, W), dtype=torch.long)


# Training process
## Model definition

## Convolution Output Size Formula

Given:
- Input size `X`: `m × n` (Height × Width)
- Kernel size: `k × k`
- Stride: `s`
- Padding: `p`

The output size `Y` after applying a 2D convolution is:

$$
\text{Output Height} = \left\lfloor \frac{m - k + 2p}{s} \right\rfloor + 1
$$

$$
\text{Output Width} = \left\lfloor \frac{n - k + 2p}{s} \right\rfloor + 1
$$

So the output matrix `Y` has size:

$$
Y \in \mathbb{R}^{\left(\left\lfloor \frac{m - k + 2p}{s} \right\rfloor + 1\right) \times \left(\left\lfloor \frac{n - k + 2p}{s} \right\rfloor + 1\right)}
$$

---

### ✅ Example:

If:
- Input: `m = 32`, `n = 32`
- Kernel size: `k = 5`
- Padding: `p = 0`
- Stride: `s = 1`

Then:

$$
\text{Output Height} = \left\lfloor \frac{32 - 5 + 0}{1} \right\rfloor + 1 = 28
$$

$$
\text{Output Width} = \left\lfloor \frac{32 - 5 + 0}{1} \right\rfloor + 1 = 28
$$

Output size: `28 × 28`

---

### 🔎 Note:

To keep the same size as the input:
- Use **"same" padding**:
  
  $$
  p = \frac{k - 1}{2} \quad \text{(when } k \text{ is odd)}
  $$



In [5]:
import torch.nn as nn

class ColorizationNet(nn.Module):
    def __init__(self):
        super(ColorizationNet, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.ReLU(),
            nn.BatchNorm2d(64),

            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.ReLU(),
            nn.BatchNorm2d(128),

            nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(),
            nn.Conv2d(256, 256, 3, stride=2, padding=1), nn.ReLU(),
            nn.BatchNorm2d(256),

            nn.Conv2d(256, 512, 3, padding=1), nn.ReLU(),
            nn.Conv2d(512, 512, 3, stride=2, padding=1), nn.ReLU(),
            nn.BatchNorm2d(512)
        )

        # Optional middle block (refinement / deeper processing)
        self.refiner = nn.Sequential(
            nn.Conv2d(512, 512, 3, padding=1), nn.ReLU()
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(32, 313, 4, stride=2, padding=1)
        )
    def forward(self, x):
        encoder = self.encoder(x)
        for _ in range(3):  # Apply refiner block 3 times
            encoder = self.refiner(encoder)
        decoder = self.decoder(encoder)
        return decoder

## Change ground truth from ab-channel to 714 channels corresponding to 714 distinct colors

In [6]:
# Fast pre-computed quantization (much faster than K-means!)
def create_ab_quantization(n_bins=313):
    """
    Create pre-computed AB color space quantization bins
    This is much faster than K-means clustering
    """
    # Define AB color space range: A and B channels are typically [-110, 110]
    a_range = np.linspace(-110, 110, int(np.sqrt(n_bins)))
    b_range = np.linspace(-110, 110, int(np.sqrt(n_bins)))
    
    # Create grid of all AB combinations
    a_grid, b_grid = np.meshgrid(a_range, b_range)
    color_bins = np.column_stack([a_grid.flatten(), b_grid.flatten()])
    
    # Take only the first n_bins (in case of rounding)
    return color_bins[:n_bins]

n_bins = 313
color_bins = create_ab_quantization(n_bins)  # shape (313, 2) - MUCH faster!
print(f"Created {len(color_bins)} color bins in AB space")

def ab_to_class_indices(ab_batch, color_bins):
    """
    ab_batch: (N, 2, H, W) or (N, H, W, 2)
    color_bins: (313, 2)
    returns: (N, H, W) — class indices
    """
    if isinstance(ab_batch, np.ndarray):
        ab_batch = torch.from_numpy(ab_batch)
    
    # Handle different input formats
    if ab_batch.shape[1] == 2:  # (N, 2, H, W) format
        N, _, H, W = ab_batch.shape
        ab_pixels = ab_batch.permute(0, 2, 3, 1).reshape(-1, 2)  # (N*H*W, 2)
    else:  # (N, H, W, 2) format
        N, H, W, _ = ab_batch.shape
        ab_pixels = ab_batch.reshape(-1, 2)  # (N*H*W, 2)
    
    # Convert to numpy for distance calculation (faster than cdist with torch)
    ab_np = ab_pixels.numpy()
    
    # Use broadcasting for faster distance calculation than scipy.spatial.distance.cdist
    # ab_np: (N*H*W, 2), color_bins: (313, 2)
    diff = ab_np[:, None, :] - color_bins[None, :, :]  # (N*H*W, 313, 2)
    distances = np.sum(diff**2, axis=2)  # (N*H*W, 313) - squared distances
    class_indices = np.argmin(distances, axis=1)  # (N*H*W,)
    
    return torch.tensor(class_indices.reshape(N, H, W), dtype=torch.long)


Created 289 color bins in AB space


## Split data 80 - 20 (memory-efficient way)

In [7]:
# Split image paths instead of loaded images
from sklearn.model_selection import train_test_split

train_paths, test_paths = train_test_split(image_paths, test_size=0.2, random_state=42)

print(f"Training images: {len(train_paths)}")
print(f"Testing images: {len(test_paths)}")

# Create datasets (these don't load images yet!)
train_dataset = ColorizationDataset(train_paths, color_bins)
test_dataset = ColorizationDataset(test_paths, color_bins)

Training images: 8000
Testing images: 2000


## Fiting model

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Create DataLoaders (memory-efficient!)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)  # Reduced batch size
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=2)

Using device: cpu


In [9]:
# Initialize model
model = ColorizationNet()
model = model.to(device)

# Loss function and optimizer  
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [10]:
# 🔍 Training Status Checker - Run this cell to check progress
# This replaces the need for resume_training.py in Kaggle

import glob
import os
from safetensors.torch import load_file

def check_kaggle_training_status():
    """Check the current status of training and provide guidance for Kaggle."""
    
    print("🔍 Kaggle Training Session Status Check")
    print("=" * 50)
    
    # Check for checkpoints
    checkpoint_dir = "checkpoints"
    if os.path.exists(checkpoint_dir):
        checkpoint_files = glob.glob(f"{checkpoint_dir}/checkpoint_epoch_*.safetensors")
        
        if checkpoint_files:
            # Get latest checkpoint
            latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
            print(f"✅ Found {len(checkpoint_files)} checkpoint(s)")
            print(f"📍 Latest checkpoint: {os.path.basename(latest_checkpoint)}")
            
            # Load checkpoint info
            try:
                checkpoint_data = load_file(latest_checkpoint)
                epoch = int(checkpoint_data.get('epoch', 0))
                loss = float(checkpoint_data.get('loss', 0))
                timestamp = checkpoint_data.get('timestamp', 'unknown')
                
                print(f"🎯 Last completed epoch: {epoch}")
                print(f"📉 Last loss: {loss:.4f}")
                print(f"⏰ Checkpoint time: {timestamp}")
                
                # Calculate progress
                total_epochs = 50  # From updated configuration
                progress = (epoch + 1) / total_epochs * 100
                remaining_epochs = total_epochs - (epoch + 1)
                
                print(f"📊 Training progress: {progress:.1f}% ({epoch + 1}/{total_epochs} epochs)")
                print(f"🎯 Remaining epochs: {remaining_epochs}")
                
                if remaining_epochs > 0:
                    print(f"\n💡 Next steps:")
                    print(f"   1. Continue running the training cells below")
                    print(f"   2. Training will automatically resume from epoch {epoch + 1}")
                    print(f"   3. Estimated time: ~{remaining_epochs * 0.2:.1f} hours remaining")
                else:
                    print(f"\n🎉 Training appears to be complete!")
                    
                return epoch, loss, remaining_epochs
                    
            except Exception as e:
                print(f"⚠️  Could not read checkpoint details: {e}")
                return None, None, None
                
        else:
            print("⚠️  No checkpoints found in checkpoints/ directory")
            print("💡 Training will start from epoch 0")
            return 0, None, 50
    else:
        print("📁 No checkpoints directory found")
        print("💡 Training will start from the beginning")
        return 0, None, 50

# Run the status check
print("🚀 Checking training status before starting...")
last_epoch, last_loss, remaining = check_kaggle_training_status()

if last_epoch is not None and last_epoch > 0:
    print(f"\n✨ Will resume training from epoch {last_epoch + 1}")
else:
    print("\nWill start fresh training from epoch 0")

🚀 Checking training status before starting...
🔍 Kaggle Training Session Status Check
📁 No checkpoints directory found
💡 Training will start from the beginning

Will start fresh training from epoch 0


In [None]:
# Enhanced Training loop with checkpointing for Kaggle's 12-hour limit
import time
import glob
import os
from datetime import datetime
from safetensors.torch import save_file, load_file

# Kaggle-optimized configuration
num_epochs = 50  # Reduced from 1000 to fit within 12-hour limit
checkpoint_interval = 1  # Save checkpoint every 10 epochs
max_training_hours = 11  # Stop before 12-hour limit

def save_checkpoint(model, optimizer, epoch, loss, checkpoint_dir="checkpoints"):
    """Save training checkpoint"""
    os.makedirs(checkpoint_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    checkpoint_path = f"{checkpoint_dir}/checkpoint_epoch_{epoch:03d}_{timestamp}.safetensors"
    
    checkpoint_data = {
        **model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss,
        'timestamp': timestamp
    }
    
    save_file(checkpoint_data, checkpoint_path)
    print(f"💾 Checkpoint saved: {checkpoint_path}")
    return checkpoint_path

def load_latest_checkpoint(model, optimizer, checkpoint_dir="checkpoints"):
    """Load the latest checkpoint if available"""
    if not os.path.exists(checkpoint_dir):
        return 0, None
    
    checkpoint_files = glob.glob(f"{checkpoint_dir}/checkpoint_epoch_*.safetensors")
    if not checkpoint_files:
        return 0, None
    
    # Get the latest checkpoint
    latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
    print(f"🔄 Loading checkpoint: {latest_checkpoint}")
    
    checkpoint_data = load_file(latest_checkpoint)
    
    # Separate model weights from other data
    model_state = {k: v for k, v in checkpoint_data.items() if not k.startswith(('optimizer_', 'epoch', 'loss', 'timestamp'))}
    model.load_state_dict(model_state)
    optimizer.load_state_dict(checkpoint_data['optimizer_state_dict'])
    
    start_epoch = int(checkpoint_data['epoch']) + 1
    last_loss = float(checkpoint_data['loss'])
    print(f"✅ Resumed from epoch {start_epoch-1}, loss: {last_loss:.4f}")
    return start_epoch, last_loss

# Try to load from checkpoint
start_epoch, last_loss = load_latest_checkpoint(model, optimizer)
print(f"📊 Starting training from epoch {start_epoch}")

# Training loop with time monitoring
training_start_time = time.time()
max_training_seconds = max_training_hours * 3600

for epoch in range(start_epoch, num_epochs):
    # Check time limit
    elapsed_time = time.time() - training_start_time
    if elapsed_time > max_training_seconds:
        print(f"⏰ Reached time limit ({max_training_hours} hours). Stopping training...")
        print(f"💾 Saving final checkpoint...")
        save_checkpoint(model, optimizer, epoch-1, avg_loss)
        break
    
    model.train()
    running_loss = 0.0
    epoch_start_time = time.time()
    
    for i, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # Print progress every 50 batches
        if (i + 1) % 50 == 0:
            elapsed = time.time() - training_start_time
            remaining = max_training_seconds - elapsed
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}, Time remaining: {remaining/3600:.1f}h')
    
    avg_loss = running_loss / len(train_loader)
    epoch_time = time.time() - epoch_start_time
    total_elapsed = time.time() - training_start_time
    
    print(f"Epoch [{epoch+1}/{num_epochs}] completed in {epoch_time:.1f}s, Average Loss: {avg_loss:.4f}, Total time: {total_elapsed/3600:.2f}h")
    
    # Save checkpoint at intervals
    if (epoch + 1) % checkpoint_interval == 0:
        save_checkpoint(model, optimizer, epoch, avg_loss)
        print(f"📈 Progress: {(epoch+1)/num_epochs*100:.1f}% complete")

print("\n🎉 Training completed!")
print(f"⏱️  Total training time: {(time.time() - training_start_time)/3600:.2f} hours")


📊 Starting training from epoch 0
Epoch [1/50], Batch [50/1000], Loss: 0.5045, Time remaining: 11.0h
Epoch [1/50], Batch [100/1000], Loss: 1.1799, Time remaining: 10.9h
Epoch [1/50], Batch [150/1000], Loss: 0.1648, Time remaining: 10.9h
Epoch [1/50], Batch [200/1000], Loss: 0.1915, Time remaining: 10.8h
Epoch [1/50], Batch [250/1000], Loss: 0.4012, Time remaining: 10.8h
Epoch [1/50], Batch [300/1000], Loss: 0.1492, Time remaining: 10.7h
Epoch [1/50], Batch [350/1000], Loss: 0.3735, Time remaining: 10.7h
Epoch [1/50], Batch [400/1000], Loss: 0.1863, Time remaining: 10.6h
Epoch [1/50], Batch [450/1000], Loss: 0.1544, Time remaining: 10.6h
Epoch [1/50], Batch [500/1000], Loss: 0.1915, Time remaining: 10.5h
Epoch [1/50], Batch [550/1000], Loss: 0.4390, Time remaining: 10.5h
Epoch [1/50], Batch [600/1000], Loss: 0.4974, Time remaining: 10.5h
Epoch [1/50], Batch [650/1000], Loss: 0.1582, Time remaining: 10.4h


# 🚀 Kaggle Session Management Guide

## 📋 How to Handle Kaggle's 12-Hour Session Limit

### ✅ What This Enhanced Training Does:
- **Automatic Checkpointing**: Saves progress every 10 epochs
- **Time Monitoring**: Tracks elapsed time and stops before 12-hour limit
- **Resume Capability**: Automatically resumes from latest checkpoint when restarted
- **Optimized Epochs**: Reduced to 50 epochs (from 1000) to fit time constraint

### 🔄 If Your Session Gets Interrupted:

1. **Re-run the notebook from the beginning** (to reload libraries and data)
2. **Run the "Training Status Checker" cell** to see your progress
3. **The training will automatically resume** from the latest checkpoint
4. **No need for external scripts** - everything works within the notebook!

### 📊 Monitoring Progress:
- **Real-time time tracking**: Shows remaining time in each batch
- **Checkpoint intervals**: Progress saved every 10 epochs
- **Memory-efficient**: Uses on-demand image loading

### 🛠️ Manual Checkpoint Management:

```python
# The "Training Status Checker" cell above handles this automatically!
# But you can also manually check checkpoints:

import glob
checkpoints = glob.glob('checkpoints/*.safetensors')
for cp in sorted(checkpoints):
    print(cp)

# The training loop automatically finds and loads the latest checkpoint
```

### ⚡ Performance Tips for Kaggle:
- **Batch size**: Optimized to 8 for memory efficiency
- **Epochs**: Reduced to 50 to ensure completion
- **GPU utilization**: Monitor GPU usage in Kaggle sidebar
- **Storage**: Checkpoints and models saved to persistent storage

# 🛠️ Additional Optimizations & Troubleshooting

## ⚡ Further Optimizations for Kaggle

### Memory Management:
```python
# If you encounter memory issues, try these:

# 1. Reduce batch size further
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=1)

# 2. Clear GPU cache periodically
import torch
torch.cuda.empty_cache()

# 3. Reduce image count for testing
image_paths = image_paths[:10000]  # Use subset for faster testing
```

### Performance Monitoring:
```python
# Monitor GPU usage
import GPUtil
gpus = GPUtil.getGPUs()
for gpu in gpus:
    print(f"GPU {gpu.id}: {gpu.memoryUtil*100:.1f}% memory, {gpu.load*100:.1f}% load")
```

### Alternative Training Strategies:

#### 1. **Progressive Training** (if time is very limited):
```python
# Train on a smaller subset first, then gradually increase
small_paths = image_paths[:5000]  # Start with 5k images
# After initial training, increase to full dataset
```

#### 2. **Transfer Learning** (faster convergence):
```python
# Use a pre-trained encoder (like ResNet) instead of training from scratch
# This can significantly reduce training time
```

#### 3. **Learning Rate Scheduling**:
```python
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)
# Add to training loop: scheduler.step(avg_loss)
```

## 🐛 Common Issues & Solutions

| Issue | Solution |
|-------|----------|
| Out of memory | Reduce batch_size to 4 or 2 |
| Training too slow | Reduce num_workers to 1, use smaller dataset subset |
| Checkpoint loading fails | Check file permissions, ensure safetensors is installed |
| Session timeout | Run `python resume_training.py` to check status |
| Model not saving | Ensure `saved_models/` directory has write permissions |

## 📊 Expected Results

After **50 epochs** with the COCO dataset, you should expect:
- **Loss**: Should decrease from ~5.0 to ~2.0-3.0
- **Quality**: Basic colorization with some realistic colors
- **Time**: ~8-10 hours on Kaggle GPU
- **File sizes**: ~100MB for model, ~10MB per checkpoint

## 🎆 Next Steps After Training

1. **Test the model** on new images using the inference cells
2. **Fine-tune** with different learning rates or architectures
3. **Experiment** with different color quantization strategies
4. **Deploy** the model for real-world applications

---

*This enhanced notebook is optimized for Kaggle's constraints while maintaining training quality!*

In [None]:
# Import libraries for model saving
from safetensors.torch import save_file, load_file
import os
import json
from datetime import datetime

In [None]:
# Save the trained model
print("\n" + "="*50)
print("💾 Saving trained model...")
print("="*50)

# Create a directory for saved models if it doesn't exist
os.makedirs("saved_models", exist_ok=True)

# Create a timestamp for the model filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_filename = f"saved_models/colorization_model_{timestamp}.safetensors"

# Save the model state dict using safetensors
save_file(model.state_dict(), model_filename)
print(f"✅ Model saved as: {model_filename}")

# Save training metadata
metadata = {
    "model_name": "Colorization CNN",
    "epochs": num_epochs,
    "batch_size": 8,
    "learning_rate": 0.001,
    "optimizer": "Adam",
    "loss_function": "CrossEntropyLoss",
    "num_color_bins": 313,
    "input_size": "256x256",
    "architecture": "encoder_decoder_with_refiner",
    "dataset": "COCO 2017",
    "num_training_images": len(train_paths),
    "num_testing_images": len(test_paths),
    "train_test_split": "80/20",
    "color_space": "LAB",
    "training_date": datetime.now().isoformat(),
    "final_avg_loss": avg_loss
}

# Save metadata as JSON
metadata_filename = f"saved_models/training_metadata_{timestamp}.json"
with open(metadata_filename, 'w') as f:
    json.dump(metadata, f, indent=4)

print(f"✅ Training metadata saved as: {metadata_filename}")
print("\n🎉 Model and metadata saved successfully!")
print(f"📁 Check the 'saved_models' directory for your files.")

In [None]:
# Functions to load saved models
def load_colorization_model(model_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """
    Load a saved colorization model from safetensors file
    
    Args:
        model_path: Path to the .safetensors file
        device: Device to load the model on ('cuda' or 'cpu')
    
    Returns:
        Loaded model ready for inference
    """
    print(f"Loading model from: {model_path}")
    
    # Initialize the model architecture (same as training)
    model = ColorizationNet()
    
    # Load the saved state dict
    state_dict = load_file(model_path)
    model.load_state_dict(state_dict)
    
    # Move to device and set to evaluation mode
    model = model.to(device)
    model.eval()
    
    print(f"✅ Model loaded successfully!")
    print(f"🎯 Model is on device: {next(model.parameters()).device}")
    print(f"🔧 Model is in evaluation mode: {not model.training}")
    
    return model

def demo_load_model():
    """
    Demo function showing how to load the most recent saved model
    """
    # List available saved models
    if os.path.exists("saved_models"):
        saved_models = [f for f in os.listdir("saved_models") if f.endswith(".safetensors")]
        if saved_models:
            print("Available saved models:")
            for i, model_file in enumerate(saved_models):
                print(f"  {i+1}. {model_file}")
            
            # Load the most recent model (last in list when sorted)
            latest_model = sorted(saved_models)[-1]
            model_path = f"saved_models/{latest_model}"
            
            print(f"\nLoading latest model: {latest_model}")
            loaded_model = load_colorization_model(model_path)
            return loaded_model
        else:
            print("No saved models found in 'saved_models' directory")
            return None
    else:
        print("'saved_models' directory doesn't exist yet. Train the model first!")
        return None

def test_saved_model():
    """
    Test function to verify the saved model can be loaded and used
    """
    try:
        # Try to load the most recent model
        loaded_model = demo_load_model()
        
        if loaded_model is not None:
            print("\n🧪 Testing loaded model...")
            
            # Test with a dummy input
            dummy_input = torch.randn(1, 1, 256, 256).to(device)
            
            with torch.no_grad():
                output = loaded_model(dummy_input)
                print(f"✅ Model inference successful!")
                print(f"   Input shape: {dummy_input.shape}")
                print(f"   Output shape: {output.shape}")
                print(f"   Expected output shape: (1, 313, 256, 256)")
                
                if output.shape == (1, 313, 256, 256):
                    print("🎉 Model loaded and working perfectly!")
                else:
                    print("⚠️  Output shape doesn't match expected dimensions")
        
    except Exception as e:
        print(f"❌ Error testing saved model: {e}")

print("🔧 Model saving and loading functions defined!")
print("💡 After training completes, your model will be automatically saved!")
print("💡 To load a saved model later, use: loaded_model = demo_load_model()")
print("💡 To test a saved model, use: test_saved_model()")

# Test phase - using memory-efficient approach

In [None]:
# Test with a single image from the test dataset
model.eval()

# Get one batch from test loader
test_iter = iter(test_loader)
test_batch = next(test_iter)
gray_input, ground_truth = test_batch

# Use the first image in the batch
gray_input = gray_input[0:1].to(device)  # shape: [1, 1, H, W]

print(f"Input shape: {gray_input.shape}")

with torch.no_grad():
    output = model(gray_input)  # shape: [1, 313, H, W]
    print(f"Output shape: {output.shape}")

In [None]:
# Get predicted color classes
pred_classes = output.argmax(dim=1)  # shape: [1, H, W]
print(f"Predicted classes shape: {pred_classes.shape}")

In [None]:
pred_classes_np = pred_classes.squeeze().cpu().numpy()  # shape: (H, W)
H, W = pred_classes_np.shape
ab_decoded = color_bins[pred_classes_np.flatten()]      # shape: (H*W, 2)
ab_decoded = ab_decoded.reshape(H, W, 2).astype(np.float32)

In [None]:
gray_input_np = gray_input.squeeze().cpu().numpy()  # (H, W)
lab_img = np.zeros((H, W, 3), dtype=np.float32)
lab_img[..., 0] = gray_input_np * 255.0  # Set L channel (denormalize from [0,1] to [0,255])
lab_img[..., 1:] = ab_decoded   # Set A and B channels

In [None]:
import cv2
rgb_img = cv2.cvtColor(lab_img, cv2.COLOR_LAB2RGB)
rgb_img = np.clip(rgb_img, 0, 1)

import matplotlib.pyplot as plt
plt.imshow(rgb_img)
plt.title("Colorized Image")
plt.axis("off")
plt.show()

In [None]:
# Optional: Test the saved model loading functionality
print("\n" + "="*50)
print("🧪 Testing saved model functionality...")
print("="*50)

# Test loading the saved model
test_saved_model()