In [15]:
from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.append('/content/drive/MyDrive/DRANet/Model')
sys.path.append('/content/drive/MyDrive/DRANet')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


==========================================

Load DATASET GAUSS

==========================================

In [None]:
from dataset import DatasetGauss
from torch.utils.data import DataLoader

dataset = DatasetGauss(
    root_dir="/content/drive/MyDrive/DRANet/Data/Train/CBSD68",
    noise_level="noisy25",
    patch_size=128,
    patches_per_image=10,
    augment=True
)

train_loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, pin_memory=True)

============TRAIN MODEL============


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from DRANet import DRANet
from dataset import DatasetGauss

# ==== PSNR CALCULATION FUNCTION ====
def calc_psnr(output, target, max_val=1.0):
    mse = F.mse_loss(output, target)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(max_val / torch.sqrt(mse))

# ==== CONFIGURATION ====
data_root = "/content/drive/MyDrive/DRANet/Data/Train"
noise_level = "noisy25"
checkpoint_path = "/content/drive/MyDrive/DRANet/Model/checkpoint_noise25.pth"
best_model_path = "/content/drive/MyDrive/DRANet/Model/model_dranet_guass.pth"

batch_size = 4
epochs = 200
lr = 1e-4
patch_size = 128
patches_per_image = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==== LOAD DATASET ====
dataset = DatasetGauss(
    root_dir=data_root,
    noise_level=noise_level,
    patch_size=patch_size,
    patches_per_image=patches_per_image,
    augment=True
)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

# ==== INITIALIZE MODEL ====
model = DRANet().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
criterion = nn.MSELoss()

# ==== LOAD CHECKPOINT (IF AVAILABLE) ====
start_epoch = 0
best_psnr = 0
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_psnr = checkpoint['best_psnr']
    print(f"Resumed from epoch {start_epoch}, best PSNR: {best_psnr:.2f} dB")

# ==== TRAINING LOOP ====
for epoch in range(start_epoch, epochs):
    model.train()
    train_loss = 0.0

    for noisy, clean, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        noisy, clean = noisy.to(device), clean.to(device)
        output = model(noisy)
        loss = criterion(output, clean)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    avg_loss = train_loss / len(train_loader)
    print(f"\nEpoch {epoch+1}: Avg Train Loss = {avg_loss:.6f}")

    # ==== PSNR EVALUATION ====
    model.eval()
    total_psnr = 0.0
    with torch.no_grad():
        for noisy, clean, _ in train_loader:
            noisy, clean = noisy.to(device), clean.to(device)
            output = model(noisy)
            total_psnr += calc_psnr(output, clean).item()

    avg_psnr = total_psnr / len(train_loader)
    print(f"Avg PSNR: {avg_psnr:.2f} dB")

    # ==== SAVE BEST MODEL ====
    if avg_psnr > best_psnr:
        best_psnr = avg_psnr
        torch.save(model.state_dict(), best_model_path)
        print("Best model saved.\n")

    # ==== SAVE CHECKPOINT ====
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_psnr': best_psnr
    }, checkpoint_path)

    # ==== STEP SCHEDULER ====
    scheduler.step()

Create image noise

In [None]:
import os
from pathlib import Path
from PIL import Image
import torch
import torchvision.transforms.functional as TF

# ===== PARAMETERS =====
clean_dir = Path('/content/drive/MyDrive/DRANet/Data/Test/Clean')         # Load folder containing clean images
noisy_dir = Path('/content/drive/MyDrive/DRANet/Data/Test/Noise')         # Save folder for noisy images
sigma = 25
noisy_dir.mkdir(parents=True, exist_ok=True)  # Create folder if it doesn't exist

# ===== ADD NOISE =====
for clean_path in clean_dir.glob("*.*"):
    try:
        # Load clean image
        clean_img = Image.open(clean_path).convert("RGB")
        clean_tensor = TF.to_tensor(clean_img)

        # Generate Gaussian noise
        noise = torch.randn_like(clean_tensor) * (sigma / 255.0)
        noisy_tensor = torch.clamp(clean_tensor + noise, 0.0, 1.0)

        # Save noisy image
        noisy_img = TF.to_pil_image(noisy_tensor)
        save_path = noisy_dir / clean_path.name
        noisy_img.save(save_path)

        print(f"Noisy image created: {save_path.name}")
    except Exception as e:
        print(f"Error with image {clean_path.name}: {e}")

===============TEST===============

Option 1: Denoise all images in a folder

In [None]:
%cd /content/drive/MyDrive/DRANet/Model
!python test.py

Option 2: Filter specific image noise, show PSNR and SSID

In [None]:
import torch
import torchvision.transforms.functional as TF
from PIL import Image
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as compare_ssim
import numpy as np
from DRANet import DRANet

# ==== CẤU HÌNH ====
model_path = "/content/drive/MyDrive/DRANet/Model/model_dranet_guass.pth"
noisy_path = "/content/drive/MyDrive/DRANet/Data/Test/Noise/huou.jpg"
clean_path = "/content/drive/MyDrive/DRANet/Data/Test/Clean/huou.jpg"

# ==== LOAD IMAGE ====
noisy_img = Image.open(noisy_path).convert('RGB')
clean_img = Image.open(clean_path).convert('RGB') if clean_path else None

noisy_tensor = TF.to_tensor(noisy_img).unsqueeze(0)
clean_tensor = TF.to_tensor(clean_img).unsqueeze(0) if clean_img else None

# ==== LOAD MODEL ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DRANet().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

# ==== FILTER ====
with torch.no_grad():
    output_tensor = model(noisy_tensor.to(device)).clamp(0, 1).cpu()

# ==== PSNR & SSIM ====
def calc_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

def calc_ssim(img1, img2):
    img1_np = img1.squeeze(0).permute(1, 2, 0).numpy()
    img2_np = img2.squeeze(0).permute(1, 2, 0).numpy()
    return compare_ssim(img1_np, img2_np, channel_axis=2, data_range=1.0)

if clean_tensor is not None:
    psnr = calc_psnr(output_tensor, clean_tensor)
    ssim = calc_ssim(output_tensor, clean_tensor)
else:
    psnr, ssim = None, None

# ==== SHOW ====
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.imshow(noisy_img)
plt.title("Noisy")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(TF.to_pil_image(output_tensor.squeeze(0)))
title = f"Denoised\nPSNR: {psnr:.2f} dB, SSIM: {ssim:.3f}" if psnr else "Denoised"
plt.title(title)
plt.axis('off')
plt.tight_layout()
plt.show()
