In [1]:
import numpy as np

# Load the data
inner_data = np.loadtxt(r'C:\Users\h\Desktop\PET_Recons\Data\ML_Data\inner_training_data_test.csv', delimiter=',')
outer_data = np.loadtxt(r'C:\Users\h\Desktop\PET_Recons\Data\ML_Data\outer_training_data_test.csv', delimiter=',')

print("Inner shape:", inner_data.shape)
print("Outer shape:", outer_data.shape)

# Separate features and labels
inner_features, inner_labels = inner_data[:, :-6], inner_data[:, -6:]
outer_features, outer_labels = outer_data[:, :-6], outer_data[:, -6:]

print("Inner labels shape:", inner_labels.shape)
print("Outer labels shape:", outer_labels.shape)

# Validate that labels are identical
labels_equal = np.array_equal(inner_labels, outer_labels)

print("Labels are identical:" if labels_equal else "⚠️ Labels are NOT identical!")


Inner shape: (5000, 8495)
Outer shape: (5000, 10135)
Inner labels shape: (5000, 6)
Outer labels shape: (5000, 6)
⚠️ Labels are NOT identical!


In [2]:
import numpy as np

# Compare labels row-wise
row_equal = np.all(inner_labels == outer_labels, axis=1)

# Get indices of mismatched rows
mismatched_indices = np.where(~row_equal)[0]

print(f"Number of mismatched rows: {len(mismatched_indices)}")
print("First 10 indices with mismatched labels:", mismatched_indices[:10])

# If you want to inspect the actual labels for a few mismatches:
for idx in mismatched_indices[:5]:  # show first 5 mismatches
    print(f"\nRow {idx}:")
    print("Inner:", inner_labels[idx])
    print("Outer:", outer_labels[idx])


Number of mismatched rows: 4844
First 10 indices with mismatched labels: [13 16 30 33 37 40 42 45 47 48]

Row 13:
Inner: [ 233.3549     42.308315 -140.34695     0.          0.          0.      ]
Outer: [   0.         0.         0.      -199.61256 -141.07738 -142.0763 ]

Row 16:
Inner: [   0.         0.         0.      -199.61256 -141.07738 -142.0763 ]
Outer: [ 89.508736 252.85379   84.34045    0.         0.         0.      ]

Row 30:
Inner: [   0.          0.          0.       -231.473     -84.782135 -147.94    ]
Outer: [   0.          0.          0.       -273.9889    -22.72473    53.701527]

Row 33:
Inner: [   0.          0.          0.       -273.9889    -22.72473    53.701527]
Outer: [270.10522    -0.7890214 -31.816393    0.          0.          0.       ]

Row 37:
Inner: [270.10522    -0.7890214 -31.816393    0.          0.          0.       ]
Outer: [-243.86166  -77.32867 -135.17099    0.         0.         0.     ]


In [3]:
import numpy as np

# Assuming data is already loaded:
# inner_data, outer_data

# Split features/labels
inner_features, inner_labels = inner_data[:, :-6], inner_data[:, -6:]
outer_features, outer_labels = outer_data[:, :-6], outer_data[:, -6:]

# The "7th last column" is -7
inner_keys = inner_data[:, -7]
outer_keys = outer_data[:, -7]

# Storage for filtered results
filtered_inner_features = []
filtered_outer_features = []
filtered_inner_labels = []
filtered_outer_labels = []

i, j = 0, 0
while i < len(inner_keys) and j < len(outer_keys):
    if inner_keys[i] == outer_keys[j]:
        # Keys match → keep both
        filtered_inner_features.append(inner_features[i])
        filtered_outer_features.append(outer_features[j])
        filtered_inner_labels.append(inner_labels[i])
        filtered_outer_labels.append(outer_labels[j])
        i += 1
        j += 1
    elif inner_keys[i] < outer_keys[j]:
        # Advance inner pointer
        i += 1
    else:
        # Advance outer pointer
        j += 1

# Convert lists back into numpy arrays
filtered_inner_features = np.array(filtered_inner_features)
filtered_outer_features = np.array(filtered_outer_features)
filtered_inner_labels = np.array(filtered_inner_labels)
filtered_outer_labels = np.array(filtered_outer_labels)

print("Filtered inner features shape:", filtered_inner_features.shape)
print("Filtered outer features shape:", filtered_outer_features.shape)
print("Filtered inner labels shape:", filtered_inner_labels.shape)
print("Filtered outer labels shape:", filtered_outer_labels.shape)

# Final validation: check labels after alignment
labels_equal = np.array_equal(filtered_inner_labels, filtered_outer_labels)
print("Labels are identical after filtering?", labels_equal)

if labels_equal is True:
    labels = filtered_inner_labels


Filtered inner features shape: (4316, 8489)
Filtered outer features shape: (4316, 10129)
Filtered inner labels shape: (4316, 6)
Filtered outer labels shape: (4316, 6)
Labels are identical after filtering? True


In [4]:
# Remove first and last column from feature arrays
filtered_inner_features = filtered_inner_features[:, 1:-1]
filtered_outer_features = filtered_outer_features[:, 1:-1]

print("Inner features shape after dropping cols:", filtered_inner_features.shape)
print("Outer features shape after dropping cols:", filtered_outer_features.shape)


Inner features shape after dropping cols: (4316, 8487)
Outer features shape after dropping cols: (4316, 10127)


In [5]:
import numpy as np
from collections import Counter

run_lengths = []
current_len = 1

for i in range(1, len(labels)):
    if np.array_equal(labels[i], labels[i-1]):
        current_len += 1
    else:
        run_lengths.append(current_len)
        current_len = 1

# Append last run
run_lengths.append(current_len)

# Count occurrences of each run length
run_length_counts = Counter(run_lengths)

print("Contiguous label run stats:")
for length, count in sorted(run_length_counts.items()):
    print(f"Length {length}: {count} times")

print("\nTotal runs found:", len(run_lengths))


Contiguous label run stats:
Length 1: 36 times
Length 2: 318 times
Length 3: 833 times
Length 4: 144 times
Length 5: 53 times
Length 6: 33 times
Length 7: 9 times
Length 8: 2 times
Length 9: 3 times

Total runs found: 1431


In [6]:
import numpy as np

def make_contiguous_examples(inner_features, outer_features, labels, run_length=3):
    """
    Create tuples of (inner_window, outer_window, label) for contiguous sequences of 'run_length'.
    Labels must remain identical across the window.
    """
    examples = []
    n = len(labels)
    start = 0
    
    while start < n:
        # Find length of current contiguous block
        curr_label = labels[start]
        end = start + 1
        while end < n and np.array_equal(labels[end], curr_label):
            end += 1
        block_len = end - start
        
        # If block is long enough, generate sliding windows of size run_length
        if block_len >= run_length:
            for i in range(start, end - run_length + 1):
                inner_window = inner_features[i:i+run_length]
                outer_window = outer_features[i:i+run_length]
                # label is same across whole block, so we just take one
                examples.append((inner_window, outer_window, curr_label))
        
        # Move to next block
        start = end
    
    return examples

data = make_contiguous_examples(inner_features=filtered_inner_features,
                                outer_features=filtered_outer_features,
                                labels=labels,
                                run_length=3)

print(f"len(data) = {len(data)} examples made")
print(f"len(data[0]) = {len(data[0])} which are the inner, outer, and labels")
print(f"Inner shape: data[0][0].shape = {data[0][0].shape}")
print(f"Outer shape: data[0][1].shape = {data[0][1].shape}")
print(f"Label shape: data[0][2].shape = {data[0][2].shape}")

len(data) = 1490 examples made
len(data[0]) = 3 which are the inner, outer, and labels
Inner shape: data[0][0].shape = (3, 8487)
Outer shape: data[0][1].shape = (3, 10127)
Label shape: data[0][2].shape = (6,)


In [7]:
def reshape_examples(data):
    reshaped = []
    for inner, outer, label in data:
        # reshape each contiguous example
        inner_reshaped = inner.reshape(inner.shape[0], 207, 41)
        outer_reshaped = outer.reshape(outer.shape[0], 247, 41)
        reshaped.append((inner_reshaped, outer_reshaped, label))
    return reshaped

reshaped_data = reshape_examples(data)

print("Original:", data[0][0].shape, data[0][0].shape, data[0][0].shape)
print("Original:", data[0][1].shape, data[0][1].shape, data[0][1].shape)
print("\nReshaped:", reshaped_data[0][0].shape, reshaped_data[0][0].shape, reshaped_data[0][0].shape)
print("Reshaped:", reshaped_data[0][1].shape, reshaped_data[0][1].shape, reshaped_data[0][1].shape)


Original: (3, 8487) (3, 8487) (3, 8487)
Original: (3, 10127) (3, 10127) (3, 10127)

Reshaped: (3, 207, 41) (3, 207, 41) (3, 207, 41)
Reshaped: (3, 247, 41) (3, 247, 41) (3, 247, 41)


In [8]:
import numpy as np
from scipy.ndimage import zoom

def resize_features(arr, target_len):
    """
    Resizes arr from shape (frames, length, channels) 
    to (frames, target_len, channels).
    Uses linear interpolation along the 'length' axis.
    """
    frames, length, channels = arr.shape
    zoom_factor = target_len / length
    # zoom only along axis=1 (the "length" dimension), keep others unchanged
    resized = zoom(arr, (1, zoom_factor, 1), order=1)
    return resized

def resize_dataset(data, target_len=207):
    """
    Applies resize_features to the entire dataset of (inner, outer, label) triples.
    - Inner is kept as-is (already target_len).
    - Outer is resized to target_len.
    """
    resized_data = []
    for inner, outer, label in data:
        inner_resized = inner  # already target_len
        outer_resized = resize_features(outer, target_len)
        resized_data.append((inner_resized, outer_resized, label))
    return resized_data

# Apply to your dataset
resized_data = resize_dataset(reshaped_data, target_len=207)

len(resized_data)
len(resized_data[0])
resized_data[0][0].shape
resized_data[0][1].shape
resized_data[0][2].shape


(6,)

In [9]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_inner_outer_slider(inner, outer):
    """
    Visualize inner and outer arrays side-by-side with a slider to control timestep.
    inner: np array of shape (3, 207, 41)
    outer: np array of shape (3, 207, 41)
    """
    frames = []
    for t in range(inner.shape[0]):
        frames.append(go.Frame(
            data=[
                go.Heatmap(z=inner[t], colorscale='Cividis', colorbar=dict(title='Inner'), zmin=inner.min(), zmax=inner.max()),
                go.Heatmap(z=outer[t], colorscale='Cividis', colorbar=dict(title='Outer'), zmin=outer.min(), zmax=outer.max())
            ],
            name=str(t)
        ))

    fig = make_subplots(rows=1, cols=2,
                        subplot_titles=("Inner", "Outer"),
                        horizontal_spacing=0.1,
                        specs=[[{"type": "heatmap"}, {"type": "heatmap"}]])

    fig.add_trace(frames[0].data[1], row=1, col=1)  # inner heatmap (first trace of frame)
    fig.add_trace(frames[0].data[1], row=1, col=2)  # outer heatmap (second trace of frame)

    # Slider steps
    steps = []
    for i in range(len(frames)):
        steps.append(dict(
            method="animate",
            args=[[str(i)], 
                  dict(mode="immediate", frame=dict(duration=500, redraw=True), transition=dict(duration=0))],
            label=str(i)
        ))

    sliders = [dict(
        active=0,
        pad={"t": 50},
        steps=steps
    )]

    fig.update_layout(
        title="Inner and Outer Feature Maps Over Time",
        sliders=sliders,
        height=600,
        width=900
    )

    fig.frames = frames

    fig.show()

example = 15

plot_inner_outer_slider(resized_data[example][0],
                        resized_data[example][1])

In [10]:
import torch

# Prepare lists for inner and outer tensors for all elements
inner_tensors = []
outer_tensors = []
labels_tensors = []

for data in resized_data:
    inner_tensors.append(torch.tensor(data[0]))  # (timesteps, height, width)
    outer_tensors.append(torch.tensor(data[1]))  # (timesteps, height, width)
    labels_tensors.append(torch.tensor(data[2]))  # (label)

# Stack each list into a tensor along dimension 0 (num_tensors)
inner_tensor = torch.stack(inner_tensors, dim=0)  # (num_tensors, timesteps, height, width)
outer_tensor = torch.stack(outer_tensors, dim=0)  # (num_tensors, timesteps, height, width)
labels_tensor = torch.stack(labels_tensors, dim=0)  # (num_tensors, timesteps, height, width)

# Stack inner and outer along new dim=2 (channels)
combined_tensor = torch.stack((inner_tensor, outer_tensor), dim=1)  # (num_tensors, timesteps, 2, height, width)

print(combined_tensor.shape)
print(labels_tensor.shape)


torch.Size([1490, 2, 3, 207, 41])
torch.Size([1490, 6])


In [11]:
import torch

# Set your split ratio
train_ratio = 0.8
num_samples = combined_tensor.shape[0]
train_size = int(num_samples * train_ratio)
test_size = num_samples - train_size

# Optionally shuffle indices
indices = torch.randperm(num_samples)   # shuffle all sample indices
train_idx = indices[:train_size]
test_idx = indices[train_size:]

# Index into tensors
train_data = combined_tensor[train_idx]
train_labels = labels_tensor[train_idx]

test_data = combined_tensor[test_idx]
test_labels = labels_tensor[test_idx]

In [12]:
print(f"Train tensor: {train_data.shape}, Labels:{train_labels.shape}")
print(f"Test tensor: {test_data.shape}, Labels:{test_labels.shape}")

Train tensor: torch.Size([1192, 2, 3, 207, 41]), Labels:torch.Size([1192, 6])
Test tensor: torch.Size([298, 2, 3, 207, 41]), Labels:torch.Size([298, 6])


In [13]:
import torch

def filter_examples_with_zero_in_label(data_tensor, labels_tensor):
    has_zeros = (labels_tensor == 0)

    has_any_zero_in_row = torch.any(has_zeros, dim=1)
    indices_to_keep = ~has_any_zero_in_row
    filtered_data = data_tensor[indices_to_keep]
    filtered_labels = labels_tensor[indices_to_keep]

    return filtered_data, filtered_labels

# # --- Apply the Function ---
train_data, train_labels = filter_examples_with_zero_in_label(train_data, train_labels)
test_data, test_labels = filter_examples_with_zero_in_label(test_data, test_labels)

print(train_data.shape)
print(train_labels.shape)
print(test_data.shape)
print(test_labels.shape)


torch.Size([154, 2, 3, 207, 41])
torch.Size([154, 6])
torch.Size([42, 2, 3, 207, 41])
torch.Size([42, 6])


In [84]:
# Modified version with positional features added to input
# takes attention_petnet.py and places the global attention layer after layer 1
# for higher dimensional G.A. Employs point-wise seperation convolutions over
# full fat classical convolutions for an 8-10x reduction in paramater count. 
# Also introduces a residual connection after the global attention, and layer norm
# for the global attention (Apparently batch norm works better for Convolutions, 
# and layer norm works better for Transformers...we'll see about that). 
# Uses 3 full connected layers with dropout for a more regression head.   
# Modified to use windowed attention with 4x4 non-overlapping windows.
# ADDED: Positional features (height and width indices) to input channels

import torch
import torch.nn as nn

###############################################################################
# Positional Feature Generator
###############################################################################
def add_positional_features(x, normalize=True):
    """
    Add height and width positional features to input tensor.
    
    Args:
        x: input tensor of shape (batch, channels, D, H, W)
        normalize: whether to normalize positional features to [0, 1]
    
    Returns:
        tensor of shape (batch, channels + 2, D, H, W) with positional features
    """
    batch_size, channels, D, H, W = x.shape
    device = x.device
    
    # Create height indices (0 to H-1) for each position
    height_indices = torch.arange(H, dtype=torch.float32, device=device)
    height_indices = height_indices.view(1, 1, 1, H, 1).expand(batch_size, 1, D, H, W)
    
    # Create width indices (0 to W-1) for each position  
    width_indices = torch.arange(W, dtype=torch.float32, device=device)
    width_indices = width_indices.view(1, 1, 1, 1, W).expand(batch_size, 1, D, H, W)
    
    if normalize:
        # Normalize to [0, 1] range
        if H > 1:
            height_indices = height_indices / (H - 1)
        if W > 1:
            width_indices = width_indices / (W - 1)
    
    # Concatenate original features with positional features
    x_with_pos = torch.cat([x, height_indices, width_indices], dim=1)
    
    return x_with_pos



# petnet_cyl3d.py
from typing import Tuple, Literal

import torch
import torch.nn as nn
import torch.nn.functional as F


# -----------------------------------------------------------
# Utilities
# -----------------------------------------------------------
class DropPath(nn.Module):
    """Per-sample stochastic depth."""

    def __init__(self, drop_prob: float = 0.0):
        super().__init__()
        self.drop_prob = float(drop_prob)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.drop_prob == 0.0 or not self.training:
            return x
        keep = 1.0 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        mask = keep + torch.rand(shape, dtype=x.dtype, device=x.device)
        mask.floor_()
        return x * (mask / keep)


def make_norm(norm: Literal["group", "batch", "instance"], num_channels: int, groups: int = 8):
    if norm == "group":
        g = max(1, min(groups, num_channels))
        while num_channels % g != 0 and g > 1:
            g -= 1
        return nn.GroupNorm(g, num_channels)
    elif norm == "batch":
        return nn.BatchNorm3d(num_channels)
    else:
        return nn.InstanceNorm3d(num_channels, affine=True)


# -----------------------------------------------------------
# Circular-padded Conv3d (wrap-around on W only)
# -----------------------------------------------------------
class Conv3dCircW(nn.Module):
    """
    Conv3d with **circular** padding on the last spatial dim (W = circumference),
    and normal zero-padding on T and H. Works for kernel_size 1 or 3 cleanly.
    """

    def __init__(self, in_ch, out_ch, kernel_size=3, stride=(1, 1, 1), bias=False):
        super().__init__()
        if isinstance(kernel_size, int):
            kD = kH = kW = kernel_size
        else:
            kD, kH, kW = kernel_size
        # We pre-pad W ourselves; so we set conv padding=(padT, padH, padW=0)
        padT = (kD - 1) // 2
        padH = (kH - 1) // 2
        self.pad_w = (kW - 1) // 2
        self.conv = nn.Conv3d(in_ch, out_ch, kernel_size=(kD, kH, kW),
                              stride=stride, padding=(padT, padH, 0), bias=bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.pad_w > 0:
            # F.pad pads in reverse order: (Wl, Wr, Hl, Hr, Dl, Dr)
            x = F.pad(x, (self.pad_w, self.pad_w, 0, 0, 0, 0), mode="circular")
        return self.conv(x)


# -----------------------------------------------------------
# Enhanced Residual Block with more regularization
# -----------------------------------------------------------
class ResidualBlock3D(nn.Module):
    """
    Enhanced residual block with better regularization for generalization.
    """

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            stride: Tuple[int, int, int] = (1, 2, 2),
            norm: Literal["group", "batch", "instance"] = "group",
            drop_path: float = 0.0,
            spatial_dropout: float = 0.1,
            groups_for_gn: int = 8,
    ):
        super().__init__()
        self.conv1 = Conv3dCircW(in_channels, out_channels, kernel_size=3, stride=stride, bias=False)
        self.n1 = make_norm(norm, out_channels, groups_for_gn)
        self.act = nn.GELU()

        # Add spatial dropout between convolutions
        self.spatial_dropout = nn.Dropout3d(spatial_dropout) if spatial_dropout > 0 else nn.Identity()

        self.conv2 = Conv3dCircW(out_channels, out_channels, kernel_size=3, stride=(1, 1, 1), bias=False)
        self.n2 = make_norm(norm, out_channels, groups_for_gn)

        self.shortcut = None
        if stride != (1, 1, 1) or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                make_norm(norm, out_channels, groups_for_gn),
            )
        self.drop_path = DropPath(drop_prob=drop_path) if drop_path > 0 else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        out = self.act(self.n1(self.conv1(x)))
        out = self.spatial_dropout(out)  # Add spatial dropout
        out = self.n2(self.conv2(out))

        if self.shortcut is not None:
            identity = self.shortcut(identity)
        out = self.act(identity + self.drop_path(out))
        return out


# -----------------------------------------------------------
# Improved PetNetCyl3D with better regularization
# -----------------------------------------------------------
class PetNetCyl3D(nn.Module):
    def __init__(
            self,
            in_channels: int = 4,
            base_channels: int = 8,  # Reduced from 16 to 8
            norm: Literal["group", "batch", "instance"] = "group",
            groups_for_gn: int = 4,  # Reduced groups
            dropout3d: float = 0.15,  # Increased dropout
            spatial_dropout: float = 0.1,  # New spatial dropout
            drop_path_rate: float = 0.15,  # Increased drop path
            fc_dropout: float = 0.6,  # Increased FC dropout
            num_layers: int = 4,  # Reduced from 5 to 4 layers
    ):
        super().__init__()
        print(f"Loading PetNetCyl3D (regularized) - base_channels={base_channels}, layers={num_layers}")

        self.act = nn.GELU()

        # Simpler stem
        self.stem = nn.Sequential(
            Conv3dCircW(in_channels, base_channels, kernel_size=3, stride=(1, 1, 1), bias=False),
            make_norm(norm, base_channels, groups_for_gn),
            nn.GELU(),
            nn.Dropout3d(dropout3d) if dropout3d > 0 else nn.Identity(),
        )

        # Reduced channel progression: 8 -> 16 -> 32 -> 64 -> 128 (instead of -> 256)
        chs = [base_channels * (2 ** i) for i in range(num_layers + 1)]
        strides = [(1, 2, 2)] * num_layers
        dprs = [drop_path_rate * i / (num_layers - 1) for i in range(num_layers)]

        # Create layers dynamically
        layers = []
        for i in range(num_layers):
            layers.append(ResidualBlock3D(
                chs[i], chs[i + 1],
                stride=strides[i],
                norm=norm,
                drop_path=dprs[i],
                spatial_dropout=spatial_dropout,
                groups_for_gn=groups_for_gn
            ))

        self.backbone = nn.Sequential(*layers)

        # Global average pool over (T,H,W)
        self.gap = nn.AdaptiveAvgPool3d((1, 1, 1))

        # Simplified and more regularized head
        hidden_dim = max(64, chs[-1] // 2)  # Adaptive hidden dimension

        self.fc_shared = nn.Sequential(
            nn.Linear(chs[-1], hidden_dim, bias=True),
            nn.GELU(),
            nn.Dropout(fc_dropout),
            nn.Linear(hidden_dim, hidden_dim // 2, bias=True),  # Additional layer for better representation
            nn.GELU(),
            nn.Dropout(fc_dropout * 0.5),  # Reduced dropout for final layer
        )

        # Separate heads for inner and outer endpoints
        head_input_dim = hidden_dim // 2
        self.head_inner = nn.Sequential(
            nn.Linear(head_input_dim, 16, bias=True),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(16, 3, bias=True)
        )

        self.head_outer = nn.Sequential(
            nn.Linear(head_input_dim, 16, bias=True),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(16, 3, bias=True)
        )

        self._init_weights()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, C=2, T, H, W); W is the unrolled circumference (we circular-pad it).
        returns: (B, 6) = [cosφ1, sinφ1, z1, cosφ2, sinφ2, z2]
        """
        x = self.stem(x)
        x = self.backbone(x)
        x = self.gap(x).flatten(1)
        x = self.fc_shared(x)

        e1 = self.head_inner(x)  # (B,3)
        e2 = self.head_outer(x)  # (B,3)

        # Encourage valid cosine/sine by soft-normalizing
        def normalize_cos_sin(v: torch.Tensor) -> torch.Tensor:
            cos_sin = v[..., :2]
            z = v[..., 2:3]
            # More stable normalization
            cos_sin = torch.tanh(cos_sin * 0.5)  # Reduced scaling for stability
            norm = torch.clamp(torch.linalg.norm(cos_sin, dim=-1, keepdim=True), min=1e-6)
            cos_sin = cos_sin / norm
            return torch.cat([cos_sin, z], dim=-1)

        e1 = normalize_cos_sin(e1)
        e2 = normalize_cos_sin(e2)
        out = torch.cat([e1, e2], dim=-1)  # (B,6)
        return out

    def _init_weights(self):
        """More conservative weight initialization."""
        for m in self.modules():
            if isinstance(m, (nn.Conv3d, nn.Linear)):
                # Use smaller initialization for better generalization
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu", mode="fan_out")
                if getattr(m, "bias", None) is not None and m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
            elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm, nn.InstanceNorm3d)):
                if hasattr(m, 'weight') and m.weight is not None:
                    nn.init.constant_(m.weight, 1.0)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)


# -----------------------------------------------------------
# Alternative: Even more compact version for small datasets
# -----------------------------------------------------------
class PetNetCyl3DCompact(nn.Module):
    """Ultra-compact version for very small datasets."""

    def __init__(
            self,
            in_channels: int = 4,
            base_channels: int = 6,  # Very small base
            dropout_rate: float = 0.7,  # Aggressive dropout
    ):
        super().__init__()
        print("Loading PetNetCyl3D Compact (minimal overfitting)")

        self.backbone = nn.Sequential(
            Conv3dCircW(in_channels, base_channels, 3, (1, 1, 1)),
            nn.GroupNorm(2, base_channels),
            nn.GELU(),
            nn.Dropout3d(0.2),

            Conv3dCircW(base_channels, base_channels * 2, 3, (1, 2, 2)),
            nn.GroupNorm(2, base_channels * 2),
            nn.GELU(),
            nn.Dropout3d(0.3),

            Conv3dCircW(base_channels * 2, base_channels * 4, 3, (1, 2, 2)),
            nn.GroupNorm(4, base_channels * 4),
            nn.GELU(),
            nn.Dropout3d(0.4),
        )

        self.gap = nn.AdaptiveAvgPool3d(1)

        self.head = nn.Sequential(
            nn.Linear(base_channels * 4, 32),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(32, 6)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.gap(x).flatten(1)
        return self.head(x)
    
    
# -----------------------------------------------------------
# Compact version without GAP (uses full feature map)
# -----------------------------------------------------------
class PetNetCyl3DFullFeatures(nn.Module):
    """Compact version that uses ALL output features (no GAP), TorchScript compatible."""

    def __init__(
        self,
        in_channels: int = 4,
        base_channels: int = 6,
        input_shape=(2, 3, 207, 41),  # (C, T, H, W) so we can compute flatten_dim
        dropout_rate: float = 0.5,
        out_features: int = 6,
    ):
        super().__init__()
        print("Loading PetNetCyl3D (Full Features TorchScript Compatible)")

        # Backbone with stride-based downsampling in H,W
        self.backbone = nn.Sequential(
            nn.Conv3d(in_channels, base_channels, kernel_size=3, stride=(1, 1, 1), padding=1),
            nn.GroupNorm(2, base_channels),
            nn.GELU(),
            nn.Dropout3d(0.2),

            nn.Conv3d(base_channels, base_channels * 2, kernel_size=3, stride=(1, 2, 2), padding=1),
            nn.GroupNorm(2, base_channels * 2),
            nn.GELU(),
            nn.Dropout3d(0.3),

            nn.Conv3d(base_channels * 2, base_channels * 4, kernel_size=3, stride=(1, 2, 2), padding=1),
            nn.GroupNorm(4, base_channels * 4),
            nn.GELU(),
            nn.Dropout3d(0.4),
        )

        # Precompute flatten_dim from input shape
        flatten_dim = self._compute_flatten_dim(in_channels, input_shape)
        print(f"[Init] Flattened feature size = {flatten_dim}")

        # Fully connected head
        self.head = nn.Sequential(
            nn.Linear(flatten_dim, 128),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, out_features)
        )

    def _compute_flatten_dim(self, in_channels, input_shape):
        """Compute feature map flatten size at init time (TorchScript safe)."""
        C, T, H, W = input_shape
        dummy = torch.zeros(1, C, T, H, W)
        with torch.no_grad():
            out = self.backbone(dummy)
        return out.numel()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        return self.head(x)


# -----------------------------------------------------------
# Compact, full features, attention last layer
# -----------------------------------------------------------
class PetNetCyl3DAttentionFull(nn.Module):
    """Attention version that keeps ALL attended features (TorchScript-compatible)."""

    def __init__(
        self,
        in_channels: int = 4,
        base_channels: int = 6,
        input_shape=(2, 3, 207, 41),  # (C, T, H, W)
        dropout_rate: float = 0.5,
        out_features: int = 6,
        attn_heads: int = 1,
    ):
        super().__init__()
        print("Loading PetNetCyl3D (Attention Full Feature Map, TorchScript Compatible)")

        # Backbone with stride-based downsampling
        self.backbone = nn.Sequential(
            nn.Conv3d(in_channels, base_channels, kernel_size=3, stride=(1, 1, 1), padding=1),
            nn.GroupNorm(2, base_channels),
            nn.GELU(),
            nn.Dropout3d(0.2),

            nn.Conv3d(base_channels, base_channels * 2, kernel_size=3, stride=(1, 2, 2), padding=1),
            nn.GroupNorm(2, base_channels * 2),
            nn.GELU(),
            nn.Dropout3d(0.3),

            nn.Conv3d(base_channels * 2, base_channels * 4, kernel_size=3, stride=(1, 2, 2), padding=1),
            nn.GroupNorm(4, base_channels * 4),
            nn.GELU(),
            nn.Dropout3d(0.4),
        )

        # Precompute shape
        self.flatten_dim, self.feature_dim, self.num_tokens = self._compute_flatten_stats(in_channels, input_shape)
        print(f"[Init] Flatten: {self.flatten_dim}, Tokens={self.num_tokens}, Feature_dim={self.feature_dim}")

        # Attention projections
        self.q_proj = nn.Linear(self.feature_dim, self.feature_dim, bias=False)
        self.k_proj = nn.Linear(self.feature_dim, self.feature_dim, bias=False)
        self.v_proj = nn.Linear(self.feature_dim, self.feature_dim, bias=False)

        self.attn_heads = attn_heads
        self.scale = (self.feature_dim // attn_heads) ** -0.5

        # Final head takes ALL tokens (N * C)
        self.head = nn.Sequential(
            nn.Linear(self.num_tokens * self.feature_dim, 256),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, out_features)
        )

    def _compute_flatten_stats(self, in_channels, input_shape):
        """Compute feature dims and token count (TorchScript safe)."""
        C, T, H, W = input_shape
        dummy = torch.zeros(1, C, T, H, W)
        with torch.no_grad():
            out = self.backbone(dummy)  # (1, C', T', H', W')
        _, C_out, T_out, H_out, W_out = out.shape
        num_tokens = T_out * H_out * W_out
        flatten_dim = num_tokens * C_out
        return flatten_dim, C_out, num_tokens

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Backbone
        x = self.backbone(x)  # (B, C, T, H, W)
        B, C, T, H, W = x.shape
        N = T * H * W

        # Flatten tokens
        x = x.view(B, C, N).transpose(1, 2)  # (B, N, C)

        # Attention
        Q = self.q_proj(x)  # (B, N, C)
        K = self.k_proj(x)  # (B, N, C)
        V = self.v_proj(x)  # (B, N, C)

        attn_scores = torch.bmm(Q, K.transpose(1, 2)) * self.scale  # (B, N, N)
        attn_weights = torch.softmax(attn_scores, dim=-1)            # (B, N, N)
        attended = torch.bmm(attn_weights, V)                       # (B, N, C)

        # Keep ALL features -> flatten (B, N*C)
        flat = attended.reshape(B, N * C)

        # Fully connected head
        return self.head(flat)


# -----------------------------------------------------------
# Depthwise Separable Conv3D implementation
# -----------------------------------------------------------
class DepthwiseSeparableConv3d(nn.Module):
    """
    Depthwise Separable Conv3d: depthwise conv followed by pointwise conv.
    Reduces parameters from (in_ch * out_ch * k^3) to (in_ch * k^3 + in_ch * out_ch).
    """
    
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=(1, 1, 1), padding=1, bias=False):
        super().__init__()
        
        # Depthwise convolution (groups = in_channels means each input channel gets its own filter)
        self.depthwise = nn.Conv3d(
            in_channels, in_channels, 
            kernel_size=kernel_size, 
            stride=stride, 
            padding=padding, 
            groups=in_channels,  # Key: groups = in_channels
            bias=False
        )
        
        # Pointwise convolution (1x1x1 conv to mix channels)
        self.pointwise = nn.Conv3d(
            in_channels, out_channels, 
            kernel_size=1, 
            stride=1, 
            padding=0, 
            bias=bias
        )
    
    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x


# -----------------------------------------------------------
# PetNetCyl3D with Depthwise Separable Convolutions
# -----------------------------------------------------------
class PetNetCyl3DDepthwise(nn.Module):
    """
    Compact version using depthwise separable convolutions for parameter efficiency.
    Based on PetNetCyl3DFullFeatures but with depthwise separable convs.
    """

    def __init__(
        self,
        in_channels: int = 4,
        base_channels: int = 6,
        input_shape=(2, 3, 207, 41),  # (C, T, H, W) so we can compute flatten_dim
        dropout_rate: float = 0.5,
        out_features: int = 6,
    ):
        super().__init__()
        print("Loading PetNetCyl3D (Depthwise Separable Convolutions)")

        # First layer uses regular conv (can't use depthwise when in_channels < groups)
        self.conv1 = nn.Conv3d(in_channels, base_channels, kernel_size=3, stride=(1, 1, 1), padding=1)
        self.norm1 = nn.GroupNorm(2, base_channels)
        self.act1 = nn.GELU()
        self.drop1 = nn.Dropout3d(0.2)

        # Subsequent layers use depthwise separable convolutions
        self.conv2 = DepthwiseSeparableConv3d(
            base_channels, base_channels * 2, 
            kernel_size=3, stride=(1, 2, 2), padding=1
        )
        self.norm2 = nn.GroupNorm(2, base_channels * 2)
        self.act2 = nn.GELU()
        self.drop2 = nn.Dropout3d(0.3)

        self.conv3 = DepthwiseSeparableConv3d(
            base_channels * 2, base_channels * 4, 
            kernel_size=3, stride=(1, 2, 2), padding=1
        )
        self.norm3 = nn.GroupNorm(4, base_channels * 4)
        self.act3 = nn.GELU()
        self.drop3 = nn.Dropout3d(0.4)

        # Precompute flatten_dim from input shape
        flatten_dim = self._compute_flatten_dim(in_channels, input_shape)
        print(f"[Init] Flattened feature size = {flatten_dim}")

        # Fully connected head
        self.head = nn.Sequential(
            nn.Linear(flatten_dim, 128),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, out_features)
        )

    def _compute_flatten_dim(self, in_channels, input_shape):
        """Compute feature map flatten size at init time (TorchScript safe)."""
        C, T, H, W = input_shape
        dummy = torch.zeros(1, C, T, H, W)
        with torch.no_grad():
            # Forward through backbone manually
            x = self.act1(self.norm1(self.conv1(dummy)))
            x = self.drop1(x)
            x = self.act2(self.norm2(self.conv2(x)))
            x = self.drop2(x)
            x = self.act3(self.norm3(self.conv3(x)))
            x = self.drop3(x)
        return x.numel()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Forward through backbone
        x = self.act1(self.norm1(self.conv1(x)))
        x = self.drop1(x)
        
        x = self.act2(self.norm2(self.conv2(x)))
        x = self.drop2(x)
        
        x = self.act3(self.norm3(self.conv3(x)))
        x = self.drop3(x)
        
        # Flatten and pass through head
        x = torch.flatten(x, 1)
        return self.head(x)


# -----------------------------------------------------------
# Windowed Attention Block (TorchScript Compatible)
# -----------------------------------------------------------
class WindowedAttention3D(nn.Module):
    """
    Windowed attention that partitions the feature map into windows along one spatial dimension.
    TorchScript compatible implementation.
    """
    
    def __init__(self, dim: int, window_size: int, num_heads: int = 2, window_dim: str = 'width'):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divisible by num_heads {num_heads}"
        
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.window_dim = window_dim  # 'width' or 'height'
        
        # Linear projections
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(0.1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, C, T, H, W) - feature map from conv layer
        """
        B, C, T, H, W = x.shape
        
        # Reshape to (B, T*H*W, C) for attention
        x = x.view(B, C, T * H * W).transpose(1, 2)  # (B, T*H*W, C)
        
        if self.window_dim == 'width':
            # Partition along width dimension
            x = self._partition_width(x, B, T, H, W)
        else:  # height
            # Partition along height dimension  
            x = self._partition_height(x, B, T, H, W)
            
        # Apply attention to each window
        x = self._windowed_attention(x)
        
        if self.window_dim == 'width':
            # Reverse width partitioning
            x = self._reverse_partition_width(x, B, T, H, W)
        else:  # height
            # Reverse height partitioning
            x = self._reverse_partition_height(x, B, T, H, W)
            
        # Reshape back to (B, C, T, H, W)
        x = x.transpose(1, 2).contiguous().view(B, C, T, H, W)
        return x
    
    def _partition_width(self, x: torch.Tensor, B: int, T: int, H: int, W: int) -> torch.Tensor:
        """Partition feature map into windows along width dimension."""
        # x: (B, T*H*W, C)
        # Reshape to (B, T, H, W, C) then partition width
        x = x.view(B, T, H, W, self.dim)
        
        # Pad width if needed
        pad_w = (self.window_size - W % self.window_size) % self.window_size
        if pad_w > 0:
            x = F.pad(x, (0, 0, 0, pad_w))  # Pad width dimension
            W = W + pad_w
            
        # Partition: (B, T, H, num_windows_w, window_size, C)
        num_windows_w = W // self.window_size
        x = x.view(B, T, H, num_windows_w, self.window_size, self.dim)
        
        # Reshape for attention: (B * T * H * num_windows_w, window_size, C)
        x = x.contiguous().view(B * T * H * num_windows_w, self.window_size, self.dim)
        return x
    
    def _partition_height(self, x: torch.Tensor, B: int, T: int, H: int, W: int) -> torch.Tensor:
        """Partition feature map into windows along height dimension."""
        # x: (B, T*H*W, C)  
        # Reshape to (B, T, H, W, C) then partition height
        x = x.view(B, T, H, W, self.dim)
        
        # Pad height if needed
        pad_h = (self.window_size - H % self.window_size) % self.window_size
        if pad_h > 0:
            x = F.pad(x, (0, 0, 0, 0, 0, pad_h))  # Pad height dimension
            H = H + pad_h
            
        # Partition: (B, T, num_windows_h, window_size, W, C)
        num_windows_h = H // self.window_size
        x = x.view(B, T, num_windows_h, self.window_size, W, self.dim)
        
        # Reshape for attention: (B * T * W * num_windows_h, window_size, C)
        x = x.permute(0, 1, 4, 2, 3, 5).contiguous().view(B * T * W * num_windows_h, self.window_size, self.dim)
        return x
    
    def _reverse_partition_width(self, x: torch.Tensor, B: int, T: int, H: int, W: int) -> torch.Tensor:
        """Reverse width partitioning."""
        # Pad width if needed (same as in partition)
        pad_w = (self.window_size - W % self.window_size) % self.window_size
        W_padded = W + pad_w if pad_w > 0 else W
        num_windows_w = W_padded // self.window_size
        
        # x: (B * T * H * num_windows_w, window_size, C)
        # Reshape back: (B, T, H, num_windows_w, window_size, C)
        x = x.contiguous().view(B, T, H, num_windows_w, self.window_size, self.dim)
        
        # Merge windows: (B, T, H, W_padded, C)
        x = x.contiguous().view(B, T, H, W_padded, self.dim)
        
        # Remove padding if added
        if pad_w > 0:
            x = x[:, :, :, :W, :]
            
        # Flatten spatial dims: (B, T*H*W, C)
        x = x.contiguous().view(B, T * H * W, self.dim)
        return x
    
    def _reverse_partition_height(self, x: torch.Tensor, B: int, T: int, H: int, W: int) -> torch.Tensor:
        """Reverse height partitioning."""
        # Pad height if needed (same as in partition)
        pad_h = (self.window_size - H % self.window_size) % self.window_size
        H_padded = H + pad_h if pad_h > 0 else H
        num_windows_h = H_padded // self.window_size
        
        # x: (B * T * W * num_windows_h, window_size, C)
        # Reshape back: (B, T, W, num_windows_h, window_size, C)
        x = x.contiguous().view(B, T, W, num_windows_h, self.window_size, self.dim)
        
        # Permute and merge: (B, T, H_padded, W, C)
        x = x.permute(0, 1, 3, 4, 2, 5).contiguous().view(B, T, H_padded, W, self.dim)
        
        # Remove padding if added
        if pad_h > 0:
            x = x[:, :, :H, :, :]
            
        # Flatten spatial dims: (B, T*H*W, C)
        x = x.contiguous().view(B, T * H * W, self.dim)
        return x
    
    def _windowed_attention(self, x: torch.Tensor) -> torch.Tensor:
        """Apply multi-head attention within each window."""
        # x: (num_windows, window_size, C)
        num_windows, window_size, _ = x.shape
        
        # Generate Q, K, V
        qkv = self.qkv(x)  # (num_windows, window_size, 3*C)
        qkv = qkv.view(num_windows, window_size, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, num_windows, num_heads, window_size, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Attention
        attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale  # (num_windows, num_heads, window_size, window_size)
        attn = F.softmax(attn, dim=-1)
        
        # Apply attention to values
        x = torch.matmul(attn, v)  # (num_windows, num_heads, window_size, head_dim)
        x = x.transpose(1, 2).contiguous().view(num_windows, window_size, self.dim)
        
        # Final projection
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


# -----------------------------------------------------------
# PetNetCyl3D with Windowed Attention and Multiple FC Heads
# -----------------------------------------------------------
class PetNetCyl3DWindowedAttention(nn.Module):
    """
    PetNet with windowed attention applied after first convolution.
    Two successive attention blocks: width-windowed, then height-windowed.
    Multiple specialized FC heads for each output channel.
    TorchScript compatible.
    """

    def __init__(
        self,
        in_channels: int = 4,
        base_channels: int = 6,
        input_shape=(2, 3, 207, 41),  # (C, T, H, W) - original channels before positional
        dropout_rate: float = 0.5,
        out_features: int = 6,
        window_size: int = 8,  # Window size for attention
        attn_heads: int = 2,
        normalize_positions: bool = True,
    ):
        super().__init__()
        print("Loading PetNetCyl3D (Windowed Attention, Multiple FC Heads, TorchScript Compatible)")
        
        self.normalize_positions = normalize_positions
        self.out_features = out_features

        # First convolution (input channels = 4 after adding positional features)
        self.conv1 = nn.Conv3d(in_channels, base_channels, kernel_size=3, stride=(1, 1, 1), padding=1)
        self.norm1 = nn.GroupNorm(2, base_channels)
        self.act1 = nn.GELU()
        self.drop1 = nn.Dropout3d(0.2)

        # Windowed attention blocks after first conv
        self.width_attention = WindowedAttention3D(
            dim=base_channels, 
            window_size=window_size, 
            num_heads=attn_heads, 
            window_dim='width'
        )
        
        self.height_attention = WindowedAttention3D(
            dim=base_channels, 
            window_size=window_size, 
            num_heads=attn_heads, 
            window_dim='height'
        )

        # Remaining convolution layers
        self.conv2 = nn.Conv3d(base_channels, base_channels * 2, kernel_size=3, stride=(1, 2, 2), padding=1)
        self.norm2 = nn.GroupNorm(2, base_channels * 2)
        self.act2 = nn.GELU()
        self.drop2 = nn.Dropout3d(0.3)

        self.conv3 = nn.Conv3d(base_channels * 2, base_channels * 4, kernel_size=3, stride=(1, 2, 2), padding=1)
        self.norm3 = nn.GroupNorm(4, base_channels * 4)
        self.act3 = nn.GELU()
        self.drop3 = nn.Dropout3d(0.4)

        # Precompute flatten_dim from input shape (with positional features)
        flatten_dim = self._compute_flatten_dim(in_channels, input_shape)
        print(f"[Init] Flattened feature size = {flatten_dim}")

        # Shared feature extraction layers
        self.shared_fc1 = nn.Linear(flatten_dim, 256)
        self.shared_dropout1 = nn.Dropout(dropout_rate)
        self.shared_fc2 = nn.Linear(256, 128)
        self.shared_dropout2 = nn.Dropout(dropout_rate)

        # Individual FC heads for each output channel
        self.fc_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(128, 64),
                nn.GELU(),
                nn.Dropout(0.3),
                nn.Linear(64, 32),
                nn.GELU(),
                nn.Dropout(0.2),
                nn.Linear(32, 1)
            ) for _ in range(out_features)
        ])

        # Alternative: Simpler individual heads (uncomment to use)
        # self.fc_heads = nn.ModuleList([
        #     nn.Sequential(
        #         nn.Linear(128, 64),
        #         nn.GELU(),
        #         nn.Dropout(0.3),
        #         nn.Linear(64, 1)
        #     ) for _ in range(out_features)
        # ])

    def _compute_flatten_dim(self, in_channels, input_shape):
        """Compute feature map flatten size at init time (TorchScript safe)."""
        C, T, H, W = input_shape
        # Create dummy input with original channels
        dummy = torch.zeros(1, C, T, H, W)
        
        with torch.no_grad():
            # Add positional features to match actual forward pass
            dummy = add_positional_features(dummy, normalize=self.normalize_positions)
            
            # Forward through full model pipeline
            x = self.act1(self.norm1(self.conv1(dummy)))
            x = self.drop1(x)
            
            # Apply windowed attention blocks
            x = self.width_attention(x)
            x = self.height_attention(x)
            
            # Remaining convolutions
            x = self.act2(self.norm2(self.conv2(x)))
            x = self.drop2(x)
            x = self.act3(self.norm3(self.conv3(x)))
            x = self.drop3(x)
            
        return x.numel()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Add positional features (height and width indices)
        x = add_positional_features(x, normalize=self.normalize_positions)

        # First convolution
        x = self.act1(self.norm1(self.conv1(x)))
        x = self.drop1(x)
        
        # Apply windowed attention: width-windowed, then height-windowed
        x = self.width_attention(x)
        x = self.height_attention(x)
        
        # Remaining convolutions
        x = self.act2(self.norm2(self.conv2(x)))
        x = self.drop2(x)
        
        x = self.act3(self.norm3(self.conv3(x)))
        x = self.drop3(x)
        
        # Flatten and pass through shared feature extraction
        x = torch.flatten(x, 1)
        
        # Shared feature extraction
        x = self.shared_fc1(x)
        x = self.act1(x)  # Reuse activation
        x = self.shared_dropout1(x)
        
        x = self.shared_fc2(x)
        x = self.act1(x)  # Reuse activation
        x = self.shared_dropout2(x)
        
        # Apply individual FC heads
        outputs = []
        for head in self.fc_heads:
            head_output = head(x)  # Shape: (batch_size, 1)
            outputs.append(head_output)
        
        # Concatenate all head outputs
        final_output = torch.cat(outputs, dim=1)  # Shape: (batch_size, out_features)
        
        return final_output

    def get_head_outputs(self, x: torch.Tensor):
        """
        Get outputs from each head separately (useful for analysis/debugging).
        
        Returns:
            List of tensors, each of shape (batch_size, 1)
        """
        # Forward through shared layers
        x = add_positional_features(x, normalize=self.normalize_positions)
        x = self.act1(self.norm1(self.conv1(x)))
        x = self.drop1(x)
        x = self.width_attention(x)
        x = self.height_attention(x)
        x = self.act2(self.norm2(self.conv2(x)))
        x = self.drop2(x)
        x = self.act3(self.norm3(self.conv3(x)))
        x = self.drop3(x)
        x = torch.flatten(x, 1)
        x = self.shared_fc1(x)
        x = self.act1(x)
        x = self.shared_dropout1(x)
        x = self.shared_fc2(x)
        x = self.act1(x)
        x = self.shared_dropout2(x)
        
        # Get individual head outputs
        head_outputs = []
        for head in self.fc_heads:
            head_output = head(x)
            head_outputs.append(head_output)
        
        return head_outputs

    def get_feature_maps(self, x: torch.Tensor, return_attention_maps: bool = False):
        """
        Extract intermediate feature maps for visualization/analysis.
        
        Args:
            x: Input tensor
            return_attention_maps: Whether to return attention maps from windowed attention
        
        Returns:
            Dictionary containing feature maps at different stages
        """
        feature_maps = {}
        
        x = add_positional_features(x, normalize=self.normalize_positions)
        feature_maps['input_with_pos'] = x.clone()
        
        # After first conv
        x = self.act1(self.norm1(self.conv1(x)))
        feature_maps['after_conv1'] = x.clone()
        x = self.drop1(x)
        
        # After attention blocks
        x = self.width_attention(x)
        feature_maps['after_width_attention'] = x.clone()
        x = self.height_attention(x)
        feature_maps['after_height_attention'] = x.clone()
        
        # After remaining convolutions
        x = self.act2(self.norm2(self.conv2(x)))
        feature_maps['after_conv2'] = x.clone()
        x = self.drop2(x)
        
        x = self.act3(self.norm3(self.conv3(x)))
        feature_maps['after_conv3'] = x.clone()
        x = self.drop3(x)
        
        # Flattened features
        x = torch.flatten(x, 1)
        feature_maps['flattened'] = x.clone()
        
        # After shared FC layers
        x = self.shared_fc1(x)
        x = self.act1(x)
        x = self.shared_dropout1(x)
        feature_maps['after_shared_fc1'] = x.clone()
        
        x = self.shared_fc2(x)
        x = self.act1(x)
        x = self.shared_dropout2(x)
        feature_maps['after_shared_fc2'] = x.clone()
        
        return feature_maps


# Example usage and comparison
if __name__ == "__main__":
    import torch
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Test with 8 output classes (cylindrical coordinates)
    CLASSES = 8
    INPUT_SHAPE = (2, 3, 207, 41)  # Original input shape before positional features
    
    model = PetNetCyl3DWindowedAttention(
        in_channels=4,  # Will be 4 after adding positional features
        base_channels=6,
        input_shape=INPUT_SHAPE,
        dropout_rate=0.5,
        out_features=CLASSES,
        window_size=8,
        attn_heads=2,
        normalize_positions=True
    ).to(device)
    
    # Print parameter count
    param_count = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {param_count:,}")
    
    # Count parameters in each component
    shared_params = sum(p.numel() for name, p in model.named_parameters() if 'fc_heads' not in name)
    head_params = sum(p.numel() for name, p in model.named_parameters() if 'fc_heads' in name)
    print(f"Shared parameters: {shared_params:,}")
    print(f"FC heads parameters: {head_params:,}")
    print(f"Parameters per head: {head_params // CLASSES:,}")
    
    # Test forward pass
    batch_size = 2
    C, T, H, W = INPUT_SHAPE
    dummy_input = torch.randn(batch_size, C, T, H, W).to(device)
    
    print("\nTesting forward pass...")
    with torch.no_grad():
        output = model(dummy_input)
        print(f"Final output shape: {output.shape}")
        print(f"Output sample:\n{output[0]}")  # Show first sample
        
        # Test individual head outputs
        head_outputs = model.get_head_outputs(dummy_input)
        print(f"\nIndividual head outputs:")
        for i, head_out in enumerate(head_outputs):
            print(f"Head {i+1}: {head_out.shape} -> values: {head_out[0].item():.4f}")
    
    print("\n✅ Multi-head PetNetCyl3D working correctly!")
    
    # Show the architecture of one FC head
    print(f"\nFC Head Architecture:")
    for i, layer in enumerate(model.fc_heads[0]):
        print(f"  Layer {i+1}: {layer}")
    
    # Architecture summary
    print(f"\n📊 ARCHITECTURE SUMMARY:")
    print(f"Input: {INPUT_SHAPE} -> with positional: {(4, *INPUT_SHAPE[1:])}")
    print(f"Conv1: {4} -> {6} channels")
    print(f"Windowed attention: width-windowed + height-windowed")
    print(f"Conv2: {6} -> {12} channels, stride=(1,2,2)")
    print(f"Conv3: {12} -> {24} channels, stride=(1,2,2)")
    print(f"Shared FC: flattened -> 256 -> 128")
    print(f"Individual heads: {CLASSES} heads, each 128->64->32->1")
    print(f"Output: {CLASSES} values (specialized for cylindrical coordinates)")
    
    # Benefits summary
    print(f"\n🎯 MULTI-HEAD BENEFITS:")
    print(f"- Specialized learning for each coordinate component")
    print(f"- Better handling of heterogeneous outputs (r, sin, cos, z)")
    print(f"- Independent optimization and regularization")
    print(f"- Improved interpretability and debugging")
    print(f"- Can analyze individual coordinate predictions")

# -----------------------------------------------------------
# Test
# -----------------------------------------------------------
if __name__ == "__main__":
    # Test regular model
    B, C, T, H, W = 4, 2, 3, 207, 41

    model = PetNetCyl3D(in_channels=C, base_channels=8, num_layers=3)
    x = torch.randn(B, C, T, H, W)
    y_pred = model(x)
    print(f"Regular model output shape: {y_pred.shape}")

    compact_model = PetNetCyl3DCompact(in_channels=C)
    y_pred_compact = compact_model(x)
    print(f"Compact model output shape: {y_pred_compact.shape}")
    
    full_model = PetNetCyl3DFullFeatures(in_channels=C, base_channels=6, out_features=6)
    y_pred_full = full_model(x)
    print(f"Full model output shape: {y_pred_full.shape}")

    full_attn_model = PetNetCyl3DAttentionFull(in_channels=C, base_channels=6, out_features=6)
    y_pred_full_attm = full_attn_model(x)
    print(f"Full attention model output shape: {y_pred_full_attm.shape}")

    # Standard depthwise model
    depthwise_model = PetNetCyl3DDepthwise(in_channels=C, base_channels=6, out_features=6)
    y_pred_dw = depthwise_model(x)
    print(f"Depthwise model output shape: {y_pred_dw.shape}")

    # windowed_model = PetNetCyl3DWindowedAttention(
    #     in_channels=C, 
    #     base_channels=6, 
    #     out_features=6,
    #     window_size=8,
    #     attn_heads=2
    # )
    # y_pred_window = depthwise_model(x)
    # print(f"Windowed model output shape: {y_pred_window.shape}")



    # Count parameters
    def count_params(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"Regular model parameters: {count_params(model):,}")
    print(f"Compact model parameters: {count_params(compact_model):,}")
    print(f"Full model parameters: {count_params(full_model):,}")
    print(f"Full attention model parameters: {count_params(full_attn_model):,}")
    print(f"DW model parameters: {count_params(depthwise_model):,}")
    # print(f"Windowed model parameters: {count_params(windowed_model):,}")



if __name__ == "__main__":

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    B = 1
    # B = 128

    C = 2  # Original channels
    T = 3 

    H = 207
    # H = 496

    W = 41
    # W = 84
    
    CLASSES = 8

    # Model instantiation
    model = PetNetCyl3DDepthwise(in_channels=C, base_channels=6, out_features=CLASSES).to(device)
    
    # Print parameter count
    param_count = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {param_count:,}")
    
    # Test a dummy pass - note input still has original 2 channels
    # The positional features are added inside the forward pass
    dummy_input = torch.randn(B, C, T, H, W).to(device)
    dummy_target = torch.randn(B, CLASSES).to(device)
    
    print("Testing forward pass with positional features...")
    model.forward(dummy_input)

    # Dummy training loop to observe loss reduction
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    criterion = nn.MSELoss()

    import time

    total_time = 0.0
    epochs = 5

    for epoch in range(epochs):
        optimizer.zero_grad()
        start_time = time.time()
        output = model(dummy_input)
        end_time = time.time()
        forward_time = end_time - start_time
        total_time += forward_time

        loss = criterion(output, dummy_target)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Forward Time: {forward_time:.6f}s")

    average_time = total_time / epochs
    print(f"Average Forward Pass Time: {average_time:.6f}s")

Using device: cuda
Loading PetNetCyl3D (Windowed Attention, Multiple FC Heads, TorchScript Compatible)
[Init] Flattened feature size = 41184
Total parameters: 10,670,002
Shared parameters: 10,587,050
FC heads parameters: 82,952
Parameters per head: 10,369

Testing forward pass...
Final output shape: torch.Size([2, 8])
Output sample:
tensor([-0.1629,  0.1664,  0.0263, -0.0554,  0.0808, -0.1654,  0.1036, -0.1459],
       device='cuda:0')

Individual head outputs:
Head 1: torch.Size([2, 1]) -> values: -0.1464
Head 2: torch.Size([2, 1]) -> values: 0.1781
Head 3: torch.Size([2, 1]) -> values: 0.0111
Head 4: torch.Size([2, 1]) -> values: -0.0649
Head 5: torch.Size([2, 1]) -> values: 0.1128
Head 6: torch.Size([2, 1]) -> values: -0.1472
Head 7: torch.Size([2, 1]) -> values: 0.1247
Head 8: torch.Size([2, 1]) -> values: -0.1279

✅ Multi-head PetNetCyl3D working correctly!

FC Head Architecture:
  Layer 1: Linear(in_features=128, out_features=64, bias=True)
  Layer 2: GELU(approximate='none')
  L

In [71]:
# unnormalised cartesian
if CLASSES == 6:
    import torch
    from torch.utils.data import DataLoader, TensorDataset
    import matplotlib.pyplot as plt

    # Check device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    BATCH_SIZE = 1
    NUM_EPOCHS = 40
    # Scanner dimensions
    R_INNER_MM = 235.422
    R_OUTER_MM = 278.296
    Z_HALF_MM = 148.0

    print("🔍 UNNORMALIZED CARTESIAN COORDINATE PREDICTION")
    print("=" * 60)
    print("Testing effect of removing output normalization...")
    print("Model outputs raw values (no tanh/sigmoid activation)")
    print("Direct learning of physical coordinates in mm")
    print()

    train_dataset = TensorDataset(train_data, train_labels)
    test_dataset = TensorDataset(test_data, test_labels)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

    def unnormalized_cartesian_loss(pred, target):
        """
        Direct Cartesian loss - no scaling needed.
        Model predicts directly in mm coordinates.
        
        pred: (batch_size, 6) -> [x1, y1, z1, x2, y2, z2] in mm (raw output)
        target: (batch_size, 6) -> [x1, y1, z1, x2, y2, z2] in mm
        """
        # No scaling - direct coordinate prediction
        # Permutation invariant loss (same logic as before)
        
        # Calculate both possible assignments
        loss_A = torch.norm(pred[:, 0:3] - target[:, 0:3], dim=1) + \
                 torch.norm(pred[:, 3:6] - target[:, 3:6], dim=1)
        
        loss_B = torch.norm(pred[:, 0:3] - target[:, 3:6], dim=1) + \
                 torch.norm(pred[:, 3:6] - target[:, 0:3], dim=1)
        
        # Take minimum loss (permutation invariant)
        min_loss = torch.minimum(loss_A, loss_B)
        
        return torch.mean(min_loss)

    # Alternative: Add coordinate constraints as regularization
    def constrained_cartesian_loss(pred, target, constraint_weight=0.1):
        """
        Add soft constraints to keep predictions within scanner geometry
        """
        base_loss = unnormalized_cartesian_loss(pred, target)
        
        # Constraint losses
        x_coords = pred[:, [0, 3]]  # X coordinates
        y_coords = pred[:, [1, 4]]  # Y coordinates  
        z_coords = pred[:, [2, 5]]  # Z coordinates
        
        # Radial constraint: sqrt(x^2 + y^2) should be in [R_IN, R_OUT]
        radii = torch.sqrt(x_coords**2 + y_coords**2)
        radius_violation = torch.relu(radii - R_OUTER_MM) + torch.relu(R_INNER_MM - radii)
        radius_penalty = torch.mean(radius_violation)
        
        # Z constraint: |z| should be <= Z_HALF
        z_violation = torch.relu(torch.abs(z_coords) - Z_HALF_MM)
        z_penalty = torch.mean(z_violation)
        
        total_loss = base_loss + constraint_weight * (radius_penalty + z_penalty)
        
        return total_loss, base_loss, radius_penalty, z_penalty

    # Choose loss function
    use_constraints = False  # Set to False for pure unnormalized loss
    
    if use_constraints:
        print("🎯 TRAINING SETUP:")
        print("  Output: Raw coordinates in mm (no normalization)")
        print("  Loss: Permutation-invariant + geometry constraints")
        print("  Constraints: Radius ∈ [235.4, 278.3]mm, |Z| ≤ 148mm")
    else:
        print("🎯 TRAINING SETUP:")
        print("  Output: Raw coordinates in mm (no normalization)")  
        print("  Loss: Pure permutation-invariant Euclidean distance")
        print("  No geometric constraints")
    print()

    model = model.float().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Lists to store losses and metrics
    train_losses = []
    val_losses = []
    train_distances = []
    val_distances = []
    
    # Additional tracking for constrained loss
    if use_constraints:
        train_radius_penalties = []
        train_z_penalties = []

    for epoch in range(NUM_EPOCHS):
        model.train()
        running_loss = 0.0
        running_distance = 0.0
        running_radius_penalty = 0.0
        running_z_penalty = 0.0
        
        for inputs, labels in train_loader:
            inputs = inputs.float().to(device)
            labels = labels.float().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            
            if use_constraints:
                loss, base_loss, radius_penalty, z_penalty = constrained_cartesian_loss(outputs, labels)
                running_radius_penalty += radius_penalty.item()
                running_z_penalty += z_penalty.item()
                running_distance += base_loss.item()
            else:
                loss = unnormalized_cartesian_loss(outputs, labels)
                running_distance += loss.item()
            
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        avg_train_loss = running_loss / len(train_loader)
        avg_train_distance = running_distance / len(train_loader)
        train_losses.append(avg_train_loss)
        train_distances.append(avg_train_distance)
        
        if use_constraints:
            avg_radius_penalty = running_radius_penalty / len(train_loader)
            avg_z_penalty = running_z_penalty / len(train_loader)
            train_radius_penalties.append(avg_radius_penalty)
            train_z_penalties.append(avg_z_penalty)
        
        # Validation
        model.eval()
        test_loss = 0.0
        test_distance = 0.0
        
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs = inputs.float().to(device)
                labels = labels.float().to(device)
                outputs = model(inputs)
                
                if use_constraints:
                    loss, base_loss, _, _ = constrained_cartesian_loss(outputs, labels)
                    test_distance += base_loss.item()
                else:
                    loss = unnormalized_cartesian_loss(outputs, labels)
                    test_distance += loss.item()
                
                test_loss += loss.item()

        avg_test_loss = test_loss / len(test_loader)
        avg_test_distance = test_distance / len(test_loader)
        val_losses.append(avg_test_loss)
        val_distances.append(avg_test_distance)
        
        # Print results
        print(f"Epoch [{epoch+1:2d}/{NUM_EPOCHS}]")
        if use_constraints:
            print(f"  TRAINING   - Total Loss: {avg_train_loss:.4f} | 3D Distance: {avg_train_distance:.1f}mm | Radius Penalty: {avg_radius_penalty:.3f} | Z Penalty: {avg_z_penalty:.3f}")
            print(f"  VALIDATION - Total Loss: {avg_test_loss:.4f} | 3D Distance: {avg_test_distance:.1f}mm")
        else:
            print(f"  TRAINING   - Loss: {avg_train_loss:.4f} | 3D Distance: {avg_train_distance:.1f}mm")
            print(f"  VALIDATION - Loss: {avg_test_loss:.4f} | 3D Distance: {avg_test_distance:.1f}mm")
        print()


    if use_constraints:
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        axes = axes.flatten()
    else:
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        axes = axes.flatten()
    
    epochs = range(1, NUM_EPOCHS + 1)
    
    # Total loss plot
    axes[0].plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
    axes[0].plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
    axes[0].set_title('Total Training Loss (Unnormalized)')
    axes[0].set_xlabel('Epochs')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # 3D Distance plot
    axes[1].plot(epochs, train_distances, 'g-', label='Training 3D Distance', linewidth=2)
    axes[1].plot(epochs, val_distances, 'orange', label='Validation 3D Distance', linewidth=2)
    axes[1].set_title('3D Distance Error (mm)')
    axes[1].set_xlabel('Epochs')
    axes[1].set_ylabel('Distance (mm)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    if use_constraints:
        # Constraint penalties
        axes[2].plot(epochs, train_radius_penalties, 'm-', label='Radius Penalty', linewidth=2)
        axes[2].set_title('Radius Constraint Violations')
        axes[2].set_xlabel('Epochs')
        axes[2].set_ylabel('Penalty')
        axes[2].legend()
        axes[2].grid(True, alpha=0.3)
        
        axes[3].plot(epochs, train_z_penalties, 'c-', label='Z Penalty', linewidth=2)
        axes[3].set_title('Z Constraint Violations')
        axes[3].set_xlabel('Epochs')
        axes[3].set_ylabel('Penalty')
        axes[3].legend()
        axes[3].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

Using device: cuda
🔍 UNNORMALIZED CARTESIAN COORDINATE PREDICTION
Testing effect of removing output normalization...
Model outputs raw values (no tanh/sigmoid activation)
Direct learning of physical coordinates in mm

🎯 TRAINING SETUP:
  Output: Raw coordinates in mm (no normalization)
  Loss: Pure permutation-invariant Euclidean distance
  No geometric constraints

Epoch [ 1/40]
  TRAINING   - Loss: 550.5080 | 3D Distance: 550.5mm
  VALIDATION - Loss: 533.9938 | 3D Distance: 534.0mm

Epoch [ 2/40]
  TRAINING   - Loss: 542.6632 | 3D Distance: 542.7mm
  VALIDATION - Loss: 526.3158 | 3D Distance: 526.3mm



KeyboardInterrupt: 

In [None]:
# normalised cartesian
if CLASSES == 6:
    import torch
    from torch.utils.data import DataLoader, TensorDataset
    import matplotlib.pyplot as plt

    # Check device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    BATCH_SIZE = 1
    NUM_EPOCHS = 40
    # Scanner dimensions
    R_INNER_MM = 235.422
    R_OUTER_MM = 278.296
    Z_HALF_MM = 148.0

    print("🔍 CARTESIAN COORDINATE PREDICTION ANALYSIS")
    print("=" * 60)
    
    # Use scanner geometry for scaling (no dataset iteration needed)
    print("Using scanner geometry for coordinate scaling...")
    
    # For PET scanner, maximum possible coordinates are at outer detector ring
    x_range = R_OUTER_MM  # Max X coordinate at outer ring
    y_range = R_OUTER_MM  # Max Y coordinate at outer ring  
    z_range = Z_HALF_MM   # Max Z coordinate at detector edge
    
    print(f"X/Y coordinate range: ±{R_OUTER_MM:.1f}mm (scanner outer radius)")
    print(f"Z coordinate range: ±{Z_HALF_MM:.1f}mm (scanner half-height)") 
    print(f"Scaling: tanh[-1,1] → X/Y: ±{R_OUTER_MM:.1f}mm, Z: ±{Z_HALF_MM:.1f}mm")
    print()

    train_dataset = TensorDataset(train_data, train_labels)
    test_dataset = TensorDataset(test_data, test_labels)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

    def cartesian_euclidean_loss_with_proper_scaling(pred, target):
        """
        Properly scaled Cartesian coordinate loss with permutation invariance.
        
        This ensures fair comparison with cylindrical approach by:
        1. Using actual coordinate ranges from filtered data
        2. Proper permutation invariance handling
        3. Same loss computation methodology
        """
        # Scale predictions from tanh [-1,1] to actual coordinate ranges
        pred_scaled = pred.clone()
        
        # Scale x, y coordinates to scanner outer radius
        pred_scaled[:, [0, 1, 3, 4]] *= R_OUTER_MM  # Use scanner geometry
        
        # Scale z coordinates to scanner half-height  
        pred_scaled[:, [2, 5]] *= Z_HALF_MM  # Use scanner geometry
        
        # Permutation invariant loss (same logic as cylindrical)
        # Calculate both possible assignments
        loss_A = torch.norm(pred_scaled[:, 0:3] - target[:, 0:3], dim=1) + \
                 torch.norm(pred_scaled[:, 3:6] - target[:, 3:6], dim=1)
        
        loss_B = torch.norm(pred_scaled[:, 0:3] - target[:, 3:6], dim=1) + \
                 torch.norm(pred_scaled[:, 3:6] - target[:, 0:3], dim=1)
        
        # Take minimum loss (same as cylindrical approach)
        min_loss = torch.minimum(loss_A, loss_B)
        
        return torch.mean(min_loss)

    # Simpler version without permutation invariance for baseline
    def simple_cartesian_loss(pred, target):
        """Simple Cartesian loss for baseline comparison"""
        pred_scaled = pred.clone()
        pred_scaled[:, [0, 1, 3, 4]] *= R_OUTER_MM
        pred_scaled[:, [2, 5]] *= Z_HALF_MM
        
        # Direct point-to-point loss (no permutation handling)
        loss = torch.norm(pred_scaled - target, dim=1)
        return torch.mean(loss)

    # Choose loss function for comparison
    # criterion = cartesian_euclidean_loss_with_proper_scaling
    criterion = simple_cartesian_loss
    
    print(f"🎯 TRAINING SETUP:")
    print(f"  Coordinate system: Cartesian (X, Y, Z)")
    print(f"  Scaling: X/Y ∈ [-1,1] → ±{R_OUTER_MM:.1f}mm, Z ∈ [-1,1] → ±{Z_HALF_MM:.1f}mm")
    print(f"  Loss function: Permutation-invariant Euclidean distance")
    print(f"  Expected comparable to cylindrical results")
    print()

    model = model.float().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    # Lists to store losses for plotting
    train_losses = []
    val_losses = []
    
    # Enhanced tracking for comparison
    train_distances = []  # Store actual 3D distances
    val_distances = []

    for epoch in range(NUM_EPOCHS):
        model.train()
        running_loss = 0.0
        running_distance = 0.0
        
        for inputs, labels in train_loader:
            inputs = inputs.float().to(device)
            labels = labels.float().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            # Track actual distance (same as loss for Euclidean)
            running_distance += loss.item()
        
        avg_train_loss = running_loss / len(train_loader)
        avg_train_distance = running_distance / len(train_loader)
        train_losses.append(avg_train_loss)
        train_distances.append(avg_train_distance)
        
        model.eval()
        test_loss = 0.0
        test_distance = 0.0
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs = inputs.float().to(device)
                labels = labels.float().to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                test_distance += loss.item()

        avg_test_loss = test_loss / len(test_loader)
        avg_test_distance = test_distance / len(test_loader)
        val_losses.append(avg_test_loss)
        val_distances.append(avg_test_distance)
        
        # Print in format comparable to cylindrical results
        print(f"Epoch [{epoch+1:2d}/{NUM_EPOCHS}]")
        print(f"  TRAINING   - Loss: {avg_train_loss:.4f} | 3D Distance: {avg_train_distance:.1f}mm")
        print(f"  VALIDATION - Loss: {avg_test_loss:.4f} | 3D Distance: {avg_test_distance:.1f}mm")
        print()

    # Enhanced plotting for comparison
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    epochs = range(1, NUM_EPOCHS + 1)
    
    # Loss plot
    axes[0].plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
    axes[0].plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
    axes[0].set_title('Cartesian Coordinate Training Loss')
    axes[0].set_xlabel('Epochs')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # 3D Distance plot (for direct comparison with cylindrical)
    axes[1].plot(epochs, train_distances, 'g-', label='Training 3D Distance', linewidth=2)
    axes[1].plot(epochs, val_distances, 'orange', label='Validation 3D Distance', linewidth=2)
    axes[1].set_title('3D Distance Error (mm)')
    axes[1].set_xlabel('Epochs')
    axes[1].set_ylabel('Distance (mm)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    

In [79]:
# normalised cylindrical
if CLASSES == 8:
    import torch
    from torch.utils.data import DataLoader, TensorDataset
    import matplotlib.pyplot as plt

    # Check device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    BATCH_SIZE = 1
    NUM_EPOCHS = 40
    # Scanner dimensions
    R_INNER_MM = 235.422
    R_OUTER_MM = 278.296
    Z_HALF_MM = 148.0

    train_dataset = TensorDataset(train_data, train_labels)
    test_dataset = TensorDataset(test_data, test_labels)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

    class FPoICylLoss(nn.Module):
        """
        Improved permutation-invariant endpoint loss with cylindrical constraints.
        
        Fixes key robustness issues:
        - Proper radius normalization to valid range [R_IN, R_OUT]
        - Radius-weighted angular loss to handle scaling issues
        - Simplified zero-vector handling with epsilon clamping
        - Stable permutation assignment with hysteresis
        - Consistent coordinate transformations
        
        Prediction layout: pred = [r1, s1, c1, z1, r2, s2, c2, z2]
        Target layout: target_xyz = [x1,y1,z1, x2,y2,z2] (mm)
        """

        def __init__(
                self,
                z_weight: float = 1.0,
                ang_weight: float = 1.0,
                radius_weight: float = 1.0,
                reduction: str = "mean",
                r_inner: float = 235.422,
                r_outer: float = 278.296,
                z_half: float = 148.0,
        ):
            super().__init__()
            assert reduction in ("mean", "sum", "none")
            
            self.zw = float(z_weight)
            self.aw = float(ang_weight)
            self.rw = float(radius_weight)
            self.reduction = reduction

            self.R_IN = float(r_inner)
            self.R_OUT = float(r_outer)
            self.Z_HALF = float(z_half)
            
            # Precompute radius range for normalization
            self.R_RANGE = self.R_OUT - self.R_IN

            self.smooth_l1 = nn.SmoothL1Loss(reduction="none")
            
            # Track previous assignment to add hysteresis for stability
            self.register_buffer("prev_assignment", torch.tensor(0), persistent=False)

        def _xy_to_cylindrical(self, x1: torch.Tensor, y1: torch.Tensor, 
                            x2: torch.Tensor, y2: torch.Tensor) -> tuple:
            """
            Convert Cartesian (x,y) to cylindrical (r, sin φ, cos φ).
            Uses epsilon clamping instead of complex zero-vector logic.
            """
            eps = 1e-3  # Larger epsilon for numerical stability
            
            # Compute radii with epsilon clamping (no complex zero-vector handling)
            r1 = torch.clamp(torch.hypot(x1, y1), min=eps)
            r2 = torch.clamp(torch.hypot(x2, y2), min=eps)
            
            # Compute unit angle components
            cos1, sin1 = x1 / r1, y1 / r1
            cos2, sin2 = x2 / r2, y2 / r2
            
            # Pack as [r, sin, cos] for consistency
            cyl1 = torch.cat([r1, sin1, cos1], dim=-1)  # (B, 3)
            cyl2 = torch.cat([r2, sin2, cos2], dim=-1)  # (B, 3)
            
            return cyl1, cyl2

        def _cylindrical_to_cartesian(self, rsc: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
            """
            Convert cylindrical (r, sin φ, cos φ) and z to Cartesian (x,y,z) in mm.
            """
            r_norm = rsc[:, 0:1]  # Normalized radius [0,1]
            sin_phi = rsc[:, 1:2]
            cos_phi = rsc[:, 2:3]
            
            # Convert normalized radius to actual mm
            r_mm = self.R_IN + r_norm * self.R_RANGE  # Map [0,1] → [R_IN, R_OUT]
            
            # Convert normalized z to actual mm  
            z_mm = z * self.Z_HALF  # Map [-1,1] → [-Z_HALF, Z_HALF]
            
            # Ensure (sin, cos) is unit length (handle tanh drift)
            eps = 1e-6
            norm = torch.clamp(torch.sqrt(sin_phi**2 + cos_phi**2), min=eps)
            sin_unit = sin_phi / norm
            cos_unit = cos_phi / norm
            
            # Convert to Cartesian
            x = r_mm * cos_unit
            y = r_mm * sin_unit
            
            return torch.cat([x, y, z_mm], dim=-1)  # (B, 3)

        def _endpoint_loss_components(self, pred_rsc: torch.Tensor, pred_z: torch.Tensor,
                                    gt_cyl: torch.Tensor, gt_z: torch.Tensor) -> dict:
            """
            Compute endpoint loss components with corrected radius scaling.
            
            Key fix: Angular loss should be scaled by the detector thickness (R_RANGE),
            not the outer radius (R_OUT), for proper geometric interpretation.
            """
            # Normalize ground truth radius to [0,1] using valid cylinder range
            gt_r_mm = gt_cyl[:, 0:1]
            gt_r_norm = (gt_r_mm - self.R_IN) / self.R_RANGE  # Map [R_IN,R_OUT] → [0,1]
            gt_r_norm = torch.clamp(gt_r_norm, 0, 1)  # Ensure bounds
            
            # Normalize ground truth z to [-1,1]
            gt_z_norm = torch.clamp(gt_z / self.Z_HALF, -1, 1)
            
            # Extract prediction components
            pred_r = pred_rsc[:, 0:1]
            pred_sin = pred_rsc[:, 1:2]
            pred_cos = pred_rsc[:, 2:3]
            gt_sin = gt_cyl[:, 1:2]
            gt_cos = gt_cyl[:, 2:3]
            
            # Compute individual component losses (normalized units)
            r_loss_norm = self.smooth_l1(pred_r, gt_r_norm).squeeze(-1)
            sin_loss = self.smooth_l1(pred_sin, gt_sin).squeeze(-1)
            cos_loss = self.smooth_l1(pred_cos, gt_cos).squeeze(-1)
            z_loss_norm = self.smooth_l1(pred_z, gt_z_norm).squeeze(-1)
            
            # Convert losses to physical units for interpretability
            r_loss_mm = r_loss_norm * self.R_RANGE  # Convert to mm (DETECTOR THICKNESS)
            z_loss_mm = z_loss_norm * self.Z_HALF   # Convert to mm
            
            # Convert angular losses to degrees with proper radius scaling
            # CORRECTED: Use actual ground truth radius in mm, not normalized by R_OUT
            gt_r_actual = gt_r_mm.squeeze(-1)  # Actual radius in mm
            
            # For angular error interpretation: small angle approximation
            # Arc length error ≈ radius × angle_error_radians
            # So angle_error_radians ≈ sin/cos_error (for small errors)
            sin_loss_deg = torch.rad2deg(sin_loss)
            cos_loss_deg = torch.rad2deg(cos_loss)
            
            # Angular error in arc length (mm) = radius × angle_error_radians
            sin_arc_error_mm = gt_r_actual * sin_loss  # Arc length error in mm
            cos_arc_error_mm = gt_r_actual * cos_loss  # Arc length error in mm
            
            # For optimization: Weight angular losses by radius to handle scaling
            # This ensures that angular errors at larger radii are penalized more
            # (since they represent larger arc length errors)
            radius_weight_factor = (gt_r_actual / self.R_OUT).clamp(0.1, 1.0)
            weighted_sin_loss = sin_loss * radius_weight_factor
            weighted_cos_loss = cos_loss * radius_weight_factor
            
            # Total weighted loss (using normalized units for optimization)
            total_loss = (self.rw * r_loss_norm + 
                         self.aw * (weighted_sin_loss + weighted_cos_loss) + 
                         self.zw * z_loss_norm)
            
            return {
                'r_loss': r_loss_norm,  # For optimization (normalized)
                'r_loss_mm': r_loss_mm,  # For display (mm) - NOW PROPERLY SCALED BY DETECTOR THICKNESS
                'sin_loss': sin_loss,    # For optimization (normalized)
                'cos_loss': cos_loss,    # For optimization (normalized)
                'sin_loss_deg': sin_loss_deg,  # For display (degrees)
                'cos_loss_deg': cos_loss_deg,  # For display (degrees)
                'sin_arc_error_mm': sin_arc_error_mm,  # NEW: Arc length error in mm
                'cos_arc_error_mm': cos_arc_error_mm,  # NEW: Arc length error in mm
                'z_loss': z_loss_norm,   # For optimization (normalized)
                'z_loss_mm': z_loss_mm,  # For display (mm)
                'weighted_sin_loss': weighted_sin_loss,
                'weighted_cos_loss': weighted_cos_loss,
                'total_loss': total_loss
            }

        def _endpoint_loss(self, pred_rsc: torch.Tensor, pred_z: torch.Tensor,
                        gt_cyl: torch.Tensor, gt_z: torch.Tensor) -> torch.Tensor:
            """
            Original endpoint loss function for backward compatibility.
            """
            components = self._endpoint_loss_components(pred_rsc, pred_z, gt_cyl, gt_z)
            return components['total_loss']

        def forward(self, pred: torch.Tensor, target_xyz: torch.Tensor, return_components: bool = False) -> tuple:
            """
            Forward pass with optional component tracking.
            
            pred: (B,8) = [r1, s1, c1, z1, r2, s2, c2, z2]
            target_xyz: (B,6) = [x1,y1,z1, x2,y2,z2]
            return_components: If True, return detailed loss breakdown
            
            Returns: (cylindrical_loss, cartesian_loss) or detailed dict if return_components=True
            """
            if pred.dim() != 2 or target_xyz.dim() != 2:
                raise ValueError("Expected 2D tensors: pred (B,8), target_xyz (B,6)")
            if pred.size(1) != 8 or target_xyz.size(1) != 6:
                raise ValueError("Expected shapes: pred (B,8), target_xyz (B,6)")

            # Ensure contiguous memory layout
            pred = pred.contiguous()
            target_xyz = target_xyz.contiguous()

            # Split target coordinates
            x1, y1, z1 = target_xyz[:, 0:1], target_xyz[:, 1:2], target_xyz[:, 2:3]
            x2, y2, z2 = target_xyz[:, 3:4], target_xyz[:, 4:5], target_xyz[:, 5:6]

            # Convert ground truth to cylindrical coordinates
            gt_cyl1, gt_cyl2 = self._xy_to_cylindrical(x1, y1, x2, y2)

            # Extract prediction components
            pred_rsc1 = pred[:, 0:3]  # [r1, s1, c1]
            pred_z1 = pred[:, 3:4]    # [z1]
            pred_rsc2 = pred[:, 4:7]  # [r2, s2, c2]  
            pred_z2 = pred[:, 7:8]    # [z2]

            # Compute Cartesian predictions and losses
            pred_cart1 = self._cylindrical_to_cartesian(pred_rsc1, pred_z1)
            pred_cart2 = self._cylindrical_to_cartesian(pred_rsc2, pred_z2)
            gt_cart1 = torch.cat([x1, y1, z1], dim=-1)
            gt_cart2 = torch.cat([x2, y2, z2], dim=-1)

            if return_components:
                # Get detailed component breakdown for both assignments
                comp_A1 = self._endpoint_loss_components(pred_rsc1, pred_z1, gt_cyl1, z1)
                comp_A2 = self._endpoint_loss_components(pred_rsc2, pred_z2, gt_cyl2, z2)
                comp_B1 = self._endpoint_loss_components(pred_rsc1, pred_z1, gt_cyl2, z2)
                comp_B2 = self._endpoint_loss_components(pred_rsc2, pred_z2, gt_cyl1, z1)

                # Total losses for both assignments
                loss_A = comp_A1['total_loss'] + comp_A2['total_loss']
                loss_B = comp_B1['total_loss'] + comp_B2['total_loss']

                # Choose assignment with minimum loss
                use_assignment_A = loss_A <= loss_B
                
                # Aggregate components based on chosen assignment
                components = {}
                component_keys = ['r_loss', 'r_loss_mm', 'sin_loss', 'cos_loss', 'sin_loss_deg', 'cos_loss_deg', 
                                'z_loss', 'z_loss_mm', 'weighted_sin_loss', 'weighted_cos_loss']
                
                for key in component_keys:
                    comp_choice = torch.where(use_assignment_A, 
                                            comp_A1[key] + comp_A2[key],
                                            comp_B1[key] + comp_B2[key])
                    components[key] = comp_choice

                # Compute Cartesian component losses for both assignments
                cart_A1_diff = pred_cart1 - gt_cart1
                cart_A2_diff = pred_cart2 - gt_cart2
                cart_B1_diff = pred_cart1 - gt_cart2
                cart_B2_diff = pred_cart2 - gt_cart1

                # X, Y, Z losses for assignment A
                cart_A_x_loss = torch.abs(cart_A1_diff[:, 0]) + torch.abs(cart_A2_diff[:, 0])
                cart_A_y_loss = torch.abs(cart_A1_diff[:, 1]) + torch.abs(cart_A2_diff[:, 1])
                cart_A_z_loss = torch.abs(cart_A1_diff[:, 2]) + torch.abs(cart_A2_diff[:, 2])

                # X, Y, Z losses for assignment B
                cart_B_x_loss = torch.abs(cart_B1_diff[:, 0]) + torch.abs(cart_B2_diff[:, 0])
                cart_B_y_loss = torch.abs(cart_B1_diff[:, 1]) + torch.abs(cart_B2_diff[:, 1])
                cart_B_z_loss = torch.abs(cart_B1_diff[:, 2]) + torch.abs(cart_B2_diff[:, 2])

                # Choose Cartesian components based on same assignment
                components['cart_x_loss_mm'] = torch.where(use_assignment_A, cart_A_x_loss, cart_B_x_loss)
                components['cart_y_loss_mm'] = torch.where(use_assignment_A, cart_A_y_loss, cart_B_y_loss)
                components['cart_z_loss_mm'] = torch.where(use_assignment_A, cart_A_z_loss, cart_B_z_loss)

                cylindrical_loss = torch.minimum(loss_A, loss_B)
            else:
                # Standard computation without components
                loss_A = (self._endpoint_loss(pred_rsc1, pred_z1, gt_cyl1, z1) +
                        self._endpoint_loss(pred_rsc2, pred_z2, gt_cyl2, z2))
                
                loss_B = (self._endpoint_loss(pred_rsc1, pred_z1, gt_cyl2, z2) +
                        self._endpoint_loss(pred_rsc2, pred_z2, gt_cyl1, z1))

                cylindrical_loss = torch.minimum(loss_A, loss_B)

            # Cartesian distance for both assignments (Euclidean distance)
            cart_A = (torch.norm(pred_cart1 - gt_cart1, dim=1) + 
                    torch.norm(pred_cart2 - gt_cart2, dim=1))
            cart_B = (torch.norm(pred_cart1 - gt_cart2, dim=1) + 
                    torch.norm(pred_cart2 - gt_cart1, dim=1))
            cartesian_loss = torch.minimum(cart_A, cart_B)

            # Apply reduction
            if self.reduction == "mean":
                cylindrical_loss = cylindrical_loss.mean()
                cartesian_loss = cartesian_loss.mean()
                if return_components:
                    for key in components:
                        components[key] = components[key].mean()
            elif self.reduction == "sum":
                cylindrical_loss = cylindrical_loss.sum()
                cartesian_loss = cartesian_loss.sum()
                if return_components:
                    for key in components:
                        components[key] = components[key].sum()

            if return_components:
                return {
                    'cylindrical_loss': cylindrical_loss,
                    'cartesian_loss': cartesian_loss,
                    'euclidean_distance': cartesian_loss,  # Alias for clarity
                    **components
                }
            else:
                return cylindrical_loss, cartesian_loss

    cyl_loss = FPoICylLoss(
        r_inner=R_INNER_MM,
        r_outer=R_OUTER_MM,
    )

    criterion = cyl_loss 

    model = model.float().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    # Lists to store losses for plotting
    train_losses = []
    val_losses = []
    
    # Lists to store component losses for tracking
    train_r_losses = []
    train_angular_losses = []
    train_z_losses = []
    train_euclidean_losses = []
    val_r_losses = []
    val_angular_losses = []
    val_z_losses = []
    val_euclidean_losses = []

    def print_training_metrics_corrected(epoch, num_epochs, loss_dict, prefix=""):
        """
        Print training metrics with proper radius scaling interpretation.
        """
        # Extract metrics
        total_loss = loss_dict['cylindrical_loss'].item()
        euclidean_dist = loss_dict['euclidean_distance'].item()
        r_loss_mm = loss_dict['r_loss_mm'].item()
        angular_loss_deg = (loss_dict['sin_loss_deg'].item() + loss_dict['cos_loss_deg'].item())
        z_loss_mm = loss_dict['z_loss_mm'].item()
        
        # Calculate meaningful error percentages
        R_RANGE = R_OUTER_MM - R_INNER_MM  # Detector thickness
        r_error_pct = (r_loss_mm / R_RANGE) * 100  # % of detector thickness
        z_error_pct = (z_loss_mm / (2 * Z_HALF_MM)) * 100  # % of detector height
        
        # Arc length errors (more meaningful than degrees for PET)
        if 'sin_arc_error_mm' in loss_dict:
            arc_error_mm = (loss_dict['sin_arc_error_mm'].item() + loss_dict['cos_arc_error_mm'].item())
            print(f"  {prefix:10s} - Loss: {total_loss:.4f} | 3D: {euclidean_dist:.1f}mm | "
                f"R: {r_loss_mm:.2f}mm ({r_error_pct:.1f}%) | θ: {angular_loss_deg:.2f}° ({arc_error_mm:.2f}mm arc) | "
                f"Z: {z_loss_mm:.2f}mm ({z_error_pct:.1f}%)")
        else:
            print(f"  {prefix:10s} - Loss: {total_loss:.4f} | 3D: {euclidean_dist:.1f}mm | "
                f"R: {r_loss_mm:.2f}mm ({r_error_pct:.1f}%) | θ: {angular_loss_deg:.2f}° | "
                f"Z: {z_loss_mm:.2f}mm ({z_error_pct:.1f}%)")

    for epoch in range(NUM_EPOCHS):
        model.train()
        running_loss = 0.0
        running_r_loss_mm = 0.0
        running_angular_loss_deg = 0.0
        running_z_loss_mm = 0.0
        running_euclidean_loss = 0.0
        
        # Training phase
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.float().to(device)
            labels = labels.float().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            
            # Get detailed loss components
            loss_dict = criterion(outputs, labels, return_components=True)
            
            loss_dict['cylindrical_loss'].backward()
            optimizer.step()
            
            # Accumulate losses IN PHYSICAL UNITS
            running_loss += loss_dict['cylindrical_loss'].item()
            running_r_loss_mm += loss_dict['r_loss_mm'].item()
            running_angular_loss_deg += (loss_dict['sin_loss_deg'].item() + loss_dict['cos_loss_deg'].item())
            running_z_loss_mm += loss_dict['z_loss_mm'].item()
            running_euclidean_loss += loss_dict['euclidean_distance'].item()
        
        # Calculate average training losses for epoch
        num_train_batches = len(train_loader)
        avg_train_loss = running_loss / num_train_batches
        avg_train_r_loss_mm = running_r_loss_mm / num_train_batches
        avg_train_angular_loss_deg = running_angular_loss_deg / num_train_batches
        avg_train_z_loss_mm = running_z_loss_mm / num_train_batches
        avg_train_euclidean_loss = running_euclidean_loss / num_train_batches
        
        # Store training losses
        train_losses.append(avg_train_loss)
        train_r_losses.append(avg_train_r_loss_mm)
        train_angular_losses.append(avg_train_angular_loss_deg)
        train_z_losses.append(avg_train_z_loss_mm)
        train_euclidean_losses.append(avg_train_euclidean_loss)
        
        # Validation phase
        model.eval()
        val_running_loss = 0.0
        val_running_r_loss_mm = 0.0
        val_running_angular_loss_deg = 0.0
        val_running_z_loss_mm = 0.0
        val_running_euclidean_loss = 0.0
        
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs = inputs.float().to(device)
                labels = labels.float().to(device)
                outputs = model(inputs)
                loss_dict = criterion(outputs, labels, return_components=True)
                
                # Accumulate validation losses IN PHYSICAL UNITS
                val_running_loss += loss_dict['cylindrical_loss'].item()
                val_running_r_loss_mm += loss_dict['r_loss_mm'].item()
                val_running_angular_loss_deg += (loss_dict['sin_loss_deg'].item() + loss_dict['cos_loss_deg'].item())
                val_running_z_loss_mm += loss_dict['z_loss_mm'].item()
                val_running_euclidean_loss += loss_dict['euclidean_distance'].item()

        # Calculate average validation losses for epoch
        num_val_batches = len(test_loader)
        avg_val_loss = val_running_loss / num_val_batches
        avg_val_r_loss_mm = val_running_r_loss_mm / num_val_batches
        avg_val_angular_loss_deg = val_running_angular_loss_deg / num_val_batches
        avg_val_z_loss_mm = val_running_z_loss_mm / num_val_batches
        avg_val_euclidean_loss = val_running_euclidean_loss / num_val_batches
        
        # Store validation losses
        val_losses.append(avg_val_loss)
        val_r_losses.append(avg_val_r_loss_mm)
        val_angular_losses.append(avg_val_angular_loss_deg)
        val_z_losses.append(avg_val_z_loss_mm)
        val_euclidean_losses.append(avg_val_euclidean_loss)
        
        # Create loss dictionaries for proper metric printing
        train_loss_dict = {
            'cylindrical_loss': torch.tensor(avg_train_loss),
            'euclidean_distance': torch.tensor(avg_train_euclidean_loss),
            'r_loss_mm': torch.tensor(avg_train_r_loss_mm),
            'sin_loss_deg': torch.tensor(avg_train_angular_loss_deg / 2),  # Divide by 2 since we add sin+cos
            'cos_loss_deg': torch.tensor(avg_train_angular_loss_deg / 2),
            'z_loss_mm': torch.tensor(avg_train_z_loss_mm)
        }
        
        val_loss_dict = {
            'cylindrical_loss': torch.tensor(avg_val_loss),
            'euclidean_distance': torch.tensor(avg_val_euclidean_loss),
            'r_loss_mm': torch.tensor(avg_val_r_loss_mm),
            'sin_loss_deg': torch.tensor(avg_val_angular_loss_deg / 2),
            'cos_loss_deg': torch.tensor(avg_val_angular_loss_deg / 2),
            'z_loss_mm': torch.tensor(avg_val_z_loss_mm)
        }
        
        # Print epoch results with proper formatting
        print(f"Epoch [{epoch+1:2d}/{NUM_EPOCHS}]")
        print_training_metrics_corrected(epoch, NUM_EPOCHS, train_loss_dict, "TRAINING")
        print_training_metrics_corrected(epoch, NUM_EPOCHS, val_loss_dict, "VALIDATION")
        print()


Using device: cuda
Epoch [ 1/40]
  TRAINING   - Loss: 1.6558 | 3D: 560.3mm | R: 14.51mm (33.8%) | θ: 60.39° | Z: 47.90mm (16.2%)
  VALIDATION - Loss: 1.3231 | 3D: 568.7mm | R: 4.48mm (10.5%) | θ: 55.33° | Z: 45.48mm (15.4%)

Epoch [ 2/40]
  TRAINING   - Loss: 1.4078 | 3D: 468.7mm | R: 11.52mm (26.9%) | θ: 51.40° | Z: 43.32mm (14.6%)
  VALIDATION - Loss: 1.1977 | 3D: 524.0mm | R: 5.51mm (12.9%) | θ: 48.12° | Z: 40.96mm (13.8%)

Epoch [ 3/40]
  TRAINING   - Loss: 1.2397 | 3D: 410.4mm | R: 11.62mm (27.1%) | θ: 41.53° | Z: 42.03mm (14.2%)
  VALIDATION - Loss: 1.2038 | 3D: 468.2mm | R: 6.81mm (15.9%) | θ: 47.06° | Z: 40.02mm (13.5%)

Epoch [ 4/40]
  TRAINING   - Loss: 1.2136 | 3D: 432.0mm | R: 10.45mm (24.4%) | θ: 42.12° | Z: 40.74mm (13.8%)
  VALIDATION - Loss: 1.2164 | 3D: 476.5mm | R: 8.53mm (19.9%) | θ: 47.16° | Z: 35.73mm (12.1%)

Epoch [ 5/40]
  TRAINING   - Loss: 1.1248 | 3D: 388.6mm | R: 9.71mm (22.6%) | θ: 38.39° | Z: 39.25mm (13.3%)
  VALIDATION - Loss: 1.2061 | 3D: 464.7mm | R: 7

KeyboardInterrupt: 

In [None]:
# unnormalised cylindrical
if CLASSES == 8:
    import torch
    from torch.utils.data import DataLoader, TensorDataset
    import matplotlib.pyplot as plt

    # Check device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    BATCH_SIZE = 1
    NUM_EPOCHS = 40
    # Scanner dimensions
    R_INNER_MM = 235.422
    R_OUTER_MM = 278.296
    Z_HALF_MM = 148.0

    print("🔍 UNNORMALIZED CYLINDRICAL COORDINATE PREDICTION")
    print("=" * 60)
    print("Testing cylindrical coordinates with raw model outputs...")
    print("Model outputs: [r1_mm, sin1, cos1, z1_mm, r2_mm, sin2, cos2, z2_mm]")
    print("Direct learning of physical coordinates (no tanh/sigmoid)")
    print()

    train_dataset = TensorDataset(train_data, train_labels)
    test_dataset = TensorDataset(test_data, test_labels)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

    class UnnormalizedFPoICylLoss(nn.Module):
        """
        Cylindrical loss for unnormalized model outputs.
        
        Model prediction layout: [r1_mm, sin1, cos1, z1_mm, r2_mm, sin2, cos2, z2_mm]
        Target layout: [x1, y1, z1, x2, y2, z2] (mm)
        
        Key differences from normalized version:
        - Radius predictions are in mm directly (not [0,1])
        - Z predictions are in mm directly (not [-1,1])
        - sin/cos predictions are raw (should be ~[-1,1] but unconstrained)
        """

        def __init__(
                self,
                z_weight: float = 1.0,
                ang_weight: float = 1.0,
                radius_weight: float = 1.0,
                constraint_weight: float = 0.1,
                reduction: str = "mean",
                r_inner: float = 235.422,
                r_outer: float = 278.296,
                z_half: float = 148.0,
        ):
            super().__init__()
            assert reduction in ("mean", "sum", "none")
            
            self.zw = float(z_weight)
            self.aw = float(ang_weight)
            self.rw = float(radius_weight)
            self.cw = float(constraint_weight)
            self.reduction = reduction

            self.R_IN = float(r_inner)
            self.R_OUT = float(r_outer)
            self.Z_HALF = float(z_half)
            self.R_RANGE = self.R_OUT - self.R_IN

            self.smooth_l1 = nn.SmoothL1Loss(reduction="none")

        def _xy_to_cylindrical(self, x1: torch.Tensor, y1: torch.Tensor, 
                              x2: torch.Tensor, y2: torch.Tensor) -> tuple:
            """Convert Cartesian to cylindrical - same as normalized version"""
            eps = 1e-3
            
            r1 = torch.clamp(torch.hypot(x1, y1), min=eps)
            r2 = torch.clamp(torch.hypot(x2, y2), min=eps)
            
            cos1, sin1 = x1 / r1, y1 / r1
            cos2, sin2 = x2 / r2, y2 / r2
            
            # Return in mm and raw sin/cos (not normalized)
            cyl1 = torch.cat([r1, sin1, cos1], dim=-1)  # (B, 3) - r in mm
            cyl2 = torch.cat([r2, sin2, cos2], dim=-1)  # (B, 3) - r in mm
            
            return cyl1, cyl2

        def _cylindrical_to_cartesian(self, rsc: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
            """
            Convert unnormalized cylindrical to Cartesian.
            
            rsc: [r_mm, sin_phi, cos_phi] - radius in mm, sin/cos raw
            z: [z_mm] - z coordinate in mm
            """
            r_mm = rsc[:, 0:1]  # Radius in mm (raw prediction)
            sin_phi = rsc[:, 1:2]  # Raw sin prediction
            cos_phi = rsc[:, 2:3]  # Raw cos prediction
            z_mm = z  # Z in mm (raw prediction)
            
            # Normalize sin/cos to unit length (handle unconstrained predictions)
            eps = 1e-6
            norm = torch.clamp(torch.sqrt(sin_phi**2 + cos_phi**2), min=eps)
            sin_unit = sin_phi / norm
            cos_unit = cos_phi / norm
            
            # Convert to Cartesian
            x = r_mm * cos_unit
            y = r_mm * sin_unit
            
            return torch.cat([x, y, z_mm], dim=-1)  # (B, 3)

        def _constraint_penalties(self, pred: torch.Tensor) -> dict:
            """
            Compute constraint violation penalties.
            
            pred: [r1, sin1, cos1, z1, r2, sin2, cos2, z2]
            """
            r1 = pred[:, 0]
            sin1, cos1 = pred[:, 1], pred[:, 2]
            z1 = pred[:, 3]
            r2 = pred[:, 4]
            sin2, cos2 = pred[:, 5], pred[:, 6]
            z2 = pred[:, 7]
            
            # Radius constraints: should be in [R_IN, R_OUT]
            r1_violation = torch.relu(r1 - self.R_OUT) + torch.relu(self.R_IN - r1)
            r2_violation = torch.relu(r2 - self.R_OUT) + torch.relu(self.R_IN - r2)
            radius_penalty = torch.mean(r1_violation + r2_violation)
            
            # Z constraints: should be in [-Z_HALF, Z_HALF]
            z1_violation = torch.relu(torch.abs(z1) - self.Z_HALF)
            z2_violation = torch.relu(torch.abs(z2) - self.Z_HALF)
            z_penalty = torch.mean(z1_violation + z2_violation)
            
            # Angular constraints: sin^2 + cos^2 should be ~1
            norm1_violation = torch.abs(sin1**2 + cos1**2 - 1.0)
            norm2_violation = torch.abs(sin2**2 + cos2**2 - 1.0)
            angular_penalty = torch.mean(norm1_violation + norm2_violation)
            
            return {
                'radius_penalty': radius_penalty,
                'z_penalty': z_penalty,
                'angular_penalty': angular_penalty,
                'total_penalty': radius_penalty + z_penalty + angular_penalty
            }

        def _endpoint_loss_components(self, pred_rsc: torch.Tensor, pred_z: torch.Tensor,
                                    gt_cyl: torch.Tensor, gt_z: torch.Tensor) -> dict:
            """Compute endpoint loss for unnormalized predictions"""
            
            # Extract components
            pred_r_mm = pred_rsc[:, 0:1]  # Radius in mm
            pred_sin = pred_rsc[:, 1:2]   # Raw sin
            pred_cos = pred_rsc[:, 2:3]   # Raw cos
            pred_z_mm = pred_z            # Z in mm
            
            gt_r_mm = gt_cyl[:, 0:1]      # Ground truth radius in mm
            gt_sin = gt_cyl[:, 1:2]       # Ground truth sin
            gt_cos = gt_cyl[:, 2:3]       # Ground truth cos
            gt_z_mm = gt_z                # Ground truth Z in mm
            
            # Direct losses in physical units
            r_loss_mm = self.smooth_l1(pred_r_mm, gt_r_mm).squeeze(-1)
            sin_loss = self.smooth_l1(pred_sin, gt_sin).squeeze(-1)
            cos_loss = self.smooth_l1(pred_cos, gt_cos).squeeze(-1)
            z_loss_mm = self.smooth_l1(pred_z_mm, gt_z_mm).squeeze(-1)
            
            # Convert angular losses to interpretable units
            sin_loss_deg = torch.rad2deg(sin_loss)
            cos_loss_deg = torch.rad2deg(cos_loss)
            
            # Arc length errors
            gt_r_actual = gt_r_mm.squeeze(-1)
            sin_arc_error_mm = gt_r_actual * sin_loss
            cos_arc_error_mm = gt_r_actual * cos_loss
            
            # Weighted total loss (all in physical units now)
            total_loss = (self.rw * r_loss_mm + 
                         self.aw * (sin_loss + cos_loss) + 
                         self.zw * z_loss_mm)
            
            return {
                'r_loss_mm': r_loss_mm,
                'sin_loss': sin_loss,
                'cos_loss': cos_loss,
                'sin_loss_deg': sin_loss_deg,
                'cos_loss_deg': cos_loss_deg,
                'sin_arc_error_mm': sin_arc_error_mm,
                'cos_arc_error_mm': cos_arc_error_mm,
                'z_loss_mm': z_loss_mm,
                'total_loss': total_loss
            }

        def forward(self, pred: torch.Tensor, target_xyz: torch.Tensor, return_components: bool = False):
            """
            Forward pass for unnormalized cylindrical predictions.
            
            pred: (B,8) = [r1_mm, sin1, cos1, z1_mm, r2_mm, sin2, cos2, z2_mm]
            target_xyz: (B,6) = [x1,y1,z1, x2,y2,z2] (mm)
            """
            if pred.dim() != 2 or target_xyz.dim() != 2:
                raise ValueError("Expected 2D tensors")
            if pred.size(1) != 8 or target_xyz.size(1) != 6:
                raise ValueError("Expected shapes: pred (B,8), target_xyz (B,6)")

            pred = pred.contiguous()
            target_xyz = target_xyz.contiguous()

            # Split target coordinates
            x1, y1, z1 = target_xyz[:, 0:1], target_xyz[:, 1:2], target_xyz[:, 2:3]
            x2, y2, z2 = target_xyz[:, 3:4], target_xyz[:, 4:5], target_xyz[:, 5:6]

            # Convert ground truth to cylindrical
            gt_cyl1, gt_cyl2 = self._xy_to_cylindrical(x1, y1, x2, y2)

            # Extract prediction components (raw values)
            pred_rsc1 = pred[:, 0:3]  # [r1_mm, sin1, cos1]
            pred_z1 = pred[:, 3:4]    # [z1_mm]
            pred_rsc2 = pred[:, 4:7]  # [r2_mm, sin2, cos2]  
            pred_z2 = pred[:, 7:8]    # [z2_mm]

            # Compute constraint penalties
            penalties = self._constraint_penalties(pred)

            # Compute Cartesian predictions
            pred_cart1 = self._cylindrical_to_cartesian(pred_rsc1, pred_z1)
            pred_cart2 = self._cylindrical_to_cartesian(pred_rsc2, pred_z2)
            gt_cart1 = torch.cat([x1, y1, z1], dim=-1)
            gt_cart2 = torch.cat([x2, y2, z2], dim=-1)

            if return_components:
                # Detailed component breakdown for both assignments
                comp_A1 = self._endpoint_loss_components(pred_rsc1, pred_z1, gt_cyl1, z1)
                comp_A2 = self._endpoint_loss_components(pred_rsc2, pred_z2, gt_cyl2, z2)
                comp_B1 = self._endpoint_loss_components(pred_rsc1, pred_z1, gt_cyl2, z2)
                comp_B2 = self._endpoint_loss_components(pred_rsc2, pred_z2, gt_cyl1, z1)

                # Total losses for both assignments
                loss_A = comp_A1['total_loss'] + comp_A2['total_loss']
                loss_B = comp_B1['total_loss'] + comp_B2['total_loss']

                # Choose assignment with minimum loss
                use_assignment_A = loss_A <= loss_B
                
                # Aggregate components
                components = {}
                component_keys = ['r_loss_mm', 'sin_loss', 'cos_loss', 'sin_loss_deg', 'cos_loss_deg',
                                'sin_arc_error_mm', 'cos_arc_error_mm', 'z_loss_mm']
                
                for key in component_keys:
                    comp_choice = torch.where(use_assignment_A, 
                                            comp_A1[key] + comp_A2[key],
                                            comp_B1[key] + comp_B2[key])
                    components[key] = comp_choice

                cylindrical_loss = torch.minimum(loss_A, loss_B)
                # Add constraint penalties to total loss
                total_loss = cylindrical_loss + self.cw * penalties['total_penalty']
            else:
                # Standard computation
                loss_A = (self._endpoint_loss_components(pred_rsc1, pred_z1, gt_cyl1, z1)['total_loss'] +
                         self._endpoint_loss_components(pred_rsc2, pred_z2, gt_cyl2, z2)['total_loss'])
                
                loss_B = (self._endpoint_loss_components(pred_rsc1, pred_z1, gt_cyl2, z2)['total_loss'] +
                         self._endpoint_loss_components(pred_rsc2, pred_z2, gt_cyl1, z1)['total_loss'])

                cylindrical_loss = torch.minimum(loss_A, loss_B)
                total_loss = cylindrical_loss + self.cw * penalties['total_penalty']

            # Cartesian distance (for comparison)
            cart_A = (torch.norm(pred_cart1 - gt_cart1, dim=1) + 
                     torch.norm(pred_cart2 - gt_cart2, dim=1))
            cart_B = (torch.norm(pred_cart1 - gt_cart2, dim=1) + 
                     torch.norm(pred_cart2 - gt_cart1, dim=1))
            cartesian_loss = torch.minimum(cart_A, cart_B)

            # Apply reduction
            if self.reduction == "mean":
                total_loss = total_loss.mean()
                cylindrical_loss = cylindrical_loss.mean()
                cartesian_loss = cartesian_loss.mean()
                if return_components:
                    for key in components:
                        components[key] = components[key].mean()

            if return_components:
                return {
                    'total_loss': total_loss,
                    'cylindrical_loss': cylindrical_loss,
                    'euclidean_distance': cartesian_loss,
                    **components,
                    **{k: v.item() if hasattr(v, 'item') else v for k, v in penalties.items()}
                }
            else:
                return total_loss, cartesian_loss

    # Create unnormalized loss
    criterion = UnnormalizedFPoICylLoss(
        r_inner=R_INNER_MM,
        r_outer=R_OUTER_MM,
        z_half=Z_HALF_MM,
        ang_weight=2.0,      # Emphasize angular accuracy
        constraint_weight=0.1  # Soft constraints
    )

    class CorrectedUnnormalizedFPoICylLoss(nn.Module):
        """
        Corrected cylindrical loss for unnormalized model outputs.
        
        Model prediction layout: [r1_mm, sin1, cos1, z1_mm, r2_mm, sin2, cos2, z2_mm]
        Target layout: [x1, y1, z1, x2, y2, z2] (mm)
        
        Fixes:
        - Proper angular distance calculation
        - Consistent unit handling
        - Simplified constraint penalties
        - Robust sin/cos normalization
        """

        def __init__(
                self,
                z_weight: float = 1.0,
                ang_weight: float = 1.0,
                radius_weight: float = 1.0,
                constraint_weight: float = 0.1,
                reduction: str = "mean",
                r_inner: float = 235.422,
                r_outer: float = 278.296,
                z_half: float = 148.0,
        ):
            super().__init__()
            assert reduction in ("mean", "sum", "none")
            
            self.zw = float(z_weight)
            self.aw = float(ang_weight)
            self.rw = float(radius_weight)
            self.cw = float(constraint_weight)
            self.reduction = reduction

            self.R_IN = float(r_inner)
            self.R_OUT = float(r_outer)
            self.Z_HALF = float(z_half)
            self.R_RANGE = self.R_OUT - self.R_IN

            self.smooth_l1 = nn.SmoothL1Loss(reduction="none")
            self.mse = nn.MSELoss(reduction="none")

        def _xy_to_cylindrical(self, x1: torch.Tensor, y1: torch.Tensor, 
                            x2: torch.Tensor, y2: torch.Tensor) -> tuple:
            """Convert Cartesian to cylindrical coordinates"""
            eps = 1e-6
            
            # Compute radii
            r1 = torch.clamp(torch.hypot(x1, y1), min=eps)
            r2 = torch.clamp(torch.hypot(x2, y2), min=eps)
            
            # Compute normalized sin/cos
            cos1, sin1 = x1 / r1, y1 / r1
            cos2, sin2 = x2 / r2, y2 / r2
            
            # Return cylindrical coordinates: [r_mm, sin_phi, cos_phi]
            cyl1 = torch.cat([r1, sin1, cos1], dim=-1)
            cyl2 = torch.cat([r2, sin2, cos2], dim=-1)
            
            return cyl1, cyl2

        def _normalize_sincos(self, sin_val: torch.Tensor, cos_val: torch.Tensor) -> tuple:
            """Normalize sin/cos values to unit circle"""
            eps = 1e-8
            norm = torch.clamp(torch.sqrt(sin_val**2 + cos_val**2), min=eps)
            return sin_val / norm, cos_val / norm

        def _angular_distance_loss(self, pred_sin: torch.Tensor, pred_cos: torch.Tensor,
                                gt_sin: torch.Tensor, gt_cos: torch.Tensor) -> torch.Tensor:
            """
            Compute angular distance loss using dot product method.
            More robust than separate sin/cos losses.
            """
            # Normalize predictions to unit circle
            pred_sin_norm, pred_cos_norm = self._normalize_sincos(pred_sin, pred_cos)
            
            # Compute dot product (cosine of angle difference)
            cos_diff = pred_sin_norm * gt_sin + pred_cos_norm * gt_cos
            cos_diff = torch.clamp(cos_diff, -1.0 + 1e-7, 1.0 - 1e-7)
            
            # Angular difference in radians
            angular_error = torch.acos(cos_diff)
            
            return angular_error

        def _cylindrical_to_cartesian(self, rsc: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
            """
            Convert cylindrical to Cartesian coordinates.
            
            rsc: [r_mm, sin_phi, cos_phi]
            z: [z_mm]
            """
            r_mm = rsc[:, 0:1]
            sin_phi = rsc[:, 1:2]
            cos_phi = rsc[:, 2:3]
            z_mm = z
            
            # Normalize sin/cos
            sin_unit, cos_unit = self._normalize_sincos(sin_phi.squeeze(-1), cos_phi.squeeze(-1))
            sin_unit = sin_unit.unsqueeze(-1)
            cos_unit = cos_unit.unsqueeze(-1)
            
            # Convert to Cartesian
            x = r_mm * cos_unit
            y = r_mm * sin_unit
            
            return torch.cat([x, y, z_mm], dim=-1)

        def _compute_constraint_penalties(self, pred: torch.Tensor) -> dict:
            """
            Compute soft constraint violation penalties.
            
            pred: [r1, sin1, cos1, z1, r2, sin2, cos2, z2]
            """
            r1, sin1, cos1, z1 = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3]
            r2, sin2, cos2, z2 = pred[:, 4], pred[:, 5], pred[:, 6], pred[:, 7]
            
            # Radius constraints: penalize values outside [R_IN, R_OUT]
            r1_penalty = torch.relu(r1 - self.R_OUT) + torch.relu(self.R_IN - r1)
            r2_penalty = torch.relu(r2 - self.R_OUT) + torch.relu(self.R_IN - r2)
            radius_penalty = torch.mean(r1_penalty + r2_penalty)
            
            # Z constraints: penalize values outside [-Z_HALF, Z_HALF]
            z1_penalty = torch.relu(torch.abs(z1) - self.Z_HALF)
            z2_penalty = torch.relu(torch.abs(z2) - self.Z_HALF)
            z_penalty = torch.mean(z1_penalty + z2_penalty)
            
            # Angular normalization penalty: encourage sin²+cos²≈1
            norm1_penalty = torch.abs(sin1**2 + cos1**2 - 1.0)
            norm2_penalty = torch.abs(sin2**2 + cos2**2 - 1.0)
            angular_penalty = torch.mean(norm1_penalty + norm2_penalty)
            
            return {
                'radius_penalty': radius_penalty,
                'z_penalty': z_penalty,
                'angular_penalty': angular_penalty,
                'total_penalty': radius_penalty + z_penalty + angular_penalty
            }

        def _endpoint_loss_components(self, pred_rsc: torch.Tensor, pred_z: torch.Tensor,
                                    gt_cyl: torch.Tensor, gt_z: torch.Tensor) -> dict:
            """
            Compute loss components for one endpoint assignment.
            
            Uses angular distance instead of separate sin/cos losses.
            """
            # Extract components
            pred_r_mm = pred_rsc[:, 0:1]
            pred_sin = pred_rsc[:, 1:2]
            pred_cos = pred_rsc[:, 2:3]
            pred_z_mm = pred_z
            
            gt_r_mm = gt_cyl[:, 0:1]
            gt_sin = gt_cyl[:, 1:2]
            gt_cos = gt_cyl[:, 2:3]
            gt_z_mm = gt_z
            
            # Radius loss (in mm)
            r_loss_mm = self.smooth_l1(pred_r_mm, gt_r_mm).squeeze(-1)
            
            # Angular loss (in radians, converted to arc length in mm)
            angular_loss_rad = self._angular_distance_loss(
                pred_sin.squeeze(-1), pred_cos.squeeze(-1),
                gt_sin.squeeze(-1), gt_cos.squeeze(-1)
            )
            
            # Convert angular error to arc length error (mm) using ground truth radius
            gt_r_actual = gt_r_mm.squeeze(-1)
            angular_loss_mm = gt_r_actual * angular_loss_rad
            
            # Z loss (in mm)
            z_loss_mm = self.smooth_l1(pred_z_mm, gt_z_mm).squeeze(-1)
            
            # Weighted total loss (all components now in mm)
            total_loss = (self.rw * r_loss_mm + 
                        self.aw * angular_loss_mm + 
                        self.zw * z_loss_mm)
            
            return {
                'r_loss_mm': r_loss_mm,
                'angular_loss_rad': angular_loss_rad,
                'angular_loss_mm': angular_loss_mm,
                'z_loss_mm': z_loss_mm,
                'total_loss': total_loss
            }

        def forward(self, pred: torch.Tensor, target_xyz: torch.Tensor, return_components: bool = False):
            """
            Forward pass for corrected cylindrical loss.
            
            pred: (B,8) = [r1_mm, sin1, cos1, z1_mm, r2_mm, sin2, cos2, z2_mm]
            target_xyz: (B,6) = [x1,y1,z1, x2,y2,z2] (mm)
            """
            if pred.dim() != 2 or target_xyz.dim() != 2:
                raise ValueError("Expected 2D tensors")
            if pred.size(1) != 8 or target_xyz.size(1) != 6:
                raise ValueError("Expected shapes: pred (B,8), target_xyz (B,6)")

            pred = pred.contiguous()
            target_xyz = target_xyz.contiguous()

            # Split target coordinates
            x1, y1, z1 = target_xyz[:, 0:1], target_xyz[:, 1:2], target_xyz[:, 2:3]
            x2, y2, z2 = target_xyz[:, 3:4], target_xyz[:, 4:5], target_xyz[:, 5:6]

            # Convert ground truth to cylindrical
            gt_cyl1, gt_cyl2 = self._xy_to_cylindrical(x1, y1, x2, y2)

            # Extract prediction components
            pred_rsc1 = pred[:, 0:3]  # [r1_mm, sin1, cos1]
            pred_z1 = pred[:, 3:4]    # [z1_mm]
            pred_rsc2 = pred[:, 4:7]  # [r2_mm, sin2, cos2]
            pred_z2 = pred[:, 7:8]    # [z2_mm]

            # Compute constraint penalties
            penalties = self._compute_constraint_penalties(pred)

            # Try both endpoint assignments and pick the better one
            # Assignment A: pred1->gt1, pred2->gt2
            comp_A1 = self._endpoint_loss_components(pred_rsc1, pred_z1, gt_cyl1, z1)
            comp_A2 = self._endpoint_loss_components(pred_rsc2, pred_z2, gt_cyl2, z2)
            loss_A = comp_A1['total_loss'] + comp_A2['total_loss']

            # Assignment B: pred1->gt2, pred2->gt1  
            comp_B1 = self._endpoint_loss_components(pred_rsc1, pred_z1, gt_cyl2, z2)
            comp_B2 = self._endpoint_loss_components(pred_rsc2, pred_z2, gt_cyl1, z1)
            loss_B = comp_B1['total_loss'] + comp_B2['total_loss']

            # Choose better assignment
            use_assignment_A = loss_A <= loss_B
            cylindrical_loss = torch.minimum(loss_A, loss_B)
            
            # Add constraint penalties
            total_loss = cylindrical_loss + self.cw * penalties['total_penalty']

            # Compute Cartesian distance for reference
            pred_cart1 = self._cylindrical_to_cartesian(pred_rsc1, pred_z1)
            pred_cart2 = self._cylindrical_to_cartesian(pred_rsc2, pred_z2)
            gt_cart1 = torch.cat([x1, y1, z1], dim=-1)
            gt_cart2 = torch.cat([x2, y2, z2], dim=-1)
            
            cart_dist_A = (torch.norm(pred_cart1 - gt_cart1, dim=1) + 
                        torch.norm(pred_cart2 - gt_cart2, dim=1))
            cart_dist_B = (torch.norm(pred_cart1 - gt_cart2, dim=1) + 
                        torch.norm(pred_cart2 - gt_cart1, dim=1))
            cartesian_distance = torch.minimum(cart_dist_A, cart_dist_B)

            # Apply reduction
            if self.reduction == "mean":
                total_loss = total_loss.mean()
                cylindrical_loss = cylindrical_loss.mean()
                cartesian_distance = cartesian_distance.mean()

            if return_components:
                # Aggregate components from chosen assignment
                components = {}
                component_keys = ['r_loss_mm', 'angular_loss_rad', 'angular_loss_mm', 'z_loss_mm']
                
                for key in component_keys:
                    comp_choice = torch.where(use_assignment_A, 
                                            comp_A1[key] + comp_A2[key],
                                            comp_B1[key] + comp_B2[key])
                    if self.reduction == "mean":
                        comp_choice = comp_choice.mean()
                    components[key] = comp_choice

                # Add penalty components
                penalty_components = {k: v.item() if hasattr(v, 'item') else v 
                                    for k, v in penalties.items()}

                return {
                    'total_loss': total_loss,
                    'cylindrical_loss': cylindrical_loss,
                    'euclidean_distance': cartesian_distance,
                    **components,
                    **penalty_components
                }
            else:
                return total_loss, cartesian_distance


    criterion = CorrectedUnnormalizedFPoICylLoss(
        r_inner=R_INNER_MM,
        r_outer=R_OUTER_MM,
        z_half=Z_HALF_MM,
        radius_weight=1.0,
        ang_weight=2.0,      # Higher weight for angular accuracy
        z_weight=1.0,
        constraint_weight=0.1,  # Soft constraints
        reduction="mean"
    )

    print("🎯 TRAINING SETUP:")
    print("  Coordinate system: Cylindrical (unnormalized)")
    print("  Model outputs: [r_mm, sin, cos, z_mm, r_mm, sin, cos, z_mm]")
    print("  Loss: Direct physical coordinate loss + soft constraints")
    print("  Constraints: Radius bounds, Z bounds, sin²+cos²=1")
    print()

    model = model.float().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    # Tracking lists
    train_losses = []
    val_losses = []
    train_distances = []
    val_distances = []
    train_radius_penalties = []
    train_z_penalties = []
    train_angular_penalties = []

    for epoch in range(NUM_EPOCHS):
        model.train()
        running_loss = 0.0
        running_distance = 0.0
        running_r_penalty = 0.0
        running_z_penalty = 0.0
        running_ang_penalty = 0.0
        
        for inputs, labels in train_loader:
            inputs = inputs.float().to(device)
            labels = labels.float().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            
            loss_dict = criterion(outputs, labels, return_components=True)
            
            loss_dict['total_loss'].backward()
            optimizer.step()
            
            running_loss += loss_dict['total_loss'].item()
            running_distance += loss_dict['euclidean_distance'].item()
            running_r_penalty += loss_dict['radius_penalty']
            running_z_penalty += loss_dict['z_penalty']
            running_ang_penalty += loss_dict['angular_penalty']
        
        # Training averages
        avg_train_loss = running_loss / len(train_loader)
        avg_train_distance = running_distance / len(train_loader)
        avg_r_penalty = running_r_penalty / len(train_loader)
        avg_z_penalty = running_z_penalty / len(train_loader)
        avg_ang_penalty = running_ang_penalty / len(train_loader)
        
        train_losses.append(avg_train_loss)
        train_distances.append(avg_train_distance)
        train_radius_penalties.append(avg_r_penalty)
        train_z_penalties.append(avg_z_penalty)
        train_angular_penalties.append(avg_ang_penalty)
        
        # Validation
        model.eval()
        test_loss = 0.0
        test_distance = 0.0
        
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs = inputs.float().to(device)
                labels = labels.float().to(device)
                outputs = model(inputs)
                loss_dict = criterion(outputs, labels, return_components=True)
                
                test_loss += loss_dict['total_loss'].item()
                test_distance += loss_dict['euclidean_distance'].item()

        avg_test_loss = test_loss / len(test_loader)
        avg_test_distance = test_distance / len(test_loader)
        val_losses.append(avg_test_loss)
        val_distances.append(avg_test_distance)
        
        # Print results
        print(f"Epoch [{epoch+1:2d}/{NUM_EPOCHS}]")
        print(f"  TRAINING   - Total: {avg_train_loss:.4f} | 3D: {avg_train_distance:.1f}mm | R: {avg_r_penalty:.3f} | Z: {avg_z_penalty:.3f} | Ang: {avg_ang_penalty:.3f}")
        print(f"  VALIDATION - Total: {avg_test_loss:.4f} | 3D: {avg_test_distance:.1f}mm")
        print()


Using device: cuda
🔍 UNNORMALIZED CYLINDRICAL COORDINATE PREDICTION
Testing cylindrical coordinates with raw model outputs...
Model outputs: [r1_mm, sin1, cos1, z1_mm, r2_mm, sin2, cos2, z2_mm]
Direct learning of physical coordinates (no tanh/sigmoid)

🎯 TRAINING SETUP:
  Coordinate system: Cylindrical (unnormalized)
  Model outputs: [r_mm, sin, cos, z_mm, r_mm, sin, cos, z_mm]
  Loss: Direct physical coordinate loss + soft constraints
  Constraints: Radius bounds, Z bounds, sin²+cos²=1

Epoch [ 1/40]
  TRAINING   - Total: 2046.8021 | 3D: 555.5mm | R: 471.466 | Z: 0.000 | Ang: 2.554
  VALIDATION - Total: 1871.8144 | 3D: 554.2mm

Epoch [ 2/40]
  TRAINING   - Total: 1774.0863 | 3D: 555.3mm | R: 470.483 | Z: 0.000 | Ang: 5.239
  VALIDATION - Total: 1721.5726 | 3D: 553.8mm

Epoch [ 3/40]
  TRAINING   - Total: 1515.3095 | 3D: 554.4mm | R: 469.366 | Z: 0.000 | Ang: 8.025
  VALIDATION - Total: 1782.1525 | 3D: 553.2mm

Epoch [ 4/40]
  TRAINING   - Total: 1380.9817 | 3D: 553.2mm | R: 467.804 | 