In [15]:
import os
import glob
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from google.colab import drive
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import sys

In [16]:
# Mount Google Drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [17]:
# Define dataset paths
base_path = '/content/drive/MyDrive/DATASET'
splits = {
    'train': os.path.join(base_path, 'Train'),
    'validate': os.path.join(base_path, 'Validate'),
    'test': os.path.join(base_path, 'Test')
}
sigma_levels = [10]  # Only σ = 10

In [18]:
# Verify folder existence
for split, path in splits.items():
    if not os.path.exists(path):
        raise FileNotFoundError(f"Folder not found: {path}")
    noisy_path = os.path.join(path, f'noisy_sigma{sigma_levels[0]}')
    if not os.path.exists(noisy_path):
        raise FileNotFoundError(f"Noisy folder not found: {noisy_path}")
    print(f"Found {split} folder: {path}")

Found train folder: /content/drive/MyDrive/DATASET/Train
Found validate folder: /content/drive/MyDrive/DATASET/Validate
Found test folder: /content/drive/MyDrive/DATASET/Test


In [19]:
# PairedNoisyCleanDataset for RGB images
class PairedNoisyCleanDataset(Dataset):
    def __init__(self, base_dir, split, transform=None):
        self.transform = transform
        split_dir = os.path.join(base_dir, split.capitalize())
        noisy_dir = os.path.join(split_dir, 'noisy_sigma10')

        # Collect clean images (excluding noisy_sigma10)
        self.clean_images = []
        for root, _, files in os.walk(split_dir):
            if os.path.abspath(root) == os.path.abspath(noisy_dir):
                continue
            for f in files:
                if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                    self.clean_images.append(os.path.join(root, f))

        # Collect noisy images
        self.noisy_images = [
            os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir)
            if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))
        ]

        # Match pairs by filename
        clean_dict = {os.path.basename(p): p for p in self.clean_images}
        noisy_dict = {os.path.basename(p): p for p in self.noisy_images}
        self.matched_clean = []
        self.matched_noisy = []
        for fname in noisy_dict:
            if fname in clean_dict:
                self.matched_clean.append(clean_dict[fname])
                self.matched_noisy.append(noisy_dict[fname])
            else:
                print(f"Warning: No matching clean image for noisy file {fname}")

        print(f"[{split}] Found {len(self.matched_clean)} paired samples")
        if len(self.matched_clean) == 0:
            raise ValueError(f"No matched pairs found in {split} dataset!")

    def __len__(self):
        return len(self.matched_clean)

    def __getitem__(self, idx):
        noisy_img = Image.open(self.matched_noisy[idx]).convert('RGB')
        clean_img = Image.open(self.matched_clean[idx]).convert('RGB')

        if self.transform:
            seed = torch.randint(0, 2**31, (1,)).item()
            torch.manual_seed(seed)
            noisy_img = self.transform(noisy_img)
            torch.manual_seed(seed)
            clean_img = self.transform(clean_img)

        return noisy_img, clean_img


In [20]:
# Define transforms
train_transform = transforms.Compose([
    transforms.RandomCrop(64),  # 64x64 crops as per document
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()  # [3, 64, 64], [0, 1]
])
val_test_transform = transforms.ToTensor()  # [3, 256, 256]

In [21]:
# Create datasets and dataloaders
batch_size = 16
train_dataset = PairedNoisyCleanDataset(base_path, 'train', train_transform)
val_dataset = PairedNoisyCleanDataset(base_path, 'validate', val_test_transform)
test_dataset = PairedNoisyCleanDataset(base_path, 'test', val_test_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Train dataset size: {len(train_dataset)}")  # Expected: 7,999
print(f"Validation dataset size: {len(val_dataset)}")  # Expected: 1,002
print(f"Test dataset size: {len(test_dataset)}")  # Expected: 1,004

[train] Found 7998 paired samples
[validate] Found 1001 paired samples
[test] Found 1003 paired samples
Train dataset size: 7998
Validation dataset size: 1001
Test dataset size: 1003


In [22]:
# Debug: Inspect batch structure
for batch in train_loader:
    print("Batch type:", type(batch))
    print("Batch length:", len(batch))
    for i, item in enumerate(batch):
        print(f"Item {i} shape: {item.shape}")
    break

Batch type: <class 'list'>
Batch length: 2
Item 0 shape: torch.Size([16, 3, 64, 64])
Item 1 shape: torch.Size([16, 3, 64, 64])


In [23]:
# Bi-MSAAE Model
class MultiScaleEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Depthwise separable convolution for 3x3
        self.branch3x3 = nn.Sequential(
            nn.Conv2d(3, 3, kernel_size=3, padding=1, groups=3),  # Depthwise
            nn.Conv2d(3, 32, kernel_size=1),  # Pointwise
            nn.ReLU(inplace=True)
        )
        # Depthwise separable convolution for 5x5
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(3, 3, kernel_size=5, padding=2, groups=3),  # Depthwise
            nn.Conv2d(3, 32, kernel_size=1),  # Pointwise
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        f1 = self.branch3x3(x)  # [batch, 32, H, W]
        f2 = self.branch5x5(x)  # [batch, 32, H, W]
        return torch.cat([f1, f2], dim=1)  # [batch, 64, H, W]

class NoiseGateModule(nn.Module):
    def __init__(self, in_channels=64):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        mask = self.conv(x)  # [batch, 1, H, W]
        mask = self.sigmoid(mask)
        return x * mask  # [batch, 64, H, W]

class ChannelAttentionBlock(nn.Module):
    def __init__(self, in_channels=64):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.conv = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.avg_pool(x)  # [batch, 64, 1, 1]
        max_out = self.max_pool(x)  # [batch, 64, 1, 1]
        pooled = torch.cat([avg_out, max_out], dim=1)  # [batch, 128, 1, 1]
        attn = self.conv(pooled)  # [batch, 64, 1, 1]
        attn = self.sigmoid(attn)
        return x * attn  # [batch, 64, H, W]

class DualHeadDecoder(nn.Module):
    def __init__(self, in_channels=64):
        super().__init__()
        self.structural_head = nn.Conv2d(in_channels, 3, kernel_size=3, padding=1)
        self.texture_head = nn.Conv2d(in_channels, 3, kernel_size=5, padding=2)
        self.alpha = 0.6

    def forward(self, x):
        struct = self.structural_head(x)  # [batch, 3, H, W]
        text = self.texture_head(x)  # [batch, 3, H, W]
        return self.alpha * struct + (1 - self.alpha) * text  # [batch, 3, H, W]

class BiMSAAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = MultiScaleEncoder()
        self.noise_gate = NoiseGateModule()
        self.attention = ChannelAttentionBlock()
        self.decoder = DualHeadDecoder()
        # Skip connection convolution to match dimensions
        self.skip_conv = nn.Conv2d(64, 64, kernel_size=1)

    def forward(self, x):
        # Encoder
        f = self.encoder(x)  # [batch, 64, H, W]
        # Skip connection
        skip = self.skip_conv(f)
        # Noise-Gate
        f = self.noise_gate(f)  # [batch, 64, H, W]
        # Attention
        f = self.attention(f)  # [batch, 64, H, W]
        # Add skip connection
        f = f + skip
        # Decoder
        out = self.decoder(f)  # [batch, 3, H, W]
        return torch.sigmoid(out)  # Ensure [0, 1] output

In [24]:
# Animated Epoch Logger
GREEN = "\033[92m"
YELLOW = "\033[93m"
CYAN = "\033[96m"
RESET = "\033[0m"

def animated_epoch_update(epoch, num_epochs, train_loss, val_loss, psnr, ssim, epoch_time):
    bar_len = 30
    filled_len = int(round(bar_len * (epoch + 1) / num_epochs))
    bar = '=' * filled_len + '-' * (bar_len - filled_len)
    sys.stdout.write(f"\r{CYAN}Epoch {epoch+1}/{num_epochs} [{bar}]{RESET}")
    sys.stdout.flush()
    time.sleep(0.1)
    print(f"\n  🏋️  {YELLOW}Train Loss{RESET} : {train_loss:.4f}")
    print(f"  📉  {YELLOW}Val Loss  {RESET} : {val_loss:.4f}")
    print(f"  📷  {YELLOW}PSNR      {RESET} : {psnr:.2f} dB")
    print(f"  🔍  {YELLOW}SSIM      {RESET} : {ssim:.4f}")
    print(f"  ⏱️  {YELLOW}Time/Epoch{RESET}: {epoch_time:.2f} sec")
    print("-" * 50)


In [25]:
# Training Function
def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs):
    model.to(device)
    train_dataset = train_loader.dataset
    val_dataset = val_loader.dataset

    for epoch in range(num_epochs):
        start_time = time.time()
        model.train()
        train_loss = 0.0
        for batch in train_loader:
            noisy, clean = batch
            noisy, clean = noisy.to(device), clean.to(device)
            optimizer.zero_grad()
            outputs = model(noisy)
            loss = criterion(outputs, clean)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * noisy.size(0)

        # Validation
        model.eval()
        val_loss, psnr_val, ssim_val = 0.0, 0.0, 0.0
        with torch.no_grad():
            for noisy, clean in val_loader:
                noisy, clean = noisy.to(device), clean.to(device)
                outputs = model(noisy)
                loss = criterion(outputs, clean)
                val_loss += loss.item() * noisy.size(0)
                outputs_np = outputs.cpu().numpy()
                clean_np = clean.cpu().numpy()
                for i in range(outputs_np.shape[0]):
                    psnr_channels = [
                        peak_signal_noise_ratio(clean_np[i, c], outputs_np[i, c], data_range=1.0)
                        for c in range(3)
                    ]
                    ssim_channels = [
                        structural_similarity(clean_np[i, c], outputs_np[i, c], data_range=1.0)
                        for c in range(3)
                    ]
                    psnr_val += sum(psnr_channels) / 3
                    ssim_val += sum(ssim_channels) / 3

        avg_train_loss = train_loss / len(train_dataset)
        avg_val_loss = val_loss / len(val_dataset)
        avg_psnr_val = psnr_val / len(val_dataset)
        avg_ssim_val = ssim_val / len(val_dataset)
        epoch_time = time.time() - start_time

        animated_epoch_update(epoch, num_epochs, avg_train_loss, avg_val_loss, avg_psnr_val, avg_ssim_val, epoch_time)

In [28]:
# Initialize model, loss, optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BiMSAAE().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 1

In [None]:
# Train model
train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs)

# Save model
model_path = '/content/drive/MyDrive/bi_msaae_model_sigma10_rgb.pth'
torch.save(model.state_dict(), model_path)
print(f"Training complete. Model saved to {model_path}")