In [None]:
!pip install einops timm tqdm scikit-image

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->timm)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->timm)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->timm)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch->tim

In [None]:
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets.folder import default_loader
from torch.utils.data import random_split
import os
from PIL import Image
import glob
import zipfile
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt
import random

Define Res FFT-Conv Block

In [None]:
class ResFFTConvBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)

    def forward(self, x):
        fft = torch.fft.fft2(x)
        amp = torch.abs(fft)
        phase = torch.angle(fft)

        # Process amplitude
        amp = self.conv1(amp)
        amp = F.relu(self.conv2(amp))

        # Reconstruct FFT with modified amp but original phase
        real = amp * torch.cos(phase)
        imag = amp * torch.sin(phase)
        fft_modified = torch.complex(real, imag)

        out = torch.fft.ifft2(fft_modified).real
        return out + x

Define GC Block

In [None]:
class GCBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv_mask = nn.Conv2d(in_channels, 1, 1)
        self.softmax = nn.Softmax(dim=2)
        self.transform = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 2, 1),
            nn.ReLU(),
            nn.Conv2d(in_channels // 2, in_channels, 1)
        )

    def forward(self, x):
        b, c, h, w = x.size()
        input_x = x

        context_mask = self.conv_mask(x).view(b, 1, -1)
        context_mask = self.softmax(context_mask)
        context = x.view(b, c, -1)
        context = torch.bmm(context, context_mask.permute(0, 2, 1)).view(b, c, 1, 1)
        transformed = self.transform(context)

        return input_x + transformed


Define Full Model

In [None]:
class RaindropRemovalNet(nn.Module):
    def __init__(self, channels=64, n_blocks=19, gc_blocks=5):
        super().__init__()
        self.initial = nn.Conv2d(3, channels, 3, padding=1)
        self.res_blocks = nn.Sequential(
            *[ResFFTConvBlock(channels) for _ in range(n_blocks)]
        )
        self.gc_blocks = nn.Sequential(
            *[GCBlock(channels) for _ in range(gc_blocks)]
        )
        self.final = nn.Conv2d(channels, 3, 3, padding=1)

    def forward(self, x):
        x = self.initial(x)
        x = self.res_blocks(x)
        x = self.gc_blocks(x)
        x = self.final(x)
        return x


Define Loss & Metrics

In [None]:
def msed_loss(pred, target):
    return F.mse_loss(pred, target)

def msfr_loss(pred, target):
    return F.l1_loss(torch.fft.fft2(pred).abs(), torch.fft.fft2(target).abs())

def total_loss(pred, target, alpha=1.0, beta=0.05):
    return alpha * F.mse_loss(pred, target) + beta * F.l1_loss(torch.fft.fft2(pred).abs(), torch.fft.fft2(target).abs())


Training Loop

In [None]:
def train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs=5, resume=False):
    start_epoch = 0
    best_val_loss = float('inf')

    if resume and os.path.exists("checkpoint.pth"):
        checkpoint = torch.load("checkpoint.pth")
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        best_val_loss = checkpoint["best_val_loss"]
        print(f"Resumed training from epoch {start_epoch}")

    for epoch in range(start_epoch, num_epochs):
        model.train()
        total = 0
        print(f"\nEpoch {epoch+1}")

        for img, gt in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
            img, gt = img.to(device), gt.to(device)
            optimizer.zero_grad()
            pred = model(img)
            loss = criterion(pred, gt)
            loss.backward()
            optimizer.step()
            total += loss.item()

        print(f"  Training Loss: {total / len(train_loader):.4f}")

        # --------- Validation ---------
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for rain_img, clean_img in val_loader:
                rain_img = rain_img.to(device)
                clean_img = clean_img.to(device)

                output = model(rain_img)
                loss = criterion(output, clean_img)
                val_loss += loss.item()

        val_loss /= len(val_loader)
        print(f"  Validation Loss: {val_loss:.4f}")

        # --------- Save Best Model ---------
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_raindrop_model.pth")
            print("Saved Best Model!")

        # --------- Save Latest Checkpoint ---------
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "best_val_loss": best_val_loss,
        }, "checkpoint.pth")


In [None]:
from google.colab import drive
drive.mount('/content/drive')
zip_path = "/content/drive/MyDrive/raindrop_dataset.zip"
extract_path = "/content/raindrop_dataset"

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print("Extraction complete.")

Mounted at /content/drive
Extraction complete.


In [None]:
# Paths to the ZIPs
train_zip = "/content/raindrop_dataset/train.zip"
test_a_zip = "/content/raindrop_dataset/test_a.zip"
test_b_zip = "/content/raindrop_dataset/test_b.zip"

# Paths to extract them
train_path = "/content/raindrop_dataset/train"
test_a_path = "/content/raindrop_dataset/test_a"
test_b_path = "/content/raindrop_dataset/test_b"

# Extract function
def extract(zip_path, extract_to):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    print(f"Extracted: {zip_path} → {extract_to}")

# Extract all
extract(train_zip, train_path)
extract(test_a_zip, test_a_path)
extract(test_b_zip, test_b_path)

Extracted: /content/raindrop_dataset/train.zip → /content/raindrop_dataset/train
Extracted: /content/raindrop_dataset/test_a.zip → /content/raindrop_dataset/test_a
Extracted: /content/raindrop_dataset/test_b.zip → /content/raindrop_dataset/test_b


In [None]:
#define path
train_rain_path = "/content/raindrop_dataset/train/train/data"
train_clean_path = "/content/raindrop_dataset/train/train/gt"

test_a_rain_path = "/content/raindrop_dataset/test_a/test_a/data"
test_a_clean_path = "/content/raindrop_dataset/test_a/test_a/gt"

test_b_rain_path = "/content/raindrop_dataset/test_b/test_b/data"
test_b_clean_path = "/content/raindrop_dataset/test_b/test_b/gt"

In [None]:
#create pytorch dataset class
class RaindropDataset(Dataset):
    def __init__(self, rain_dir, clean_dir, transform=None):
        self.rain_dir = rain_dir
        self.clean_dir = clean_dir
        self.transform = transform
        self.rain_images = sorted([f for f in os.listdir(rain_dir) if f.endswith(('png', 'jpg'))])
        self.clean_images = sorted([f for f in os.listdir(clean_dir) if f.endswith(('png', 'jpg'))])

    def __len__(self):
        return min(len(self.rain_images), len(self.clean_images))

    def __getitem__(self, idx):
        rain_img_path = os.path.join(self.rain_dir, self.rain_images[idx])
        clean_img_path = os.path.join(self.clean_dir, self.clean_images[idx])

        rain_img = Image.open(rain_img_path).convert("RGB")
        clean_img = Image.open(clean_img_path).convert("RGB")

        if self.transform:
            rain_img = self.transform(rain_img)
            clean_img = self.transform(clean_img)

        return rain_img, clean_img


In [None]:
#define transform and load dataset
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

train_dataset = RaindropDataset(train_rain_path, train_clean_path, transform=transform)

training setup

In [None]:
# 90% training, 10% validation
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_set, val_set = random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
val_loader = DataLoader(val_set, batch_size=4, shuffle=False)

In [None]:
#initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = RaindropRemovalNet(channels=64, n_blocks=8, gc_blocks=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
#training
train(model, train_loader, val_loader, optimizer, total_loss, device, num_epochs=5, resume=True)


Epoch 1


Training Epoch 1: 100%|██████████| 194/194 [02:18<00:00,  1.40it/s]


  Training Loss: 0.5479
  Validation Loss: 0.4882
Saved Best Model!

Epoch 2


Training Epoch 2: 100%|██████████| 194/194 [02:16<00:00,  1.42it/s]


  Training Loss: 0.4912
  Validation Loss: 0.4744
Saved Best Model!

Epoch 3


Training Epoch 3: 100%|██████████| 194/194 [02:16<00:00,  1.42it/s]


  Training Loss: 0.4813
  Validation Loss: 0.4686
Saved Best Model!

Epoch 4


Training Epoch 4: 100%|██████████| 194/194 [02:15<00:00,  1.43it/s]


  Training Loss: 0.4747
  Validation Loss: 0.4744

Epoch 5


Training Epoch 5: 100%|██████████| 194/194 [02:15<00:00,  1.43it/s]


  Training Loss: 0.4729
  Validation Loss: 0.4636
Saved Best Model!


In [None]:
#saving model
torch.save(model.state_dict(), "raindrop_model.pth")

In [None]:
def evaluate(model, loader, device):
    model.eval()
    total_loss_value = 0
    total_psnr = 0
    total_ssim = 0
    with torch.no_grad():
        for img, gt in tqdm(loader, desc="Evaluating"):
            img, gt = img.to(device), gt.to(device)
            pred = model(img)
            loss = total_loss(pred, gt)
            total_loss_value += loss.item()

            # Convert tensors to NumPy arrays for PSNR/SSIM
            pred_np = pred.cpu().numpy()
            gt_np = gt.cpu().numpy()

            for i in range(pred_np.shape[0]):
                pred_img = pred_np[i].transpose(1, 2, 0)
                gt_img = gt_np[i].transpose(1, 2, 0)

                pred_img = (pred_img * 255).clip(0, 255).astype('uint8')
                gt_img = (gt_img * 255).clip(0, 255).astype('uint8')

                total_psnr += psnr(gt_img, pred_img, data_range=255)
                total_ssim += ssim(gt_img, pred_img, data_range=255, channel_axis=2, win_size=7)

    n_samples = len(loader.dataset)
    avg_loss = total_loss_value / len(loader)
    avg_psnr = total_psnr / n_samples
    avg_ssim = total_ssim / n_samples

    print(f"\nResults on Test Set:")
    print(f"  Avg Loss:  {avg_loss:.4f}")
    print(f"  Avg PSNR:  {avg_psnr:.2f} dB")
    print(f"  Avg SSIM:  {avg_ssim:.4f}")
    return avg_loss, avg_psnr, avg_ssim

In [None]:
# Setup test loaders
test_a_dataset = RaindropDataset(test_a_rain_path, test_a_clean_path, transform=transform)
test_b_dataset = RaindropDataset(test_b_rain_path, test_b_clean_path, transform=transform)

test_a_loader = DataLoader(test_a_dataset, batch_size=4, shuffle=False, num_workers=2)
test_b_loader = DataLoader(test_b_dataset, batch_size=4, shuffle=False, num_workers=2)

# Evaluate
evaluate(model, test_a_loader, device)
evaluate(model, test_b_loader, device)

Evaluating: 100%|██████████| 15/15 [00:04<00:00,  3.54it/s]



Results on Test Set:
  Avg Loss:  0.3285
  Avg PSNR:  24.73 dB
  Avg SSIM:  0.7977


Evaluating: 100%|██████████| 63/63 [00:17<00:00,  3.62it/s]


Results on Test Set:
  Avg Loss:  0.4467
  Avg PSNR:  22.81 dB
  Avg SSIM:  0.7506





(0.44665726498951985,
 np.float64(22.81078682848064),
 np.float64(0.7505623299978201))

In [None]:
# Visual output check
def visualize_sample(model, dataset, device, num_samples=4):
    model.eval()
    indices = random.sample(range(len(dataset)), num_samples)

    fig, axs = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))

    for i, idx in enumerate(indices):
        rain_img, clean_img = dataset[idx]
        rain_input = rain_img.unsqueeze(0).to(device)

        with torch.no_grad():
            pred = model(rain_input)

        # Unnormalize and convert to displayable image
        def unnormalize(img_tensor):
            return (img_tensor * 0.5 + 0.5).clamp(0, 1)

        pred_img = unnormalize(pred.squeeze(0).detach()).cpu().permute(1, 2, 0).numpy()
        rain_np = unnormalize(rain_img).cpu().permute(1, 2, 0).numpy()
        clean_np = unnormalize(clean_img).cpu().permute(1, 2, 0).numpy()

        axs[i, 0].imshow(rain_np)
        axs[i, 0].set_title("Rain Image")
        axs[i, 1].imshow(clean_np)
        axs[i, 1].set_title("Ground Truth")
        axs[i, 2].imshow(pred_img)
        axs[i, 2].set_title("Prediction")

        for j in range(3):
            axs[i, j].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
visualize_sample(model, test_a_dataset, device, num_samples=25)

In [None]:
visualize_sample(model, test_b_dataset, device, num_samples=5)