In [1]:
# %pip install pytorch-msssim
# %pip install torchmetrics
# %pip install lpips
# %pip install h5py
# %pip install pathos
# %pip install multiprocess
# %pip install lmdb

In [2]:
!nvidia-smi

Thu Jul  3 04:28:06 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 576.52                 Driver Version: 576.52         CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce GTX 1650 Ti   WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   55C    P5              5W /   50W |      89MiB /   4096MiB |     27%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
import torch
import os
import gc
import cv2
import lpips
import lmdb
import time
import shutil
import random
import torchmetrics
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.multiprocessing as mp
import torchvision.utils as vutils
from tqdm import tqdm
from functools import partial
from torchvision import transforms
from torch.utils.data import Dataset
from pytorch_msssim import MS_SSIM
from multiprocess.pool import ThreadPool
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from torchmetrics.image import PeakSignalNoiseRatio

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
mp.set_start_method('spawn', force=True)

In [5]:
# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available(
) else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {DEVICE}")
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

COLORBLIND_TYPES = ["protanopia", "deuteranopia", "tritanopia"]
NUM_EPOCHS = 50
BATCH_SIZE = 16
IMG_SIZE = 256

Using device: cuda


In [6]:
def process_batch(file_batch, data_dir, cb_type, img_size):
    original_dir = os.path.join(data_dir, "original")
    type_dir = os.path.join(data_dir, cb_type)

    orig_batch = []
    cb_batch = []

    for fname in file_batch:
        try:
            orig_img = cv2.imread(os.path.join(original_dir, fname))
            if orig_img is None:
                raise FileNotFoundError(f"Original image {fname} not found")
            orig_img = cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB)
            orig_img = cv2.resize(orig_img, (img_size, img_size))

            cb_img = cv2.imread(os.path.join(type_dir, fname))
            if cb_img is None:
                raise FileNotFoundError(f"Colorblind image {fname} not found")
            cb_img = cv2.cvtColor(cb_img, cv2.COLOR_BGR2RGB)
            cb_img = cv2.resize(cb_img, (img_size, img_size))

            orig_batch.append(orig_img)
            cb_batch.append(cb_img)
        except Exception as e:
            print(f"\nError processing {fname}: {str(e)}")
            raise

    return np.array(orig_batch), np.array(cb_batch)


def preprocess_type(data_dir, cb_type, output_path, num_workers=6, batch_size=1000):
    """Preprocess and write LMDB for one colorblind type"""
    type_dir = os.path.join(data_dir, cb_type)
    files = [f for f in os.listdir(type_dir)
             if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    if not files:
        raise ValueError(f"No images found in {type_dir}")

    batches = [files[i:i + batch_size]
               for i in range(0, len(files), batch_size)]

    env = None
    try:
        # Estimate needed map size
        map_size = 15 * 1024 ** 3  # 15 GB max
        env = lmdb.open(output_path, map_size=map_size)

        with env.begin(write=True) as txn:
            with ThreadPool(num_workers) as p:
                with tqdm(total=len(batches), desc=f"Processing {cb_type}") as pbar:
                    process_func = partial(process_batch,
                                           data_dir=data_dir,
                                           cb_type=cb_type,
                                           img_size=IMG_SIZE)

                    total_idx = 0
                    for orig_data, cb_data in p.imap(process_func, batches):
                        for img_idx in range(len(orig_data)):
                            txn.put(f'original_{total_idx}'.encode(),
                                    orig_data[img_idx].tobytes())
                            txn.put(f'{cb_type}_{total_idx}'.encode(),
                                    cb_data[img_idx].tobytes())
                            total_idx += 1
                        pbar.update(1)

    except Exception as e:
        print(f"Error during preprocessing {cb_type}: {str(e)}")
        if os.path.exists(output_path):
            print(f"Deleting incomplete LMDB {output_path}")
            # Add retry loop for Windows
            for _ in range(3):
                try:
                    shutil.rmtree(output_path, ignore_errors=True)
                    break
                except PermissionError:
                    time.sleep(0.5)
        raise
    finally:
        if env is not None:
            env.close()
            # Add this for Windows compatibility
            if os.name == 'nt':
                try:
                    os.remove(os.path.join(output_path, 'lock.mdb'))
                except Exception as e:
                    pass

In [7]:
class ColorblindDataset(Dataset):
    def __init__(self, lmdb_path, cb_type, transform=None):
        self.lmdb_path = os.path.abspath(lmdb_path)
        self.cb_type = cb_type
        self.transform = transform or transforms.Compose([
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])
        self._env = None  # Will be initialized per worker

        # Use temporary environment to count entries
        with lmdb.open(self.lmdb_path, readonly=True, lock=False) as env:
            with env.begin() as txn:
                self._len = txn.stat()['entries'] // 2
                if self._len == 0:
                    raise ValueError(f"LMDB {lmdb_path} contains 0 samples.")

    def _init_env(self):
        if self._env is None:
            self._env = lmdb.open(
                self.lmdb_path, readonly=True, lock=False, max_readers=128)

    def __getitem__(self, idx):
        try:
            self._init_env()
            with self._env.begin(buffers=True) as txn:
                orig_bytes = txn.get(f'original_{idx}'.encode())
                cb_bytes = txn.get(f'{self.cb_type}_{idx}'.encode())

                if orig_bytes is None or cb_bytes is None:
                    raise RuntimeError(f"Missing data for index {idx}")

                orig = np.frombuffer(orig_bytes, dtype=np.uint8).reshape(
                    IMG_SIZE, IMG_SIZE, 3)
                cb = np.frombuffer(cb_bytes, dtype=np.uint8).reshape(
                    IMG_SIZE, IMG_SIZE, 3)

            orig = torch.from_numpy(orig.copy()).permute(
                2, 0, 1).float().div(255)
            cb = torch.from_numpy(cb.copy()).permute(2, 0, 1).float().div(255)

            if self.transform:
                orig = self.transform(orig)
                cb = self.transform(cb)

            return orig, cb
        except Exception as e:
            print(f"Error loading index {idx}: {e}")
            raise

    def __len__(self):
        return self._len

    def close(self):
        if self._env is not None:
            self._env.close()
            self._env = None

In [8]:
# Autoencoder Model
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(True))
        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64), nn.ReLU(True))
        self.enc3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128), nn.ReLU(True))
        self.enc4 = nn.Sequential(
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256), nn.ReLU(True))
        self.enc5 = nn.Sequential(
            nn.Conv2d(256, 512, 3, stride=2, padding=1),
            nn.BatchNorm2d(512), nn.ReLU(True))

        # Decoder with Skip Connections
        self.dec5 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256), nn.ReLU(True))
        self.dec4 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128), nn.ReLU(True))
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64), nn.ReLU(True))
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(128, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(True))
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh())

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)  # 128
        e2 = self.enc2(e1)  # 64
        e3 = self.enc3(e2)  # 32
        e4 = self.enc4(e3)  # 16
        e5 = self.enc5(e4)  # 8

        # Decoder
        d5 = self.dec5(e5)
        d4 = self.dec4(torch.cat([d5, e4], 1))
        d3 = self.dec3(torch.cat([d4, e3], 1))
        d2 = self.dec2(torch.cat([d3, e2], 1))
        d1 = self.dec1(torch.cat([d2, e1], 1))
        return d1

In [9]:
# Training Utilities
class EarlyStopping:
    def __init__(self, patience=5, delta=0.001):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, val_loss):
        if (self.best_loss - val_loss) > self.delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True


def process_image(tensor):
    """Universal image processing function for proper denormalization"""
    img = tensor.detach().cpu().numpy()
    img = img.transpose(1, 2, 0)  # CHW to HWC
    img = np.clip(img, -1, 1)     # Ensure values are in [-1, 1]
    img = (img * 0.5 + 0.5)       # Scale to [0, 1]
    img = (img * 255).astype(np.uint8)
    return img

# Visualization with Saving


def visualize_and_save(orig, target, output, cb_type, epoch, batch_idx, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    for ax, img, title in zip(axes, [orig, target, output], ['Original', 'TargetDaltonized', 'Generated']):
        processed_img = process_image(img)
        ax.imshow(processed_img)
        ax.set_title(f"{title} ({cb_type})")
        ax.axis('off')

    plt.savefig(os.path.join(save_dir, f"{cb_type}_e{epoch+1}_b{batch_idx}.png"),
                bbox_inches='tight',
                dpi=100)
    plt.close()

In [10]:
def worker_init_fn(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [11]:
# Modified DataLoader setup
def get_loader(dataset, batch_size, num_workers, shuffle):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        worker_init_fn=worker_init_fn,
        persistent_workers=False
    )

In [12]:
def train_colorblind_type(cb_type):
    train_lmdb_path = f"temp_train_{cb_type}.lmdb"
    val_lmdb_path = f"temp_val_{cb_type}.lmdb"
    metric_path = f"metrics_{cb_type}.npz"
    checkpoint_path = f"checkpoint_{cb_type}.pth"

    try:
        # Check for existing checkpoint
        resume = os.path.exists(checkpoint_path)

        if not resume:
            # Fresh training setup
            print(f"\n{'='*40}\nPreprocessing {cb_type}\n{'='*40}")
            if not os.path.exists(train_lmdb_path):
                preprocess_type("data/train", cb_type, train_lmdb_path)
            if not os.path.exists(val_lmdb_path):
                preprocess_type("data/val", cb_type, val_lmdb_path)

            # Validate LMDBs only for fresh start
            for path in [train_lmdb_path, val_lmdb_path]:
                with lmdb.open(path, readonly=True) as env:
                    with env.begin() as txn:
                        entries = txn.stat()["entries"]
                        if entries == 0:
                            raise RuntimeError(f"LMDB {path} is empty!")
                        print(f"{path} has {entries//2} samples")

        # Initialize model and components
        model = Autoencoder().to(DEVICE)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        scaler = GradScaler()
        early_stopper = EarlyStopping(patience=5)
        ms_ssim = MS_SSIM(data_range=2.0).to(DEVICE)
        psnr = PeakSignalNoiseRatio().to(DEVICE)
        lpips_model = lpips.LPIPS(net='alex').to(DEVICE)

        start_epoch = 0
        metrics = {
            'train_loss': [], 'val_loss': [],
            'train_psnr': [], 'val_psnr': [],
            'train_lpips': [], 'val_lpips': []
        }

        # Resume logic
        if resume:
            print(f"\n{'='*40}\nResuming {cb_type} training\n{'='*40}")
            checkpoint = torch.load(checkpoint_path)

            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
            early_stopper.counter = checkpoint['early_stop_counter']
            start_epoch = checkpoint['epoch'] + 1
            metrics = checkpoint['metrics']

            print(f"Resuming from epoch {start_epoch} with:")
            print(f"- Best val loss: {min(metrics['val_loss']):.4f}")
            print(f"- Current early stop counter: {early_stopper.counter}")

        # Data loading (always needed)
        train_dataset = ColorblindDataset(train_lmdb_path, cb_type)
        val_dataset = ColorblindDataset(val_lmdb_path, cb_type)
        train_loader = get_loader(
            train_dataset, BATCH_SIZE, num_workers=0, shuffle=True)
        val_loader = get_loader(val_dataset, BATCH_SIZE,
                                num_workers=0, shuffle=False)

        # Training loop
        for epoch in range(start_epoch, NUM_EPOCHS):
            model.train()
            epoch_loss = epoch_psnr = epoch_lpips = 0.0

            # Training phase
            for batch_idx, (orig, target) in tqdm(enumerate(train_loader),
                                                  desc=f"{cb_type} Epoch {epoch+1}",
                                                  total=len(train_loader)):
                orig, target = orig.to(DEVICE), target.to(DEVICE)

                with autocast():
                    output = model(orig)
                    loss = 1 - ms_ssim(output, target)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

                with torch.no_grad():
                    output_scaled = (output + 1) / 2
                    target_scaled = (target + 1) / 2
                    psnr_val = psnr(output_scaled, target_scaled)
                    lpips_val = lpips_model(output, target).mean()

                epoch_loss += loss.item()
                epoch_psnr += psnr_val.item()
                epoch_lpips += lpips_val.item()

                if batch_idx % 500 == 0:
                    visualize_and_save(orig[0], target[0], output[0], cb_type,
                                       epoch, batch_idx, "training_samples")

            # Update metrics
            metrics['train_loss'].append(epoch_loss / len(train_loader))
            metrics['train_psnr'].append(epoch_psnr / len(train_loader))
            metrics['train_lpips'].append(epoch_lpips / len(train_loader))

            # Validation phase
            model.eval()
            val_loss = val_psnr = val_lpips = 0.0
            with torch.no_grad():
                for orig, target in val_loader:
                    orig, target = orig.to(DEVICE), target.to(DEVICE)
                    output = model(orig)

                    loss = 1 - ms_ssim(output, target)
                    output_scaled = (output + 1) / 2
                    target_scaled = (target + 1) / 2

                    val_loss += loss.item()
                    val_psnr += psnr(output_scaled, target_scaled).item()
                    val_lpips += lpips_model(output, target).mean().item()

            metrics['val_loss'].append(val_loss / len(val_loader))
            metrics['val_psnr'].append(val_psnr / len(val_loader))
            metrics['val_lpips'].append(val_lpips / len(val_loader))

            # Save checkpoint
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
                'early_stop_counter': early_stopper.counter,
                'metrics': metrics
            }, checkpoint_path)

            # Save best model
            if metrics['val_loss'][-1] == min(metrics['val_loss']):
                torch.save(model.state_dict(), f"best_model_{cb_type}.pth")
                visualize_and_save(orig[0], target[0], output[0], cb_type,
                                   epoch, 0, "best_results")

            # Early stopping check
            early_stopper(metrics['val_loss'][-1])
            if early_stopper.early_stop:
                print(f"Early stopping {cb_type} at epoch {epoch+1}")
                break

            # Print progress
            print(f"[Epoch {epoch+1}/{NUM_EPOCHS}] "
                  f"Train Loss: {metrics['train_loss'][-1]:.4f}, "
                  f"Val Loss: {metrics['val_loss'][-1]:.4f}, "
                  f"Train PSNR: {metrics['train_psnr'][-1]:.2f}, "
                  f"Val PSNR: {metrics['val_psnr'][-1]:.2f}, "
                  f"Train LPIPS: {metrics['train_lpips'][-1]:.4f}, "
                  f"Val LPIPS: {metrics['val_lpips'][-1]:.4f}")

        # Final cleanup after successful completion
        np.savez(metric_path, **metrics)
        torch.save(model.state_dict(), f"final_model_{cb_type}.pth")

        # Remove temporary files
        if os.path.exists(checkpoint_path):
            os.remove(checkpoint_path)

        print(f"\n{'='*40}\nTraining completed for {cb_type}\n{'='*40}")

    finally:
        # Cleanup resources without deleting LMDBs
        if 'train_dataset' in locals():
            train_dataset.close()
        if 'val_dataset' in locals():
            val_dataset.close()

        if 'train_loader' in locals():
            del train_loader
        if 'val_loader' in locals():
            del val_loader

        torch.cuda.empty_cache()
        gc.collect()

In [13]:
def force_delete(path):
    """Robust deletion with Windows lock handling"""
    for _ in range(5):  # 5 retries
        try:
            if os.path.exists(path):
                shutil.rmtree(path, ignore_errors=True)
                # Handle Windows lock file explicitly
                if os.name == 'nt':
                    lock_file = os.path.join(path, 'lock.mdb')
                    if os.path.exists(lock_file):
                        os.remove(lock_file)
                print(f"Successfully deleted {path}")
                return
        except Exception as e:
            print(f"Attempt {_+1}/5 failed for {path}: {str(e)}")
            time.sleep(1)

In [14]:
if __name__ == "__main__":
    os.makedirs("training_samples", exist_ok=True)
    os.makedirs("best_results", exist_ok=True)

    for cb_type in COLORBLIND_TYPES:
        train_colorblind_type(cb_type)
        current_files = [
            f"temp_train_{cb_type}.lmdb",
            f"temp_val_{cb_type}.lmdb"
        ]
        for path in current_files:
            print(f"\nCleaning up {path} before next type...")
            force_delete(path)

    print("\nTraining completed for all types!")


Preprocessing protanopia


Processing protanopia: 100%|██████████| 30/30 [03:14<00:00,  6.49s/it]
Processing protanopia: 100%|██████████| 3/3 [00:23<00:00,  7.89s/it]


temp_train_protanopia.lmdb has 29670 samples
temp_val_protanopia.lmdb has 2967 samples
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: c:\Users\Taneem\AppData\Local\Programs\Python\Python312\Lib\site-packages\lpips\weights\v0.1\alex.pth


protanopia Epoch 1: 100%|██████████| 1855/1855 [19:21<00:00,  1.60it/s]


[Epoch 1/50] Train Loss: 0.1130, Val Loss: 0.0491, Train PSNR: 19.19, Val PSNR: 22.14, Train LPIPS: 0.2612, Val LPIPS: 0.1582


protanopia Epoch 2: 100%|██████████| 1855/1855 [18:31<00:00,  1.67it/s]


[Epoch 2/50] Train Loss: 0.0610, Val Loss: 0.0295, Train PSNR: 22.44, Val PSNR: 23.87, Train LPIPS: 0.1380, Val LPIPS: 0.1141


protanopia Epoch 3: 100%|██████████| 1855/1855 [18:32<00:00,  1.67it/s]


[Epoch 3/50] Train Loss: 0.0519, Val Loss: 0.0284, Train PSNR: 23.48, Val PSNR: 24.72, Train LPIPS: 0.1150, Val LPIPS: 0.1024


protanopia Epoch 4: 100%|██████████| 1855/1855 [18:31<00:00,  1.67it/s]


[Epoch 4/50] Train Loss: 0.0488, Val Loss: 0.0239, Train PSNR: 24.00, Val PSNR: 25.27, Train LPIPS: 0.1062, Val LPIPS: 0.0931


protanopia Epoch 5: 100%|██████████| 1855/1855 [18:25<00:00,  1.68it/s]


[Epoch 5/50] Train Loss: 0.0429, Val Loss: 0.0260, Train PSNR: 24.51, Val PSNR: 25.57, Train LPIPS: 0.0974, Val LPIPS: 0.0926


protanopia Epoch 6: 100%|██████████| 1855/1855 [18:23<00:00,  1.68it/s]


[Epoch 6/50] Train Loss: 0.0672, Val Loss: 0.0244, Train PSNR: 23.46, Val PSNR: 25.37, Train LPIPS: 0.1169, Val LPIPS: 0.0888


protanopia Epoch 7: 100%|██████████| 1855/1855 [18:26<00:00,  1.68it/s]


[Epoch 7/50] Train Loss: 0.0547, Val Loss: 0.0224, Train PSNR: 24.25, Val PSNR: 25.65, Train LPIPS: 0.1007, Val LPIPS: 0.0872


protanopia Epoch 8: 100%|██████████| 1855/1855 [18:31<00:00,  1.67it/s]


[Epoch 8/50] Train Loss: 0.0505, Val Loss: 0.0235, Train PSNR: 24.60, Val PSNR: 25.99, Train LPIPS: 0.0950, Val LPIPS: 0.0866


protanopia Epoch 9: 100%|██████████| 1855/1855 [18:27<00:00,  1.68it/s]


[Epoch 9/50] Train Loss: 0.0430, Val Loss: 0.0241, Train PSNR: 25.04, Val PSNR: 25.75, Train LPIPS: 0.0885, Val LPIPS: 0.0812


protanopia Epoch 10: 100%|██████████| 1855/1855 [18:14<00:00,  1.70it/s]


[Epoch 10/50] Train Loss: 0.0360, Val Loss: 0.0183, Train PSNR: 25.41, Val PSNR: 26.62, Train LPIPS: 0.0845, Val LPIPS: 0.0759


protanopia Epoch 11: 100%|██████████| 1855/1855 [18:21<00:00,  1.68it/s]


[Epoch 11/50] Train Loss: 0.0290, Val Loss: 0.0173, Train PSNR: 25.83, Val PSNR: 26.81, Train LPIPS: 0.0800, Val LPIPS: 0.0739


protanopia Epoch 12: 100%|██████████| 1855/1855 [18:17<00:00,  1.69it/s]


[Epoch 12/50] Train Loss: 0.0308, Val Loss: 0.0181, Train PSNR: 25.97, Val PSNR: 27.08, Train LPIPS: 0.0790, Val LPIPS: 0.0715


protanopia Epoch 13: 100%|██████████| 1855/1855 [18:22<00:00,  1.68it/s]


[Epoch 13/50] Train Loss: 0.0276, Val Loss: 0.0158, Train PSNR: 26.23, Val PSNR: 27.27, Train LPIPS: 0.0754, Val LPIPS: 0.0678


protanopia Epoch 14: 100%|██████████| 1855/1855 [18:25<00:00,  1.68it/s]


[Epoch 14/50] Train Loss: 0.0267, Val Loss: 0.0156, Train PSNR: 26.37, Val PSNR: 27.18, Train LPIPS: 0.0742, Val LPIPS: 0.0650


protanopia Epoch 15: 100%|██████████| 1855/1855 [18:20<00:00,  1.69it/s]


[Epoch 15/50] Train Loss: 0.0230, Val Loss: 0.0145, Train PSNR: 26.73, Val PSNR: 27.53, Train LPIPS: 0.0704, Val LPIPS: 0.0656


protanopia Epoch 16: 100%|██████████| 1855/1855 [18:25<00:00,  1.68it/s]


[Epoch 16/50] Train Loss: 0.0220, Val Loss: 0.0137, Train PSNR: 26.90, Val PSNR: 27.62, Train LPIPS: 0.0682, Val LPIPS: 0.0614


protanopia Epoch 17: 100%|██████████| 1855/1855 [18:29<00:00,  1.67it/s]


[Epoch 17/50] Train Loss: 0.0226, Val Loss: 0.0134, Train PSNR: 26.98, Val PSNR: 27.81, Train LPIPS: 0.0670, Val LPIPS: 0.0598


protanopia Epoch 18: 100%|██████████| 1855/1855 [18:19<00:00,  1.69it/s]


[Epoch 18/50] Train Loss: 0.0201, Val Loss: 0.0129, Train PSNR: 27.23, Val PSNR: 27.92, Train LPIPS: 0.0644, Val LPIPS: 0.0579


protanopia Epoch 19: 100%|██████████| 1855/1855 [18:23<00:00,  1.68it/s]


[Epoch 19/50] Train Loss: 0.0195, Val Loss: 0.0130, Train PSNR: 27.35, Val PSNR: 28.02, Train LPIPS: 0.0626, Val LPIPS: 0.0565


protanopia Epoch 20: 100%|██████████| 1855/1855 [18:21<00:00,  1.68it/s]


[Epoch 20/50] Train Loss: 0.0193, Val Loss: 0.0145, Train PSNR: 27.46, Val PSNR: 28.03, Train LPIPS: 0.0611, Val LPIPS: 0.0552


protanopia Epoch 21: 100%|██████████| 1855/1855 [18:22<00:00,  1.68it/s]


[Epoch 21/50] Train Loss: 0.0182, Val Loss: 0.0123, Train PSNR: 27.61, Val PSNR: 28.35, Train LPIPS: 0.0593, Val LPIPS: 0.0538


protanopia Epoch 22: 100%|██████████| 1855/1855 [18:24<00:00,  1.68it/s]


[Epoch 22/50] Train Loss: 0.0181, Val Loss: 0.0124, Train PSNR: 27.70, Val PSNR: 28.17, Train LPIPS: 0.0580, Val LPIPS: 0.0524


protanopia Epoch 23: 100%|██████████| 1855/1855 [18:20<00:00,  1.69it/s]


[Epoch 23/50] Train Loss: 0.0177, Val Loss: 0.0123, Train PSNR: 27.81, Val PSNR: 28.63, Train LPIPS: 0.0566, Val LPIPS: 0.0511


protanopia Epoch 24: 100%|██████████| 1855/1855 [18:36<00:00,  1.66it/s]


[Epoch 24/50] Train Loss: 0.0170, Val Loss: 0.0125, Train PSNR: 27.92, Val PSNR: 28.24, Train LPIPS: 0.0552, Val LPIPS: 0.0500


protanopia Epoch 25: 100%|██████████| 1855/1855 [18:25<00:00,  1.68it/s]


[Epoch 25/50] Train Loss: 0.0170, Val Loss: 0.0117, Train PSNR: 28.00, Val PSNR: 28.43, Train LPIPS: 0.0541, Val LPIPS: 0.0484


protanopia Epoch 26: 100%|██████████| 1855/1855 [18:22<00:00,  1.68it/s]


Early stopping protanopia at epoch 26

Training completed for protanopia

Cleaning up temp_train_protanopia.lmdb before next type...
Successfully deleted temp_train_protanopia.lmdb

Cleaning up temp_val_protanopia.lmdb before next type...
Successfully deleted temp_val_protanopia.lmdb

Preprocessing deuteranopia


Processing deuteranopia: 100%|██████████| 30/30 [08:04<00:00, 16.16s/it]
Processing deuteranopia: 100%|██████████| 3/3 [00:21<00:00,  7.33s/it]


temp_train_deuteranopia.lmdb has 29670 samples
temp_val_deuteranopia.lmdb has 2967 samples
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: c:\Users\Taneem\AppData\Local\Programs\Python\Python312\Lib\site-packages\lpips\weights\v0.1\alex.pth


deuteranopia Epoch 1: 100%|██████████| 1855/1855 [19:01<00:00,  1.63it/s]


[Epoch 1/50] Train Loss: 0.1124, Val Loss: 0.0404, Train PSNR: 18.78, Val PSNR: 22.52, Train LPIPS: 0.2796, Val LPIPS: 0.1437


deuteranopia Epoch 2: 100%|██████████| 1855/1855 [18:34<00:00,  1.66it/s]


[Epoch 2/50] Train Loss: 0.0662, Val Loss: 0.0300, Train PSNR: 22.49, Val PSNR: 24.06, Train LPIPS: 0.1356, Val LPIPS: 0.1100


deuteranopia Epoch 3: 100%|██████████| 1855/1855 [18:23<00:00,  1.68it/s]


[Epoch 3/50] Train Loss: 0.0539, Val Loss: 0.0252, Train PSNR: 23.55, Val PSNR: 24.97, Train LPIPS: 0.1123, Val LPIPS: 0.0983


deuteranopia Epoch 4: 100%|██████████| 1855/1855 [18:18<00:00,  1.69it/s]


[Epoch 4/50] Train Loss: 0.0508, Val Loss: 0.0224, Train PSNR: 24.12, Val PSNR: 25.66, Train LPIPS: 0.1051, Val LPIPS: 0.0935


deuteranopia Epoch 5: 100%|██████████| 1855/1855 [18:27<00:00,  1.68it/s]


[Epoch 5/50] Train Loss: 0.0466, Val Loss: 0.0210, Train PSNR: 24.64, Val PSNR: 25.94, Train LPIPS: 0.0988, Val LPIPS: 0.0858


deuteranopia Epoch 6: 100%|██████████| 1855/1855 [18:23<00:00,  1.68it/s]


[Epoch 6/50] Train Loss: 0.0454, Val Loss: 0.0265, Train PSNR: 24.88, Val PSNR: 26.09, Train LPIPS: 0.0958, Val LPIPS: 0.0848


deuteranopia Epoch 7: 100%|██████████| 1855/1855 [18:32<00:00,  1.67it/s]


[Epoch 7/50] Train Loss: 0.0508, Val Loss: 0.0275, Train PSNR: 24.73, Val PSNR: 26.06, Train LPIPS: 0.0989, Val LPIPS: 0.0860


deuteranopia Epoch 8: 100%|██████████| 1855/1855 [18:37<00:00,  1.66it/s]


[Epoch 8/50] Train Loss: 0.0401, Val Loss: 0.0227, Train PSNR: 25.29, Val PSNR: 26.59, Train LPIPS: 0.0878, Val LPIPS: 0.0793


deuteranopia Epoch 9: 100%|██████████| 1855/1855 [18:37<00:00,  1.66it/s]


[Epoch 9/50] Train Loss: 0.0502, Val Loss: 0.0215, Train PSNR: 24.98, Val PSNR: 26.62, Train LPIPS: 0.0907, Val LPIPS: 0.0798


deuteranopia Epoch 10: 100%|██████████| 1855/1855 [18:30<00:00,  1.67it/s]


[Epoch 10/50] Train Loss: 0.0415, Val Loss: 0.0185, Train PSNR: 25.50, Val PSNR: 26.97, Train LPIPS: 0.0836, Val LPIPS: 0.0735


deuteranopia Epoch 11: 100%|██████████| 1855/1855 [18:30<00:00,  1.67it/s]


[Epoch 11/50] Train Loss: 0.0326, Val Loss: 0.0173, Train PSNR: 25.92, Val PSNR: 26.98, Train LPIPS: 0.0786, Val LPIPS: 0.0696


deuteranopia Epoch 12: 100%|██████████| 1855/1855 [18:35<00:00,  1.66it/s]


[Epoch 12/50] Train Loss: 0.0325, Val Loss: 0.0164, Train PSNR: 26.11, Val PSNR: 27.38, Train LPIPS: 0.0777, Val LPIPS: 0.0674


deuteranopia Epoch 13: 100%|██████████| 1855/1855 [18:32<00:00,  1.67it/s]


[Epoch 13/50] Train Loss: 0.0274, Val Loss: 0.0152, Train PSNR: 26.50, Val PSNR: 27.59, Train LPIPS: 0.0723, Val LPIPS: 0.0640


deuteranopia Epoch 14: 100%|██████████| 1855/1855 [18:27<00:00,  1.68it/s]


[Epoch 14/50] Train Loss: 0.0308, Val Loss: 0.0238, Train PSNR: 26.45, Val PSNR: 26.97, Train LPIPS: 0.0719, Val LPIPS: 0.0740


deuteranopia Epoch 15: 100%|██████████| 1855/1855 [18:28<00:00,  1.67it/s]


[Epoch 15/50] Train Loss: 0.0261, Val Loss: 0.0148, Train PSNR: 26.80, Val PSNR: 27.85, Train LPIPS: 0.0672, Val LPIPS: 0.0585


deuteranopia Epoch 16: 100%|██████████| 1855/1855 [18:34<00:00,  1.66it/s]


[Epoch 16/50] Train Loss: 0.0267, Val Loss: 0.0194, Train PSNR: 26.91, Val PSNR: 27.52, Train LPIPS: 0.0663, Val LPIPS: 0.0599


deuteranopia Epoch 17: 100%|██████████| 1855/1855 [18:29<00:00,  1.67it/s]


[Epoch 17/50] Train Loss: 0.0231, Val Loss: 0.0157, Train PSNR: 27.20, Val PSNR: 28.08, Train LPIPS: 0.0628, Val LPIPS: 0.0561


deuteranopia Epoch 18: 100%|██████████| 1855/1855 [18:28<00:00,  1.67it/s]


[Epoch 18/50] Train Loss: 0.0224, Val Loss: 0.0139, Train PSNR: 27.34, Val PSNR: 28.34, Train LPIPS: 0.0611, Val LPIPS: 0.0534


deuteranopia Epoch 19: 100%|██████████| 1855/1855 [18:32<00:00,  1.67it/s]


[Epoch 19/50] Train Loss: 0.0212, Val Loss: 0.0135, Train PSNR: 27.52, Val PSNR: 28.33, Train LPIPS: 0.0589, Val LPIPS: 0.0523


deuteranopia Epoch 20: 100%|██████████| 1855/1855 [18:40<00:00,  1.66it/s]


[Epoch 20/50] Train Loss: 0.0208, Val Loss: 0.0130, Train PSNR: 27.60, Val PSNR: 28.34, Train LPIPS: 0.0576, Val LPIPS: 0.0494


deuteranopia Epoch 21: 100%|██████████| 1855/1855 [18:41<00:00,  1.65it/s]


[Epoch 21/50] Train Loss: 0.0191, Val Loss: 0.0134, Train PSNR: 27.82, Val PSNR: 28.42, Train LPIPS: 0.0552, Val LPIPS: 0.0487


deuteranopia Epoch 22: 100%|██████████| 1855/1855 [18:32<00:00,  1.67it/s]


[Epoch 22/50] Train Loss: 0.0182, Val Loss: 0.0127, Train PSNR: 27.98, Val PSNR: 28.80, Train LPIPS: 0.0537, Val LPIPS: 0.0468


deuteranopia Epoch 23: 100%|██████████| 1855/1855 [18:32<00:00,  1.67it/s]


[Epoch 23/50] Train Loss: 0.0179, Val Loss: 0.0130, Train PSNR: 28.07, Val PSNR: 28.90, Train LPIPS: 0.0525, Val LPIPS: 0.0473


deuteranopia Epoch 24: 100%|██████████| 1855/1855 [18:39<00:00,  1.66it/s]


[Epoch 24/50] Train Loss: 0.0170, Val Loss: 0.0129, Train PSNR: 28.22, Val PSNR: 28.84, Train LPIPS: 0.0509, Val LPIPS: 0.0447


deuteranopia Epoch 25: 100%|██████████| 1855/1855 [18:32<00:00,  1.67it/s]


[Epoch 25/50] Train Loss: 0.0166, Val Loss: 0.0119, Train PSNR: 28.33, Val PSNR: 29.08, Train LPIPS: 0.0496, Val LPIPS: 0.0437


deuteranopia Epoch 26: 100%|██████████| 1855/1855 [18:40<00:00,  1.65it/s]


[Epoch 26/50] Train Loss: 0.0160, Val Loss: 0.0119, Train PSNR: 28.44, Val PSNR: 29.23, Train LPIPS: 0.0483, Val LPIPS: 0.0427


deuteranopia Epoch 27: 100%|██████████| 1855/1855 [18:40<00:00,  1.65it/s]


[Epoch 27/50] Train Loss: 0.0156, Val Loss: 0.0117, Train PSNR: 28.55, Val PSNR: 29.11, Train LPIPS: 0.0473, Val LPIPS: 0.0420


deuteranopia Epoch 28: 100%|██████████| 1855/1855 [18:40<00:00,  1.66it/s]


[Epoch 28/50] Train Loss: 0.0153, Val Loss: 0.0115, Train PSNR: 28.64, Val PSNR: 29.30, Train LPIPS: 0.0464, Val LPIPS: 0.0401


deuteranopia Epoch 29: 100%|██████████| 1855/1855 [18:40<00:00,  1.66it/s]


[Epoch 29/50] Train Loss: 0.0148, Val Loss: 0.0119, Train PSNR: 28.74, Val PSNR: 29.38, Train LPIPS: 0.0453, Val LPIPS: 0.0411


deuteranopia Epoch 30: 100%|██████████| 1855/1855 [18:30<00:00,  1.67it/s]


[Epoch 30/50] Train Loss: 0.0147, Val Loss: 0.0119, Train PSNR: 28.82, Val PSNR: 29.55, Train LPIPS: 0.0446, Val LPIPS: 0.0418


deuteranopia Epoch 31: 100%|██████████| 1855/1855 [18:33<00:00,  1.67it/s]


[Epoch 31/50] Train Loss: 0.0144, Val Loss: 0.0113, Train PSNR: 28.91, Val PSNR: 29.49, Train LPIPS: 0.0438, Val LPIPS: 0.0388


deuteranopia Epoch 32: 100%|██████████| 1855/1855 [18:43<00:00,  1.65it/s]


Early stopping deuteranopia at epoch 32

Training completed for deuteranopia

Cleaning up temp_train_deuteranopia.lmdb before next type...
Successfully deleted temp_train_deuteranopia.lmdb

Cleaning up temp_val_deuteranopia.lmdb before next type...
Successfully deleted temp_val_deuteranopia.lmdb

Preprocessing tritanopia


Processing tritanopia: 100%|██████████| 30/30 [10:21<00:00, 20.73s/it]
Processing tritanopia: 100%|██████████| 3/3 [00:25<00:00,  8.42s/it]


temp_train_tritanopia.lmdb has 29670 samples
temp_val_tritanopia.lmdb has 2967 samples
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: c:\Users\Taneem\AppData\Local\Programs\Python\Python312\Lib\site-packages\lpips\weights\v0.1\alex.pth


tritanopia Epoch 1: 100%|██████████| 1855/1855 [19:19<00:00,  1.60it/s]


[Epoch 1/50] Train Loss: 0.1196, Val Loss: 0.0442, Train PSNR: 19.25, Val PSNR: 22.71, Train LPIPS: 0.2238, Val LPIPS: 0.1178


tritanopia Epoch 2: 100%|██████████| 1855/1855 [18:30<00:00,  1.67it/s]


[Epoch 2/50] Train Loss: 0.0680, Val Loss: 0.0377, Train PSNR: 22.42, Val PSNR: 24.00, Train LPIPS: 0.1178, Val LPIPS: 0.0992


tritanopia Epoch 3: 100%|██████████| 1855/1855 [18:41<00:00,  1.65it/s]


[Epoch 3/50] Train Loss: 0.0642, Val Loss: 0.0498, Train PSNR: 23.14, Val PSNR: 23.82, Train LPIPS: 0.1073, Val LPIPS: 0.1189


tritanopia Epoch 4: 100%|██████████| 1855/1855 [18:39<00:00,  1.66it/s]


[Epoch 4/50] Train Loss: 0.0616, Val Loss: 0.0455, Train PSNR: 23.48, Val PSNR: 24.46, Train LPIPS: 0.1011, Val LPIPS: 0.0955


tritanopia Epoch 5: 100%|██████████| 1855/1855 [18:31<00:00,  1.67it/s]


[Epoch 5/50] Train Loss: 0.0554, Val Loss: 0.0370, Train PSNR: 23.94, Val PSNR: 24.98, Train LPIPS: 0.0916, Val LPIPS: 0.0895


tritanopia Epoch 6: 100%|██████████| 1855/1855 [18:36<00:00,  1.66it/s]


[Epoch 6/50] Train Loss: 0.0566, Val Loss: 0.0260, Train PSNR: 24.02, Val PSNR: 25.78, Train LPIPS: 0.0912, Val LPIPS: 0.0751


tritanopia Epoch 7: 100%|██████████| 1855/1855 [18:45<00:00,  1.65it/s]


[Epoch 7/50] Train Loss: 0.0516, Val Loss: 0.0254, Train PSNR: 24.35, Val PSNR: 25.77, Train LPIPS: 0.0858, Val LPIPS: 0.0718


tritanopia Epoch 8: 100%|██████████| 1855/1855 [18:29<00:00,  1.67it/s]


[Epoch 8/50] Train Loss: 0.0508, Val Loss: 0.0264, Train PSNR: 24.45, Val PSNR: 25.81, Train LPIPS: 0.0845, Val LPIPS: 0.0746


tritanopia Epoch 9: 100%|██████████| 1855/1855 [18:28<00:00,  1.67it/s]


[Epoch 9/50] Train Loss: 0.0494, Val Loss: 0.0242, Train PSNR: 24.62, Val PSNR: 26.24, Train LPIPS: 0.0834, Val LPIPS: 0.0697


tritanopia Epoch 10: 100%|██████████| 1855/1855 [18:26<00:00,  1.68it/s]


[Epoch 10/50] Train Loss: 0.0411, Val Loss: 0.0216, Train PSNR: 25.05, Val PSNR: 26.22, Train LPIPS: 0.0758, Val LPIPS: 0.0635


tritanopia Epoch 11: 100%|██████████| 1855/1855 [18:23<00:00,  1.68it/s]


[Epoch 11/50] Train Loss: 0.0380, Val Loss: 0.0230, Train PSNR: 25.31, Val PSNR: 26.29, Train LPIPS: 0.0728, Val LPIPS: 0.0656


tritanopia Epoch 12: 100%|██████████| 1855/1855 [18:32<00:00,  1.67it/s]


[Epoch 12/50] Train Loss: 0.0377, Val Loss: 0.0226, Train PSNR: 25.46, Val PSNR: 26.76, Train LPIPS: 0.0703, Val LPIPS: 0.0606


tritanopia Epoch 13: 100%|██████████| 1855/1855 [18:24<00:00,  1.68it/s]


[Epoch 13/50] Train Loss: 0.0364, Val Loss: 0.0208, Train PSNR: 25.57, Val PSNR: 26.89, Train LPIPS: 0.0688, Val LPIPS: 0.0583


tritanopia Epoch 14: 100%|██████████| 1855/1855 [18:36<00:00,  1.66it/s]


[Epoch 14/50] Train Loss: 0.0308, Val Loss: 0.0207, Train PSNR: 25.95, Val PSNR: 26.62, Train LPIPS: 0.0649, Val LPIPS: 0.0547


tritanopia Epoch 15: 100%|██████████| 1855/1855 [18:25<00:00,  1.68it/s]


[Epoch 15/50] Train Loss: 0.0301, Val Loss: 0.0189, Train PSNR: 26.07, Val PSNR: 27.02, Train LPIPS: 0.0642, Val LPIPS: 0.0545


tritanopia Epoch 16: 100%|██████████| 1855/1855 [18:27<00:00,  1.68it/s]


[Epoch 16/50] Train Loss: 0.0273, Val Loss: 0.0198, Train PSNR: 26.35, Val PSNR: 27.05, Train LPIPS: 0.0608, Val LPIPS: 0.0534


tritanopia Epoch 17: 100%|██████████| 1855/1855 [18:38<00:00,  1.66it/s]


[Epoch 17/50] Train Loss: 0.0263, Val Loss: 0.0186, Train PSNR: 26.49, Val PSNR: 27.48, Train LPIPS: 0.0590, Val LPIPS: 0.0516


tritanopia Epoch 18: 100%|██████████| 1855/1855 [18:19<00:00,  1.69it/s]


[Epoch 18/50] Train Loss: 0.0254, Val Loss: 0.0181, Train PSNR: 26.66, Val PSNR: 27.43, Train LPIPS: 0.0570, Val LPIPS: 0.0493


tritanopia Epoch 19: 100%|██████████| 1855/1855 [18:22<00:00,  1.68it/s]


[Epoch 19/50] Train Loss: 0.0244, Val Loss: 0.0174, Train PSNR: 26.80, Val PSNR: 27.63, Train LPIPS: 0.0554, Val LPIPS: 0.0484


tritanopia Epoch 20: 100%|██████████| 1855/1855 [18:28<00:00,  1.67it/s]


[Epoch 20/50] Train Loss: 0.0253, Val Loss: 0.0173, Train PSNR: 26.78, Val PSNR: 27.71, Train LPIPS: 0.0553, Val LPIPS: 0.0469


tritanopia Epoch 21: 100%|██████████| 1855/1855 [18:18<00:00,  1.69it/s]


[Epoch 21/50] Train Loss: 0.0233, Val Loss: 0.0169, Train PSNR: 27.02, Val PSNR: 27.73, Train LPIPS: 0.0524, Val LPIPS: 0.0451


tritanopia Epoch 22: 100%|██████████| 1855/1855 [18:32<00:00,  1.67it/s]


[Epoch 22/50] Train Loss: 0.0238, Val Loss: 0.0174, Train PSNR: 27.01, Val PSNR: 27.76, Train LPIPS: 0.0522, Val LPIPS: 0.0453


tritanopia Epoch 23: 100%|██████████| 1855/1855 [18:32<00:00,  1.67it/s]


[Epoch 23/50] Train Loss: 0.0222, Val Loss: 0.0163, Train PSNR: 27.22, Val PSNR: 27.88, Train LPIPS: 0.0502, Val LPIPS: 0.0429


tritanopia Epoch 24: 100%|██████████| 1855/1855 [18:20<00:00,  1.69it/s]


[Epoch 24/50] Train Loss: 0.0218, Val Loss: 0.0168, Train PSNR: 27.30, Val PSNR: 27.99, Train LPIPS: 0.0491, Val LPIPS: 0.0425


tritanopia Epoch 25: 100%|██████████| 1855/1855 [18:33<00:00,  1.67it/s]


[Epoch 25/50] Train Loss: 0.0216, Val Loss: 0.0172, Train PSNR: 27.37, Val PSNR: 28.19, Train LPIPS: 0.0480, Val LPIPS: 0.0422


tritanopia Epoch 26: 100%|██████████| 1855/1855 [18:29<00:00,  1.67it/s]


[Epoch 26/50] Train Loss: 0.0208, Val Loss: 0.0176, Train PSNR: 27.48, Val PSNR: 27.37, Train LPIPS: 0.0466, Val LPIPS: 0.0422


tritanopia Epoch 27: 100%|██████████| 1855/1855 [18:25<00:00,  1.68it/s]


[Epoch 27/50] Train Loss: 0.0204, Val Loss: 0.0157, Train PSNR: 27.57, Val PSNR: 28.01, Train LPIPS: 0.0457, Val LPIPS: 0.0392


tritanopia Epoch 28: 100%|██████████| 1855/1855 [18:33<00:00,  1.67it/s]


Early stopping tritanopia at epoch 28

Training completed for tritanopia

Cleaning up temp_train_tritanopia.lmdb before next type...
Successfully deleted temp_train_tritanopia.lmdb

Cleaning up temp_val_tritanopia.lmdb before next type...
Successfully deleted temp_val_tritanopia.lmdb

Training completed for all types!


In [15]:
def plot_metrics(metrics, title_suffix="", figsize=(18, 6), dpi=300):
    plt.figure(figsize=figsize, dpi=dpi)

    # Loss Plot
    plt.subplot(1, 3, 1)
    plt.plot(metrics['train_loss'], 'b--', label='Train Loss')
    plt.plot(metrics['val_loss'], 'r-', label='Val Loss')
    plt.title('1 - MS-SSIM Loss', fontsize=12)
    plt.xlabel('Epoch', fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.legend()

    # PSNR Plot
    plt.subplot(1, 3, 2)
    plt.plot(metrics['train_psnr'], 'g--', label='Train PSNR')
    plt.plot(metrics['val_psnr'], 'm-', label='Val PSNR')
    plt.title('PSNR (dB)', fontsize=12)
    plt.xlabel('Epoch', fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.legend()

    # LPIPS Plot
    plt.subplot(1, 3, 3)
    plt.plot(metrics['train_lpips'], 'c--', label='Train LPIPS')
    plt.plot(metrics['val_lpips'], 'y-', label='Val LPIPS')
    plt.title('LPIPS (Lower = Better)', fontsize=12)
    plt.xlabel('Epoch', fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.legend()

    plt.suptitle(f'Training Metrics {title_suffix}', y=1.02, fontsize=14)
    plt.tight_layout()

    save_path = f"training_metrics_{title_suffix.replace(' ', '_')}.png"
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()
    print(f"Saved metrics plot to {save_path}")

In [16]:
for cb_type in COLORBLIND_TYPES:
    metrics = np.load(f"metrics_{cb_type}.npz")
    plot_metrics(
        metrics,
        title_suffix=f"({cb_type.capitalize()})",
        figsize=(20, 6),
        dpi=150
    )

Saved metrics plot to training_metrics_(Protanopia).png
Saved metrics plot to training_metrics_(Deuteranopia).png
Saved metrics plot to training_metrics_(Tritanopia).png


In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 16
colorblind_types = ["protanopia", "deuteranopia", "tritanopia"]

# Initialize metrics with correct data ranges
ms_ssim = MS_SSIM(data_range=2.0).to(device)  # For [-1, 1] range
psnr = PeakSignalNoiseRatio().to(device)
lpips_model = lpips.LPIPS(net='alex').to(device).eval()  # LPIPS in eval mode
ssim = torchmetrics.StructuralSimilarityIndexMeasure(
    data_range=2.0).to(device)  # Corrected data range


for cb_type in colorblind_types:
    print(f"\n{'='*40}\nTesting for {cb_type.upper()}\n{'='*40}")

    test_lmdb_path = f"temp_test_{cb_type}.lmdb"
    if not os.path.exists(test_lmdb_path):
        preprocess_type("data/test", cb_type, test_lmdb_path)

    test_dataset = ColorblindDataset(test_lmdb_path, cb_type)
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, num_workers=0,
        shuffle=False, pin_memory=True)

    model = Autoencoder().to(device)
    model_path = f"best_model_{cb_type}.pth"
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    # Initialize metrics
    metrics = {
        'loss': 0.0,
        'psnr': 0.0,
        'lpips': 0.0,
        'ssim': 0.0,
        'times': []
    }

    save_dir = f"test_results/{cb_type}"
    os.makedirs(save_dir, exist_ok=True)
    img_counter = 0  # For sequential image naming

    with torch.no_grad():
        for batch_idx, (orig, target) in enumerate(tqdm(test_loader, desc=f"Testing {cb_type}")):
            orig, target = orig.to(device, non_blocking=True), target.to(
                device, non_blocking=True)

            # Synchronize CUDA for accurate timing
            if device.type == 'cuda':
                torch.cuda.synchronize()
            start = time.time()

            output = model(orig)

            if device.type == 'cuda':
                torch.cuda.synchronize()
            metrics['times'].append(time.time() - start)

            # Calculate metrics
            metrics['loss'] += (1 - ms_ssim(output, target)).item()
            metrics['ssim'] += ssim(output, target).item()

            # For PSNR: scale to [0, 1]
            output_scaled = (output + 1) / 2
            target_scaled = (target + 1) / 2
            metrics['psnr'] += psnr(output_scaled, target_scaled).item()

            # LPIPS calculation (keep in [-1, 1] range)
            metrics['lpips'] += lpips_model(output, target).mean().item()

            # Save first 100 images for comparison
            if img_counter < 100:
                for i in range(output.size(0)):
                    visualize_and_save(
                        orig[i],
                        target[i],
                        output[i],
                        cb_type,
                        epoch=0,  # Use 0 for testing phase
                        batch_idx=img_counter,
                        save_dir=save_dir
                    )

                    # Save individual images
                    orig_img = process_image(orig[i])
                    gen_img = process_image(output[i])

                    # Convert to BGR for OpenCV and save
                    cv2.imwrite(
                        os.path.join(
                            save_dir, f"original_{img_counter:04d}.png"),
                        cv2.cvtColor(orig_img, cv2.COLOR_RGB2BGR)
                    )
                    cv2.imwrite(
                        os.path.join(
                            save_dir, f"generated_{img_counter:04d}.png"),
                        cv2.cvtColor(gen_img, cv2.COLOR_RGB2BGR)
                    )

                    img_counter += 1

    # Calculate final metrics
    n_batches = len(test_loader)
    print(f"\n{cb_type.upper()} Test Results:")
    print(f"- Avg MS-SSIM Loss: {metrics['loss'] / n_batches:.4f}")
    print(f"- Avg SSIM: {metrics['ssim'] / n_batches:.4f}")
    print(f"- Avg PSNR: {metrics['psnr'] / n_batches:.2f} dB")
    print(f"- Avg LPIPS: {metrics['lpips'] / n_batches:.4f}")

    # Timing statistics
    avg_time = np.mean(metrics['times']) * 1000
    fps = batch_size / (avg_time / 1000)
    print(f"- Inference Speed: {avg_time:.2f} ms/batch ({fps:.1f} FPS)")

    if 'test_dataset' in locals():
        test_dataset.close()

    current_file = f"temp_test_{cb_type}.lmdb"
    print(f"\nCleaning up {current_file} before next type...")
    force_delete(current_file)

    torch.cuda.empty_cache()

print("\nAll tests completed. Results saved in test_results/ directory.")

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: c:\Users\Taneem\AppData\Local\Programs\Python\Python312\Lib\site-packages\lpips\weights\v0.1\alex.pth





Testing for PROTANOPIA


Processing protanopia: 100%|██████████| 3/3 [00:28<00:00,  9.49s/it]
Testing protanopia: 100%|██████████| 186/186 [01:22<00:00,  2.26it/s]



PROTANOPIA Test Results:
- Avg MS-SSIM Loss: 0.0113
- Avg SSIM: 0.9125
- Avg PSNR: 28.39 dB
- Avg LPIPS: 0.0478
- Inference Speed: 52.73 ms/batch (303.4 FPS)

Cleaning up temp_test_protanopia.lmdb before next type...
Successfully deleted temp_test_protanopia.lmdb

Testing for DEUTERANOPIA


Processing deuteranopia: 100%|██████████| 3/3 [00:21<00:00,  7.17s/it]
Testing deuteranopia: 100%|██████████| 186/186 [01:21<00:00,  2.29it/s]



DEUTERANOPIA Test Results:
- Avg MS-SSIM Loss: 0.0109
- Avg SSIM: 0.9174
- Avg PSNR: 29.51 dB
- Avg LPIPS: 0.0382
- Inference Speed: 52.59 ms/batch (304.2 FPS)

Cleaning up temp_test_deuteranopia.lmdb before next type...
Successfully deleted temp_test_deuteranopia.lmdb

Testing for TRITANOPIA


Processing tritanopia: 100%|██████████| 3/3 [00:20<00:00,  6.76s/it]
Testing tritanopia: 100%|██████████| 186/186 [01:21<00:00,  2.28it/s]



TRITANOPIA Test Results:
- Avg MS-SSIM Loss: 0.0157
- Avg SSIM: 0.9041
- Avg PSNR: 27.98 dB
- Avg LPIPS: 0.0392
- Inference Speed: 55.71 ms/batch (287.2 FPS)

Cleaning up temp_test_tritanopia.lmdb before next type...
Successfully deleted temp_test_tritanopia.lmdb

All tests completed. Results saved in test_results/ directory.
