In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import re
import math

In [2]:
class PositionalEncoding2D(nn.Module):
    def __init__(self, embed_dim: int):
        super().__init__()
        self.proj = nn.Conv2d(2, embed_dim, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        device = x.device

        y_lin = torch.linspace(-1.0, 1.0, steps=H, device=device)
        x_lin = torch.linspace(-1.0, 1.0, steps=W, device=device)
        y_grid = y_lin.view(1, 1, H, 1).expand(B, 1, H, W)
        x_grid = x_lin.view(1, 1, 1, W).expand(B, 1, H, W)
        coords = torch.cat([x_grid, y_grid], dim=1)

        pos_emb = self.proj(coords)
        return x + pos_emb

class ConvBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)

class UpConv(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)

    def forward(self, x):
        return self.up(x)

class AttentionBlock(nn.Module):
    def __init__(self, F_g: int, F_l: int, F_int: int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, bias=True),
            nn.BatchNorm2d(F_int),
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, bias=True),
            nn.BatchNorm2d(F_int),
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        mask = self.psi(psi)
        return x * mask

class AttentionUNet(nn.Module):
    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        enc_channels: tuple = (128, 64, 32, 16),
    ):
        super().__init__()
        assert len(enc_channels) == 4

        self.in_conv = ConvBlock(in_channels, enc_channels[0])
        self.pos0 = PositionalEncoding2D(enc_channels[0])
        self.pool = nn.MaxPool2d(2)

        self.enc1 = ConvBlock(enc_channels[0], enc_channels[1])
        self.pos1 = PositionalEncoding2D(enc_channels[1])

        self.enc2 = ConvBlock(enc_channels[1], enc_channels[2])
        self.pos2 = PositionalEncoding2D(enc_channels[2])

        self.enc3 = ConvBlock(enc_channels[2], enc_channels[3])
        self.pos3 = PositionalEncoding2D(enc_channels[3])

        self.up_stage3 = UpConv(enc_channels[3], enc_channels[2])
        self.att_stage3 = AttentionBlock(
            F_g=enc_channels[2], F_l=enc_channels[2], F_int=enc_channels[2] // 2
        )
        self.conv_stage3 = ConvBlock(enc_channels[2] + enc_channels[2], enc_channels[2])
        self.pos_d3 = PositionalEncoding2D(enc_channels[2])

        self.up_stage2 = UpConv(enc_channels[2], enc_channels[1])
        self.att_stage2 = AttentionBlock(
            F_g=enc_channels[1], F_l=enc_channels[1], F_int=enc_channels[1] // 2
        )
        self.conv_stage2 = ConvBlock(enc_channels[1] + enc_channels[1], enc_channels[1])
        self.pos_d2 = PositionalEncoding2D(enc_channels[1])

        self.up_stage1 = UpConv(enc_channels[1], enc_channels[0])
        self.att_stage1 = AttentionBlock(
            F_g=enc_channels[0], F_l=enc_channels[0], F_int=enc_channels[0] // 2
        )
        self.conv_stage1 = ConvBlock(enc_channels[0] + enc_channels[0], enc_channels[0])
        self.pos_d1 = PositionalEncoding2D(enc_channels[0])

        self.out_conv = nn.Sequential(
            nn.Conv2d(enc_channels[0], out_channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x0 = self.pos0(self.in_conv(x))
        x1 = self.pos1(self.enc1(self.pool(x0)))
        x2 = self.pos2(self.enc2(self.pool(x1)))
        x3 = self.pos3(self.enc3(self.pool(x2)))

        d3_up = self.up_stage3(x3)
        x2_att = self.att_stage3(x2, d3_up)
        d3_concat = torch.cat([d3_up, x2_att], dim=1)
        d3 = self.pos_d3(self.conv_stage3(d3_concat))

        d2_up = self.up_stage2(d3)
        x1_att = self.att_stage2(x1, d2_up)
        d2_concat = torch.cat([d2_up, x1_att], dim=1)
        d2 = self.pos_d2(self.conv_stage2(d2_concat))

        d1_up = self.up_stage1(d2)
        x0_att = self.att_stage1(x0, d1_up)
        d1_concat = torch.cat([d1_up, x0_att], dim=1)
        d1 = self.pos_d1(self.conv_stage1(d1_concat))

        return self.out_conv(d1)

In [3]:
class AugraphyPatchDataset(Dataset):
    def __init__(self, input_dir, aug_dirs, patch_size=(256, 256), overlap=96, transform=None, num_images_to_use=None):
        self.patch_w, self.patch_h = patch_size
        self.overlap = overlap
        self.stride_w = self.patch_w - self.overlap
        self.stride_h = self.patch_h - self.overlap
        if self.stride_w <= 0 or self.stride_h <= 0:
            raise ValueError("Stride must be positive. Patch size must be greater than overlap.")

        self.transform = transform if transform else transforms.ToTensor()

        self.data_pairs = []
        self.clean_image_map = {
            os.path.splitext(f)[0].lower(): os.path.join(input_dir, f)
            for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff'))
        }

        all_aug_paths = []
        for aug_dir in aug_dirs:
            for aug_file_name in os.listdir(aug_dir):
                if aug_file_name.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff')):
                    aug_img_path = os.path.join(aug_dir, aug_file_name)
                    aug_base_name = os.path.splitext(aug_file_name)[0].lower()

                    best_match_parent = None
                    for clean_base in self.clean_image_map.keys():
                        if aug_base_name.startswith(clean_base) and (
                            len(aug_base_name) == len(clean_base) or
                            (len(aug_base_name) > len(clean_base) and aug_base_name[len(clean_base)] == '_')
                        ):
                            if best_match_parent is None or len(clean_base) > len(best_match_parent):
                                best_match_parent = clean_base

                    if best_match_parent and best_match_parent in self.clean_image_map:
                        clean_img_path = self.clean_image_map[best_match_parent]
                        all_aug_paths.append((aug_img_path, clean_img_path))

        if num_images_to_use is not None and num_images_to_use > 0:
            self.data_pairs = all_aug_paths[:num_images_to_use]
        else:
            self.data_pairs = all_aug_paths

        self.patch_coords = []
        for img_idx, (aug_img_path, _) in enumerate(self.data_pairs):
            with Image.open(aug_img_path) as img:
                img_width, img_height = img.size

            num_patches_x = math.ceil((img_width - self.patch_w + self.stride_w) / self.stride_w) if img_width > self.patch_w else 1
            num_patches_y = math.ceil((img_height - self.patch_h + self.stride_h) / self.stride_h) if img_height > self.patch_h else 1

            if img_width <= self.patch_w: num_patches_x = 1
            if img_height <= self.patch_h: num_patches_y = 1

            for r_idx in range(num_patches_y):
                for c_idx in range(num_patches_x):
                    x_start = c_idx * self.stride_w
                    y_start = r_idx * self.stride_h

                    x_start_actual = min(x_start, img_width - self.patch_w)
                    y_start_actual = min(y_start, img_height - self.patch_h)

                    x_start_actual = max(0, x_start_actual)
                    y_start_actual = max(0, y_start_actual)

                    self.patch_coords.append((img_idx, x_start_actual, y_start_actual))

    def __len__(self):
        return len(self.patch_coords)

    def __getitem__(self, idx):
        img_idx, x_start, y_start = self.patch_coords[idx]
        aug_img_path, clean_img_path = self.data_pairs[img_idx]

        aug_img = Image.open(aug_img_path).convert("RGB")
        clean_img = Image.open(clean_img_path).convert("RGB")

        aug_patch = aug_img.crop((x_start, y_start, x_start + self.patch_w, y_start + self.patch_h))
        clean_patch = clean_img.crop((x_start, y_start, x_start + self.patch_w, y_start + self.patch_h))

        if self.transform:
            aug_patch = self.transform(aug_patch)
            clean_patch = self.transform(clean_patch)

        return aug_patch, clean_patch

In [4]:
def predict_full_image_sliding_window(model, image_path, patch_size, overlap, transform, device):
    model.eval()
    input_image = Image.open(image_path).convert("RGB")
    img_width, img_height = input_image.size

    patch_w, patch_h = patch_size
    stride_w = patch_w - overlap
    stride_h = patch_h - overlap

    num_patches_x = max(1, math.ceil((img_width - patch_w + stride_w) / stride_w)) if img_width > patch_w else 1
    num_patches_y = max(1, math.ceil((img_height - patch_h + stride_h) / stride_h)) if img_height > patch_h else 1

    output_full_image = torch.zeros((1, 3, img_height, img_width), device=device, dtype=torch.float32)
    overlap_counts = torch.zeros((1, 1, img_height, img_width), device=device, dtype=torch.float32)

    with torch.no_grad():
        for r_idx in range(num_patches_y):
            for c_idx in range(num_patches_x):
                x_start = c_idx * stride_w
                y_start = r_idx * stride_h

                x_start_actual = min(x_start, img_width - patch_w)
                y_start_actual = min(y_start, img_height - patch_h)
                
                x_start_actual = max(0, x_start_actual)
                y_start_actual = max(0, y_start_actual)

                input_patch_pil = input_image.crop((x_start_actual, y_start_actual, x_start_actual + patch_w, y_start_actual + patch_h))
                input_patch_tensor = transform(input_patch_pil).unsqueeze(0).to(device)

                predicted_patch_tensor = model(input_patch_tensor)

                output_full_image[:, :, y_start_actual : y_start_actual + patch_h, x_start_actual : x_start_actual + patch_w] += predicted_patch_tensor.squeeze(0)
                overlap_counts[:, :, y_start_actual : y_start_actual + patch_h, x_start_actual : x_start_actual + patch_w] += 1

    output_full_image = output_full_image / torch.clamp(overlap_counts, min=1)
    return output_full_image.squeeze(0).cpu()

In [5]:
import random

In [16]:
if __name__ == "__main__":
    input_dir = "/kaggle/input/augraphy-dataset/augraphy/input"
    aug_dirs = [
        "/kaggle/input/augraphy-dataset/augraphy/augraphy_output_1",
        "/kaggle/input/augraphy-dataset/augraphy/augraphy_output_2"
    ]
    PATCH_SIZE = (512, 512)
    OVERLAP_SIZE = 96
    NUM_IMAGES_FOR_TRAINING = -1
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_epochs = 70
    batch_size = 8
    learning_rate = 1e-4
    use_data_parallel = torch.cuda.device_count() > 1
    checkpoint_path = "/kaggle/input/checkpoints-unet1/denoising_unet_patch_checkpoint.pth"
    start_epoch = 0

    model = AttentionUNet(in_channels=3, out_channels=3)
    if use_data_parallel:
        model = nn.DataParallel(model)
    model = model.to(device)
    criterion = nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    if os.path.exists(checkpoint_path):
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device)
            if 'patch_size' in checkpoint and checkpoint['patch_size'] == PATCH_SIZE:
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch'] + 1
                print(f"Resuming training from Epoch {start_epoch}")
            else:
                print(f"Checkpoint patch size {checkpoint.get('patch_size')} does not match current patch size {PATCH_SIZE}.")
                print("Starting training from scratch (Epoch 0).")
                start_epoch = 0
        except Exception as e:
            print(f"Error loading checkpoint: {e}. Starting training from scratch (Epoch 0).")
            start_epoch = 0
    else:
        print(f"No checkpoint found at {checkpoint_path}. Starting training from scratch (Epoch 0).")
        start_epoch = 0

    dataset = AugraphyPatchDataset(
        input_dir=input_dir,
        aug_dirs=aug_dirs,
        patch_size=PATCH_SIZE,
        overlap=OVERLAP_SIZE,
        transform=transform,
        num_images_to_use=NUM_IMAGES_FOR_TRAINING
    )

    if len(dataset) == 0:
        print("Dataset is empty. Cannot proceed with training. Please check paths and image pairing.")
    else:
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count() // 2 if os.cpu_count() else 0, pin_memory=True)
        
        print("\nStarting training...")
        for epoch in range(start_epoch, num_epochs):
            model.train()
            epoch_loss = 0.0
            for batch_idx, (aug_patch, clean_patch) in enumerate(loader):
                aug_patch, clean_patch = aug_patch.to(device), clean_patch.to(device)
                optimizer.zero_grad()
                output_patch = model(aug_patch)
                loss = criterion(output_patch, clean_patch)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
                if batch_idx % 50 == 0:
                    print(f"  Epoch {epoch+1}, Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}")
            avg_epoch_loss = epoch_loss / len(loader)
            print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {avg_epoch_loss:.4f}")
            print(f"Saving checkpoint for Epoch {epoch+1}...")

            output_checkpoint_dir = "/kaggle/working/checkpoints"
            os.makedirs(output_checkpoint_dir, exist_ok=True)
            output_checkpoint_path = os.path.join(output_checkpoint_dir, "denoising_unet_patch_checkpoint.pth")

            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_epoch_loss,
                'patch_size': PATCH_SIZE,
            }, output_checkpoint_path)
            print("Checkpoint saved.")
        print("\nTraining finished.")
        
        if len(dataset.data_pairs) > 0:
            random_idx = random.randint(0, len(dataset.data_pairs) - 1)
            test_aug_image_path, _ = dataset.data_pairs[random_idx]

            output_denoised_image_dir = "/kaggle/working/denoised_output"
            os.makedirs(output_denoised_image_dir, exist_ok=True)
            
            output_file_name = f"denoised_{os.path.basename(test_aug_image_path)}"
            output_denoised_image_path = os.path.join(output_denoised_image_dir, output_file_name)

            print(f"\nPerforming full image inference on a randomly selected image: {test_aug_image_path}")
            
            predicted_full_tensor = predict_full_image_sliding_window(
                model=model,
                image_path=test_aug_image_path,
                patch_size=PATCH_SIZE,
                overlap=OVERLAP_SIZE,
                transform=transform,
                device=device
            )
            
            predicted_pil = transforms.ToPILImage()(predicted_full_tensor.clamp(0, 1))
            predicted_pil.save(output_denoised_image_path)
            print(f"Denoised image saved to: {output_denoised_image_path}")
        else:
            print("No image pairs found in the dataset for inference. Skipping inference.")

Resuming training from Epoch 70

Starting training...

Training finished.

Performing full image inference on a randomly selected image: /kaggle/input/augraphy-dataset/augraphy/augraphy_output_2/image_25_augmented.png
Denoised image saved to: /kaggle/working/denoised_output/denoised_image_25_augmented.png


In [12]:
import shutil
shutil.copy("/kaggle/working/checkpoints/denoising_unet_patch_checkpoint.pth", "/kaggle/working/denoising_unet_patch_checkpoint.pth")

'/kaggle/working/denoising_unet_patch_checkpoint.pth'