In [5]:
import os
import cv2
import math
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from skimage.metrics import structural_similarity as compare_ssim, peak_signal_noise_ratio as compare_psnr
from torch.utils.tensorboard import SummaryWriter

# ------------------- Model Components -------------------
class DamageRestorer(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = nn.Sequential(
            nn.Conv2d(3, 3, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(3, 3, 3, padding=1),
            nn.Sigmoid()
        )
        self.ffn = nn.Sequential(
            nn.Conv2d(3, 16, 1),
            nn.ReLU(),
            nn.Conv2d(16, 3, 1)
        )

    def forward(self, x):
        x_attn = x * self.attn(x)
        return x + self.ffn(x_attn)

class IlluminationEstimator(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layers(x)

class RetinexMamba(nn.Module):
    def __init__(self, boost=0.7):
        super().__init__()
        self.ie = IlluminationEstimator()
        self.restorer = DamageRestorer()
        self.boost = boost

    def forward(self, x):
        illum_map = self.ie(x)
        illum_map_boosted = torch.clamp(illum_map + self.boost, 0, 1.2)
        illuminated = x * illum_map_boosted
        enhanced = self.restorer(illuminated)
        return torch.clamp(enhanced, 0, 1)

# ------------------- Dataset -------------------
class ImagePairDataset(Dataset):
    def __init__(self, input_dir, target_dir, size=(256, 256)):
        self.input_paths = sorted([os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
        self.target_paths = sorted([os.path.join(target_dir, f) for f in os.listdir(target_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
        self.size = size
        self.transform = T.Compose([
            T.Resize(size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.input_paths)

    def __getitem__(self, idx):
        inp_img = Image.open(self.input_paths[idx]).convert("RGB")
        tgt_img = Image.open(self.target_paths[idx]).convert("RGB")
        return self.transform(inp_img), self.transform(tgt_img), self.input_paths[idx]

# ------------------- Loss Functions -------------------
def ssim_loss(pred, target):
    pred = pred.detach().cpu()
    target = target.detach().cpu()
    loss = 0
    for i in range(pred.shape[0]):
        pred_np = pred[i].permute(1, 2, 0).numpy()
        target_np = target[i].permute(1, 2, 0).numpy()
        ssim = compare_ssim(pred_np, target_np, channel_axis=2, data_range=1.0)
        loss += 1 - ssim
    return loss / pred.shape[0]

# ------------------- Training -------------------
def train(model, loader, optimizer, criterion, device, writer, epoch):
    model.train()
    total_loss = 0
    for i, (input_img, target_img, _) in enumerate(tqdm(loader, desc="Training")):
        input_img = input_img.to(device)
        target_img = target_img.to(device)
        optimizer.zero_grad()
        output = model(input_img)
        loss = criterion(output, target_img) + 0.1 * ssim_loss(output, target_img)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        writer.add_scalar("Loss/Batch", loss.item(), epoch * len(loader) + i)
    avg_loss = total_loss / len(loader)
    writer.add_scalar("Loss/Epoch", avg_loss, epoch)
    return avg_loss

# ------------------- Evaluation -------------------
def evaluate(model, loader, device, writer, epoch, output_dir):
    model.eval()
    total_psnr, total_ssim = 0, 0
    with torch.no_grad():
        for idx, (input_img, target_img, img_path) in enumerate(tqdm(loader, desc="Evaluating")):
            input_img = input_img.to(device)
            target_img = target_img.to(device)
            output = model(input_img)

            # Use the first image in batch (batch size is 1 for val)
            inp_np = input_img[0].permute(1, 2, 0).cpu().numpy()
            out_np = output[0].permute(1, 2, 0).cpu().numpy()
            tgt_np = target_img[0].permute(1, 2, 0).cpu().numpy()

            psnr = compare_psnr(tgt_np, out_np, data_range=1.0)
            ssim = compare_ssim(tgt_np, out_np, channel_axis=2, data_range=1.0)
            total_psnr += psnr
            total_ssim += ssim

            if idx < 10:
                concat_img = np.hstack((
                    (inp_np * 255).astype(np.uint8),
                    (out_np * 255).astype(np.uint8),
                    (tgt_np * 255).astype(np.uint8)
                ))
                img_name = os.path.basename(img_path[0])
                os.makedirs(output_dir, exist_ok=True)
                cv2.imwrite(os.path.join(output_dir, f"compare_epoch{epoch}_{img_name}"), cv2.cvtColor(concat_img, cv2.COLOR_RGB2BGR))

    avg_psnr = total_psnr / len(loader)
    avg_ssim = total_ssim / len(loader)
    writer.add_scalar("Val/PSNR", avg_psnr, epoch)
    writer.add_scalar("Val/SSIM", avg_ssim, epoch)
    return avg_psnr, avg_ssim

# ------------------- Main -------------------
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_dir', default='/content/drive/MyDrive/data/LOLv2/Real_captured/Test/Low')
    parser.add_argument('--target_dir', default='/content/drive/MyDrive/data/LOLv2/Real_captured/Test/Normal')
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--save_path', default='/content/retinexmamba_trained.pth')
    parser.add_argument('--log_dir', default='/content/runs')
    parser.add_argument('--output_vis_dir', default='/content/validation_outputs')
    args = parser.parse_args([])  # Allows Colab use

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = RetinexMamba().to(device)

    dataset = ImagePairDataset(args.input_dir, args.target_dir)
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.MSELoss()
    writer = SummaryWriter(args.log_dir)

    for epoch in range(args.epochs):
        print(f"\nEpoch {epoch+1}/{args.epochs}")
        train_loss = train(model, train_loader, optimizer, criterion, device, writer, epoch)
        val_psnr, val_ssim = evaluate(model, val_loader, device, writer, epoch, args.output_vis_dir)
        print(f"Loss: {train_loss:.4f} | Val PSNR: {val_psnr:.2f}, SSIM: {val_ssim:.4f}")

    torch.save(model.state_dict(), args.save_path)
    print(f"\n✅ Model saved to {args.save_path}")
    writer.close()

if __name__ == '__main__':
    main()



Epoch 1/20


Training: 100%|██████████| 23/23 [00:26<00:00,  1.14s/it]
Evaluating: 100%|██████████| 10/10 [00:02<00:00,  3.46it/s]


Loss: 0.0904 | Val PSNR: 14.18, SSIM: 0.5922

Epoch 2/20


Training: 100%|██████████| 23/23 [00:07<00:00,  2.99it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 10.20it/s]


Loss: 0.0857 | Val PSNR: 14.48, SSIM: 0.6069

Epoch 3/20


Training: 100%|██████████| 23/23 [00:08<00:00,  2.76it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 14.00it/s]


Loss: 0.0819 | Val PSNR: 14.76, SSIM: 0.6201

Epoch 4/20


Training: 100%|██████████| 23/23 [00:08<00:00,  2.66it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 13.80it/s]


Loss: 0.0789 | Val PSNR: 15.03, SSIM: 0.6319

Epoch 5/20


Training: 100%|██████████| 23/23 [00:07<00:00,  3.05it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 13.67it/s]


Loss: 0.0745 | Val PSNR: 15.28, SSIM: 0.6422

Epoch 6/20


Training: 100%|██████████| 23/23 [00:08<00:00,  2.64it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 13.96it/s]


Loss: 0.0730 | Val PSNR: 15.51, SSIM: 0.6511

Epoch 7/20


Training: 100%|██████████| 23/23 [00:07<00:00,  2.89it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 10.38it/s]


Loss: 0.0697 | Val PSNR: 15.72, SSIM: 0.6591

Epoch 8/20


Training: 100%|██████████| 23/23 [00:07<00:00,  2.89it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 13.50it/s]


Loss: 0.0674 | Val PSNR: 15.91, SSIM: 0.6661

Epoch 9/20


Training: 100%|██████████| 23/23 [00:08<00:00,  2.65it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 14.23it/s]


Loss: 0.0650 | Val PSNR: 16.09, SSIM: 0.6721

Epoch 10/20


Training: 100%|██████████| 23/23 [00:07<00:00,  3.01it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 13.82it/s]


Loss: 0.0634 | Val PSNR: 16.25, SSIM: 0.6777

Epoch 11/20


Training: 100%|██████████| 23/23 [00:08<00:00,  2.62it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 13.44it/s]


Loss: 0.0610 | Val PSNR: 16.41, SSIM: 0.6824

Epoch 12/20


Training: 100%|██████████| 23/23 [00:08<00:00,  2.81it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 10.40it/s]


Loss: 0.0594 | Val PSNR: 16.55, SSIM: 0.6867

Epoch 13/20


Training: 100%|██████████| 23/23 [00:07<00:00,  2.89it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 13.58it/s]


Loss: 0.0581 | Val PSNR: 16.68, SSIM: 0.6905

Epoch 14/20


Training: 100%|██████████| 23/23 [00:08<00:00,  2.64it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 13.54it/s]


Loss: 0.0567 | Val PSNR: 16.80, SSIM: 0.6938

Epoch 15/20


Training: 100%|██████████| 23/23 [00:07<00:00,  3.08it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 10.45it/s]


Loss: 0.0557 | Val PSNR: 16.92, SSIM: 0.6967

Epoch 16/20


Training: 100%|██████████| 23/23 [00:08<00:00,  2.73it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 14.27it/s]


Loss: 0.0546 | Val PSNR: 17.02, SSIM: 0.6994

Epoch 17/20


Training: 100%|██████████| 23/23 [00:08<00:00,  2.70it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 10.45it/s]


Loss: 0.0541 | Val PSNR: 17.12, SSIM: 0.7017

Epoch 18/20


Training: 100%|██████████| 23/23 [00:07<00:00,  3.01it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 14.10it/s]


Loss: 0.0531 | Val PSNR: 17.21, SSIM: 0.7039

Epoch 19/20


Training: 100%|██████████| 23/23 [00:08<00:00,  2.61it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 13.51it/s]


Loss: 0.0530 | Val PSNR: 17.30, SSIM: 0.7057

Epoch 20/20


Training: 100%|██████████| 23/23 [00:07<00:00,  2.93it/s]
Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.91it/s]

Loss: 0.0516 | Val PSNR: 17.37, SSIM: 0.7075

✅ Model saved to /content/retinexmamba_trained.pth



