In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

##############################################
# 1. Define Basic Building Blocks
##############################################

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_channels)
        self.relu  = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_channels)
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class AttentionGate(nn.Module):
    """
    Attention Gate as in Attention U-Net.
    g: gating signal (from decoder, coarser features)
    x: skip connection features (from encoder, finer features)
    """
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        # Multiply attention coefficients with the encoder features
        return x * psi

##############################################
# 2. Define the Attention Residual U-Net
##############################################

class AttentionResidualUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=[64, 128, 256, 512]):
        """
        A simplified U-Net with residual blocks and attention gates on skip connections.
        """
        super(AttentionResidualUNet, self).__init__()
        self.encoder_blocks = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Encoder: create residual blocks and save skip connection outputs
        prev_channels = in_channels
        for feature in features:
            self.encoder_blocks.append(ResidualBlock(prev_channels, feature))
            prev_channels = feature
        
        # Bottleneck: an extra residual block
        self.bottleneck = ResidualBlock(features[-1], features[-1]*2)
        
        # Decoder: upsampling layers, attention gates, and residual blocks
        self.upconv_blocks = nn.ModuleList()
        self.decoder_blocks = nn.ModuleList()
        self.attention_gates = nn.ModuleList()
        rev_features = features[::-1]
        decoder_in_channels = features[-1]*2
        for feature in rev_features:
            self.upconv_blocks.append(
                nn.ConvTranspose2d(decoder_in_channels, feature, kernel_size=2, stride=2)
            )
            # Attention gate: gating signal from decoder and skip connection from encoder
            self.attention_gates.append(
                AttentionGate(F_g=feature, F_l=feature, F_int=feature // 2)
            )
            # After concatenation, channels double
            self.decoder_blocks.append(
                ResidualBlock(feature * 2, feature)
            )
            decoder_in_channels = feature
        
        # Final convolution to get desired output channels
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for enc in self.encoder_blocks:
            x = enc(x)
            skip_connections.append(x)
            x = self.pool(x)
        
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]  # reverse for decoder
        
        for idx in range(len(self.upconv_blocks)):
            x = self.upconv_blocks[idx](x)
            skip_connection = skip_connections[idx]
            # Apply attention gate on skip connection features
            att_gate = self.attention_gates[idx]
            skip_connection = att_gate(g=x, x=skip_connection)
            # Concatenate skip connection features with upsampled features
            x = torch.cat([skip_connection, x], dim=1)
            x = self.decoder_blocks[idx](x)
        
        return torch.sigmoid(self.final_conv(x))


In [2]:
!pip install opencv-python-headless torchsummary # For install cv2

Collecting opencv-python-headless
  Using cached opencv_python_headless-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting torchsummary
  Using cached torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Using cached opencv_python_headless-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (50.0 MB)
Using cached torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary, opencv-python-headless
Successfully installed opencv-python-headless-4.11.0.86 torchsummary-1.5.1


In [3]:
from torchinfo import summary

# Initialize the model
model = AttentionResidualUNet(in_channels=1, out_channels=1)

# Move the model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

summary(model, input_size=(1, 1, 64, 64), device=device)


Layer (type:depth-idx)                   Output Shape              Param #
AttentionResidualUNet                    [1, 1, 64, 64]            --
├─ModuleList: 1-7                        --                        (recursive)
│    └─ResidualBlock: 2-1                [1, 64, 64, 64]           --
│    │    └─Conv2d: 3-1                  [1, 64, 64, 64]           576
│    │    └─BatchNorm2d: 3-2             [1, 64, 64, 64]           128
│    │    └─ReLU: 3-3                    [1, 64, 64, 64]           --
│    │    └─Conv2d: 3-4                  [1, 64, 64, 64]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 64, 64]           128
│    │    └─Sequential: 3-6              [1, 64, 64, 64]           192
│    │    └─ReLU: 3-7                    [1, 64, 64, 64]           --
├─MaxPool2d: 1-2                         [1, 64, 32, 32]           --
├─ModuleList: 1-7                        --                        (recursive)
│    └─ResidualBlock: 2-2                [1, 128, 32, 32]  

In [2]:
!pip install torchinfo




In [4]:
import os
import time
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp

# Fix random seeds for reproducibility
seed = 1
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU
torch.use_deterministic_algorithms(True)

class PatchDataset(Dataset):
    def __init__(self, data_path):
        """
        Dataset for loading pre-extracted patches.
        """
        self.original_patches = np.load(os.path.join(data_path, 'original_patches.npy')) # Original patch
        self.ground_truth_patches = np.load(os.path.join(data_path, 'ground_truth_patches.npy')) # Ground truth patch
        self.num_samples = len(self.original_patches)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        original_patch = self.original_patches[idx]
        ground_truth_patch = self.ground_truth_patches[idx]

        # Convert to tensors and normalize to [0, 1]
        original_patch = torch.tensor(original_patch, dtype=torch.float32).unsqueeze(0) / 255.0
        ground_truth_patch = torch.tensor(ground_truth_patch, dtype=torch.float32).unsqueeze(0) / 255.0

        return original_patch, ground_truth_patch

# Initialize the dataset
train_data_path = 'data/train_patches_64x64x25'
train_dataset = PatchDataset(train_data_path)

# Define a seed
seed = 1

# Worker initialization function
def worker_init_fn(worker_id):
    # Seed each worker with a combination of the base seed and the worker ID
    np.random.seed(seed + worker_id)
    torch.manual_seed(seed + worker_id)

# Create DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=False, # Need continous context
    num_workers=6,
    pin_memory=True,
    worker_init_fn=worker_init_fn
)

# Model setup (example: AttU_Net)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AttentionResidualUNet(in_channels=1, out_channels=1).to(device)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training loop
epochs = 10
scaler = amp.GradScaler()

print(f"Total patches: {len(train_dataset)}")
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU available: {torch.cuda.get_device_name(0)}")

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    start_time = time.time()

    print(f"Epoch {epoch+1}/{epochs}")
    for batch_idx, (original_patches, ground_truth_patches) in enumerate(train_loader):
        inputs = original_patches.to(device)
        targets = ground_truth_patches.to(device)

        optimizer.zero_grad()

        # Forward pass with mixed precision
        with amp.autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        # Backpropagation with mixed precision
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

        # Batch progress and memory usage
        if batch_idx % 1000 == 0:  # Log every 1000 batches
            print(f"Batch {batch_idx}/{len(train_loader)} - Loss: {loss.item():.4f}")
            if device.type == 'cuda':
                gpu_memory = torch.cuda.memory_allocated(device) / 1e6  # Convert to MB
                print(f"    GPU Memory Usage: {gpu_memory:.2f} MB")

    epoch_loss = running_loss / len(train_loader)
    epoch_time = time.time() - start_time

    print(f"Epoch [{epoch+1}/{epochs}] completed in {epoch_time:.2f}s")
    print(f"Average Loss: {epoch_loss:.4f}")
    print("-" * 50)

# Save the trained model
os.makedirs('model', exist_ok=True)
torch.save(model.state_dict(), 'model_64x64x25/attresunet_trained_64x64x25.pth')
print("Model training completed!")

Total patches: 187150
Using device: cuda
GPU available: NVIDIA A10G
Epoch 1/10
Batch 0/2925 - Loss: 0.2639
    GPU Memory Usage: 666.25 MB
Batch 1000/2925 - Loss: 0.0111
    GPU Memory Usage: 667.82 MB
Batch 2000/2925 - Loss: 0.0183
    GPU Memory Usage: 667.82 MB
Epoch [1/10] completed in 233.46s
Average Loss: 0.0228
--------------------------------------------------
Epoch 2/10
Batch 0/2925 - Loss: 0.0210
    GPU Memory Usage: 667.82 MB
Batch 1000/2925 - Loss: 0.0061
    GPU Memory Usage: 667.82 MB
Batch 2000/2925 - Loss: 0.0178
    GPU Memory Usage: 667.82 MB
Epoch [2/10] completed in 218.90s
Average Loss: 0.0165
--------------------------------------------------
Epoch 3/10
Batch 0/2925 - Loss: 0.0113
    GPU Memory Usage: 667.82 MB
Batch 1000/2925 - Loss: 0.0043
    GPU Memory Usage: 667.82 MB
Batch 2000/2925 - Loss: 0.0172
    GPU Memory Usage: 667.82 MB
Epoch [3/10] completed in 194.40s
Average Loss: 0.0151
--------------------------------------------------
Epoch 4/10
Batch 0/2925

In [5]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import structural_similarity as ssim
from skimage import img_as_ubyte
import matplotlib.pyplot as plt
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the trained model
model = AttentionResidualUNet(in_channels=1, out_channels=1).to(device)
model.load_state_dict(torch.load('model_64x64x25/attresunet_trained_64x64x25.pth'))
model.eval()

# Define the dataset and DataLoader
class PatchDataset(Dataset):
    def __init__(self, data_path):
        """
        Dataset for loading pre-extracted patches.
        """
        self.original_patches = np.load(os.path.join(data_path, 'original_patches.npy'))
        self.ground_truth_patches = np.load(os.path.join(data_path, 'ground_truth_patches.npy'))
        self.num_samples = len(self.original_patches)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        original_patch = self.original_patches[idx]
        ground_truth_patch = self.ground_truth_patches[idx]
        # Convert to tensors and normalize to [0, 1]
        original_patch = torch.tensor(original_patch, dtype=torch.float32).unsqueeze(0) / 255.0
        ground_truth_patch = torch.tensor(ground_truth_patch, dtype=torch.float32).unsqueeze(0) / 255.0
        return original_patch, ground_truth_patch

# Helper function for SSIM (using skimage)
def calculate_ssim_custom(output, target):
    output_np = output.squeeze().cpu().numpy()
    target_np = target.squeeze().cpu().numpy()
    output_np = img_as_ubyte(np.clip(output_np, 0, 1))
    target_np = img_as_ubyte(np.clip(target_np, 0, 1))
    return ssim(output_np, target_np, data_range=255.0)

##############################################
# Custom MSE and PSNR Functions from Scratch
##############################################

def calculate_mse(output, target):
    """
    Compute the Mean Squared Error between two images.
    Args:
        output (np.ndarray): The output image.
        target (np.ndarray): The ground truth image.
    Returns:
        float: The MSE value.
    """
    return np.mean((output - target) ** 2)

def calculate_psnr_from_scratch(output, target, max_pixel=1.0):
    """
    Compute the Peak Signal-to-Noise Ratio using MSE.
    Args:
        output (np.ndarray): The output image.
        target (np.ndarray): The ground truth image.
        max_pixel (float): The maximum possible pixel value (default: 1.0 for normalized images).
    Returns:
        float: The PSNR value in decibels.
    """
    mse_value = calculate_mse(output, target)
    if mse_value == 0:
        return float('inf')
    return 10 * math.log10((max_pixel ** 2) / mse_value)

##############################################
# Evaluation on Test Data
##############################################

data_path = 'data/test_patches_64x64x25'
patch_dataset = PatchDataset(data_path)
patch_loader = DataLoader(patch_dataset, batch_size=1, shuffle=False)

# Metrics storage and image reconstruction
mse_list = []
psnr_list = []
ssim_list = []
reconstructed_patches = []

for i, (original_patch, ground_truth_patch) in enumerate(patch_loader):
    original_patch = original_patch.to(device)
    ground_truth_patch = ground_truth_patch.to(device)

    # Predict the output from the model
    with torch.no_grad():
        output_patch = model(original_patch)

    # Convert tensors to numpy arrays for metric calculations
    output_np = output_patch.squeeze().cpu().numpy()
    ground_truth_np = ground_truth_patch.squeeze().cpu().numpy()

    # Calculate MSE using our custom function
    mse_value = calculate_mse(output_np, ground_truth_np)
    
    # Calculate PSNR using our custom function
    psnr_value = calculate_psnr_from_scratch(output_np, ground_truth_np, max_pixel=1.0)
    
    # Calculate SSIM
    ssim_value = calculate_ssim_custom(output_patch, ground_truth_patch)
    
    mse_list.append(mse_value)
    psnr_list.append(psnr_value)
    ssim_list.append(ssim_value)
    
    # Collect patches for reconstruction if needed
    reconstructed_patches.append(output_np)

# Calculate and print the average metrics
average_mse = np.mean(mse_list)
average_psnr = np.mean(psnr_list)
average_ssim = np.mean(ssim_list)

print(f"Average MSE: {average_mse:.4f}")
print(f"Average PSNR: {average_psnr:.4f} dB")
print(f"Average SSIM: {average_ssim:.4f}")


Average MSE: 0.0399
Average PSNR: 21.9758 dB
Average SSIM: 0.8463
