In [1]:
import numpy as np
import torch
from torch_geometric.data import Data

from icosahedron_gen import IcosahedralGrid
from volumetric_surface_parameterization import VolumetricSphericalParameterization
from crsm import CRSM
from gat_stack import GATStack
from ico_blocks import IcoUNet
from s2c_head import S2CHead
from integration import ScalingAndSquaring

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
grid = IcosahedralGrid(subdivisions=2)
print("Vertices:", grid.vertices.shape)
print("Edges:", grid.edge_index.shape)
# grid.visualize(save_path="ico.png")

Vertices: torch.Size([162, 3])
Edges: torch.Size([2, 960])


In [3]:
template_npy = np.load("/shared/scratch/0/home/v_nishchay_nilabh/oasis_data/scans/OASIS_OAS1_0406_MR1/seg4_onehot.npy")
template = torch.from_numpy(template_npy)
labels_fixed = torch.argmax(template, dim=0)
fixed_volume = (labels_fixed > 0).long() 
print("fixed_volume shape:", fixed_volume.shape)

sample_npy = np.load("/shared/scratch/0/home/v_nishchay_nilabh/oasis_data/scans/OASIS_OAS1_0021_MR1/seg4_onehot.npy")
sample = torch.from_numpy(sample_npy)
labels_moving = torch.argmax(sample, dim=0)
moving_volume = (labels_moving > 0).long()
print("moving_volume shape:", moving_volume.shape)

vsp = VolumetricSphericalParameterization()

print("\nProcessing moving volume...")
data_moving = vsp(moving_volume, grid.vertices, grid.edge_index)
print(f"Moving features shape (data_moving.x): {data_moving.x.shape}")

print("Processing fixed volume...")
data_fixed = vsp(fixed_volume, grid.vertices, grid.edge_index)
print(f"Fixed features shape (data_fixed.x): {data_fixed.x.shape}")

combined_features = torch.cat([data_moving.x, data_fixed.x], dim=1)
data = Data(x=combined_features, edge_index=data_moving.edge_index)

print("\n--- Final Combined Data for Network ---")
print(f"Shape of combined features (data.x): {data.x.shape}")
print(f"Shape of edge index (data.edge_index): {data.edge_index.shape}")

fixed_volume shape: torch.Size([256, 256, 256])
moving_volume shape: torch.Size([256, 256, 256])

Processing moving volume...
Moving features shape (data_moving.x): torch.Size([162, 11])
Processing fixed volume...
Fixed features shape (data_fixed.x): torch.Size([162, 11])

--- Final Combined Data for Network ---
Shape of combined features (data.x): torch.Size([162, 22])
Shape of edge index (data.edge_index): torch.Size([2, 960])


In [11]:
data.x.shape

torch.Size([162, 22])

In [12]:
crsm = CRSM(radial_channel_indices=range(22), conical_depth_indices=[0, 5, 10, 11, 16, 21], aggregation='mean', mlp_hidden=(32,16,), mlp_out_dim=8).cuda()

data_crsm = crsm(data)
print("Input shape:", data.x.shape)
print("Output shape:", data_crsm.x.shape)

in_dim 28
hidden_dims (32, 16)
out_dim 8
Input shape: torch.Size([162, 22])
Output shape: torch.Size([162, 8])


In [5]:
gat = GATStack(in_channels=data_crsm.x.size(1), hidden_channels=16, out_channels=16, num_layers=3, heads=4, dropout=0.1).cuda()
print(gat)
x_gat = gat(data_crsm.x, data_crsm.edge_index)
data_gat = data_crsm
data_gat.x = x_gat
print("After GAT stack - shape:", data_gat.x.shape)

GATStack(
  (layers): ModuleList(
    (0): GATBlock(
      (gat): GATConv(8, 16, heads=4)
      (residual): Linear(in_features=8, out_features=64, bias=True)
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (act): LeakyReLU(negative_slope=0.2)
    )
    (1): GATBlock(
      (gat): GATConv(64, 16, heads=4)
      (residual): Linear(in_features=64, out_features=64, bias=True)
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (act): LeakyReLU(negative_slope=0.2)
    )
    (2): GATBlock(
      (gat): GATConv(64, 16, heads=1)
      (residual): Linear(in_features=64, out_features=16, bias=True)
      (norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (act): LeakyReLU(negative_slope=0.2)
    )
  )
)
After GAT stack - shape: torch.Size([162, 16])


In [6]:
def dict_to_pool_tensor(pool_dict, device="cuda"):
    Nc = max(pool_dict.keys()) + 1
    k = max(len(v) for v in pool_dict.values())
    pool_tensor = torch.full((Nc, k), -1, dtype=torch.long, device=device)
    for c, fine_list in pool_dict.items():
        pool_tensor[c, :len(fine_list)] = torch.tensor(fine_list, dtype=torch.long, device=device)
    return pool_tensor

def dict_to_up_tensor(pool_dict, device="cuda"):
    # pool_dict: coarse->list(fine)
    Nf = max(max(v) for v in pool_dict.values()) + 1
    up_tensor = torch.full((Nf,), -1, dtype=torch.long, device=device)
    for c, fine_list in pool_dict.items():
        for f in fine_list:
            up_tensor[f] = c
    return up_tensor

In [7]:
# The map that goes from 162 nodes to 42 nodes
pool_map_162_to_42 = grid.pool_maps[1] 

# The map that goes from 42 nodes to 12 nodes
pool_map_42_to_12 = grid.pool_maps[0]

# Define the list in the correct hierarchical order
grid.pool_maps = [pool_map_162_to_42, pool_map_42_to_12]

# --- Now, create your pool_tensor and up_tensor from this corrected list ---
pool_tensor_lvl0 = dict_to_pool_tensor(grid.pool_maps[0])
pool_tensor_lvl1 = dict_to_pool_tensor(grid.pool_maps[1])

up_tensor_lvl0 = dict_to_up_tensor(grid.pool_maps[0])
up_tensor_lvl1 = dict_to_up_tensor(grid.pool_maps[1])

pool_maps = [pool_tensor_lvl0, pool_tensor_lvl1]
up_maps   = [up_tensor_lvl0, up_tensor_lvl1]

model = IcoUNet(
    in_ch=16,
    channels=[32, 64],
    pool_maps=pool_maps,
    up_maps=up_maps
).cuda()

# x = torch.randn(2, 16, pool_tensor_lvl0.shape[1]*len(grid.pool_maps[0])).cuda()
x = data_gat.x  # shape [N, F]
x = x.unsqueeze(0).permute(0, 2, 1).cuda()  # -> [1, F, N]

# Add this print statement to confirm the large dimension
print(f"Input shape to model: {x.shape}")

# Wrap the model call to disable cuDNN
with torch.backends.cudnn.flags(enabled=False):
    out = model(x).permute(0, 2, 1)

print(out.shape)

Input shape to model: torch.Size([1, 16, 162])
min idx: -1 max idx: 161 x.shape[-1]= 162
min idx: -1 max idx: 41 x.shape[-1]= 42
torch.Size([1, 162, 16])


In [8]:
head = S2CHead(in_channels=16).cuda()

out_size = (16, 16, 16) # D, H, W
u = head(out, grid.vertices, out_size)

print("Output shape:", u.shape)
# Expected: (16, 16, 16, 3)

Output shape: torch.Size([16, 16, 16, 3])


In [9]:
u_permuted = u.permute(3, 0, 1, 2).unsqueeze(0) # Shape: [1, 3, 16, 16, 16]

# Define the target size
full_size = (256, 256, 256)

# Upsample using trilinear interpolation
u_full_res = torch.nn.functional.interpolate(u_permuted, size=full_size, mode='trilinear', align_corners=False)

# Remove the batch dimension and permute back
u_full_res = u_full_res.squeeze(0).permute(1, 2, 3, 0) # Shape: [256, 256, 256, 3]

print(f"Final full-resolution deformation field shape: {u_full_res.shape}")

Final full-resolution deformation field shape: torch.Size([256, 256, 256, 3])


In [15]:
integrator = ScalingAndSquaring(max_scale=0.5).to(device)

# also obtain displacement field
phi_map, disp = integrator(u_full_res, return_displacement=True)
print('returned map and displacement shapes:', phi_map.shape, disp.shape)

returned map and displacement shapes: torch.Size([256, 256, 256, 3]) torch.Size([256, 256, 256, 3])


In [None]:
from spatial_transformer import SpatialTransformer

stn = SpatialTransformer(size=(256, 256, 256), device=device)

disp_batch = disp.permute(3, 0, 1, 2).unsqueeze(0)
moving_volume_batch = moving_volume.float().unsqueeze(0).unsqueeze(0).to(device)

print(f"Displacement field shape: {disp_batch.shape}")
print(f"Moving scan shape for STN: {moving_volume_batch.shape}")
warped_scan = stn(moving=moving_volume_batch, flow=disp_batch)

print(f"\nSuccessfully warped scan!")
print(f"Final warped scan shape: {warped_scan.shape}")

Displacement field shape: torch.Size([1, 3, 256, 256, 256])
Moving scan shape for STN: torch.Size([1, 1, 256, 256, 256])

Successfully warped scan!
Final warped scan shape: torch.Size([1, 1, 256, 256, 256])


In [25]:
from losses import composite_loss

print(composite_loss(fixed_volume.float().cuda(), warped_scan[0][0].cuda(), disp_batch.cuda()))

tensor(1.4046, device='cuda:0', grad_fn=<AddBackward0>)
