In [1]:
import kagglehub

path = "D:\\Dropbox\\UMA Augusta\\PhD\\Research Thesis\\brain_tumor_mri_dataset"

# print("Path to dataset files:", path)

In [2]:
import os
import pandas as pd
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import torch
import torch.cuda
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import hashlib
from typing import Tuple, Dict, List
import multiprocessing

if __name__ == '__main__':
    multiprocessing.set_start_method('spawn', force=True)

class WatermarkMRIDataset(Dataset):
    def __init__(self, dataframe: pd.DataFrame, image_size: int = 512, watermark_size: int = 64):
        self.filepaths = dataframe['filepath'].values
        self.image_hashes = dataframe['image_hash'].values
        self.image_size = image_size
        self.watermark_size = watermark_size

    def __len__(self) -> int:
        return len(self.filepaths)

    def generate_watermark(self, image_hash: str) -> torch.Tensor:
        #  hash to numpy array
        hash_bytes = bytes.fromhex(image_hash)
        hash_array = np.frombuffer(hash_bytes, dtype=np.uint8)
        hash_array = hash_array.astype(np.float32) / 255.0

        required_size = 16 * 32 * 32
        if hash_array.size < required_size:
            # Tile the hash_array until we have enough elements
            reps = required_size // hash_array.size + 1
            hash_array = np.tile(hash_array, reps)

        watermark = hash_array[:required_size].reshape(16, 32, 32)
        watermark_tensor = torch.from_numpy(watermark)

        # Upsample
        watermark_tensor = watermark_tensor.unsqueeze(0)  # Add batch dimension
        watermark_tensor = F.interpolate(
            watermark_tensor,
            size=(self.watermark_size, self.watermark_size),
            mode='bilinear',
            align_corners=False
        )
        watermark_tensor = watermark_tensor.squeeze(0)  # Remove batch dimension

        return watermark_tensor


    def process_image(self, image_path: str) -> torch.Tensor:
        """Load and process image to tensor."""
        img = Image.open(image_path)

        # Convert to RGB if grayscale
        if img.mode != 'RGB':
            img = img.convert('RGB')

        # Resize
        img = img.resize((self.image_size, self.image_size), Image.BILINEAR)

        # Convert to tensor
        img_tensor = torch.from_numpy(np.array(img)).float().div(255.0)
        img_tensor = img_tensor.permute(2, 0, 1)  # Convert to (C,H,W)

        return img_tensor

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        img_tensor = self.process_image(self.filepaths[idx])

        watermark_tensor = self.generate_watermark(self.image_hashes[idx])

        return img_tensor, watermark_tensor

class MRIDatasetPreprocessor:
    def __init__(self, base_path: str):
        # GPU
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")

        # Set up paths
        self.base_path = Path(base_path)
        self.train_path = self.base_path / 'Training'
        self.test_path = self.base_path / 'Testing'

        #  directories
        self.processed_dir = Path('processed_data')
        self.watermarks_dir = Path('watermarks')
        self.watermarked_dir = Path('watermarked_images')

        for dir_path in [self.processed_dir, self.watermarks_dir, self.watermarked_dir]:
            dir_path.mkdir(exist_ok=True)

        self.train_df = None
        self.test_df = None
        self.image_stats = {}
        self.watermark_stats = {}

    def create_dataloaders(self, batch_size: int = 32) -> Tuple[DataLoader, DataLoader]:

        train_dataset = WatermarkMRIDataset(self.train_df)
        test_dataset = WatermarkMRIDataset(self.test_df)

        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=0,
            pin_memory=True if torch.cuda.is_available() else False
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            pin_memory=True if torch.cuda.is_available() else False
        )

        return train_loader, test_loader

    def pad_to_multiple_32(self, image: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
        """
        Pad image to multiple of 32 while preserving dimensions info
        """
        _, h, w = image.shape
        pad_h = (32 - h % 32) % 32
        pad_w = (32 - w % 32) % 32

        padding = (0, pad_w, 0, pad_h)  # left, right, top, bottom
        padded_image = F.pad(image, padding, mode='reflect')

        return padded_image, (pad_h, pad_w)

    def analyze_image_dimensions(self, dataframe: pd.DataFrame) -> Dict:

        dimensions = []
        sizes_kb = []
        aspect_ratios = []

        for filepath in tqdm(dataframe['filepath'], desc="Analyzing images"):
            with Image.open(filepath) as img:
                w, h = img.size
                dimensions.append((w, h))
                sizes_kb.append(os.path.getsize(filepath) / 1024)
                aspect_ratios.append(w / h)

        dimensions = np.array(dimensions)
        sizes_kb = np.array(sizes_kb)

        stats = {
            'unique_dimensions': np.unique(dimensions, axis=0),
            'min_width': dimensions[:, 0].min(),
            'max_width': dimensions[:, 0].max(),
            'min_height': dimensions[:, 1].min(),
            'max_height': dimensions[:, 1].max(),
            'mean_width': dimensions[:, 0].mean(),
            'mean_height': dimensions[:, 1].mean(),
            'min_size_kb': sizes_kb.min(),
            'max_size_kb': sizes_kb.max(),
            'mean_size_kb': sizes_kb.mean(),
            'aspect_ratios': aspect_ratios,
            'min_aspect_ratio': min(aspect_ratios),
            'max_aspect_ratio': max(aspect_ratios),
            'mean_aspect_ratio': np.mean(aspect_ratios)
        }

        return stats

    def analyze_for_watermarking(self, dataframe: pd.DataFrame) -> Dict:
        watermark_stats = {
            'min_dimension': float('inf'),
            'max_dimension': 0,
            'aspect_ratios': [],
            'optimal_watermark_sizes': []
        }

        for filepath in tqdm(dataframe['filepath'], desc="Analyzing for watermarking"):
            with Image.open(filepath) as img:
                w, h = img.size
                min_dim = min(w, h)
                max_dim = max(w, h)
                aspect_ratio = w / h

                watermark_stats['min_dimension'] = min(watermark_stats['min_dimension'], min_dim)
                watermark_stats['max_dimension'] = max(watermark_stats['max_dimension'], max_dim)
                watermark_stats['aspect_ratios'].append(aspect_ratio)

                # Calculate optimal watermark size for this image
                optimal_size = max(32, min_dim // 8)
                watermark_stats['optimal_watermark_sizes'].append(optimal_size)

        watermark_stats['mean_optimal_size'] = np.mean(watermark_stats['optimal_watermark_sizes'])
        watermark_stats['median_optimal_size'] = np.median(watermark_stats['optimal_watermark_sizes'])

        return watermark_stats

    def create_dataset_dataframes(self) -> Tuple[pd.DataFrame, pd.DataFrame]:

        def process_directory(base_path: Path) -> pd.DataFrame:
            filepaths = []
            labels = []
            dimensions = []
            image_hashes = []

            for fold in os.listdir(base_path):
                fold_path = base_path / fold
                if not fold_path.is_dir():
                    continue

                for file in os.listdir(fold_path):
                    if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                        filepath = str(fold_path / file)
                        with Image.open(filepath) as img:
                            dimensions.append(img.size)
                            # Calculate image hash
                            img_array = np.array(img)
                            img_hash = hashlib.sha256(img_array.tobytes()).hexdigest()
                            image_hashes.append(img_hash)

                        filepaths.append(filepath)
                        labels.append(fold)

            return pd.DataFrame({
                'filepath': filepaths,
                'label': labels,
                'dimensions': dimensions,
                'image_hash': image_hashes
            })

        print("Processing training set...")
        self.train_df = process_directory(self.train_path)
        print("Processing testing set...")
        self.test_df = process_directory(self.test_path)

        print("\nAnalyzing training set dimensions...")
        self.image_stats['train'] = self.analyze_image_dimensions(self.train_df)
        print("\nAnalyzing testing set dimensions...")
        self.image_stats['test'] = self.analyze_image_dimensions(self.test_df)

        print("\nAnalyzing training set for watermarking...")
        self.watermark_stats['train'] = self.analyze_for_watermarking(self.train_df)
        print("\nAnalyzing testing set for watermarking...")
        self.watermark_stats['test'] = self.analyze_for_watermarking(self.test_df)

        return self.train_df, self.test_df

    def save_dataset_info(self):
        self.train_df.to_csv(self.processed_dir / 'train_info.csv', index=False)
        self.test_df.to_csv(self.processed_dir / 'test_info.csv', index=False)

        stats_df = pd.DataFrame({
            'train': self.image_stats['train'],
            'test': self.image_stats['test']
        })
        stats_df.to_csv(self.processed_dir / 'image_statistics.csv')

        watermark_stats_df = pd.DataFrame({
            'train': self.watermark_stats['train'],
            'test': self.watermark_stats['test']
        })
        watermark_stats_df.to_csv(self.processed_dir / 'watermark_statistics.csv')

    def get_dataset_statistics(self):

        print("\n=== Dataset Statistics Summary ===")
        print(f"Total training images: {len(self.train_df)}")
        print(f"Total testing images: {len(self.test_df)}")

        for dataset_type in ['train', 'test']:
            stats = self.image_stats[dataset_type]
            wm_stats = self.watermark_stats[dataset_type]

            print(f"\n{dataset_type.capitalize()} Set Summary:")
            print(f"Dimensions:")
            print(f"  Min Width: {stats['min_width']}, Max Width: {stats['max_width']}")
            print(f"  Min Height: {stats['min_height']}, Max Height: {stats['max_height']}")
            print(f"  Mean Width: {stats['mean_width']:.2f}, Mean Height: {stats['mean_height']:.2f}")

            print(f"\nFile Sizes:")
            print(f"  Min: {stats['min_size_kb']:.2f} KB")
            print(f"  Max: {stats['max_size_kb']:.2f} KB")
            print(f"  Mean: {stats['mean_size_kb']:.2f} KB")

            print(f"\nWatermarking Info:")
            print(f"  Min Dimension: {wm_stats['min_dimension']}")
            print(f"  Max Dimension: {wm_stats['max_dimension']}")
            print(f"  Mean Optimal Watermark Size: {wm_stats['mean_optimal_size']:.2f}")
            print(f"  Median Optimal Watermark Size: {wm_stats['median_optimal_size']:.2f}")

def test_preprocessor(dataset_path: str):
    try:
        print("Initializing preprocessor...")
        preprocessor = MRIDatasetPreprocessor(dataset_path)

        print("\nCreating and analyzing datasets...")
        train_df, test_df = preprocessor.create_dataset_dataframes()

        print("\nCreating dataloaders...")
        train_loader, test_loader = preprocessor.create_dataloaders(batch_size=3)

        print("\nSaving dataset information...")
        preprocessor.save_dataset_info()

        print("\nDisplaying dataset statistics...")
        preprocessor.get_dataset_statistics()

        # Test a batch
        # print("\nTesting batch loading...")
        # for images, watermarks in train_loader:
        #     images = images.to(preprocessor.device)
        #     watermarks = watermarks.to(preprocessor.device)
        #     print(f"Image batch shape: {images.shape}")
        #     print(f"Watermark batch shape: {watermarks.shape}")
        #     print(f"Image value range: [{images.min():.3f}, {images.max():.3f}]")
        #     print(f"Watermark value range: [{watermarks.min():.3f}, {watermarks.max():.3f}]")
        #     break  # Only test first batch

        return preprocessor, train_loader, test_loader

    except Exception as e:
        print(f"Error during preprocessing: {str(e)}")
        raise

if __name__ == "__main__":
    dataset_path = path
    preprocessor, train_loader, test_loader = test_preprocessor(dataset_path)

Initializing preprocessor...
Using device: cuda

Creating and analyzing datasets...
Processing training set...
Processing testing set...

Analyzing training set dimensions...


Analyzing images: 100%|██████████| 5712/5712 [00:11<00:00, 484.42it/s]



Analyzing testing set dimensions...


Analyzing images: 100%|██████████| 1311/1311 [00:02<00:00, 498.86it/s]



Analyzing training set for watermarking...


Analyzing for watermarking: 100%|██████████| 5712/5712 [00:11<00:00, 493.64it/s]



Analyzing testing set for watermarking...


Analyzing for watermarking: 100%|██████████| 1311/1311 [00:02<00:00, 498.65it/s]



Creating dataloaders...

Saving dataset information...

Displaying dataset statistics...

=== Dataset Statistics Summary ===
Total training images: 5712
Total testing images: 1311

Train Set Summary:
Dimensions:
  Min Width: 150, Max Width: 1920
  Min Height: 168, Max Height: 1446
  Mean Width: 451.56, Mean Height: 453.88

File Sizes:
  Min: 3.39 KB
  Max: 710.85 KB
  Mean: 22.64 KB

Watermarking Info:
  Min Dimension: 150
  Max Dimension: 1920
  Mean Optimal Watermark Size: 56.96
  Median Optimal Watermark Size: 64.00

Test Set Summary:
Dimensions:
  Min Width: 150, Max Width: 1149
  Min Height: 168, Max Height: 1019
  Mean Width: 421.18, Mean Height: 424.22

File Sizes:
  Min: 4.58 KB
  Max: 118.71 KB
  Mean: 19.48 KB

Watermarking Info:
  Min Dimension: 150
  Max Dimension: 1149
  Mean Optimal Watermark Size: 53.42
  Median Optimal Watermark Size: 64.00


In [3]:
import torch
import torch.nn as nn
import torchsummary
import torch.nn.functional as F
from typing import Tuple

class MultiHeadAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, attn_drop: float = 0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class MaxViTBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4., qkv_bias: bool = False,
                 drop: float = 0., attn_drop: float = 0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop)
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(drop)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class MaxViT(nn.Module):
    def __init__(self,
                 in_channels: int = 512,
                 embed_dim: int = 768,
                 depth: int = 16,
                 num_heads: int = 8,
                 mlp_ratio: float = 4.,
                 qkv_bias: bool = False,
                 drop_rate: float = 0.,
                 attn_drop_rate: float = 0.):
        """
        Initialize MaxViT model.

        """
        super().__init__()

        # Initial convolution to change channel dimensions
        self.conv_in = nn.Conv2d(in_channels, embed_dim, kernel_size=1)

        # Position embedding
        self.feature_size = 32 * 32  # Fixed size for bottleneck
        self.pos_embed = nn.Parameter(torch.zeros(1, self.feature_size, embed_dim))

        # Transformer blocks
        self.blocks = nn.ModuleList([
            MaxViTBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate
            )
            for _ in range(depth)
        ])

        # Final projection back to input channels
        self.conv_out = nn.Conv2d(embed_dim, in_channels, kernel_size=1)

        # Initialize position embedding
        nn.init.trunc_normal_(self.pos_embed, std=.02)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        B, C, H, W = x.shape

        assert H == W == int(self.feature_size ** 0.5), \
            f"Input spatial dimensions must be {int(self.feature_size ** 0.5)}x{int(self.feature_size ** 0.5)}, got {H}x{W}"

        # Initial convolution: (B, C, H, W) -> (B, embed_dim, H, W)
        x = self.conv_in(x)

        # Reshape for transformer: (B, embed_dim, H, W) -> (B, H*W, embed_dim)
        x = x.flatten(2).transpose(1, 2)

        # Position embedding
        x = x + self.pos_embed

        # Apply transformer blocks
        for block in self.blocks:
            x = block(x)

        # Reshape back: (B, H*W, embed_dim) -> (B, embed_dim, H, W)
        x = x.transpose(1, 2).reshape(B, -1, H, W)

        # Final projection: (B, embed_dim, H, W) -> (B, C, H, W)
        x = self.conv_out(x)

        return x

class WatermarkProcessor(nn.Module):
    def __init__(self, in_channels: int = 16):
        super().__init__()
        self.process = nn.Sequential(
            nn.Conv2d(in_channels, 256, kernel_size=1),
            nn.GELU(),
            nn.Conv2d(256, 512, kernel_size=1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.process(x)

def test_maxvit():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    x = torch.randn(2, 512, 32, 32).to(device)
    print(f"Created input tensor with shape: {x.shape}")

    model = MaxViT(in_channels=512, embed_dim=768, depth=16).to(device)
    print("Initialized MaxViT model")

    torchsummary.summary(model, input_size=(512, 32, 32))

    # Save Model Summary to File
    with open("maxvit_model_summary.txt", "w") as f:
        import sys
        sys.stdout = f
        torchsummary.summary(model, input_size=(512, 32, 32))
        sys.stdout = sys.__stdout__

    print("\nPerforming forward pass...")
    output = model(x)

    print(f"\nSummary:")
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Running on: {device}")

    if x.shape == output.shape:
        print("OK: Input and output shapes match")
    else:
        print("FAIL: Input and output shapes differ!")

    return model

if __name__ == "__main__":
    test_maxvit()


Using device: cuda
Created input tensor with shape: torch.Size([2, 512, 32, 32])
Initialized MaxViT model
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 32, 32]         393,984
         LayerNorm-2            [-1, 1024, 768]           1,536
            Linear-3           [-1, 1024, 2304]       1,769,472
           Dropout-4        [-1, 8, 1024, 1024]               0
            Linear-5            [-1, 1024, 768]         590,592
MultiHeadAttention-6            [-1, 1024, 768]               0
         LayerNorm-7            [-1, 1024, 768]           1,536
            Linear-8           [-1, 1024, 3072]       2,362,368
              GELU-9           [-1, 1024, 3072]               0
          Dropout-10           [-1, 1024, 3072]               0
           Linear-11            [-1, 1024, 768]       2,360,064
          Dropout-12            [-1, 1024, 768]              

In [4]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple

class MultiHeadAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, attn_drop: float = 0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class MaxViTBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4., qkv_bias: bool = False,
                 drop: float = 0., attn_drop: float = 0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop)
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(drop)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class MaxViT(nn.Module):
    def __init__(self, in_channels: int = 512, embed_dim: int = 768, depth: int = 16,
                 num_heads: int = 8, mlp_ratio: float = 4., qkv_bias: bool = False,
                 drop_rate: float = 0., attn_drop_rate: float = 0.):
        super().__init__()

        self.conv_in = nn.Conv2d(in_channels, embed_dim, kernel_size=1)

        self.feature_size = 32 * 32
        self.pos_embed = nn.Parameter(torch.zeros(1, self.feature_size, embed_dim))

        # MaxViT blocks
        self.blocks = nn.ModuleList([
            MaxViTBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate
            )
            for _ in range(depth)
        ])

        self.conv_out = nn.Conv2d(embed_dim, in_channels, kernel_size=1)

        nn.init.trunc_normal_(self.pos_embed, std=.02)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape

        assert H * W == self.feature_size, f"Input feature map size {H}x{W} doesn't match expected size 32x32"

        # Initial convolution
        x = self.conv_in(x)  # Shape: (B, embed_dim, H, W)

        x = x.flatten(2).transpose(1, 2)  # Shape: (B, H*W, embed_dim)

        x = x + self.pos_embed

        for block in self.blocks:
            x = block(x)

        x = x.transpose(1, 2).reshape(B, -1, H, W)

        x = self.conv_out(x)

        return x

class EncoderBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.gelu = nn.GELU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.gelu(self.bn1(self.conv1(x)))
        x = self.gelu(self.bn2(self.conv2(x)))
        skip = x
        x = self.pool(x)
        return x, skip

class DecoderBlock(nn.Module):
    def __init__(self, in_channels: int, skip_channels: int, out_channels: int):
        super().__init__()
        # Upsampling
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)

        self.conv1 = nn.Conv2d(in_channels // 2 + skip_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
        x = self.up(x)
        # If spatial dimensions differ, pad appropriately.
        if x.shape[2:] != skip.shape[2:]:
            diff_h = skip.size(2) - x.size(2)
            diff_w = skip.size(3) - x.size(3)
            x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2,
                          diff_h // 2, diff_h - diff_h // 2])
        # Concatenate along the channel dimension.
        x = torch.cat([skip, x], dim=1)
        x = self.gelu(self.bn1(self.conv1(x)))
        x = self.gelu(self.bn2(self.conv2(x)))
        return x


class AdaptiveWeightBlender(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.weight_conv = nn.Sequential(
            nn.Conv2d(channels * 2, channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, image: torch.Tensor, watermark: torch.Tensor) -> torch.Tensor:
        assert image.shape == watermark.shape, f"Shape mismatch: image {image.shape}, watermark {watermark.shape}"

        combined = torch.cat([image, watermark], dim=1)

        weights = self.weight_conv(combined)

        blended = image * weights + watermark * (1 - weights)
        return blended

class UNetGenerator(nn.Module):
    def __init__(self, maxvit_model: nn.Module):
        super().__init__()

        # Encoder pathway
        self.enc1 = EncoderBlock(3, 64)      # 512x512 -> 256x256
        self.enc2 = EncoderBlock(64, 128)    # 256x256 -> 128x128
        self.enc3 = EncoderBlock(128, 256)   # 128x128 -> 64x64
        self.enc4 = EncoderBlock(256, 512)   # 64x64 -> 32x32

        # Watermark processor (16 -> 512 channels)
        self.watermark_processor = nn.Sequential(
            nn.Conv2d(16, 256, kernel_size=1),
            nn.GELU(),
            nn.Conv2d(256, 512, kernel_size=1)
        )

        self.blender = AdaptiveWeightBlender(512)

        self.pre_maxvit_conv = nn.Conv2d(512, 768, kernel_size=1)

        self.maxvit = maxvit_model

        self.post_maxvit_conv = nn.Conv2d(768, 512, kernel_size=1)

        # Decoder pathway
        self.dec1 = DecoderBlock(in_channels=512, skip_channels=512, out_channels=256)  # 32x32 -> 64x64
        self.dec2 = DecoderBlock(in_channels=256, skip_channels=256, out_channels=128)  # 64x64 -> 128x128
        self.dec3 = DecoderBlock(in_channels=128, skip_channels=128, out_channels=64)   # 128x128 -> 256x256


        # Final output
        self.final_conv = nn.Conv2d(64, 3, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        skip_connections = []

        x, skip1 = self.enc1(x)
        skip_connections.append(skip1)

        x, skip2 = self.enc2(x)
        skip_connections.append(skip2)

        x, skip3 = self.enc3(x)
        skip_connections.append(skip3)

        x, skip4 = self.enc4(x)
        skip_connections.append(skip4)

        return x, skip_connections

    def bottleneck(self, image_features: torch.Tensor, watermark: torch.Tensor) -> torch.Tensor:
        # Process watermark (16 -> 512 channels)
        watermark_features = self.watermark_processor(watermark)

        # Resize watermark features if needed
        if watermark_features.shape[2:] != image_features.shape[2:]:
            watermark_features = F.interpolate(
                watermark_features,
                size=image_features.shape[2:],
                mode='bilinear',
                align_corners=False
            )

        blended = self.blender(image_features, watermark_features)

        maxvit_input = self.pre_maxvit_conv(blended)

        maxvit_output = self.maxvit(maxvit_input)

        refined = self.post_maxvit_conv(maxvit_output)

        return refined

    def decode(self, x: torch.Tensor, skip_connections: List[torch.Tensor]) -> torch.Tensor:
        x = self.dec1(x, skip_connections[3])
        x = self.dec2(x, skip_connections[2])
        x = self.dec3(x, skip_connections[1])

        x = self.final_conv(x)
        x = self.sigmoid(x)

        return x

    def forward(self, image: torch.Tensor, watermark: torch.Tensor) -> torch.Tensor:
        encoded, skip_connections = self.encode(image)
        bottleneck_features = self.bottleneck(encoded, watermark)
        decoded = self.decode(bottleneck_features, skip_connections)  # currently 256×256 output
        # Upsample decoded output from 256x256 to 512x512
        output = torch.nn.functional.interpolate(decoded, size=(512, 512), mode='bilinear', align_corners=False)
        return output


def test_generator():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    maxvit = MaxViT(
        in_channels=768,
        embed_dim=768,
        depth=16
    ).to(device)

    generator = UNetGenerator(maxvit).to(device)

    batch_size = 3

    image = torch.randn(batch_size, 3, 512, 512).to(device)
    watermark = torch.randn(batch_size, 16, 64, 64).to(device)

    torchsummary.summary(generator, [(3, 512, 512), (16, 64, 64)])

    # Save Model Summary to File
    with open("unet_generator_summary.txt", "w") as f:
        import sys
        sys.stdout = f
        torchsummary.summary(generator, [(3, 512, 512), (16, 64, 64)])
        sys.stdout = sys.__stdout__

    print("\nModel summary saved to unet_generator_summary.txt")



    print("\nStarting forward pass...")
    # Forward pass

    with torch.no_grad():
        output = generator(image, watermark)

    print(f"\nFinal shapes:")
    print(f"Image input shape: {image.shape}")
    print(f"Watermark input shape: {watermark.shape}")
    print(f"Output shape: {output.shape}")

    return generator

if __name__ == "__main__":
    test_generator()

In [5]:
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        # LeakyReLU
        self.leaky_slope = 0.2

        # First conv layer: 512x512x3 → 256x256x64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(self.leaky_slope, inplace=True)
        )

        # Second layer: 256x256x64 → 128x128x128
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(self.leaky_slope, inplace=True)
        )

        # Third layer: 128x128x128 → 64x64x256
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(self.leaky_slope, inplace=True)
        )

        # Final layer: 64x64x256 → 31x31x1
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid()  # For binary classification (real/fake)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.conv1(x)  # → 256x256x64
        x = self.conv2(x)  # → 128x128x128
        x = self.conv3(x)  # → 64x64x256
        x = self.conv4(x)  # → 31x31x1
        return x

def test_discriminator():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    discriminator = Discriminator().to(device)
    torchsummary.summary(discriminator, input_size=(3, 512, 512))

    # Save to file
    with open("discriminator_model_summary.txt", "w") as f:
        import sys
        sys.stdout = f
        torchsummary.summary(discriminator, input_size=(3, 512, 512))
        sys.stdout = sys.__stdout__

    batch_size = 3
    test_input = torch.randn(batch_size, 3, 512, 512).to(device)
    print(f"Input shape: {test_input.shape}")

    # Forward pass
    with torch.no_grad():
        output = discriminator(test_input)

    print(f"Output shape: {output.shape}")
    print(f"Output value range: [{output.min():.3f}, {output.max():.3f}]")

    assert output.min() >= 0 and output.max() <= 1, "Output values should be in [0,1] range"

    return discriminator

if __name__ == "__main__":
    test_discriminator()

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from typing import Tuple, Dict
import time
from tqdm import tqdm

class WatermarkTrainer:
    def __init__(
        self,
        generator: nn.Module,
        discriminator: nn.Module,
        train_loader: DataLoader,
        test_loader: DataLoader,
        device: torch.device,
        learning_rate: float = 0.0002,
        beta1: float = 0.5,
        beta2: float = 0.999
    ):

        self.generator = generator.to(device)
        self.discriminator = discriminator.to(device)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device

        # Initialize optimizers
        self.g_optimizer = optim.Adam(
            self.generator.parameters(),
            lr=learning_rate,
            betas=(beta1, beta2)
        )
        self.d_optimizer = optim.Adam(
            self.discriminator.parameters(),
            lr=learning_rate,
            betas=(beta1, beta2)
        )

        # Loss functions
        self.mse_loss = nn.MSELoss()
        self.bce_loss = nn.BCELoss()

        # tracking variables
        self.current_epoch = 0
        self.train_history = {
            'g_loss': [],
            'd_loss': [],
            'mse_loss': [],
            'adv_loss': []
        }

    def train_discriminator(
        self,
        real_images: torch.Tensor,
        watermark: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, float]]:

        batch_size = real_images.size(0)
        real_label = torch.ones(batch_size, 1, 31, 31).to(self.device)
        fake_label = torch.zeros(batch_size, 1, 31, 31).to(self.device)

        # Train on real images
        self.d_optimizer.zero_grad()
        real_output = self.discriminator(real_images)
        d_real_loss = self.bce_loss(real_output, real_label)

        # Train on fake (watermarked) images
        with torch.no_grad():
            fake_images = self.generator(real_images, watermark)
        fake_output = self.discriminator(fake_images.detach())
        d_fake_loss = self.bce_loss(fake_output, fake_label)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        self.d_optimizer.step()

        return d_loss, {
            'd_real_loss': d_real_loss.item(),
            'd_fake_loss': d_fake_loss.item(),
            'd_total_loss': d_loss.item()
        }

    def train_generator(
        self,
        real_images: torch.Tensor,
        watermark: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """
        Train the generator on a batch of data.

        """
        batch_size = real_images.size(0)
        real_label = torch.ones(batch_size, 1, 31, 31).to(self.device)

        self.g_optimizer.zero_grad()

        fake_images = self.generator(real_images, watermark)

        # MSE Loss between real and generated images
        mse_loss = self.mse_loss(fake_images, real_images)

        # Adversarial loss
        fake_output = self.discriminator(fake_images)
        adv_loss = self.bce_loss(fake_output, real_label)

        # Total generator loss (weighted sum)
        g_loss = mse_loss + 0.1 * adv_loss
        g_loss.backward()
        self.g_optimizer.step()

        return g_loss, {
            'mse_loss': mse_loss.item(),
            'adv_loss': adv_loss.item(),
            'g_total_loss': g_loss.item()
        }

    def train_epoch(self) -> Dict[str, float]:
        "train on one epoch (for convenience)"
        self.generator.train()
        self.discriminator.train()

        epoch_losses = {
            'g_loss': 0.0,
            'd_loss': 0.0,
            'mse_loss': 0.0,
            'adv_loss': 0.0
        }

        num_batches = len(self.train_loader)
        progress_bar = tqdm(self.train_loader, desc=f'Epoch {self.current_epoch+1}')

        for batch_idx, (real_images, watermark) in enumerate(progress_bar):
            real_images = real_images.to(self.device)
            watermark = watermark.to(self.device)

            d_loss, d_losses = self.train_discriminator(real_images, watermark)

            g_loss, g_losses = self.train_generator(real_images, watermark)

            progress_bar.set_postfix({
                'D_loss': f"{d_loss.item():.4f}",
                'G_loss': f"{g_loss.item():.4f}"
            })

            # Accumulate losses
            epoch_losses['d_loss'] += d_loss.item()
            epoch_losses['g_loss'] += g_loss.item()
            epoch_losses['mse_loss'] += g_losses['mse_loss']
            epoch_losses['adv_loss'] += g_losses['adv_loss']

        # Calculate averages
        for key in epoch_losses:
            epoch_losses[key] /= num_batches
            self.train_history[key].append(epoch_losses[key])

        self.current_epoch += 1
        return epoch_losses

    def train(self, num_epochs: int, save_path: str = None) -> Dict[str, list]:
        """
        Train the model for multiple epochs.

        """
        print(f"Starting training for {num_epochs} epochs...")
        start_time = time.time()

        for epoch in range(num_epochs):
            epoch_losses = self.train_epoch()

            # Print epoch summary
            print(f"\nEpoch {self.current_epoch} Summary:")
            for key, value in epoch_losses.items():
                print(f"{key}: {value:.4f}")

            # Save checkpoint
            if save_path and (epoch + 1) % 5 == 0:  # Save every 5 epochs
                self.save_checkpoint(save_path, epoch + 1)

        total_time = time.time() - start_time
        print(f"\nTraining completed in {total_time/60:.2f} minutes")
        return self.train_history

    def save_checkpoint(self, save_path: str, epoch: int):
        torch.save({
            'epoch': epoch,
            'generator_state_dict': self.generator.state_dict(),
            'discriminator_state_dict': self.discriminator.state_dict(),
            'g_optimizer_state_dict': self.g_optimizer.state_dict(),
            'd_optimizer_state_dict': self.d_optimizer.state_dict(),
            'train_history': self.train_history
        }, f"{save_path}/checkpoint_epoch_{epoch}.pt")

    def load_checkpoint(self, checkpoint_path: str):
        checkpoint = torch.load(checkpoint_path)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        self.g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
        self.d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
        self.current_epoch = checkpoint['epoch']
        self.train_history = checkpoint['train_history']

def create_watermark_dataloaders(preprocessor: MRIDatasetPreprocessor, batch_size: int = 32) -> Tuple[DataLoader, DataLoader]:
    """Create dataloaders for watermarking training."""
    train_dataset = WatermarkMRIDataset(preprocessor.train_df)
    test_dataset = WatermarkMRIDataset(preprocessor.test_df)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False
    )

    return train_loader, test_loader




def setup_training(dataset_path: str, batch_size: int = 3):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Initialize preprocessor and create datasets
    preprocessor = MRIDatasetPreprocessor(dataset_path)
    train_df, test_df = preprocessor.create_dataset_dataframes()

    # Create dataloaders
    train_loader, test_loader = create_watermark_dataloaders(preprocessor, batch_size)
    # train_loader, test_loader = create_sampled_dataloaders(preprocessor, train_samples=50, test_samples=20, batch_size=batch_size)


    # Create models
    maxvit = MaxViT(
        in_channels=768,
        embed_dim=768,
        depth=16

    ).to(device)

    generator = UNetGenerator(maxvit).to(device)
    discriminator = Discriminator().to(device)

    # Initialize trainer
    trainer = WatermarkTrainer(
        generator=generator,
        discriminator=discriminator,
        train_loader=train_loader,
        test_loader=test_loader,
        device=device
    )
    from torchsummary import summary

    return trainer, preprocessor

if __name__ == "__main__":
    trainer, preprocessor = setup_training(path)

    # Train
    history = trainer.train(
        num_epochs=10,
        save_path="D:\\Dropbox\\UMA Augusta\\PhD\\Research Thesis\\20250227_watermarked_checkpoints"
    )

Analyzing images: 100%|██████████| 5712/5712 [00:11<00:00, 491.65it/s]
Analyzing images: 100%|██████████| 1311/1311 [00:02<00:00, 491.92it/s]
Analyzing for watermarking: 100%|██████████| 5712/5712 [00:12<00:00, 460.58it/s]
Analyzing for watermarking: 100%|██████████| 1311/1311 [00:02<00:00, 473.83it/s]
Epoch 1: 100%|██████████| 1904/1904 [18:27<00:00,  1.72it/s, D_loss=0.1697, G_loss=0.2228]
Epoch 2: 100%|██████████| 1904/1904 [26:48<00:00,  1.18it/s, D_loss=0.0015, G_loss=0.6904] 
Epoch 3: 100%|██████████| 1904/1904 [35:20<00:00,  1.11s/it, D_loss=0.0007, G_loss=0.7645]
Epoch 4: 100%|██████████| 1904/1904 [31:08<00:00,  1.02it/s, D_loss=0.0001, G_loss=0.9581]
Epoch 5: 100%|██████████| 1904/1904 [34:35<00:00,  1.09s/it, D_loss=0.0007, G_loss=0.8292]
Epoch 6: 100%|██████████| 1904/1904 [34:40<00:00,  1.09s/it, D_loss=0.0002, G_loss=0.9338]
Epoch 7: 100%|██████████| 1904/1904 [37:01<00:00,  1.17s/it, D_loss=0.0001, G_loss=1.0569]
Epoch 8: 100%|██████████| 1904/1904 [33:28<00:00,  1.05s/i

In [7]:
# PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True