In [None]:
import os
import torch

# Install required libraries
!pip install torch torchvision torchaudio
!pip install numpy opencv-python tqdm scikit-image

# Verify GPU availability
print(f"GPU is available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU name: {torch.cuda.get_device_name(0)}")

# Create project directory
project_dir = "/content/Restormer_Project"
os.makedirs(project_dir, exist_ok=True)
print(f"Project directory created at {project_dir}")

# Set device for PyTorch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
import os
import zipfile
from google.colab import files
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Define paths
project_dir = "/content/Restormer_Project"
dataset_dir = os.path.join(project_dir, "Rain100H")

# Upload and extract Rain100H dataset
print("Please upload the Rain100H ZIP file from your laptop...")
uploaded = files.upload()
zip_path = list(uploaded.keys())[0]
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(dataset_dir)
print(f"Extracted Rain100H to {dataset_dir}")

In [None]:
import os

# Define the dataset directory
dataset_dir = "/content/Restormer_Project/Rain100H"

# Function to recursively list all files and folders
def list_files(startpath):
    print(f"Exploring directory: {startpath}")
    for root, dirs, files in os.walk(startpath):
        level = root.replace(startpath, '').count(os.sep)
        indent = ' ' * 4 * level
        print(f"{indent}{os.path.basename(root)}/")
        for f in sorted(files):
            print(f"{indent}    {f}")

# List the contents of the dataset directory
list_files(dataset_dir)

In [None]:
import os
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import random

# Define paths
dataset_dir = "/content/Restormer_Project/Rain100H"

# Define dataset class
class Rain100HDataset(Dataset):
    def __init__(self, root_dir, file_indices, transform=None, augment=False):
        self.root_dir = root_dir
        self.transform = transform
        self.augment = augment
        # Get all rain and norain files
        all_rain_files = sorted([f for f in os.listdir(root_dir) if f.startswith('rain-') and f.endswith('.png')])
        all_norain_files = sorted([f for f in os.listdir(root_dir) if f.startswith('norain-') and f.endswith('.png')])
        assert len(all_rain_files) == len(all_norain_files), "Mismatched number of rainy and clean images"
        # Use specified indices
        self.rain_files = [all_rain_files[i] for i in file_indices]
        self.norain_files = [all_norain_files[i] for i in file_indices]

    def __len__(self):
        return len(self.rain_files)

    def __getitem__(self, idx):
        rain_path = os.path.join(self.root_dir, self.rain_files[idx])
        norain_path = os.path.join(self.root_dir, self.norain_files[idx])
        rain_img = cv2.imread(rain_path, cv2.IMREAD_COLOR)
        norain_img = cv2.imread(norain_path, cv2.IMREAD_COLOR)

        if rain_img is None or norain_img is None:
            raise ValueError(f"Failed to load images: {rain_path} or {norain_path}")

        rain_img = cv2.cvtColor(rain_img, cv2.COLOR_BGR2RGB)
        norain_img = cv2.cvtColor(norain_img, cv2.COLOR_BGR2RGB)

        # Apply augmentations (random flips for training)
        if self.augment:
            if random.random() > 0.5:
                rain_img = cv2.flip(rain_img, 1)  # Horizontal flip
                norain_img = cv2.flip(norain_img, 1)
            if random.random() > 0.5:
                rain_img = cv2.flip(rain_img, 0)  # Vertical flip
                norain_img = cv2.flip(norain_img, 0)

        if self.transform:
            rain_img = self.transform(rain_img)
            norain_img = self.transform(norain_img)

        return rain_img, norain_img

# Define transforms
transform = transforms.Compose([
    transforms.ToTensor(),  # Converts to tensor and scales to [0,1]
])

# Create train/test split (80/20)
num_images = 100  # Rain100H has 100 pairs
indices = list(range(num_images))
random.seed(42)  # For reproducibility
random.shuffle(indices)
train_indices = indices[:80]
test_indices = indices[80:]

# Create dataset instances
train_dataset = Rain100HDataset(dataset_dir, train_indices, transform=transform, augment=True)
test_dataset = Rain100HDataset(dataset_dir, test_indices, transform=transform, augment=False)
print(f"Found {len(train_dataset)} training pairs")
print(f"Found {len(test_dataset)} test pairs")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# Visualize a sample image pair
if len(train_dataset) > 0:
    rain_img, norain_img = train_dataset[0]
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title("Rainy Image")
    plt.imshow(rain_img.permute(1, 2, 0))
    plt.axis('off')
    plt.subplot(1, 2, 2)
    plt.title("Clean Image")
    plt.imshow(norain_img.permute(1, 2, 0))
    plt.axis('off')
    plt.show()
else:
    print("No images loaded. Please check the dataset structure.")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Multi-Dconv Head Transposed Attention (MDTA)
class MDTA(nn.Module):
    def __init__(self, dim, num_heads):
        super(MDTA, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=False)
        self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=False)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=False)

    def forward(self, x):
        b, c, h, w = x.shape
        x = self.qkv(x)
        x = self.qkv_dwconv(x)
        qkv = x.chunk(3, dim=1)
        q, k, v = [x.reshape(b, self.num_heads, -1, h * w) for x in qkv]

        q = q.transpose(-2, -1)  # (b, num_heads, h*w, c//num_heads)
        k = k  # (b, num_heads, c//num_heads, h*w)
        v = v  # (b, num_heads, c//num_heads, h*w)

        attn = (k @ q) * self.temperature  # (b, num_heads, c//num_heads, c//num_heads)
        attn = attn.softmax(dim=-1)
        out = (attn @ v)  # (b, num_heads, c//num_heads, h*w)
        out = out.transpose(-2, -1).reshape(b, c, h, w)
        out = self.project_out(out)
        return out

# Gated-Dconv Feed-Forward Network (GDFN)
class GDFN(nn.Module):
    def __init__(self, dim, ffn_expansion_factor=2.66):
        super(GDFN, self).__init__()
        hidden_features = int(dim * ffn_expansion_factor)

        self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=False)
        self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, groups=hidden_features * 2, bias=False)
        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=False)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x

# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MDTA(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.ffn = GDFN(dim)

    def forward(self, x):
        b, c, h, w = x.shape
        x = x + self.attn(self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2))
        x = x + self.ffn(self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2))
        return x

# Restormer Model
class Restormer(nn.Module):
    def __init__(self, inp_channels=3, out_channels=3, dim=48, num_blocks=[2, 3, 3, 4], num_refinement_blocks=2, heads=[1, 2, 4, 8]):
        super(Restormer, self).__init__()
        self.patch_embed = nn.Conv2d(inp_channels, dim, kernel_size=3, stride=1, padding=1, bias=False)

        # Encoder
        self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0]) for _ in range(num_blocks[0])])
        self.down1_2 = nn.Conv2d(dim, dim*2, kernel_size=2, stride=2)
        self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=dim*2, num_heads=heads[1]) for _ in range(num_blocks[1])])
        self.down2_3 = nn.Conv2d(dim*2, dim*4, kernel_size=2, stride=2)
        self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=dim*4, num_heads=heads[2]) for _ in range(num_blocks[2])])
        self.down3_4 = nn.Conv2d(dim*4, dim*8, kernel_size=2, stride=2)
        self.encoder_level4 = nn.Sequential(*[TransformerBlock(dim=dim*8, num_heads=heads[3]) for _ in range(num_blocks[3])])

        # Decoder
        self.up4_3 = nn.ConvTranspose2d(dim*8, dim*4, kernel_size=2, stride=2)
        self.reduce_chan_level3 = nn.Conv2d(dim*8, dim*4, kernel_size=1)
        self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=dim*4, num_heads=heads[2]) for _ in range(num_blocks[2])])

        self.up3_2 = nn.ConvTranspose2d(dim*4, dim*2, kernel_size=2, stride=2)
        self.reduce_chan_level2 = nn.Conv2d(dim*4, dim*2, kernel_size=1)
        self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=dim*2, num_heads=heads[1]) for _ in range(num_blocks[1])])

        self.up2_1 = nn.ConvTranspose2d(dim*2, dim, kernel_size=2, stride=2)
        self.reduce_chan_level1 = nn.Conv2d(dim*2, dim, kernel_size=1) # Add this line
        self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0]) for _ in range(num_blocks[0])])

        # Refinement
        self.refinement = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0]) for _ in range(num_refinement_blocks)])
        self.output = nn.Conv2d(dim, out_channels, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, inp_img):
        # Encoder
        inp_enc_level1 = self.patch_embed(inp_img)
        print(f"Encoder Level 1: {inp_enc_level1.shape}")
        out_enc_level1 = self.encoder_level1(inp_enc_level1)

        inp_enc_level2 = self.down1_2(out_enc_level1)
        print(f"Encoder Level 2: {inp_enc_level2.shape}")
        out_enc_level2 = self.encoder_level2(inp_enc_level2)

        inp_enc_level3 = self.down2_3(out_enc_level2)
        print(f"Encoder Level 3: {inp_enc_level3.shape}")
        out_enc_level3 = self.encoder_level3(inp_enc_level3)

        inp_enc_level4 = self.down3_4(out_enc_level3)
        print(f"Encoder Level 4: {inp_enc_level4.shape}")
        out_enc_level4 = self.encoder_level4(inp_enc_level4)

        # Decoder
        inp_dec_level3 = self.up4_3(out_enc_level4)
        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
        print(f"Decoder Level 3: {inp_dec_level3.shape}")
        out_dec_level3 = self.decoder_level3(inp_dec_level3)

        inp_dec_level2 = self.up3_2(out_dec_level3)
        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
        print(f"Decoder Level 2: {inp_dec_level2.shape}")
        out_dec_level2 = self.decoder_level2(inp_dec_level2)

        inp_dec_level1 = self.up2_1(out_dec_level2)
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
        inp_dec_level1 = self.reduce_chan_level1(inp_dec_level1) # Add this line
        print(f"Decoder Level 1: {inp_dec_level1.shape}")
        out_dec_level1 = self.decoder_level1(inp_dec_level1)

        # Refinement and output
        out_dec_level1 = self.refinement(out_dec_level1)
        print(f"Refinement: {out_dec_level1.shape}")
        out_dec_level1 = self.output(out_dec_level1)

        # Residual connection
        return inp_img + out_dec_level1

# Test the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Restormer().to(device)
print("Model initialized successfully")

# Dummy input (batch_size=1, channels=3, height=128, width=128)
dummy_input = torch.randn(1, 3, 128, 128).to(device)
output = model(dummy_input)
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Model moved to {device}")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import cv2
import os
import random
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast, GradScaler

# Assume Rain100HDataset from Section 2 is defined
class Rain100HDataset(Dataset):
    def __init__(self, root_dir, file_indices, transform=None, augment=False, patch_size=128):
        self.root_dir = root_dir
        self.transform = transform
        self.augment = augment
        self.patch_size = patch_size
        all_rain_files = sorted([f for f in os.listdir(root_dir) if f.startswith('rain-') and f.endswith('.png')])
        all_norain_files = sorted([f for f in os.listdir(root_dir) if f.startswith('norain-') and f.endswith('.png')])
        assert len(all_rain_files) == len(all_norain_files), "Mismatched number of rainy and clean images"
        self.rain_files = [all_rain_files[i] for i in file_indices]
        self.norain_files = [all_norain_files[i] for i in file_indices]

    def __len__(self):
        return len(self.rain_files)

    def __getitem__(self, idx):
        rain_path = os.path.join(self.root_dir, self.rain_files[idx])
        norain_path = os.path.join(self.root_dir, self.norain_files[idx])
        rain_img = cv2.imread(rain_path, cv2.IMREAD_COLOR)
        norain_img = cv2.imread(norain_path, cv2.IMREAD_COLOR)

        if rain_img is None or norain_img is None:
            raise ValueError(f"Failed to load images: {rain_path} or {norain_path}")

        rain_img = cv2.cvtColor(rain_img, cv2.COLOR_BGR2RGB)
        norain_img = cv2.cvtColor(norain_img, cv2.COLOR_BGR2RGB)

        # Crop to patch_size
        h, w = rain_img.shape[:2]
        if h > self.patch_size and w > self.patch_size:
            top = random.randint(0, h - self.patch_size)
            left = random.randint(0, w - self.patch_size)
            rain_img = rain_img[top:top+self.patch_size, left:left+self.patch_size]
            norain_img = norain_img[top:top+self.patch_size, left:left+self.patch_size]

        # Apply augmentations
        if self.augment:
            if random.random() > 0.5:
                rain_img = cv2.flip(rain_img, 1)
                norain_img = cv2.flip(norain_img, 1)
            if random.random() > 0.5:
                rain_img = cv2.flip(rain_img, 0)
                norain_img = cv2.flip(norain_img, 0)

        if self.transform:
            rain_img = self.transform(rain_img)
            norain_img = self.transform(norain_img)

        return rain_img, norain_img

# Assume Restormer model from Section 3 is defined
# [Insert the Restormer, TransformerBlock, MDTA, GDFN classes from the previous section here]
# For brevity, I'll assume they're already in the notebook. Copy them if needed.

# Training function
def train_model(model, train_loader, criterion, optimizer, device, num_epochs, scaler):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (rain_imgs, norain_imgs) in enumerate(train_loader):
            rain_imgs, norain_imgs = rain_imgs.to(device), norain_imgs.to(device)

            optimizer.zero_grad()
            with autocast():
                outputs = model(rain_imgs)
                loss = criterion(outputs, norain_imgs)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}")

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_dir = "/content/Restormer_Project/Rain100H"
num_images = 100
indices = list(range(num_images))
random.seed(42)
random.shuffle(indices)
train_indices = indices[:80]
test_indices = indices[80:]

# Progressive learning: train with increasing patch sizes
patch_sizes = [128, 256]
num_epochs_per_stage = 5

transform = transforms.Compose([transforms.ToTensor()])
# Assume Restormer model is already defined in the notebook
# model = Restormer().to(device)
criterion = nn.L1Loss()
optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-2)
scaler = GradScaler()

for patch_size in patch_sizes:
    print(f"Starting training with patch size {patch_size}x{patch_size}")
    train_dataset = Rain100HDataset(dataset_dir, train_indices, transform=transform, augment=True, patch_size=patch_size)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    train_model(model, train_loader, criterion, optimizer, device, num_epochs_per_stage, scaler)

# Save the model
model_path = "/content/Restormer_Project/restormer_rain100h.pth"
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")

# Visualize a sample output
model.eval()
test_dataset = Rain100HDataset(dataset_dir, test_indices, transform=transform, augment=False, patch_size=256)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
rain_img, norain_img = next(iter(test_loader))
rain_img, norain_img = rain_img.to(device), norain_img.to(device)

with torch.no_grad():
    output_img = model(rain_img)

plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.title("Rainy Image")
plt.imshow(rain_img.cpu().squeeze().permute(1, 2, 0))
plt.axis('off')
plt.subplot(1, 3, 2)
plt.title("Ground Truth")
plt.imshow(norain_img.cpu().squeeze().permute(1, 2, 0))
plt.axis('off')
plt.subplot(1, 3, 3)
plt.title("Derained Output")
plt.imshow(output_img.cpu().squeeze().permute(1, 2, 0).clamp(0, 1))
plt.axis('off')
plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import cv2
import os
import random
import matplotlib.pyplot as plt
import numpy as np

# Rain100HDataset (updated with validation)
class Rain100HDataset(Dataset):
    def __init__(self, root_dir, file_indices, transform=None, augment=False, patch_size=128):
        self.root_dir = root_dir
        self.transform = transform
        self.augment = augment
        self.patch_size = patch_size
        all_rain_files = sorted([f for f in os.listdir(root_dir) if f.startswith('rain-') and f.endswith('.png')])
        all_norain_files = sorted([f for f in os.listdir(root_dir) if f.startswith('norain-') and f.endswith('.png')])
        assert len(all_rain_files) == len(all_norain_files), "Mismatched number of rainy and clean images"
        self.rain_files = [all_rain_files[i] for i in file_indices]
        self.norain_files = [all_norain_files[i] for i in file_indices]

    def __len__(self):
        return len(self.rain_files)

    def __getitem__(self, idx):
        rain_path = os.path.join(self.root_dir, self.rain_files[idx])
        norain_path = os.path.join(self.root_dir, self.norain_files[idx])
        rain_img = cv2.imread(rain_path, cv2.IMREAD_COLOR)
        norain_img = cv2.imread(norain_path, cv2.IMREAD_COLOR)

        if rain_img is None or norain_img is None:
            raise ValueError(f"Failed to load images: {rain_path} or {norain_path}")

        rain_img = cv2.cvtColor(rain_img, cv2.COLOR_BGR2RGB)
        norain_img = cv2.cvtColor(norain_img, cv2.COLOR_BGR2RGB)

        # Crop to patch_size
        h, w = rain_img.shape[:2]
        if h > self.patch_size and w > self.patch_size:
            top = random.randint(0, h - self.patch_size)
            left = random.randint(0, w - self.patch_size)
            rain_img = rain_img[top:top+self.patch_size, left:left+self.patch_size]
            norain_img = norain_img[top:top+self.patch_size, left:left+self.patch_size]

        # Apply augmentations
        if self.augment:
            if random.random() > 0.5:
                rain_img = cv2.flip(rain_img, 1)
                norain_img = cv2.flip(norain_img, 1)
            if random.random() > 0.5:
                rain_img = cv2.flip(rain_img, 0)
                norain_img = cv2.flip(norain_img, 0)

        if self.transform:
            rain_img = self.transform(rain_img)
            norain_img = self.transform(norain_img)

        return rain_img, norain_img

# Validate dataset
def validate_dataset(root_dir, indices):
    print("Validating dataset...")
    valid_indices = []
    for idx in indices:
        rain_path = os.path.join(root_dir, f"rain-{idx+1:03d}.png")
        norain_path = os.path.join(root_dir, f"norain-{idx+1:03d}.png")
        rain_img = cv2.imread(rain_path, cv2.IMREAD_COLOR)
        norain_img = cv2.imread(norain_path, cv2.IMREAD_COLOR)
        if rain_img is None or norain_img is None:
            print(f"Invalid image pair: {rain_path} or {norain_path}")
            continue
        rain_img = cv2.cvtColor(rain_img, cv2.COLOR_BGR2RGB)
        norain_img = cv2.cvtColor(norain_img, cv2.COLOR_BGR2RGB)
        if np.any(np.isnan(rain_img)) or np.any(np.isinf(rain_img)) or \
           np.any(np.isnan(norain_img)) or np.any(np.isinf(norain_img)):
            print(f"Invalid pixel values in: {rain_path} or {norain_path}")
            continue
        valid_indices.append(idx)
    return valid_indices

# Training function
def train_model(model, train_loader, criterion, optimizer, device, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (rain_imgs, norain_imgs) in enumerate(train_loader):
            rain_imgs, norain_imgs = rain_imgs.to(device), norain_imgs.to(device)

            optimizer.zero_grad()
            outputs = model(rain_imgs)
            loss = criterion(outputs, norain_imgs)

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"NaN/Inf loss detected at batch {i+1}, epoch {epoch+1}")
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

        # Visualize sample output per epoch
        model.eval()
        with torch.no_grad():
            sample_rain, sample_norain = next(iter(train_loader))
            sample_rain, sample_norain = sample_rain[:1].to(device), sample_norain[:1].to(device)
            sample_output = model(sample_rain)

            plt.figure(figsize=(15, 5))
            plt.subplot(1, 3, 1)
            plt.title("Rainy Image")
            plt.imshow(sample_rain.cpu().squeeze().permute(1, 2, 0))
            plt.axis('off')
            plt.subplot(1, 3, 2)
            plt.title("Ground Truth")
            plt.imshow(sample_norain.cpu().squeeze().permute(1, 2, 0))
            plt.axis('off')
            plt.subplot(1, 3, 3)
            plt.title(f"Derained (Epoch {epoch+1})")
            plt.imshow(sample_output.cpu().squeeze().permute(1, 2, 0).clamp(0, 1))
            plt.axis('off')
            plt.show()
        model.train()

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_dir = "/content/Restormer_Project/Rain100H"
output_dir = os.path.join(dataset_dir, "outputs")
os.makedirs(output_dir, exist_ok=True)

# Validate dataset
num_images = 100
indices = list(range(num_images))
random.seed(42)
random.shuffle(indices)
train_indices = indices[:80]
test_indices = indices[80:]
train_indices = validate_dataset(dataset_dir, train_indices)
test_indices = validate_dataset(dataset_dir, test_indices)
print(f"Found {len(train_indices)} valid training pairs, {len(test_indices)} valid test pairs")

# Progressive learning
patch_sizes = [128, 192]
num_epochs_per_stage = 5

transform = transforms.Compose([transforms.ToTensor()])
model = Restormer().to(device)
criterion = nn.L1Loss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)

for patch_size in patch_sizes:
    print(f"Starting training with patch size {patch_size}x{patch_size}")
    train_dataset = Rain100HDataset(dataset_dir, train_indices, transform=transform, augment=True, patch_size=patch_size)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    train_model(model, train_loader, criterion, optimizer, device, num_epochs_per_stage)

    # Save model after each stage
    model_path = f"/content/Restormer_Project/restormer_rain100h_patch{patch_size}.pth"
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")

# Visualize final test output
model.eval()
test_dataset = Rain100HDataset(dataset_dir, test_indices, transform=transform, augment=False, patch_size=256)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
rain_img, norain_img = next(iter(test_loader))
rain_img, norain_img = rain_img.to(device), norain_img.to(device)

with torch.no_grad():
    output_img = model(rain_img)
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.title("Rainy Image")
    plt.imshow(rain_img.cpu().squeeze().permute(1, 2, 0))
    plt.axis('off')
    plt.subplot(1, 3, 2)
    plt.title("Ground Truth")
    plt.imshow(norain_img.cpu().squeeze().permute(1, 2, 0))
    plt.axis('off')
    plt.subplot(1, 3, 3)
    plt.title("Final Derained Output")
    plt.imshow(output_img.cpu().squeeze().permute(1, 2, 0).clamp(0, 1))
    plt.axis('off')
    plt.show()

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import cv2
import os
import random
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import pandas as pd

# Rain100HDataset (from Section 2, updated for evaluation)
class Rain100HDataset(Dataset):
    def __init__(self, root_dir, file_indices, transform=None, patch_size=None):
        self.root_dir = root_dir
        self.transform = transform
        self.patch_size = patch_size
        all_rain_files = sorted([f for f in os.listdir(root_dir) if f.startswith('rain-') and f.endswith('.png')])
        all_norain_files = sorted([f for f in os.listdir(root_dir) if f.startswith('norain-') and f.endswith('.png')])
        assert len(all_rain_files) == len(all_norain_files), "Mismatched number of rainy and clean images"
        self.rain_files = [all_rain_files[i] for i in file_indices]
        self.norain_files = [all_norain_files[i] for i in file_indices]

    def __len__(self):
        return len(self.rain_files)

    def __getitem__(self, idx):
        rain_path = os.path.join(self.root_dir, self.rain_files[idx])
        norain_path = os.path.join(self.root_dir, self.norain_files[idx])
        rain_img = cv2.imread(rain_path, cv2.IMREAD_COLOR)
        norain_img = cv2.imread(norain_path, cv2.IMREAD_COLOR)

        if rain_img is None or norain_img is None:
            raise ValueError(f"Failed to load images: {rain_path} or {norain_path}")

        rain_img = cv2.cvtColor(rain_img, cv2.COLOR_BGR2RGB)
        norain_img = cv2.cvtColor(norain_img, cv2.COLOR_BGR2RGB)

        # Crop to patch_size if specified
        if self.patch_size:
            h, w = rain_img.shape[:2]
            if h > self.patch_size and w > self.patch_size:
                top = (h - self.patch_size) // 2  # Center crop for evaluation
                left = (w - self.patch_size) // 2
                rain_img = rain_img[top:top+self.patch_size, left:left+self.patch_size]
                norain_img = norain_img[top:top+self.patch_size, left:left+self.patch_size]

        if self.transform:
            rain_img = self.transform(rain_img)
            norain_img = self.transform(norain_img)

        return rain_img, norain_img, self.rain_files[idx]

# Assume Restormer, MDTA, GDFN, TransformerBlock classes are defined (from Section 3)
# [Insert Restormer, MDTA, GDFN, TransformerBlock classes here if not already in notebook]

# Evaluation function
def evaluate_model(model, test_loader, device):
    model.eval()
    psnr_scores = []
    ssim_scores = []
    results = []

    output_dir = "/content/Restormer_Project/outputs/derained"
    os.makedirs(output_dir, exist_ok=True)

    with torch.no_grad():
        for i, (rain_img, norain_img, filename_tuple) in enumerate(test_loader):
            rain_img, norain_img = rain_img.to(device), norain_img.to(device)

            output_img = model(rain_img)

            # Convert to numpy for metrics
            output_np = output_img.cpu().squeeze().permute(1, 2, 0).clamp(0, 1).numpy()
            norain_np = norain_img.cpu().squeeze().permute(1, 2, 0).numpy()
            rain_np = rain_img.cpu().squeeze().permute(1, 2, 0).numpy()

            # Compute PSNR and SSIM
            psnr_score = psnr(norain_np, output_np, data_range=1.0)
            ssim_score = ssim(norain_np, output_np, channel_axis=2, data_range=1.0, win_size=3)
            psnr_scores.append(psnr_score)
            ssim_scores.append(ssim_score)

            results.append({
                "Image": filename_tuple[0],
                "PSNR": psnr_score,
                "SSIM": ssim_score
            })

            # Save derained image
            filename = filename_tuple[0]
            output_path = os.path.join(output_dir, f"derained_{filename}")
            plt.imsave(output_path, output_np)

    avg_psnr = np.mean(psnr_scores)
    avg_ssim = np.mean(ssim_scores)
    return results, avg_psnr, avg_ssim

# Visualize sample results
def visualize_results(model, test_loader, device, num_samples=5):
    model.eval()
    samples = list(iter(test_loader))[:num_samples]

    plt.figure(figsize=(15, 5 * num_samples))
    with torch.no_grad():
        for i, (rain_img, norain_img, filename_tuple) in enumerate(samples):
            rain_img, norain_img = rain_img.to(device), norain_img.to(device)
            output_img = model(rain_img)

            filename = filename_tuple[0]

            plt.subplot(num_samples, 3, i*3 + 1)
            plt.title(f"Rainy: {filename}")
            plt.imshow(rain_img.cpu().squeeze().permute(1, 2, 0))
            plt.axis('off')

            plt.subplot(num_samples, 3, i*3 + 2)
            plt.title("Ground Truth")
            plt.imshow(norain_img.cpu().squeeze().permute(1, 2, 0))
            plt.axis('off')

            plt.subplot(num_samples, 3, i*3 + 3)
            plt.title("Derained")
            plt.imshow(output_img.cpu().squeeze().permute(1, 2, 0).clamp(0, 1))
            plt.axis('off')

    plt.tight_layout()
    plt.show()

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_dir = "/content/Restormer_Project/Rain100H"
model_path = "/content/Restormer_Project/restormer_rain100h_patch192.pth"

# Load test dataset
num_images = 100
indices = list(range(num_images))
random.seed(42)
random.shuffle(indices)
test_indices = indices[80:]
transform = transforms.Compose([transforms.ToTensor()])
test_dataset = Rain100HDataset(dataset_dir, test_indices, transform=transform, patch_size=256)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Load model
model = Restormer().to(device)
print(f"Loading model from {model_path}")
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
else:
    raise FileNotFoundError(f"Model file not found: {model_path}")

# Evaluate
print(f"Evaluating on {len(test_dataset)} test pairs...")
results, avg_psnr, avg_ssim = evaluate_model(model, test_loader, device)
print(f"Average PSNR: {avg_psnr:.2f} dB")
print(f"Average SSIM: {avg_ssim:.3f}")

# Save results to CSV
results_df = pd.DataFrame(results)
results_df.to_csv("/content/Restormer_Project/outputs/evaluation_results.csv", index=False)
print(f"Results saved to /content/Restormer_Project/outputs/evaluation_results.csv")

# Visualize 5 sample results
visualize_results(model, test_loader, device, num_samples=5)

In [None]:
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim  # Already imported
import pandas as pd # Import pandas

# New metric functions (add to notebook)
def calculate_mae(pred, gt):
    return np.mean(np.abs(pred - gt))

def calculate_nmse(pred, gt):
    mse = np.mean((pred - gt) ** 2)
    gt_energy = np.mean(gt ** 2)
    return mse / gt_energy if gt_energy > 0 else float('inf')

# Test-Time Augmentation (TTA) function for better results
def tta_predict(model, img):
    pred = model(img)
    pred_hflip = model(torch.flip(img, [3]))  # Horizontal flip
    pred_vflip = model(torch.flip(img, [2]))  # Vertical flip
    return (pred + torch.flip(pred_hflip, [3]) + torch.flip(pred_vflip, [2])) / 3

# Updated evaluate_model function
def evaluate_model(model, test_loader, device, use_tta=False):
    model.eval()
    psnr_scores = []
    ssim_scores = []
    mae_scores = []
    nmse_scores = []
    baseline_psnr = []
    baseline_ssim = []
    baseline_mae = []
    baseline_nmse = []
    results = []

    output_dir = "/content/Restormer_Project/outputs/derained"
    os.makedirs(output_dir, exist_ok=True)

    with torch.no_grad():
        for i, (rain_img, norain_img, filename_tuple) in enumerate(test_loader):
            rain_img, norain_img = rain_img.to(device), norain_img.to(device)
            filename = filename_tuple[0] # Extract filename string from tuple

            # Generate derained image (with optional TTA)
            if use_tta:
                output_img = tta_predict(model, rain_img)
            else:
                output_img = model(rain_img)

            # Convert to numpy
            output_np = output_img.cpu().squeeze().permute(1, 2, 0).clamp(0, 1).numpy()
            norain_np = norain_img.cpu().squeeze().permute(1, 2, 0).numpy()
            rain_np = rain_img.cpu().squeeze().permute(1, 2, 0).numpy()

            # Compute metrics (fix SSIM with win_size=3 and channel_axis=2)
            psnr_score = psnr(norain_np, output_np, data_range=1.0)
            ssim_score = ssim(norain_np, output_np, channel_axis=2, data_range=1.0, win_size=3)  # Fix here
            mae_score = calculate_mae(output_np, norain_np)
            nmse_score = calculate_nmse(output_np, norain_np)

            psnr_scores.append(psnr_score)
            ssim_scores.append(ssim_score)
            mae_scores.append(mae_score)
            nmse_scores.append(nmse_score)

            # Baseline metrics (rainy vs clean)
            base_psnr = psnr(norain_np, rain_np, data_range=1.0)
            base_ssim = ssim(norain_np, rain_np, channel_axis=2, data_range=1.0, win_size=3)
            base_mae = calculate_mae(rain_np, norain_np)
            base_nmse = calculate_nmse(rain_np, norain_np)

            baseline_psnr.append(base_psnr)
            baseline_ssim.append(base_ssim)
            baseline_mae.append(base_mae)
            baseline_nmse.append(base_nmse)

            results.append({
                "Image": filename,
                "PSNR": psnr_score,
                "SSIM": ssim_score,
                "MAE": mae_score,
                "NMSE": nmse_score
            })

            # Save derained image
            output_path = os.path.join(output_dir, f"derained_{filename}")
            plt.imsave(output_path, output_np)

    avg_psnr = np.mean(psnr_scores)
    avg_ssim = np.mean(ssim_scores)
    avg_mae = np.mean(mae_scores)
    avg_nmse = np.mean(nmse_scores)
    avg_base_psnr = np.mean(baseline_psnr)
    avg_base_ssim = np.mean(baseline_ssim)
    avg_base_mae = np.mean(baseline_mae)
    avg_base_nmse = np.mean(baseline_nmse)

    print(f"Baseline PSNR (rainy vs clean): {avg_base_psnr:.2f} dB")
    print(f"Baseline SSIM (rainy vs clean): {avg_base_ssim:.4f}")
    print(f"Baseline MAE (rainy vs clean): {avg_base_mae:.4f}")
    print(f"Baseline NMSE (rainy vs clean): {avg_base_nmse:.4f}")

    return results, avg_psnr, avg_ssim, avg_mae, avg_nmse

# Usage in notebook (after loading model and test_loader)
results, avg_psnr, avg_ssim, avg_mae, avg_nmse = evaluate_model(model, test_loader, device, use_tta=True)  # Enable TTA for better results
print(f"Average PSNR: {avg_psnr:.2f} dB")
print(f"Average SSIM: {avg_ssim:.4f}")
print(f"Average MAE: {avg_mae:.4f}")
print(f"Average NMSE: {avg_nmse:.4f}")

# Save extended results to CSV
results_df = pd.DataFrame(results)
results_df.to_csv("/content/Restormer_Project/outputs/evaluation_results_extended.csv", index=False)

In [None]:
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import torch
import os
import pandas as pd
import matplotlib.pyplot as plt

# Metric functions
def calculate_mae(pred, gt):
    return np.mean(np.abs(pred - gt))

def calculate_nmse(pred, gt):
    mse = np.mean((pred - gt) ** 2)
    gt_energy = np.mean(gt ** 2)
    return mse / gt_energy if gt_energy > 0 else float('inf')

# Test-Time Augmentation
def tta_predict(model, img):
    pred = model(img)
    pred_hflip = model(torch.flip(img, [3]))  # Horizontal flip
    pred_vflip = model(torch.flip(img, [2]))  # Vertical flip
    return (pred + torch.flip(pred_hflip, [3]) + torch.flip(pred_vflip, [2])) / 3

# Updated evaluate_model
def evaluate_model(model, test_loader, device, use_tta=False):
    model.eval()
    psnr_scores = []
    ssim_scores = []
    mae_scores = []
    nmse_scores = []
    baseline_psnr = []
    baseline_ssim = []
    baseline_mae = []
    baseline_nmse = []
    results = []

    output_dir = "/content/Restormer_Project/outputs/derained"
    os.makedirs(output_dir, exist_ok=True)

    with torch.no_grad():
        for i, (rain_img, norain_img, filename) in enumerate(test_loader):
            rain_img, norain_img = rain_img.to(device), norain_img.to(device)

            # Ensure normalization (images should be [0,1])
            rain_img = rain_img.clamp(0, 1)
            norain_img = norain_img.clamp(0, 1)

            # Generate derained image
            output_img = tta_predict(model, rain_img) if use_tta else model(rain_img)
            output_img = output_img.clamp(0, 1)

            # Convert to numpy
            output_np = output_img.cpu().squeeze().permute(1, 2, 0).numpy()
            norain_np = norain_img.cpu().squeeze().permute(1, 2, 0).numpy()
            rain_np = rain_img.cpu().squeeze().permute(1, 2, 0).numpy()

            # Compute metrics
            try:
                psnr_score = psnr(norain_np, output_np, data_range=1.0)
                ssim_score = ssim(norain_np, output_np, channel_axis=2, data_range=1.0, win_size=3)
                mae_score = calculate_mae(output_np, norain_np)
                nmse_score = calculate_nmse(output_np, norain_np)

                base_psnr = psnr(norain_np, rain_np, data_range=1.0)
                base_ssim = ssim(norain_np, rain_np, channel_axis=2, data_range=1.0, win_size=3)
                base_mae = calculate_mae(rain_np, norain_np)
                base_nmse = calculate_nmse(rain_np, norain_np)
            except Exception as e:
                print(f"Error computing metrics for {filename}: {e}")
                continue

            psnr_scores.append(psnr_score)
            ssim_scores.append(ssim_score)
            mae_scores.append(mae_score)
            nmse_scores.append(nmse_score)
            baseline_psnr.append(base_psnr)
            baseline_ssim.append(base_ssim)
            baseline_mae.append(base_mae)
            baseline_nmse.append(base_nmse)

            results.append({
                "Image": filename[0],  # Handle single-item batch
                "PSNR": psnr_score,
                "SSIM": ssim_score,
                "MAE": mae_score,
                "NMSE": nmse_score,
                "Baseline_PSNR": base_psnr,
                "Baseline_SSIM": base_ssim,
                "Baseline_MAE": base_mae,
                "Baseline_NMSE": base_nmse
            })

            # Save derained image
            output_path = os.path.join(output_dir, f"derained_{filename[0]}")
            plt.imsave(output_path, output_np)

    # Compute averages
    avg_psnr = np.mean(psnr_scores) if psnr_scores else float('nan')
    avg_ssim = np.mean(ssim_scores) if ssim_scores else float('nan')
    avg_mae = np.mean(mae_scores) if mae_scores else float('nan')
    avg_nmse = np.mean(nmse_scores) if nmse_scores else float('nan')
    avg_base_psnr = np.mean(baseline_psnr) if baseline_psnr else float('nan')
    avg_base_ssim = np.mean(baseline_ssim) if baseline_ssim else float('nan')
    avg_base_mae = np.mean(baseline_mae) if baseline_mae else float('nan')
    avg_base_nmse = np.mean(baseline_nmse) if baseline_nmse else float('nan')

    print(f"Baseline PSNR (rainy vs clean): {avg_base_psnr:.2f} dB")
    print(f"Baseline SSIM (rainy vs clean): {avg_base_ssim:.4f}")
    print(f"Baseline MAE (rainy vs clean): {avg_base_mae:.4f}")
    print(f"Baseline NMSE (rainy vs clean): {avg_base_nmse:.4f}")
    print(f"Average PSNR: {avg_psnr:.2f} dB")
    print(f"Average SSIM: {avg_ssim:.4f}")
    print(f"Average MAE: {avg_mae:.4f}")
    print(f"Average NMSE: {avg_nmse:.4f}")

    return results, avg_psnr, avg_ssim, avg_mae, avg_nmse

# Usage (replace in notebook)
results, avg_psnr, avg_ssim, avg_mae, avg_nmse = evaluate_model(model, test_loader, device, use_tta=True)
results_df = pd.DataFrame(results)
results_df.to_csv("/content/Restormer_Project/outputs/evaluation_results_extended.csv", index=False)
print("Results saved to /content/Restormer_Project/outputs/evaluation_results_extended.csv")

In [None]:
def visualize_results(model, test_loader, device, num_samples=10, use_tta=False):
    model.eval()
    samples = list(iter(test_loader))[:num_samples]

    plt.figure(figsize=(15, 5 * num_samples))
    with torch.no_grad():
        for i, (rain_img, norain_img, filename) in enumerate(samples):
            rain_img, norain_img = rain_img.to(device), norain_img.to(device)
            output_img = tta_predict(model, rain_img) if use_tta else model(rain_img)
            output_img = output_img.clamp(0, 1)

            output_np = output_img.cpu().squeeze().permute(1, 2, 0).numpy()
            norain_np = norain_img.cpu().squeeze().permute(1, 2, 0).numpy()
            rain_np = rain_img.cpu().squeeze().permute(1, 2, 0).numpy()

            psnr_score = psnr(norain_np, output_np, data_range=1.0)
            ssim_score = ssim(norain_np, output_np, channel_axis=2, data_range=1.0, win_size=3)
            mae_score = calculate_mae(output_np, norain_np)
            nmse_score = calculate_nmse(output_np, norain_np)

            plt.subplot(num_samples, 3, i*3 + 1)
            plt.title(f"Rainy: {filename[0]}")
            plt.imshow(rain_np)
            plt.axis('off')

            plt.subplot(num_samples, 3, i*3 + 2)
            plt.title("Ground Truth")
            plt.imshow(norain_np)
            plt.axis('off')

            plt.subplot(num_samples, 3, i*3 + 3)
            plt.title(f"Derained\nPSNR: {psnr_score:.2f}, SSIM: {ssim_score:.4f}\nMAE: {mae_score:.4f}, NMSE: {nmse_score:.4f}")
            plt.imshow(output_np)
            plt.axis('off')

    plt.tight_layout()
    plt.savefig("/content/Restormer_Project/outputs/sample_visualizations.png")
    plt.show()

# Usage
visualize_results(model, test_loader, device, num_samples=10, use_tta=True)