In [None]:
from google.colab import drive
drive.mount('/content/drive')

==========================================

DATASET GAUSS

==========================================

In [None]:
%%writefile dataset.py
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms.functional as TF
from pathlib import Path
import random
import torch
import torchvision.transforms as T

class DatasetGauss(Dataset):
    def __init__(self, root_dir, noise_level="noisy35", patch_size=128, patches_per_image=10, augment=True):
        self.root_dir = Path(root_dir)
        self.clean_dir = self.root_dir / "target"
        self.noise_dir = self.root_dir / noise_level

        self.patch_size = patch_size
        self.patches_per_image = patches_per_image
        self.augment = augment

        self.image_pairs = []
        for noisy_path in self.noise_dir.glob("*"):
            clean_path = self.clean_dir / noisy_path.name
            if clean_path.exists():
                self.image_pairs.append((noisy_path, clean_path))

        self.to_tensor = T.ToTensor()

    def __len__(self):
        # Total number of samples = number of images × patches per image
        return len(self.image_pairs) * self.patches_per_image

    def __getitem__(self, idx):
        # Get image index and patch index within the image
        img_idx = idx // self.patches_per_image
        noisy_path, clean_path = self.image_pairs[img_idx]

        # Load images and convert to tensors
        noisy_img = Image.open(noisy_path).convert("RGB")
        clean_img = Image.open(clean_path).convert("RGB")

        noisy = self.to_tensor(noisy_img)
        clean = self.to_tensor(clean_img)

        _, H, W = noisy.shape
        ps = self.patch_size

        if H < ps or W < ps:
            raise ValueError(f"Image too small to crop: {noisy_path.name}")

        # === RANDOM CROP ===
        top = random.randint(0, H - ps)
        left = random.randint(0, W - ps)
        noisy_patch = noisy[:, top:top+ps, left:left+ps]
        clean_patch = clean[:, top:top+ps, left:left+ps]

        # === DATA AUGMENTATION ===
        if self.augment:
            # Horizontal flip
            if random.random() < 0.5:
                noisy_patch = TF.hflip(noisy_patch)
                clean_patch = TF.hflip(clean_patch)

            # Vertical flip
            if random.random() < 0.5:
                noisy_patch = TF.vflip(noisy_patch)
                clean_patch = TF.vflip(clean_patch)

            # Rotation (0, 90, 180, 270)
            angle = random.choice([0, 90, 180, 270])
            if angle > 0:
                noisy_patch = TF.rotate(noisy_patch, angle)
                clean_patch = TF.rotate(clean_patch, angle)

        return noisy_patch, clean_patch, noisy_path.name

================DRANet MODEL================

In [None]:
%%writefile DRANet.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class Up(nn.Module):

    def __init__(self, nc, bias):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels=nc, out_channels=nc, kernel_size=2, stride=2, bias=bias)

    def forward(self, x1, x):
        x2 = self.up(x1)

        diffY = x.size()[2] - x2.size()[2]
        diffX = x.size()[3] - x2.size()[3]
        x3 = F.pad(x2, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        return x3


## Spatial Attention
class Basic(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, padding=0, bias=False):
        super(Basic, self).__init__()
        self.out_channels = out_planes
        groups = 1
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, padding=padding, groups=groups, bias=bias)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x


class ChannelPool(nn.Module):
    def __init__(self):
        super(ChannelPool, self).__init__()

    def forward(self, x):
        return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)


class SAB(nn.Module):
    def __init__(self):
        super(SAB, self).__init__()
        kernel_size = 5
        self.compress = ChannelPool()
        self.spatial = Basic(2, 1, kernel_size, padding=(kernel_size - 1) // 2, bias=False)

    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out)
        return x * scale


## Channel Attention Layer
class CAB(nn.Module):
    def __init__(self, nc, reduction=8, bias=False):
        super(CAB, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_du = nn.Sequential(
                nn.Conv2d(nc, nc // reduction, kernel_size=1, padding=0, bias=bias),
                nn.ReLU(inplace=True),
                nn.Conv2d(nc // reduction, nc, kernel_size=1, padding=0, bias=bias),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y


class RAB(nn.Module):
    def __init__(self, in_channels=64, out_channels=64, bias=True):
        super(RAB, self).__init__()
        kernel_size = 3
        stride = 1
        padding = 1
        layers = []
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
        self.res = nn.Sequential(*layers)
        self.sab = SAB()

    def forward(self, x):
        x1 = x + self.res(x)
        x2 = x1 + self.res(x1)
        x3 = x2 + self.res(x2)

        x3_1 = x1 + x3
        x4 = x3_1 + self.res(x3_1)
        x4_1 = x + x4

        x5 = self.sab(x4_1)
        x5_1 = x + x5

        return x5_1


class HDRAB(nn.Module):
    def __init__(self, in_channels=64, out_channels=64, bias=True):
        super(HDRAB, self).__init__()
        kernel_size = 3
        reduction = 8

        self.cab = CAB(in_channels, reduction, bias)

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=1, dilation=1, bias=bias)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=2, dilation=2, bias=bias)

        self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=3, dilation=3, bias=bias)
        self.relu3 = nn.ReLU(inplace=True)

        self.conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=4, dilation=4, bias=bias)

        self.conv3_1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=3, dilation=3, bias=bias)
        self.relu3_1 = nn.ReLU(inplace=True)

        self.conv2_1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=2, dilation=2, bias=bias)

        self.conv1_1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=1, dilation=1, bias=bias)
        self.relu1_1 = nn.ReLU(inplace=True)

        self.conv_tail = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=1, dilation=1, bias=bias)

    def forward(self, y):
        y1 = self.conv1(y)
        y1_1 = self.relu1(y1)
        y2 = self.conv2(y1_1)
        y2_1 = y2 + y

        y3 = self.conv3(y2_1)
        y3_1 = self.relu3(y3)
        y4 = self.conv4(y3_1)
        y4_1 = y4 + y2_1

        y5 = self.conv3_1(y4_1)
        y5_1 = self.relu3_1(y5)
        y6 = self.conv2_1(y5_1+y3)
        y6_1 = y6 + y4_1

        y7 = self.conv1_1(y6_1+y2_1)
        y7_1 = self.relu1_1(y7)
        y8 = self.conv_tail(y7_1+y1)
        y8_1 = y8 + y6_1

        y9 = self.cab(y8_1)
        y9_1 = y + y9

        return y9_1


class DRANet(nn.Module):
    def __init__(self, in_nc=3, out_nc=3, nc=128, bias=True):
        super(DRANet, self).__init__()
        kernel_size = 3

        self.conv_head = nn.Conv2d(in_nc, nc, kernel_size=kernel_size, padding=1, bias=bias)

        self.rab = RAB(nc, nc, bias)

        self.hdrab = HDRAB(nc, nc, bias)

        self.conv_tail = nn.Conv2d(nc, out_nc, kernel_size=kernel_size, padding=1, bias=bias)

        self.dual_tail = nn.Conv2d(2*out_nc, out_nc, kernel_size=kernel_size, padding=1, bias=bias)

        self.down = nn.Conv2d(nc, nc, kernel_size=2, stride=2, bias=bias) # giảm kích thước ảnh. VD: 128x128 -> 64x64

        self.up = Up(nc, bias)

    def forward(self, x):
        x1 = self.conv_head(x)
        x2 = self.rab(x1)
        x2_1 = self.down(x2)
        x3 = self.rab(x2_1)
        x3_1 = self.down(x3)
        x4 = self.rab(x3_1)
        x4_1 = self.up(x4, x3)
        x5 = self.rab(x4_1 + x3)
        x5_1 = self.up(x5, x2)
        x6 = self.rab(x5_1 + x2)
        x7 = self.conv_tail(x6 + x1)
        X = x - x7

        y1 = self.conv_head(x)
        y2 = self.hdrab(y1)
        y3 = self.hdrab(y2)
        y4 = self.hdrab(y3)
        y5 = self.hdrab(y4 + y3)
        y6 = self.hdrab(y5 + y2)
        y7 = self.conv_tail(y6 + y1)
        Y = x -y7

        z1 = torch.cat([X, Y], dim=1)
        z = self.dual_tail(z1)
        Z = x - z

        return Z

================Tạo file test================

In [None]:
%%writefile test.py
import os
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from DRANet import DRANet
import argparse

# ==== Preprocessing function ====
def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    tensor = TF.to_tensor(image).unsqueeze(0)  # Shape: (1, C, H, W)
    return tensor, image.size

def save_image(tensor, save_path, size=None):
    img = tensor.squeeze(0).clamp(0, 1).cpu()
    img_pil = TF.to_pil_image(img)
    if size:
        img_pil = img_pil.resize(size, Image.BICUBIC)
    img_pil.save(save_path)

# ==== PSNR calculation ====
def calc_psnr(output, target, max_val=1.0):
    mse = torch.mean((output - target) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(max_val / torch.sqrt(mse))

# ==== MAIN TEST FUNCTION ====
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model
    model = DRANet().to(device)
    model.load_state_dict(torch.load(args.model_path, map_location=device))
    model.eval()

    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)

    # Iterate over all images in the input directory
    image_paths = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
                   if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    for path in image_paths:
        filename = os.path.basename(path)
        noisy_tensor, size = load_image(path)
        noisy_tensor = noisy_tensor.to(device)

        with torch.no_grad():
            output_tensor = model(noisy_tensor)

        save_path = os.path.join(args.output_dir, filename)
        save_image(output_tensor, save_path, size=size)
        print(f"Saved: {save_path}")

if __name__ == "__main__":
    class Args:
        model_path = '/content/drive/MyDrive/DRANet/Model/model_dranet_guass.pth'  # Path to the trained model
        input_dir = '/content/drive/MyDrive/DRANet/Data/Test/Clean'                # Path to the input images
        output_dir = '/content/drive/MyDrive/DRANet/test_output'                   # Path to save denoised images

    args = Args()
    main(args)