<a href="https://colab.research.google.com/github/KelvinM9187/Deraining_Dehazing/blob/main/agent_derain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import cv2
import os
from tqdm import tqdm
from google.colab import drive
from PIL import Image

In [2]:
# Mount Google Drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# Model Definition
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

In [4]:
class FeatureFusion(nn.Module):
    def __init__(self, channels):
        super(FeatureFusion, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels*2, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x1, x2):
        x2_resized = F.interpolate(x2, size=x1.shape[2:], mode='bilinear', align_corners=False)
        fused = torch.cat([x1, x2_resized], dim=1)
        return self.conv(fused)

In [5]:
class DehazeDerainNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_channels=32):
        super(DehazeDerainNet, self).__init__()

        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, 3, padding=1),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels, base_channels, 3, padding=1),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True)
        )

        self.enc2 = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(base_channels, base_channels*2, 3, padding=1),
            nn.BatchNorm2d(base_channels*2),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels*2, base_channels*2, 3, padding=1),
            nn.BatchNorm2d(base_channels*2),
            nn.ReLU(inplace=True)
        )

        self.enc3 = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(base_channels*2, base_channels*4, 3, padding=1),
            nn.BatchNorm2d(base_channels*4),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels*4, base_channels*4, 3, padding=1),
            nn.BatchNorm2d(base_channels*4),
            nn.ReLU(inplace=True)
        )

        # Bottleneck with attention
        self.bottleneck = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(base_channels*4, base_channels*8, 3, padding=1),
            nn.BatchNorm2d(base_channels*8),
            nn.ReLU(inplace=True),
            ChannelAttention(base_channels*8),
            nn.Conv2d(base_channels*8, base_channels*8, 3, padding=1),
            nn.BatchNorm2d(base_channels*8),
            nn.ReLU(inplace=True)
        )

        # Decoder with feature fusion
        self.up1 = nn.ConvTranspose2d(base_channels*8, base_channels*4, 2, stride=2)
        self.dec1 = FeatureFusion(base_channels*4)

        self.up2 = nn.ConvTranspose2d(base_channels*4, base_channels*2, 2, stride=2)
        self.dec2 = FeatureFusion(base_channels*2)

        self.up3 = nn.ConvTranspose2d(base_channels*2, base_channels, 2, stride=2)
        self.dec3 = FeatureFusion(base_channels)

        # Output
        self.out_conv = nn.Conv2d(base_channels, out_channels, 1)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)

        # Bottleneck
        b = self.bottleneck(e3)

        # Decoder
        d1 = self.up1(b)
        d1 = self.dec1(d1, e3)

        d2 = self.up2(d1)
        d2 = self.dec2(d2, e2)

        d3 = self.up3(d2)
        d3 = self.dec3(d3, e1)

        return torch.sigmoid(self.out_conv(d3))


In [6]:
## Custom Dataset Class for your specific structure
class CombinedHazeRainDataset(Dataset):
    def __init__(self, rain_path, no_rain_path, haze_path, no_haze_path, transform=None, patch_size=128, mode='train'):
        self.transform = transform
        self.patch_size = patch_size
        self.mode = mode

        # Load Rain100 data
        self.rain_images = sorted([os.path.join(rain_path, f) for f in os.listdir(rain_path) if f.endswith(('.jpg', '.png', '.jpeg'))])
        self.no_rain_images = sorted([os.path.join(no_rain_path, f) for f in os.listdir(no_rain_path) if f.endswith(('.jpg', '.png', '.jpeg'))])

        # Load RESIDE data
        self.haze_images = sorted([os.path.join(haze_path, f) for f in os.listdir(haze_path) if f.endswith(('.jpg', '.png', '.jpeg'))])
        self.no_haze_images = sorted([os.path.join(no_haze_path, f) for f in os.listdir(no_haze_path) if f.endswith(('.jpg', '.png', '.jpeg'))])

        # Combine all samples
        self.degraded_images = self.rain_images + self.haze_images
        self.clean_images = self.no_rain_images + self.no_haze_images

        # For validation, let's take 20% of the data
        if mode == 'val':
            self.degraded_images = self.degraded_images[:int(0.2*len(self.degraded_images))]
            self.clean_images = self.clean_images[:int(0.2*len(self.clean_images))]
        elif mode == 'train':
            self.degraded_images = self.degraded_images[int(0.2*len(self.degraded_images)):]
            self.clean_images = self.clean_images[int(0.2*len(self.clean_images)):]

    def __len__(self):
        return min(len(self.degraded_images), len(self.clean_images))

    def __getitem__(self, idx):
        degraded_img = Image.open(self.degraded_images[idx]).convert('RGB')
        clean_img = Image.open(self.clean_images[idx]).convert('RGB')

        # Convert to numpy arrays
        degraded_img = np.array(degraded_img)
        clean_img = np.array(clean_img)

        # Random crop to patch_size
        h, w = degraded_img.shape[:2]
        if h > self.patch_size and w > self.patch_size:
            top = np.random.randint(0, h - self.patch_size)
            left = np.random.randint(0, w - self.patch_size)
            degraded_img = degraded_img[top:top+self.patch_size, left:left+self.patch_size]
            clean_img = clean_img[top:top+self.patch_size, left:left+self.patch_size]

        if self.transform:
            degraded_img = self.transform(degraded_img)
            clean_img = self.transform(clean_img)
        else:
            degraded_img = transforms.ToTensor()(degraded_img)
            clean_img = transforms.ToTensor()(clean_img)

        return degraded_img, clean_img


In [7]:
## Training Function
def train_model(model, train_loader, val_loader, epochs=50, lr=1e-4, device='cuda'):
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0

        for degraded, clean in tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}'):
            degraded = degraded.to(device)
            clean = clean.to(device)

            optimizer.zero_grad()
            outputs = model(degraded)
            loss = criterion(outputs, clean)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * degraded.size(0)

        train_loss /= len(train_loader.dataset)


        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for degraded, clean in val_loader:
                degraded = degraded.to(device)
                clean = clean.to(device)
                outputs = model(degraded)
                val_loss += criterion(outputs, clean).item() * degraded.size(0)

        val_loss /= len(val_loader.dataset)
        scheduler.step(val_loss)

        print(f'Epoch {epoch+1}: Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}')

          # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), '/content/drive/MyDrive/best_model.pth')

    return model

In [11]:
## Main Execution
if __name__ == '__main__':
    # Parameters
    batch_size = 8  # Reduced for Colab memory
    patch_size = 128
    base_channels = 32  # Reduced from original paper for Colab compatibility
    epochs = 10  # Start with fewer epochs

    # Define paths to your data in Google Drive
    rain_path = '/content/drive/MyDrive/Datasets/Rain100/Rain100L/rain'
    no_rain_path = '/content/drive/MyDrive/Datasets/Rain100/Rain100L/norain'
    haze_path = '/content/drive/MyDrive/Datasets/RESIDE/SOTS/outdoor/hazy'
    no_haze_path = '/content/drive/MyDrive/Datasets/RESIDE/SOTS/outdoor/gt'

    # Transformations
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # Create datasets
    train_dataset = CombinedHazeRainDataset(
        rain_path=rain_path,
        no_rain_path=no_rain_path,
        haze_path=haze_path,
        no_haze_path=no_haze_path,
        transform=transform,
        patch_size=patch_size,
        mode='train'
    )

    val_dataset = CombinedHazeRainDataset(
        rain_path=rain_path,
        no_rain_path=no_rain_path,
        haze_path=haze_path,
        no_haze_path=no_haze_path,
        transform=transform,
        patch_size=patch_size,
        mode='val'
    )

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


      # Initialize model
    model = DehazeDerainNet(base_channels=base_channels)

    # Check if GPU is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Train
    trained_model = train_model(
        model,
        train_loader,
        val_loader,
        epochs=epochs,
        lr=1e-4,
        device=device
    )

    print("Training completed! Best model saved to Google Drive.")


Using device: cuda


Epoch 1/10: 100%|██████████| 60/60 [03:32<00:00,  3.55s/it]


Epoch 1: Train Loss: 0.547689, Val Loss: 0.535294


Epoch 2/10: 100%|██████████| 60/60 [00:10<00:00,  5.66it/s]


Epoch 2: Train Loss: 0.477021, Val Loss: 0.520684


Epoch 3/10: 100%|██████████| 60/60 [00:10<00:00,  5.70it/s]


Epoch 3: Train Loss: 0.443040, Val Loss: 0.481254


Epoch 4/10: 100%|██████████| 60/60 [00:10<00:00,  5.67it/s]


Epoch 4: Train Loss: 0.430678, Val Loss: 0.438056


Epoch 5/10: 100%|██████████| 60/60 [00:09<00:00,  6.26it/s]


Epoch 5: Train Loss: 0.390097, Val Loss: 0.465636


Epoch 6/10: 100%|██████████| 60/60 [00:09<00:00,  6.59it/s]


Epoch 6: Train Loss: 0.386475, Val Loss: 0.440647


Epoch 7/10: 100%|██████████| 60/60 [00:09<00:00,  6.20it/s]


Epoch 7: Train Loss: 0.390119, Val Loss: 0.439704


Epoch 8/10: 100%|██████████| 60/60 [00:10<00:00,  5.71it/s]


Epoch 8: Train Loss: 0.372243, Val Loss: 0.462875


Epoch 9/10: 100%|██████████| 60/60 [00:10<00:00,  5.68it/s]


Epoch 9: Train Loss: 0.378464, Val Loss: 0.457064


Epoch 10/10: 100%|██████████| 60/60 [00:10<00:00,  5.74it/s]


Epoch 10: Train Loss: 0.344939, Val Loss: 0.418690
Training completed! Best model saved to Google Drive.
