In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from tqdm import tqdm

# === Residual Block ===
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)

    def forward(self, x):
        return x + self.conv2(self.relu(self.conv1(x)))

# === Student Model ===
class StudentDerain(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
        )
        self.middle = nn.Sequential(
            ResidualBlock(64), ResidualBlock(64),
            ResidualBlock(64), ResidualBlock(64)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 2, stride=2), nn.ReLU(),
            nn.Conv2d(32, 3, 3, padding=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x

# === Dataset ===
class DerainDataset(Dataset):
    def __init__(self, rainy, sharp_dir, transform):
        self.rainy = rainy
        self.sharp_dir = sharp_dir
        self.transform = transform

        rainy_files = sorted(os.listdir(rainy))
        sharp_files = sorted(os.listdir(sharp_dir))

        self.pairs = []
        for b in rainy_files:
            base = b.split('_')[0]
            sharp_match = next((s for s in sharp_files if s.startswith(base)), None)
            if sharp_match:
                self.pairs.append((b, sharp_match))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        rain_name, sharp_name = self.pairs[idx]
        rain = Image.open(os.path.join(self.rainy, rain_name)).convert("RGB")
        sharp = Image.open(os.path.join(self.sharp_dir, sharp_name)).convert("RGB")
        return self.transform(rain), self.transform(sharp)

# === Evaluation ===
def evaluate(student, dataloader, device):
    student.eval()
    s_ssim, s_psnr = 0, 0
    with torch.no_grad():
        for rain, sharp in tqdm(dataloader, desc="🔍 Evaluating Student"):
            rain = rain.to(device)
            sharp = sharp.to(device)
            pred = student(rain)

            gt = sharp[0].permute(1, 2, 0).cpu().numpy()
            out = pred[0].permute(1, 2, 0).cpu().numpy()

            gt = np.clip(gt, 0, 1)
            out = np.clip(out, 0, 1)

            s_ssim += compare_ssim(gt, out, channel_axis=2, data_range=1.0)
            s_psnr += compare_psnr(gt, out, data_range=1.0)

    total = len(dataloader)
    print(f"\n📊 Student SSIM: {s_ssim/total:.4f} | PSNR: {s_psnr/total:.2f} dB")

# === Training ===
def train(student, dataloader, device, epochs=35):
    student.train()
    optimizer = optim.Adam(student.parameters(), lr=1e-4)
    criterion = nn.L1Loss()

    for epoch in range(epochs):
        total_loss = 0
        for rain, sharp in tqdm(dataloader, desc=f"🧠 Training Epoch {epoch+1}/{epochs}"):
            rain, sharp = rain.to(device), sharp.to(device)
            pred = student(rain)
            loss = criterion(pred, sharp)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"✅ Epoch {epoch+1} completed | Avg Loss: {avg_loss:.4f}")
    torch.save({'params': student.state_dict()}, "student_kd.pth")
    print("📦 Model saved to student_kd.pth")

# === Run ===
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🚀 Using device: {device}")

    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])

    dataset = DerainDataset("Dataset/train/Rain13K/input", "Dataset/train/Rain13K/target", transform)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

    student = StudentDerain().to(device)
    train(student, dataloader, device, epochs=35)
    evaluate(student, dataloader, device)

🚀 Using device: cuda


🧠 Training Epoch 1/35: 100%|██████████| 1714/1714 [01:51<00:00, 15.43it/s]


✅ Epoch 1 completed | Avg Loss: 0.0737


🧠 Training Epoch 2/35: 100%|██████████| 1714/1714 [01:49<00:00, 15.71it/s]


✅ Epoch 2 completed | Avg Loss: 0.0520


🧠 Training Epoch 3/35: 100%|██████████| 1714/1714 [01:59<00:00, 14.35it/s]


✅ Epoch 3 completed | Avg Loss: 0.0453


🧠 Training Epoch 4/35: 100%|██████████| 1714/1714 [02:01<00:00, 14.13it/s]


✅ Epoch 4 completed | Avg Loss: 0.0416


🧠 Training Epoch 5/35: 100%|██████████| 1714/1714 [02:00<00:00, 14.20it/s]


✅ Epoch 5 completed | Avg Loss: 0.0388


🧠 Training Epoch 6/35: 100%|██████████| 1714/1714 [02:01<00:00, 14.09it/s]


✅ Epoch 6 completed | Avg Loss: 0.0371


🧠 Training Epoch 7/35: 100%|██████████| 1714/1714 [02:03<00:00, 13.90it/s]


✅ Epoch 7 completed | Avg Loss: 0.0358


🧠 Training Epoch 8/35: 100%|██████████| 1714/1714 [01:58<00:00, 14.42it/s]


✅ Epoch 8 completed | Avg Loss: 0.0346


🧠 Training Epoch 9/35: 100%|██████████| 1714/1714 [01:52<00:00, 15.22it/s]


✅ Epoch 9 completed | Avg Loss: 0.0338


🧠 Training Epoch 10/35: 100%|██████████| 1714/1714 [01:48<00:00, 15.86it/s]


✅ Epoch 10 completed | Avg Loss: 0.0329


🧠 Training Epoch 11/35: 100%|██████████| 1714/1714 [01:55<00:00, 14.90it/s]


✅ Epoch 11 completed | Avg Loss: 0.0323


🧠 Training Epoch 12/35: 100%|██████████| 1714/1714 [02:02<00:00, 14.00it/s]


✅ Epoch 12 completed | Avg Loss: 0.0316


🧠 Training Epoch 13/35: 100%|██████████| 1714/1714 [02:01<00:00, 14.06it/s]


✅ Epoch 13 completed | Avg Loss: 0.0310


🧠 Training Epoch 14/35: 100%|██████████| 1714/1714 [02:01<00:00, 14.07it/s]


✅ Epoch 14 completed | Avg Loss: 0.0306


🧠 Training Epoch 15/35: 100%|██████████| 1714/1714 [02:05<00:00, 13.66it/s]


✅ Epoch 15 completed | Avg Loss: 0.0301


🧠 Training Epoch 16/35: 100%|██████████| 1714/1714 [02:01<00:00, 14.12it/s]


✅ Epoch 16 completed | Avg Loss: 0.0297


🧠 Training Epoch 17/35: 100%|██████████| 1714/1714 [02:01<00:00, 14.08it/s]


✅ Epoch 17 completed | Avg Loss: 0.0293


🧠 Training Epoch 18/35: 100%|██████████| 1714/1714 [02:00<00:00, 14.21it/s]


✅ Epoch 18 completed | Avg Loss: 0.0288


🧠 Training Epoch 19/35: 100%|██████████| 1714/1714 [02:04<00:00, 13.71it/s]


✅ Epoch 19 completed | Avg Loss: 0.0285


🧠 Training Epoch 20/35: 100%|██████████| 1714/1714 [02:01<00:00, 14.09it/s]


✅ Epoch 20 completed | Avg Loss: 0.0283


🧠 Training Epoch 21/35: 100%|██████████| 1714/1714 [01:49<00:00, 15.70it/s]


✅ Epoch 21 completed | Avg Loss: 0.0279


🧠 Training Epoch 22/35: 100%|██████████| 1714/1714 [01:50<00:00, 15.50it/s]


✅ Epoch 22 completed | Avg Loss: 0.0278


🧠 Training Epoch 23/35: 100%|██████████| 1714/1714 [01:51<00:00, 15.33it/s]


✅ Epoch 23 completed | Avg Loss: 0.0276


🧠 Training Epoch 24/35: 100%|██████████| 1714/1714 [01:48<00:00, 15.80it/s]


✅ Epoch 24 completed | Avg Loss: 0.0273


🧠 Training Epoch 25/35: 100%|██████████| 1714/1714 [01:48<00:00, 15.86it/s]


✅ Epoch 25 completed | Avg Loss: 0.0271


🧠 Training Epoch 26/35: 100%|██████████| 1714/1714 [01:58<00:00, 14.52it/s]


✅ Epoch 26 completed | Avg Loss: 0.0268


🧠 Training Epoch 27/35: 100%|██████████| 1714/1714 [01:57<00:00, 14.62it/s]


✅ Epoch 27 completed | Avg Loss: 0.0266


🧠 Training Epoch 28/35: 100%|██████████| 1714/1714 [01:59<00:00, 14.31it/s]


✅ Epoch 28 completed | Avg Loss: 0.0265


🧠 Training Epoch 29/35: 100%|██████████| 1714/1714 [01:57<00:00, 14.53it/s]


✅ Epoch 29 completed | Avg Loss: 0.0262


🧠 Training Epoch 30/35: 100%|██████████| 1714/1714 [01:59<00:00, 14.30it/s]


✅ Epoch 30 completed | Avg Loss: 0.0261


🧠 Training Epoch 31/35: 100%|██████████| 1714/1714 [01:59<00:00, 14.36it/s]


✅ Epoch 31 completed | Avg Loss: 0.0260


🧠 Training Epoch 32/35: 100%|██████████| 1714/1714 [02:01<00:00, 14.13it/s]


✅ Epoch 32 completed | Avg Loss: 0.0257


🧠 Training Epoch 33/35: 100%|██████████| 1714/1714 [01:58<00:00, 14.50it/s]


✅ Epoch 33 completed | Avg Loss: 0.0257


🧠 Training Epoch 34/35: 100%|██████████| 1714/1714 [01:59<00:00, 14.29it/s]


✅ Epoch 34 completed | Avg Loss: 0.0254


🧠 Training Epoch 35/35: 100%|██████████| 1714/1714 [01:50<00:00, 15.46it/s]


✅ Epoch 35 completed | Avg Loss: 0.0254
📦 Model saved to student_kd.pth


🔍 Evaluating Student: 100%|██████████| 1714/1714 [01:37<00:00, 17.60it/s]


📊 Student SSIM: 0.9078 | PSNR: 30.52 dB





In [None]:
import os
import sys
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from tqdm import tqdm
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
import numpy as np

# --- Import teacher model architecture ---
sys.path.append(os.path.abspath('./Restormer/basicsr/models/archs'))
from restormer_arch import Restormer

# --- Student Model ---
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)

    def forward(self, x):
        return x + self.conv2(self.relu(self.conv1(x)))

class StudentDerain(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
        )
        self.middle = nn.Sequential(
            ResidualBlock(64), ResidualBlock(64),
            ResidualBlock(64), ResidualBlock(64)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 2, stride=2), nn.ReLU(),
            nn.Conv2d(32, 3, 3, padding=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x

# --- Dataset ---
class DerainDataset(Dataset):
    def __init__(self, rainy, sharp_dir, transform):
        self.rainy = rainy
        self.sharp_dir = sharp_dir
        self.transform = transform

        rainy_files = sorted(os.listdir(rainy))
        sharp_files = sorted(os.listdir(sharp_dir))

        self.pairs = []
        for b in rainy_files:
            base = b.split('_')[0]
            sharp_match = next((s for s in sharp_files if s.startswith(base)), None)
            if sharp_match:
                self.pairs.append((b, sharp_match))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        rain_name, sharp_name = self.pairs[idx]
        rain_path = os.path.join(self.rainy, rain_name)
        sharp_path = os.path.join(self.sharp_dir, sharp_name)
        rain = Image.open(rain_path).convert("RGB")
        sharp = Image.open(sharp_path).convert("RGB")
        return self.transform(rain), self.transform(sharp)

# --- Evaluation on a single dataset ---
def evaluate_on_folder(student, teacher, rainy, sharp_dir, transform, device, dataset_name):
    dataset = DerainDataset(rainy, sharp_dir, transform)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    teacher.eval()
    student.eval()

    teacher_ssim_total = student_ssim_total = 0
    teacher_psnr_total = student_psnr_total = 0
    total_images = len(dataset)

    print(f"\n🔎 Now evaluating dataset: 📂 **{dataset_name}**\n{'-'*60}")
    for rain, sharp in tqdm(dataloader, desc=f"📷 Images", total=total_images):
        rain = rain.to(device)
        sharp = sharp.to(device)

        with torch.no_grad():
            teacher_out = teacher(rain)
            student_out = student(rain)

        t_img = teacher_out.squeeze().permute(1, 2, 0).cpu().numpy()
        s_img = student_out.squeeze().permute(1, 2, 0).cpu().numpy()
        gt_img = sharp.squeeze().permute(1, 2, 0).cpu().numpy()

        t_img = np.clip(t_img, 0, 1)
        s_img = np.clip(s_img, 0, 1)
        gt_img = np.clip(gt_img, 0, 1)

        teacher_ssim_total += compare_ssim(gt_img, t_img, channel_axis=2, data_range=1.0)
        student_ssim_total += compare_ssim(gt_img, s_img, channel_axis=2, data_range=1.0)
        teacher_psnr_total += compare_psnr(gt_img, t_img, data_range=1.0)
        student_psnr_total += compare_psnr(gt_img, s_img, data_range=1.0)

    teacher_ssim_avg = teacher_ssim_total / total_images
    student_ssim_avg = student_ssim_total / total_images
    teacher_psnr_avg = teacher_psnr_total / total_images
    student_psnr_avg = student_psnr_total / total_images

    print(f"\n📁 Completed Evaluation: {dataset_name} ({total_images} images)")
    print(f"🎓 Teacher  | SSIM: {teacher_ssim_avg:.4f} | PSNR: {teacher_psnr_avg:.2f} dB")
    print(f"👶 Student  | SSIM: {student_ssim_avg:.4f} | PSNR: {student_psnr_avg:.2f} dB")

    return teacher_ssim_avg, teacher_psnr_avg, student_ssim_avg, student_psnr_avg

# --- Main evaluation loop ---
def evaluate_all(student_ckpt, teacher_ckpt, test_root):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🚀 Using device: {device}")

    # Load teacher
    teacher = Restormer(
        inp_channels=3, out_channels=3, dim=48,
        num_blocks=[4, 6, 6, 8], num_refinement_blocks=4,
        heads=[1, 2, 4, 8], ffn_expansion_factor=2.66,
        bias=False, LayerNorm_type='WithBias', dual_pixel_task=False
    ).to(device)
    teacher.load_state_dict(torch.load(teacher_ckpt, map_location=device)['params'])

    # Load student
    student = StudentDerain().to(device)
    ckpt = torch.load(student_ckpt, map_location=device)
    corrected_state_dict = {
        k.replace("params.", ""): v for k, v in ckpt['params'].items()
    }
    student.load_state_dict(corrected_state_dict)
    print(f"📦 Student model parameters: {sum(p.numel() for p in student.parameters())}")

    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])

    # Accumulate metrics
    total_datasets = 0
    all_teacher_ssim = all_teacher_psnr = 0
    all_student_ssim = all_student_psnr = 0

    for subfolder in sorted(os.listdir(test_root)):
        rainy = os.path.join(test_root, subfolder, "input")
        sharp_dir = os.path.join(test_root, subfolder, "target")

        if not os.path.isdir(rainy) or not os.path.isdir(sharp_dir):
            print(f"⚠️ Skipping {subfolder}: missing input/target folders.")
            continue

        t_ssim, t_psnr, s_ssim, s_psnr = evaluate_on_folder(
            student, teacher, rainy, sharp_dir, transform, device, subfolder
        )

        all_teacher_ssim += t_ssim
        all_teacher_psnr += t_psnr
        all_student_ssim += s_ssim
        all_student_psnr += s_psnr
        total_datasets += 1

    if total_datasets > 0:
        print("\n📊📊📊 Overall Average Results Across Datasets 📊📊📊")
        print(f"🎓 Teacher  | Avg SSIM: {all_teacher_ssim/total_datasets:.4f} | Avg PSNR: {all_teacher_psnr/total_datasets:.2f} dB")
        print(f"👶 Student  | Avg SSIM: {all_student_ssim/total_datasets:.4f} | Avg PSNR: {all_student_psnr/total_datasets:.2f} dB")
    else:
        print("❌ No valid datasets found to evaluate.")

# --- Entry point ---
if __name__ == "__main__":
    evaluate_all(
        student_ckpt="student_kd.pth",
        teacher_ckpt="deraining.pth",
        test_root="Dataset/test"
    )


🚀 Using device: cuda
📦 Student model parameters: 360835

🔎 Now evaluating dataset: 📂 **Rain100H**
------------------------------------------------------------


📷 Images: 100%|██████████| 100/100 [00:24<00:00,  4.13it/s]



📁 Completed Evaluation: Rain100H (100 images)
🎓 Teacher  | SSIM: 0.8498 | PSNR: 25.97 dB
👶 Student  | SSIM: 0.7820 | PSNR: 23.85 dB

🔎 Now evaluating dataset: 📂 **Rain100L**
------------------------------------------------------------


📷 Images: 100%|██████████| 100/100 [00:20<00:00,  4.85it/s]



📁 Completed Evaluation: Rain100L (100 images)
🎓 Teacher  | SSIM: 0.9614 | PSNR: 35.66 dB
👶 Student  | SSIM: 0.8805 | PSNR: 28.30 dB

🔎 Now evaluating dataset: 📂 **Test100**
------------------------------------------------------------


📷 Images: 100%|██████████| 98/98 [00:27<00:00,  3.55it/s]



📁 Completed Evaluation: Test100 (98 images)
🎓 Teacher  | SSIM: 0.8579 | PSNR: 23.96 dB
👶 Student  | SSIM: 0.8584 | PSNR: 23.82 dB

🔎 Now evaluating dataset: 📂 **Test1200**
------------------------------------------------------------


📷 Images: 100%|██████████| 1200/1200 [03:00<00:00,  6.64it/s]



📁 Completed Evaluation: Test1200 (1200 images)
🎓 Teacher  | SSIM: 0.8763 | PSNR: 25.53 dB
👶 Student  | SSIM: 0.9087 | PSNR: 28.18 dB

🔎 Now evaluating dataset: 📂 **Test2800**
------------------------------------------------------------


📷 Images: 100%|██████████| 2800/2800 [04:40<00:00,  9.97it/s]


📁 Completed Evaluation: Test2800 (2800 images)
🎓 Teacher  | SSIM: 0.9255 | PSNR: 29.06 dB
👶 Student  | SSIM: 0.9323 | PSNR: 30.93 dB

📊📊📊 Overall Average Results Across Datasets 📊📊📊
🎓 Teacher  | Avg SSIM: 0.8942 | Avg PSNR: 28.03 dB
👶 Student  | Avg SSIM: 0.8724 | Avg PSNR: 27.02 dB



