In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("shymaaelbana/vessel-mamba-mra")

print("Path to dataset files:", path)

In [None]:
import torch
from torch import nn
from math import sqrt
import numpy as np
import tqdm
from collections import OrderedDict
import scipy.io
import pickle
import nibabel as nib
import os
import getpass
from sklearn.preprocessing import MinMaxScaler


# SCRIPT CONFIGURATION
# All user-configurable parameters are centralized here.

In [None]:
'''
Those are the arguments that you should focus on 
image_folder: Contains path for the images 
num_iters: is for how many epochs you want train the model
start_image: The image that you want to start at it
end_image: The image that you want to end at it
train_batch_size: try to find the best tradeoff between speed in training and psnr
num_layers: no.of layers in network 
layer_size: no.of weights in each layer
'''
CONFIG = {
    # --- I/O Configuration ---
    "image_folder": '/kaggle/input/vessel-mamba-mra/MambaVesselNet_dataset_MRA_multiclass/imagesTr/',
    "base_results_dir": f"./results_fast_{getpass.getuser()}",
    "reconstruction_dir": "./reconstructed_images",

    # --- File Selection (1-based index) ---
    "start_image": 1, # The image that you want to start at it 
    "end_image": 28, # The image that you want to stop at it 

    # --- Training Hyperparameters ---
    "learning_rate": 1e-4,
    "num_iters": 5000,
    # Set to a specific integer (e.g., 131072 * 8 this batch size is good for 2xt4 and most probably you should increase it for a100) or (None for auto-calculation: not recommended at least for t4).
    "train_batch_size": 131072 * 8,
    "seed": 42,
    "use_amp": True,  # Use Automatic Mixed Precision for training
    "log_frequency": 10,  # Log PSNR and loss every N iterations

    # --- Model (SIREN) Architecture ---
    "num_layers": 5,
    "layer_size": 512,  # Hidden dimension size
    "w0": 30.0,         # w0 for subsequent SIREN layers
    "w0_initial": 30.0, # w0 for the first SIREN layer

    # --- Reconstruction Parameters ---
    # Set to a specific integer or None for auto-calculation.
    "inference_batch_size": None,

    # --- Automatic Batch Size Calculator (Advanced) ---
    "mem_usage_factor": 0.70,        # Use 70% of available VRAM for training batches
    "min_calc_batch_size": 16384,
    "max_calc_batch_size": 4 * 1024 * 1024,
    "voxel_divisor": 4,              # Ensures batch size is at most total_voxels / N
    "non_cuda_batch_size": 65536,    # Fallback batch size for CPU training

    # --- Automatic Inference Batch Size Calculator (Advanced) ---
    "inference_mem_factor": 0.60,    # Use 60% of VRAM for inference batches
    "max_inference_batch_size": 1 * 1024 * 1024,
    "non_cuda_inference_batch_size": 65536, # Fallback for CPU inference

    # --- Filename Configuration ---
    # These are the BASE filenames. The script will automatically PREPEND the
    # original image name to them (e.g., 'image_01_best_model.pt').
    "model_filename": 'best_model.pt',
    "scaler_filename": 'scaler.pkl',
    "header_filename": 'nifti_header.pkl',
    "metadata_filename": 'metadata.pkl',
    "coords_filename": 'input_coords.mat',
}


# UTILITY: FAST BATCH SIZE CALCULATOR


In [None]:
def calculate_optimal_batch_size(model, device, voxel_count, config):
    """Dynamically calculates an optimal batch size based on available VRAM."""
    if not torch.cuda.is_available():
        return config['non_cuda_batch_size']

    free_mem, total_mem = torch.cuda.mem_get_info()
    model_params = sum(p.numel() * p.element_size() for p in model.parameters())
    hidden_size = model.net[0].linear.out_features
    num_layers = len(model.net)
    bytes_per_sample = 4 * hidden_size * (2 * num_layers + 3) # Heuristic
    available_for_batch = (free_mem - model_params) * config['mem_usage_factor']
    
    max_batch = int(available_for_batch / bytes_per_sample)
    max_batch = min(max_batch, voxel_count // config['voxel_divisor'])
    max_batch = max(config['min_calc_batch_size'], min(max_batch, config['max_calc_batch_size']))
    
    # This now calculates the batch size for just one GPU.
    batch_size = 2 ** int(np.log2(max_batch))

    print(f"\n{'='*70}\nQUICK BATCH SIZE CALCULATION\n{'='*70}")
    print(f"Detected 1 GPU.") # MODIFIED: Hardcoded to 1 GPU for clarity.
    print(f"GPU Memory: {free_mem/1e9:.1f} GB free / {total_mem/1e9:.1f} GB total")
    print(f"Calculated optimal batch size: {batch_size:,}")
    print(f"Memory per sample (estimated): {bytes_per_sample:,} bytes")
    print(f"Expected batch memory usage: {batch_size * bytes_per_sample / 1e9:.2f} GB")
    print(f"{'='*70}\n")
    
    return batch_size

# MODEL: SIREN


In [None]:
class OptimizedSirenLayer(nn.Module):
    def __init__(self, dim_in, dim_out, w0=30., c=6., is_first=False, use_bias=True):
        super().__init__()
        self.dim_in, self.is_first, self.w0 = dim_in, is_first, w0
        self.linear = nn.Linear(dim_in, dim_out, bias=use_bias)
        w_std = (1 / dim_in) if is_first else (sqrt(c / dim_in) / w0)
        nn.init.uniform_(self.linear.weight, -w_std, w_std)
        if use_bias:
            nn.init.uniform_(self.linear.bias, -w_std, w_std)
    
    def forward(self, x):
        return torch.sin(self.w0 * self.linear(x))


class OptimizedSiren(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out, num_layers, w0=30.,
                 w0_initial=30., use_bias=True):
        super().__init__()
        layers = []
        for i in range(num_layers):
            is_first = (i == 0)
            layer_w0 = w0_initial if is_first else w0
            layer_dim_in = dim_in if is_first else dim_hidden
            layers.append(OptimizedSirenLayer(
                dim_in=layer_dim_in, dim_out=dim_hidden, w0=layer_w0,
                use_bias=use_bias, is_first=is_first
            ))
        self.net = nn.Sequential(*layers)
        self.last_layer = nn.Linear(dim_hidden, dim_out, bias=use_bias)
        w_std = sqrt(6. / dim_hidden) / w0
        nn.init.uniform_(self.last_layer.weight, -w_std, w_std)
        if use_bias:
            nn.init.uniform_(self.last_layer.bias, -w_std, w_std)
    
    def forward(self, x):
        return self.last_layer(self.net(x))


# DATA UTILITIES


In [None]:

def to_coordinates_and_features_3D(img_tensor):
    """Converts a 4D image tensor (C, D, H, W) to coordinates and features."""
    C, D, H, W = img_tensor.shape
    d_coords, h_coords, w_coords = torch.meshgrid(
        torch.linspace(-1.0, 1.0, D),
        torch.linspace(-1.0, 1.0, H),
        torch.linspace(-1.0, 1.0, W),
        indexing='ij'
    )
    coordinates = torch.stack([d_coords, h_coords, w_coords], dim=-1).reshape(-1, 3)
    features = img_tensor.permute(1, 2, 3, 0).reshape(-1, C)
    return coordinates, features

# CORE: FAST TRAINER


In [None]:
class FastTrainer:
    def __init__(self, model, lr):
        use_fused = 'fused' in torch.optim.AdamW.__init__.__kwdefaults__
      
        params = model.parameters()
        self.optimizer = torch.optim.AdamW(params, lr=lr, fused=use_fused)
        
        self.model = model
        self.loss_func = nn.MSELoss()
        self.best_vals = {'psnr': 0.0, 'loss': float('inf')}
        self.logs = {'psnr': [], 'loss': []}
        self.best_model = OrderedDict()
        
        self.last_psnr = 0.0
        
        self.scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else None

    def train(self, coordinates, features, num_iters, batch_size, log_frequency, use_amp):
        N = coordinates.shape[0]
        device = coordinates.device

        print(f"\n{'='*70}\nTRAINING START\n{'='*70}")
        print(f"Batch size: {batch_size:,} ({batch_size/N*100:.2f}% of {N:,} total voxels)")
        print(f"Iterations: {num_iters:,}")
        print(f"Mixed precision (AMP): {use_amp and self.scaler is not None}")
        print(f"{'='*70}\n")
        
        idx_buffer = torch.empty(batch_size, dtype=torch.long, device=device)
        
        self.model.train()
        with tqdm.trange(num_iters, ncols=120, postfix={'loss': 'N/A', 'psnr': 'N/A', 'best': 'N/A'}) as pbar:
            for i in pbar:
                idx_buffer.random_(0, N)
                batch_coords = coordinates[idx_buffer]
                batch_feats = features[idx_buffer]
                
                self.optimizer.zero_grad(set_to_none=True)
                
                if use_amp and self.scaler is not None:
                    with torch.amp.autocast('cuda'):
                        pred = self.model(batch_coords)
                        loss = self.loss_func(pred, batch_feats)
                    self.scaler.scale(loss).backward()
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    pred = self.model(batch_coords)
                    loss = self.loss_func(pred, batch_feats)
                    loss.backward()
                    self.optimizer.step()
                
                if i % log_frequency == 0:
                    with torch.no_grad():
                        mse = (pred.detach() - batch_feats).pow(2).mean()
                        psnr_val = 20 * torch.log10(torch.tensor(1.0)) - 10 * torch.log10(mse)
                    
                    self.last_psnr = psnr_val.item()
                    self.logs['psnr'].append(self.last_psnr)
                    self.logs['loss'].append(loss.item())
                    
                    if self.last_psnr > self.best_vals['psnr']:
                        self.best_vals['psnr'] = self.last_psnr
                        self.best_vals['loss'] = loss.item()
                        
                   
                        state_to_save = self.model.state_dict()
                        for k, v in state_to_save.items():
                            self.best_model[k] = v.detach().clone().cpu()
                
                pbar.set_postfix(
                    loss=f"{loss.item():.6f}", 
                    psnr=f"{self.last_psnr:.2f}", 
                    best=f"{self.best_vals['psnr']:.2f}")

# MAIN WORKFLOW: TRAINING


In [None]:
def train_siren_optimized(config):
    """Main function to orchestrate the training of a SIREN model."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.manual_seed(config['seed'])
    
    logdir = config['logdir']
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    
    print("\n" + "="*70 + f"\nFAST SIREN TRAINING: {os.path.basename(config['image'])}\n" + "="*70)
    print(f"Device: {device}")
    if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}")
    
    img_nii = nib.load(config['image'])
    img_data = img_nii.get_fdata().astype(np.float32)
    img_data = img_data[np.newaxis, ...] if img_data.ndim == 3 else np.transpose(img_data, (3, 0, 1, 2))
    img_tensor = torch.from_numpy(img_data)
    C, D, H, W = img_tensor.shape
    coordinates, features = to_coordinates_and_features_3D(img_tensor)
    
    scaler = MinMaxScaler(feature_range=(0, 1))
    features = torch.from_numpy(scaler.fit_transform(features.numpy()).astype(np.float32))
    coordinates, features = coordinates.to(device), features.to(device)
    
    model = OptimizedSiren(
        dim_in=3, dim_hidden=config['layer_size'], dim_out=C,
        num_layers=config['num_layers'], w0_initial=config['w0_initial'], w0=config['w0']
    )
    
    # The original code would wrap the model here for multi-GPU training.
    model.to(device)
    
    batch_size = config['train_batch_size']
    if batch_size is None:
        batch_size = calculate_optimal_batch_size(model, device, coordinates.shape[0], config)
    
    trainer = FastTrainer(model, lr=config['learning_rate'])
    trainer.train(coordinates, features, config['num_iters'], batch_size, config['log_frequency'], config['use_amp'])
    
    print(f"\n{'='*70}\nâœ… TRAINING COMPLETE! Best PSNR: {trainer.best_vals['psnr']:.2f} dB\n{'='*70}\n")
    
    print("ðŸ’¾ Saving model and metadata...")
    torch.save(trainer.best_model, os.path.join(logdir, config['model_filename']))
    with open(os.path.join(logdir, config['scaler_filename']), 'wb') as f: pickle.dump(scaler, f)
    with open(os.path.join(logdir, config['header_filename']), 'wb') as f: pickle.dump(img_nii.header.copy(), f)
    
    metadata = {
        'img_shape': (C, D, H, W),
        'model_config': {'dim_in': 3, 'dim_hidden': config['layer_size'], 'dim_out': C, 'num_layers': config['num_layers'], 'w0': config['w0'], 'w0_initial': config['w0_initial']},
        'best_psnr': trainer.best_vals['psnr'], 'best_loss': trainer.best_vals['loss']
    }
    with open(os.path.join(logdir, config['metadata_filename']), 'wb') as f: pickle.dump(metadata, f)
    scipy.io.savemat(os.path.join(logdir, config['coords_filename']), {'input_coordinates': coordinates.cpu().numpy()})
    
    print(f"âœ“ All files saved to: {logdir}\n")
    return trainer.best_model


# MAIN WORKFLOW: RECONSTRUCTION


In [None]:

def reconstruct_from_siren(model_dir, output_path, config):
    """Reconstructs an image from a trained SIREN model directory."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("\n" + "="*70 + "\nFAST RECONSTRUCTION\n" + "="*70)
    
    with open(os.path.join(model_dir, config['metadata_filename']), 'rb') as f: metadata = pickle.load(f)
    C, D, H, W = metadata['img_shape']
    with open(os.path.join(model_dir, config['scaler_filename']), 'rb') as f: scaler = pickle.load(f)
    with open(os.path.join(model_dir, config['header_filename']), 'rb') as f: header = pickle.load(f)
    coords_data = scipy.io.loadmat(os.path.join(model_dir, config['coords_filename']))
    coordinates = torch.from_numpy(coords_data['input_coordinates']).to(device)
    
    model = OptimizedSiren(**metadata['model_config'])
    model.load_state_dict(torch.load(os.path.join(model_dir, config['model_filename']), map_location=device))
    
    model.to(device).eval()
    print(f"âœ“ Model loaded from: {model_dir}")
    
    batch_size = config['inference_batch_size']
    if batch_size is None:
        if torch.cuda.is_available():
            free_mem = torch.cuda.mem_get_info()[0]
            est_mem_per_item = 4 * metadata['model_config']['dim_hidden']
            batch_size = min(config['max_inference_batch_size'], int(free_mem * config['inference_mem_factor'] / est_mem_per_item))
            batch_size = 2 ** int(np.log2(batch_size))
        else:
            batch_size = config['non_cuda_inference_batch_size']

    N = coordinates.shape[0]
    num_batches = (N + batch_size - 1) // batch_size
    print(f"\nInference batch size: {batch_size:,}\nTotal batches: {num_batches:,}\n")
    
    predictions = []
    with torch.no_grad():
        for i in tqdm.trange(num_batches, desc="Reconstructing", ncols=100):
            coord_batch = coordinates[i * batch_size: (i + 1) * batch_size]
            with torch.amp.autocast('cuda',enabled=config['use_amp']):
                pred_batch = model(coord_batch)
            predictions.append(pred_batch.cpu())
    
    predicted_features = scaler.inverse_transform(torch.cat(predictions, dim=0).numpy())
    img_decompressed = predicted_features.T.reshape(C, D, H, W)
    img_decompressed = np.transpose(img_decompressed, (1, 2, 3, 0))
    if C == 1: img_decompressed = np.squeeze(img_decompressed, axis=-1)
    
    nib.save(nib.nifti1.Nifti1Image(img_decompressed, None, header=header), output_path)
    print(f"\nâœ“ Saved reconstruction to: {output_path}\n")
    return img_decompressed

# SCRIPT EXECUTION


In [None]:

if __name__ == "__main__":
    os.makedirs(CONFIG['base_results_dir'], exist_ok=True)
    os.makedirs(CONFIG['reconstruction_dir'], exist_ok=True)
    
    try:
        all_files = sorted([f for f in os.listdir(CONFIG['image_folder']) if f.endswith(('.nii', '.nii.gz'))])
        if not all_files: raise FileNotFoundError
    except FileNotFoundError:
        print(f"Error: No '.nii' or '.nii.gz' files found in '{CONFIG['image_folder']}'. Please check the path.")
        exit()

    start_idx = CONFIG['start_image'] - 1
    end_idx = CONFIG['end_image']
    files_to_process = all_files[start_idx:end_idx]

    print(f"Found {len(all_files)} total images. Processing {len(files_to_process)} images from index {CONFIG['start_image']} to {CONFIG['end_image']}.")
    print("-" * 50)

    for i, filename in enumerate(files_to_process):
        print(f"\n\n{'='*80}")
        print(f"PROCESSING IMAGE {i+1}/{len(files_to_process)} (Overall index: {start_idx + i + 1}): {filename}")
        print(f"{'='*80}\n")
        
        current_config = CONFIG.copy()
        
        image_name_no_ext = os.path.splitext(os.path.splitext(filename)[0])[0]

        filename_keys_to_prefix = [
            "model_filename", "scaler_filename", "header_filename",
            "metadata_filename", "coords_filename"
        ]
        for key in filename_keys_to_prefix:
            original_value = current_config[key]
            current_config[key] = f"{image_name_no_ext}_{original_value}"
            
        current_config['image'] = os.path.join(CONFIG['image_folder'], filename)
        current_config['logdir'] = os.path.join(CONFIG['base_results_dir'], image_name_no_ext)
        output_path = os.path.join(CONFIG['reconstruction_dir'], filename)
        
        print(f"--- Step 1: Training on {filename} ---")
        train_siren_optimized(current_config)
        
        print(f"--- Step 2: Reconstructing {filename} ---")
        reconstruct_from_siren(
            model_dir=current_config['logdir'],
            output_path=output_path,
            config=current_config
        )

    print("\nâœ… ALL SELECTED IMAGES PROCESSED SUCCESSFULLY!")