In [39]:
import numpy as np

# Load the data
inner_data = np.loadtxt(r'C:\Users\h\Desktop\PET_Recons\Data\ML_Data\inner_training_data_test.csv', delimiter=',')
outer_data = np.loadtxt(r'C:\Users\h\Desktop\PET_Recons\Data\ML_Data\outer_training_data_test.csv', delimiter=',')

print("Inner shape:", inner_data.shape)
print("Outer shape:", outer_data.shape)

# Separate features and labels
inner_features, inner_labels = inner_data[:, :-6], inner_data[:, -6:]
outer_features, outer_labels = outer_data[:, :-6], outer_data[:, -6:]

print("Inner labels shape:", inner_labels.shape)
print("Outer labels shape:", outer_labels.shape)

# Validate that labels are identical
labels_equal = np.array_equal(inner_labels, outer_labels)

print("Labels are identical:" if labels_equal else "⚠️ Labels are NOT identical!")


Inner shape: (5000, 8495)
Outer shape: (5000, 10135)
Inner labels shape: (5000, 6)
Outer labels shape: (5000, 6)
⚠️ Labels are NOT identical!


In [40]:
import numpy as np

# Compare labels row-wise
row_equal = np.all(inner_labels == outer_labels, axis=1)

# Get indices of mismatched rows
mismatched_indices = np.where(~row_equal)[0]

print(f"Number of mismatched rows: {len(mismatched_indices)}")
print("First 10 indices with mismatched labels:", mismatched_indices[:10])

# If you want to inspect the actual labels for a few mismatches:
for idx in mismatched_indices[:5]:  # show first 5 mismatches
    print(f"\nRow {idx}:")
    print("Inner:", inner_labels[idx])
    print("Outer:", outer_labels[idx])


Number of mismatched rows: 4844
First 10 indices with mismatched labels: [13 16 30 33 37 40 42 45 47 48]

Row 13:
Inner: [ 233.3549     42.308315 -140.34695     0.          0.          0.      ]
Outer: [   0.         0.         0.      -199.61256 -141.07738 -142.0763 ]

Row 16:
Inner: [   0.         0.         0.      -199.61256 -141.07738 -142.0763 ]
Outer: [ 89.508736 252.85379   84.34045    0.         0.         0.      ]

Row 30:
Inner: [   0.          0.          0.       -231.473     -84.782135 -147.94    ]
Outer: [   0.          0.          0.       -273.9889    -22.72473    53.701527]

Row 33:
Inner: [   0.          0.          0.       -273.9889    -22.72473    53.701527]
Outer: [270.10522    -0.7890214 -31.816393    0.          0.          0.       ]

Row 37:
Inner: [270.10522    -0.7890214 -31.816393    0.          0.          0.       ]
Outer: [-243.86166  -77.32867 -135.17099    0.         0.         0.     ]


In [41]:
import numpy as np

# Assuming data is already loaded:
# inner_data, outer_data

# Split features/labels
inner_features, inner_labels = inner_data[:, :-6], inner_data[:, -6:]
outer_features, outer_labels = outer_data[:, :-6], outer_data[:, -6:]

# The "7th last column" is -7
inner_keys = inner_data[:, -7]
outer_keys = outer_data[:, -7]

# Storage for filtered results
filtered_inner_features = []
filtered_outer_features = []
filtered_inner_labels = []
filtered_outer_labels = []

i, j = 0, 0
while i < len(inner_keys) and j < len(outer_keys):
    if inner_keys[i] == outer_keys[j]:
        # Keys match → keep both
        filtered_inner_features.append(inner_features[i])
        filtered_outer_features.append(outer_features[j])
        filtered_inner_labels.append(inner_labels[i])
        filtered_outer_labels.append(outer_labels[j])
        i += 1
        j += 1
    elif inner_keys[i] < outer_keys[j]:
        # Advance inner pointer
        i += 1
    else:
        # Advance outer pointer
        j += 1

# Convert lists back into numpy arrays
filtered_inner_features = np.array(filtered_inner_features)
filtered_outer_features = np.array(filtered_outer_features)
filtered_inner_labels = np.array(filtered_inner_labels)
filtered_outer_labels = np.array(filtered_outer_labels)

print("Filtered inner features shape:", filtered_inner_features.shape)
print("Filtered outer features shape:", filtered_outer_features.shape)
print("Filtered inner labels shape:", filtered_inner_labels.shape)
print("Filtered outer labels shape:", filtered_outer_labels.shape)

# Final validation: check labels after alignment
labels_equal = np.array_equal(filtered_inner_labels, filtered_outer_labels)
print("Labels are identical after filtering?", labels_equal)

if labels_equal is True:
    labels = filtered_inner_labels


Filtered inner features shape: (4316, 8489)
Filtered outer features shape: (4316, 10129)
Filtered inner labels shape: (4316, 6)
Filtered outer labels shape: (4316, 6)
Labels are identical after filtering? True


In [42]:
# Remove first and last column from feature arrays
filtered_inner_features = filtered_inner_features[:, 1:-1]
filtered_outer_features = filtered_outer_features[:, 1:-1]

print("Inner features shape after dropping cols:", filtered_inner_features.shape)
print("Outer features shape after dropping cols:", filtered_outer_features.shape)


Inner features shape after dropping cols: (4316, 8487)
Outer features shape after dropping cols: (4316, 10127)


In [43]:
import numpy as np
from collections import Counter

run_lengths = []
current_len = 1

for i in range(1, len(labels)):
    if np.array_equal(labels[i], labels[i-1]):
        current_len += 1
    else:
        run_lengths.append(current_len)
        current_len = 1

# Append last run
run_lengths.append(current_len)

# Count occurrences of each run length
run_length_counts = Counter(run_lengths)

print("Contiguous label run stats:")
for length, count in sorted(run_length_counts.items()):
    print(f"Length {length}: {count} times")

print("\nTotal runs found:", len(run_lengths))


Contiguous label run stats:
Length 1: 36 times
Length 2: 318 times
Length 3: 833 times
Length 4: 144 times
Length 5: 53 times
Length 6: 33 times
Length 7: 9 times
Length 8: 2 times
Length 9: 3 times

Total runs found: 1431


In [44]:
import numpy as np

def make_contiguous_examples(inner_features, outer_features, labels, run_length=3):
    """
    Create tuples of (inner_window, outer_window, label) for contiguous sequences of 'run_length'.
    Labels must remain identical across the window.
    """
    examples = []
    n = len(labels)
    start = 0
    
    while start < n:
        # Find length of current contiguous block
        curr_label = labels[start]
        end = start + 1
        while end < n and np.array_equal(labels[end], curr_label):
            end += 1
        block_len = end - start
        
        # If block is long enough, generate sliding windows of size run_length
        if block_len >= run_length:
            for i in range(start, end - run_length + 1):
                inner_window = inner_features[i:i+run_length]
                outer_window = outer_features[i:i+run_length]
                # label is same across whole block, so we just take one
                examples.append((inner_window, outer_window, curr_label))
        
        # Move to next block
        start = end
    
    return examples

data = make_contiguous_examples(inner_features=filtered_inner_features,
                                outer_features=filtered_outer_features,
                                labels=labels,
                                run_length=3)

print(f"len(data) = {len(data)} examples made")
print(f"len(data[0]) = {len(data[0])} which are the inner, outer, and labels")
print(f"Inner shape: data[0][0].shape = {data[0][0].shape}")
print(f"Outer shape: data[0][1].shape = {data[0][1].shape}")
print(f"Label shape: data[0][2].shape = {data[0][2].shape}")

len(data) = 1490 examples made
len(data[0]) = 3 which are the inner, outer, and labels
Inner shape: data[0][0].shape = (3, 8487)
Outer shape: data[0][1].shape = (3, 10127)
Label shape: data[0][2].shape = (6,)


In [45]:
def reshape_examples(data):
    reshaped = []
    for inner, outer, label in data:
        # reshape each contiguous example
        inner_reshaped = inner.reshape(inner.shape[0], 207, 41)
        outer_reshaped = outer.reshape(outer.shape[0], 247, 41)
        reshaped.append((inner_reshaped, outer_reshaped, label))
    return reshaped

reshaped_data = reshape_examples(data)

print("Original:", data[0][0].shape, data[0][0].shape, data[0][0].shape)
print("Original:", data[0][1].shape, data[0][1].shape, data[0][1].shape)
print("\nReshaped:", reshaped_data[0][0].shape, reshaped_data[0][0].shape, reshaped_data[0][0].shape)
print("Reshaped:", reshaped_data[0][1].shape, reshaped_data[0][1].shape, reshaped_data[0][1].shape)


Original: (3, 8487) (3, 8487) (3, 8487)
Original: (3, 10127) (3, 10127) (3, 10127)

Reshaped: (3, 207, 41) (3, 207, 41) (3, 207, 41)
Reshaped: (3, 247, 41) (3, 247, 41) (3, 247, 41)


In [46]:
import numpy as np
from scipy.ndimage import zoom

def resize_features(arr, target_len):
    """
    Resizes arr from shape (frames, length, channels) 
    to (frames, target_len, channels).
    Uses linear interpolation along the 'length' axis.
    """
    frames, length, channels = arr.shape
    zoom_factor = target_len / length
    # zoom only along axis=1 (the "length" dimension), keep others unchanged
    resized = zoom(arr, (1, zoom_factor, 1), order=1)
    return resized

def resize_dataset(data, target_len=207):
    """
    Applies resize_features to the entire dataset of (inner, outer, label) triples.
    - Inner is kept as-is (already target_len).
    - Outer is resized to target_len.
    """
    resized_data = []
    for inner, outer, label in data:
        inner_resized = inner  # already target_len
        outer_resized = resize_features(outer, target_len)
        resized_data.append((inner_resized, outer_resized, label))
    return resized_data

# Apply to your dataset
resized_data = resize_dataset(reshaped_data, target_len=207)

len(resized_data)
len(resized_data[0])
resized_data[0][0].shape
resized_data[0][1].shape
resized_data[0][2].shape


(6,)

In [47]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_inner_outer_slider(inner, outer):
    """
    Visualize inner and outer arrays side-by-side with a slider to control timestep.
    inner: np array of shape (3, 207, 41)
    outer: np array of shape (3, 207, 41)
    """
    frames = []
    for t in range(inner.shape[0]):
        frames.append(go.Frame(
            data=[
                go.Heatmap(z=inner[t], colorscale='Cividis', colorbar=dict(title='Inner'), zmin=inner.min(), zmax=inner.max()),
                go.Heatmap(z=outer[t], colorscale='Cividis', colorbar=dict(title='Outer'), zmin=outer.min(), zmax=outer.max())
            ],
            name=str(t)
        ))

    fig = make_subplots(rows=1, cols=2,
                        subplot_titles=("Inner", "Outer"),
                        horizontal_spacing=0.1,
                        specs=[[{"type": "heatmap"}, {"type": "heatmap"}]])

    fig.add_trace(frames[0].data[1], row=1, col=1)  # inner heatmap (first trace of frame)
    fig.add_trace(frames[0].data[1], row=1, col=2)  # outer heatmap (second trace of frame)

    # Slider steps
    steps = []
    for i in range(len(frames)):
        steps.append(dict(
            method="animate",
            args=[[str(i)], 
                  dict(mode="immediate", frame=dict(duration=500, redraw=True), transition=dict(duration=0))],
            label=str(i)
        ))

    sliders = [dict(
        active=0,
        pad={"t": 50},
        steps=steps
    )]

    fig.update_layout(
        title="Inner and Outer Feature Maps Over Time",
        sliders=sliders,
        height=600,
        width=900
    )

    fig.frames = frames

    fig.show()

example = 15

plot_inner_outer_slider(resized_data[example][0],
                        resized_data[example][1])

In [48]:
import torch

# Prepare lists for inner and outer tensors for all elements
inner_tensors = []
outer_tensors = []
labels_tensors = []

for data in resized_data:
    inner_tensors.append(torch.tensor(data[0]))  # (timesteps, height, width)
    outer_tensors.append(torch.tensor(data[1]))  # (timesteps, height, width)
    labels_tensors.append(torch.tensor(data[2]))  # (label)

# Stack each list into a tensor along dimension 0 (num_tensors)
inner_tensor = torch.stack(inner_tensors, dim=0)  # (num_tensors, timesteps, height, width)
outer_tensor = torch.stack(outer_tensors, dim=0)  # (num_tensors, timesteps, height, width)
labels_tensor = torch.stack(labels_tensors, dim=0)  # (num_tensors, timesteps, height, width)

# Stack inner and outer along new dim=2 (channels)
combined_tensor = torch.stack((inner_tensor, outer_tensor), dim=1)  # (num_tensors, timesteps, 2, height, width)

print(combined_tensor.shape)
print(labels_tensor.shape)


torch.Size([1490, 2, 3, 207, 41])
torch.Size([1490, 6])


In [49]:
import torch

# Set your split ratio
train_ratio = 0.8
num_samples = combined_tensor.shape[0]
train_size = int(num_samples * train_ratio)
test_size = num_samples - train_size

# Optionally shuffle indices
indices = torch.randperm(num_samples)   # shuffle all sample indices
train_idx = indices[:train_size]
test_idx = indices[train_size:]

# Index into tensors
train_data = combined_tensor[train_idx]
train_labels = labels_tensor[train_idx]

test_data = combined_tensor[test_idx]
test_labels = labels_tensor[test_idx]

In [50]:
print(f"Train tensor: {train_data.shape}, Labels:{train_labels.shape}")
print(f"Test tensor: {test_data.shape}, Labels:{test_labels.shape}")

Train tensor: torch.Size([1192, 2, 3, 207, 41]), Labels:torch.Size([1192, 6])
Test tensor: torch.Size([298, 2, 3, 207, 41]), Labels:torch.Size([298, 6])


In [51]:
import torch

def filter_examples_with_zero_in_label(data_tensor, labels_tensor):
    has_zeros = (labels_tensor == 0)

    has_any_zero_in_row = torch.any(has_zeros, dim=1)
    indices_to_keep = ~has_any_zero_in_row
    filtered_data = data_tensor[indices_to_keep]
    filtered_labels = labels_tensor[indices_to_keep]

    return filtered_data, filtered_labels

# # --- Apply the Function ---
# train_data, train_labels = filter_examples_with_zero_in_label(train_data, train_labels)
# test_data, test_labels = filter_examples_with_zero_in_label(test_data, test_labels)

print(train_data.shape)
print(train_labels.shape)
print(test_data.shape)
print(test_labels.shape)


torch.Size([1192, 2, 3, 207, 41])
torch.Size([1192, 6])
torch.Size([298, 2, 3, 207, 41])
torch.Size([298, 6])


In [None]:
# takes attention_petnet.py and places the global attention layer after layer 1
# for higher dimensional G.A. Employs point-wise seperation convolutions over
# full fat classical convolutions for an 8-10x reduction in paramater count. 
# Also introduces a residual connection after the global attention, and layer norm
# for the global attention (Apparently batch norm works better for Convolutions, 
# and layer norm works better for Transformers...we'll see about that). 
# Uses 3 full connected layers with dropout for a more regression head.   
# Modified to use windowed attention with 4x4 non-overlapping windows.

import torch
import torch.nn as nn

###############################################################################
# Windowed Global Attention Module
###############################################################################
class WindowedGlobalAttention3D(nn.Module):
    """
    Windowed global attention module for 3D feature maps using 4x4 non-overlapping windows.
    Applies attention within each window, significantly reducing computational complexity.
    """
    
    def __init__(self, in_channels=64, embed_dim=128, output_dim=64, num_heads=2, window_size=4):
        super(WindowedGlobalAttention3D, self).__init__()
        
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.output_dim = output_dim
        self.window_size = window_size
        
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        # Map from input channels to embedding dimension
        self.channel_proj = nn.Linear(in_channels, embed_dim)
        
        # Manual attention projections (instead of nn.MultiheadAttention)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        # Map from embedding dimension to output channels
        self.output_proj = nn.Linear(embed_dim, output_dim)
        
        # Learnable positional encoding for window positions
        self.pos_encoding = None
        self.scale = self.head_dim ** -0.5
        self.layernorm = nn.LayerNorm(self.embed_dim)

        
    def _init_positional_encoding(self, window_seq_len):
        """Initialize learnable positional encoding for window sequence length"""
        self.pos_encoding = nn.Parameter(torch.randn(1, window_seq_len, self.embed_dim))
        nn.init.normal_(self.pos_encoding, std=0.02)
        
    def _create_windows(self, x):
        """
        Split the spatial dimensions (H, W) into non-overlapping 4x4 windows.
        Args:
            x: input tensor of shape (batch, channels, D, H, W)
        Returns:
            windows: tensor of shape (batch, num_windows, window_seq_len, channels)
            num_windows_h, num_windows_w: number of windows in H and W dimensions
        """
        batch_size, channels, D, H, W = x.shape
        
        # Calculate number of windows (pad if necessary)
        num_windows_h = (H + self.window_size - 1) // self.window_size
        num_windows_w = (W + self.window_size - 1) // self.window_size
        
        # Pad H and W to be divisible by window_size
        pad_h = num_windows_h * self.window_size - H
        pad_w = num_windows_w * self.window_size - W
        
        if pad_h > 0 or pad_w > 0:
            x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='constant', value=0)
            H, W = H + pad_h, W + pad_w
        
        # Reshape to create windows: (batch, channels, D, num_windows_h, window_size, num_windows_w, window_size)
        x = x.view(batch_size, channels, D, num_windows_h, self.window_size, num_windows_w, self.window_size)
        
        # Permute and reshape to group windows: (batch, D, num_windows_h, num_windows_w, channels, window_size, window_size)
        x = x.permute(0, 2, 3, 5, 1, 4, 6).contiguous()
        
        # Flatten spatial dimensions within each window and combine D with num_windows
        # Final shape: (batch, D * num_windows_h * num_windows_w, window_size * window_size, channels)
        num_windows = D * num_windows_h * num_windows_w
        window_seq_len = self.window_size * self.window_size
        windows = x.view(batch_size, num_windows, window_seq_len, channels)
        
        return windows, num_windows_h, num_windows_w, (H, W)
    
    def _restore_from_windows(self, windows, batch_size, D, num_windows_h, num_windows_w, padded_size):
        """
        Restore the original spatial structure from windowed representation.
        Args:
            windows: tensor of shape (batch, num_windows, window_seq_len, output_dim)
            Original spatial dimensions and padding info
        Returns:
            x: tensor of shape (batch, output_dim, D, H, W)
        """
        H_padded, W_padded = padded_size
        window_seq_len = self.window_size * self.window_size
        
        # Reshape back to spatial windows: (batch, D, num_windows_h, num_windows_w, output_dim, window_size, window_size)
        windows = windows.view(batch_size, D, num_windows_h, num_windows_w, self.output_dim, self.window_size, self.window_size)
        
        # Permute back: (batch, output_dim, D, num_windows_h, window_size, num_windows_w, window_size)
        windows = windows.permute(0, 4, 1, 2, 5, 3, 6).contiguous()
        
        # Reshape to recover spatial dimensions: (batch, output_dim, D, H_padded, W_padded)
        x = windows.view(batch_size, self.output_dim, D, H_padded, W_padded)
        
        # Remove padding if it was added
        original_H = H_padded - ((num_windows_h * self.window_size) - H_padded)
        original_W = W_padded - ((num_windows_w * self.window_size) - W_padded)
        
        # Note: We need to calculate the original size differently
        # Let's just remove any padding that was added
        if H_padded > original_H or W_padded > original_W:
            # Calculate original dimensions
            orig_H = H_padded - (num_windows_h * self.window_size - H_padded)
            orig_W = W_padded - (num_windows_w * self.window_size - W_padded)
            # This calculation is complex, let's store original dims in forward pass
            pass
        
        return x
        
    def forward(self, x):
        batch_size, channels, D, H, W = x.shape
        orig_H, orig_W = H, W  # Store original dimensions
        
        # Create windows
        windows, num_windows_h, num_windows_w, padded_size = self._create_windows(x)
        num_windows, window_seq_len, _ = windows.shape[1], windows.shape[2], windows.shape[3]
        
        # Initialize positional encoding if needed
        if self.pos_encoding is None or self.pos_encoding.shape[1] != window_seq_len:
            self._init_positional_encoding(window_seq_len)
            self.pos_encoding = self.pos_encoding.to(x.device)

        # Process each window independently
        # Reshape to process all windows in batch: (batch * num_windows, window_seq_len, channels)
        windows_flat = windows.view(batch_size * num_windows, window_seq_len, channels)
        
        # Project to embedding dimension
        x_proj = self.channel_proj(windows_flat)  # (batch * num_windows, window_seq_len, embed_dim)
        residual = x_proj
        
        # Add positional encoding
        x_proj = x_proj + self.pos_encoding
        
        # Apply attention within each window
        q = self.q_proj(x_proj)
        k = self.k_proj(x_proj)
        v = self.v_proj(x_proj)
        
        # Reshape for multi-head attention
        q = q.view(batch_size * num_windows, window_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size * num_windows, window_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size * num_windows, window_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Attention computation
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn_weights = torch.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, v)
        
        # Reshape back
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size * num_windows, window_seq_len, self.embed_dim)
        attn_output = self.out_proj(attn_output)
        
        # Residual connection and layer norm
        attn_output = attn_output + residual
        attn_output = self.layernorm(attn_output)
        
        # Project to output dimension
        output = self.output_proj(attn_output)  # (batch * num_windows, window_seq_len, output_dim)
        
        # Reshape back to window format
        output = output.view(batch_size, num_windows, window_seq_len, self.output_dim)
        
        # Restore spatial structure
        output = self._restore_from_windows(output, batch_size, D, num_windows_h, num_windows_w, padded_size)
        
        # Remove padding to match original input size
        if output.shape[-2] != orig_H or output.shape[-1] != orig_W:
            output = output[:, :, :, :orig_H, :orig_W]
        
        return output


###############################################################################
# 3D Residual Block
###############################################################################
class ResidualBlock3D(nn.Module):
    """
    Residual block with 3D depthwise + pointwise convolution.
    Uses circular padding on width, constant elsewhere.
    Includes post skip-connection processing layer.
    """
    def __init__(self, in_channels, out_channels, stride=(1, 2, 2)):
        super(ResidualBlock3D, self).__init__()

        # Depthwise 3D convolution (groups=in_channels)
        self.dw_conv1 = nn.Conv3d(
            in_channels, in_channels,
            kernel_size=3, stride=stride, padding=0,
            groups=in_channels, bias=False
        )
        self.pw_conv1 = nn.Conv3d(
            in_channels, out_channels,
            kernel_size=1, stride=1, padding=0, bias=False
        )
        self.bn1 = nn.BatchNorm3d(out_channels)

        # Depthwise 3D convolution (stride=1)
        self.dw_conv2 = nn.Conv3d(
            out_channels, out_channels,
            kernel_size=3, stride=1, padding=0,
            groups=out_channels, bias=False
        )
        self.pw_conv2 = nn.Conv3d(
            out_channels, out_channels,
            kernel_size=1, stride=1, padding=0, bias=False
        )
        self.bn2 = nn.BatchNorm3d(out_channels)

        # Post skip-connection processing layer (depthwise-pointwise)
        self.dw_conv_post = nn.Conv3d(
            out_channels, out_channels,
            kernel_size=3, stride=1, padding=0,
            groups=out_channels, bias=False
        )
        self.pw_conv_post = nn.Conv3d(
            out_channels, out_channels,
            kernel_size=1, stride=1, padding=0, bias=False
        )
        self.bn_post = nn.BatchNorm3d(out_channels)

        self.activation = nn.GELU()

        # Shortcut for matching dimensions
        self.shortcut = None
        if stride != (1, 1, 1) or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv3d(in_channels, out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm3d(out_channels)
            )

    def _apply_circular_padding(self, x):
        x = torch.nn.functional.pad(x, (1, 1, 0, 0, 0, 0), mode='circular')
        x = torch.nn.functional.pad(x, (0, 0, 1, 1, 1, 1), mode='constant', value=0)
        return x

    def forward(self, x):
        # Main path
        out = self._apply_circular_padding(x)
        out = self.dw_conv1(out)
        out = self.pw_conv1(out)
        out = self.bn1(out)
        out = self.activation(out)

        out = self._apply_circular_padding(out)
        out = self.dw_conv2(out)
        out = self.pw_conv2(out)
        out = self.bn2(out)

        # Shortcut
        if self.shortcut is not None:
            x = self.shortcut(x)

        out += x

        out = self._apply_circular_padding(out)
        out = self.dw_conv_post(out)
        out = self.pw_conv_post(out)
        out = self.bn_post(out)
        out = self.activation(out)

        return out



###############################################################################
# PetNetImproved3D with Windowed Global Attention
###############################################################################
class PetNetImproved3D(nn.Module):
    def __init__(self, num_classes=6):
        print("Loading PetnetImproved3D Model with Windowed Global Attention...")
        super(PetNetImproved3D, self).__init__()

        self.conv_in = nn.Sequential(
            nn.Conv3d(2, 2, kernel_size=3, stride=1, padding=1, groups=2, bias=False),    # depthwise
            nn.Conv3d(2, 16, kernel_size=1, stride=1, padding=0, bias=False)              # pointwise
        )
        self.bn_in = nn.BatchNorm3d(16)
        self.activation = nn.GELU()

        self.layer1 = ResidualBlock3D(16, 32, stride=(1, 2, 2))      # downsample H,W
        
        # Windowed global attention after layer1 (32 channels) with 4x4 windows
        self.global_attention = WindowedGlobalAttention3D(
            in_channels=32,
            embed_dim=128,
            output_dim=32,
            num_heads=8,
            window_size=4
        )

        self.layer2 = ResidualBlock3D(32, 64, stride=(1, 2, 2))
        self.layer3 = ResidualBlock3D(64, 128, stride=(1, 2, 2))
        self.layer4 = ResidualBlock3D(128, 256, stride=(1, 2, 2))
        self.layer5 = ResidualBlock3D(256, 256, stride=(1, 2, 2))

        self.dropout = nn.Dropout(0.3)

        fc_in_features = self._compute_fc_input_size()
        self.fc1 = nn.Linear(fc_in_features, 1024)
        self.dropout1 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(1024, 256)
        self.dropout2 = nn.Dropout(0.3)
        self.fc3 = nn.Linear(256, num_classes)  # final regressive output

        self._initialize_weights()

    def _compute_fc_input_size(self, C=2, T=3, H=207, W=41):
        with torch.no_grad():
            dummy = torch.zeros(1, C, T, H, W)
            out = self.conv_in(dummy)
            out = self.bn_in(out)
            out = self.activation(out)
            out = self.layer1(out)
            out = self.global_attention(out) # Windowed global attention after layer1
            out = self.layer2(out)
            out = self.layer3(out)
            out = self.layer4(out)
            out = self.layer5(out)
            out = out.view(out.size(0), -1)
            return out.shape[1]


    def forward(self, x, debug=False): 
        if debug: print(f"{x.shape} Input shape")
        x = self.conv_in(x)
        if debug: print(f"{x.shape} After conv_in")
        x = self.bn_in(x)
        if debug: print(f"{x.shape} After bn_input")
        x = self.activation(x)
        if debug: print(f"{x.shape} After activation")

        x = self.layer1(x)
        if debug: print(f"{x.shape} After layer 1")

        # Windowed global attention after layer 1
        x = self.global_attention(x)
        if debug: print(f"{x.shape} After windowed global attention (after layer 1)")

        x = self.layer2(x)
        if debug: print(f"{x.shape} After layer 2")
        x = self.layer3(x)
        if debug: print(f"{x.shape} After layer 3")
        x = self.layer4(x)
        if debug: print(f"{x.shape} After layer 4")
        x = self.layer5(x)
        if debug: print(f"{x.shape} After layer 5")

        x = x.view(x.size(0), -1)  # Flatten all features (B, all_channels)
        if debug: print(f"{x.shape} After flattening all channels/voxels")
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout1(x)
        if debug: print(f"{x.shape} After fc layer 1, activation and dropout")
        x = self.fc2(x)
        x = self.activation(x)
        x = self.dropout2(x)
        if debug: print(f"{x.shape} After fc layer 2 activation and dropout")
        x = self.fc3(x) 
        if debug: print(f"{x.shape} After fc layer 3 (output)")

        return x


    def _initialize_weights(self):
        """
        Kaiming (He) Initialization for Conv3d and Linear layers.
        """
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)




if __name__ == "__main__":

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    B = 1
    # B = 128

    C = 2
    T = 3 

    H = 207
    # H = 496

    W = 41
    # W = 84
    
    CLASSES = 6

    # Model instantiation
    model = PetNetImproved3D(num_classes=CLASSES).to(device)
    
    # Print parameter count
    param_count = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {param_count:,}")
    
    # Test a dummy pass
    dummy_input = torch.randn(B, C, T, H, W).to(device)
    dummy_target = torch.randn(B, CLASSES).to(device)
    model.forward(dummy_input, debug=True)

    # Dummy training loop to observe loss reduction
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    criterion = nn.MSELoss()

    import time

    total_time = 0.0
    epochs = 5

    for epoch in range(epochs):
        optimizer.zero_grad()
        start_time = time.time()
        output = model(dummy_input)
        end_time = time.time()
        forward_time = end_time - start_time
        total_time += forward_time

        loss = criterion(output, dummy_target)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Forward Time: {forward_time:.6f}s")

    average_time = total_time / epochs
    print(f"Average Forward Pass Time: {average_time:.6f}s")



if __name__ == "__main__":
    import time

    def add_positional_features(x, normalize=True):
        """
        Add height and width positional features to input tensor.
        
        Args:
            x: input tensor of shape (batch, channels, D, H, W)
            normalize: whether to normalize positional features to [0, 1]
        
        Returns:
            tensor of shape (batch, channels + 2, D, H, W) with positional features
        """
        batch_size, channels, D, H, W = x.shape
        device = x.device
        
        # Create height indices (0 to H-1) for each position
        height_indices = torch.arange(H, dtype=torch.float32, device=device)
        height_indices = height_indices.view(1, 1, 1, H, 1).expand(batch_size, 1, D, H, W)
        
        # Create width indices (0 to W-1) for each position  
        width_indices = torch.arange(W, dtype=torch.float32, device=device)
        width_indices = width_indices.view(1, 1, 1, 1, W).expand(batch_size, 1, D, H, W)
        
        if normalize:
            # Normalize to [0, 1] range
            if H > 1:
                height_indices = height_indices / (H - 1)
            if W > 1:
                width_indices = width_indices / (W - 1)
        
        # Concatenate original features with positional features
        x_with_pos = torch.cat([x, height_indices, width_indices], dim=1)
        
        return x_with_pos

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    B = 1
    C = 2
    T = 3 
    
    H = 207
    # H = 496
    
    W = 41
    # W = 84

    # CLASSES = 8
    CLASSES = 6

    POSITIONAL_FEATURES = True
    if POSITIONAL_FEATURES:
        C += 2
    
    X_WEIGHT = 10
    Y_WEIGHT = 10
    Z_WEIGHT = 1

    SINGLES_FILTERED = False # Don't filter the singles; more data is, for this script, worth it

    if SINGLES_FILTERED:
        LOSS_PERMUTATION = False
    else:
        LOSS_PERMUTATION = True # Helps when including singles

    TANH_ACTIVATION = True # Tanh activation is usually good; its effectively normalising the labels
    
    NUM_EPOCHS = 40
    
    R_INNER_MM = 235.422
    R_OUTER_MM = 278.296
    Z_HALF_MM = 148.0

    model = PetNetImproved3D(num_classes=CLASSES, ).to(device)

    # model = PetNetCyl3D(in_channels=C, base_channels=6).to(device)
    # model = PetNetCyl3DCompact(in_channels=C, base_channels=6, out_features=CLASSES).to(device)
    # model = PetNetCyl3DFullFeatures(base_channels=6, out_features=CLASSES, input_shape=(C, T, H, W)).to(device)
    # model = PetNetCyl3DAttentionFull(base_channels=6, out_features=CLASSES, input_shape=(C, T, H, W)).to(device)
    # model = PetNetCyl3DDepthwise(in_channels=C, base_channels=6, out_features=CLASSES).to(device)
    # model = PetNetCyl3DWindowedAttention(
    #     in_channels=C, 
    #     base_channels=6, 
    #     out_features=6,
    #     window_size=8,
    #     attn_heads=2
    # )    

    param_count = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {param_count:,}")
    
    dummy_input = torch.randn(B, C, T, H, W).to(device)
    dummy_target = torch.randn(B, CLASSES).to(device)

    if POSITIONAL_FEATURES:
        inputs = add_positional_features(dummy_input, normalize=True)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    criterion = nn.MSELoss()
    total_time = 0.0
    epochs = 5

    for epoch in range(epochs):
        optimizer.zero_grad()
        start_time = time.time()
        output = model(dummy_input)
        end_time = time.time()
        forward_time = end_time - start_time
        total_time += forward_time

        loss = criterion(output, dummy_target)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Forward Time: {forward_time:.6f}s")

    average_time = total_time / epochs
    print(f"Average Forward Pass Time: {average_time:.6f}s")

Using device: cuda
Loading PetnetImproved3D Model with Windowed Global Attention...
Total parameters: 11,933,740
torch.Size([1, 2, 3, 207, 41]) Input shape
torch.Size([1, 16, 3, 207, 41]) After conv_in
torch.Size([1, 16, 3, 207, 41]) After bn_input
torch.Size([1, 16, 3, 207, 41]) After activation
torch.Size([1, 32, 3, 104, 21]) After layer 1
torch.Size([1, 32, 3, 104, 21]) After windowed global attention (after layer 1)
torch.Size([1, 64, 3, 52, 11]) After layer 2
torch.Size([1, 128, 3, 26, 6]) After layer 3
torch.Size([1, 256, 3, 13, 3]) After layer 4
torch.Size([1, 256, 3, 7, 2]) After layer 5
torch.Size([1, 10752]) After flattening all channels/voxels
torch.Size([1, 1024]) After fc layer 1, activation and dropout
torch.Size([1, 256]) After fc layer 2 activation and dropout
torch.Size([1, 6]) After fc layer 3 (output)
Epoch 1, Loss: 3.2390, Forward Time: 0.005000s
Epoch 2, Loss: 1.3882, Forward Time: 0.007001s
Epoch 3, Loss: 0.7558, Forward Time: 0.006001s
Epoch 4, Loss: 5.0673, Forw

RuntimeError: Given groups=2, weight of size [2, 1, 3, 3, 3], expected input[1, 4, 3, 207, 41] to have 2 channels, but got 4 channels instead

In [38]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

train_dataset = TensorDataset(train_data, train_labels)
test_dataset = TensorDataset(test_data, test_labels)

train_loader = DataLoader(train_dataset, batch_size=B, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=B)


class CylPredictionLoss(nn.Module):
    def __init__(self, x_weight=1.0, y_weight=1.0, z_weight=1.0):
        super().__init__()
        self.r_thickness = R_OUTER_MM - R_INNER_MM
        self.x_weight = x_weight
        self.y_weight = y_weight
        self.z_weight = z_weight
        
        # Create weight tensor for broadcasting: [x1, y1, z1, x2, y2, z2]
        self.register_buffer('coord_weights', torch.tensor([
            x_weight, y_weight, z_weight, x_weight, y_weight, z_weight
        ], dtype=torch.float32))
    
    def forward(self, pred, target, return_components=False):
        # pred is (B, 8): [r1, sin1, cos1, z1, r2, sin2, cos2, z2]
        # target is (B, 6): [x1, y1, z1, x2, y2, z2]
        B = pred.shape[0]
        
        # Optionally apply tanh activation
        if TANH_ACTIVATION:
            pred = torch.tanh(pred)

        if SINGLES_FILTERED:
            # Scale radii from [-1,1] to [R_INNER, R_OUTER]
            r1_scaled = R_INNER_MM + (pred[:, 0] + 1) * self.r_thickness / 2
            r2_scaled = R_INNER_MM + (pred[:, 4] + 1) * self.r_thickness / 2
        else:
            # Scale radii from [-1,1] to [0, R_OUTER]
            r1_scaled = (pred[:, 0] + 1) * R_OUTER_MM / 2
            r2_scaled = (pred[:, 4] + 1) * R_OUTER_MM / 2
        
        # Scale z from [-1,1] to [-Z_HALF, Z_HALF]
        z1_scaled = pred[:, 3] * Z_HALF_MM
        z2_scaled = pred[:, 7] * Z_HALF_MM

        # sin and cos are already normalized by tanh activation
        sin1 = pred[:, 1]
        cos1 = pred[:, 2]
        sin2 = pred[:, 5]
        cos2 = pred[:, 6]

        # Convert to cartesian - original assignment
        x1 = r1_scaled * cos1
        y1 = r1_scaled * sin1
        x2 = r2_scaled * cos2
        y2 = r2_scaled * sin2

        # Original assignment: pred_point1 = [x1, y1, z1], pred_point2 = [x2, y2, z2]
        pred_cartesian_original = torch.stack([x1, y1, z1_scaled, x2, y2, z2_scaled], dim=1)
        
        if LOSS_PERMUTATION:
            # Permuted assignment: pred_point1 = [x2, y2, z2], pred_point2 = [x1, y1, z1]
            pred_cartesian_permuted = torch.stack([x2, y2, z2_scaled, x1, y1, z1_scaled], dim=1)
            
            # Calculate differences (unweighted for component analysis)
            diff_original = pred_cartesian_original - target
            diff_permuted = pred_cartesian_permuted - target
            
            # Calculate component losses (unweighted, absolute differences)
            x_loss_original = torch.abs(diff_original[:, [0, 3]]).mean()  # x1, x2
            y_loss_original = torch.abs(diff_original[:, [1, 4]]).mean()  # y1, y2
            z_loss_original = torch.abs(diff_original[:, [2, 5]]).mean()  # z1, z2
            
            x_loss_permuted = torch.abs(diff_permuted[:, [0, 3]]).mean()
            y_loss_permuted = torch.abs(diff_permuted[:, [1, 4]]).mean()
            z_loss_permuted = torch.abs(diff_permuted[:, [2, 5]]).mean()
            
            # Apply coordinate weights for loss calculation
            weighted_diff_original = diff_original * self.coord_weights
            weighted_diff_permuted = diff_permuted * self.coord_weights
            
            # Calculate weighted Euclidean distances
            euclidean_dist_original = torch.norm(weighted_diff_original, dim=1)
            euclidean_dist_permuted = torch.norm(weighted_diff_permuted, dim=1)
            
            # Determine which assignment is better for each sample
            use_original = euclidean_dist_original <= euclidean_dist_permuted
            
            # Take minimum loss for each sample
            min_euclidean_dist = torch.min(euclidean_dist_original, euclidean_dist_permuted)
            loss = min_euclidean_dist.mean()
            
            # Calculate component losses based on the better assignment
            x_loss = torch.where(use_original, x_loss_original, x_loss_permuted).mean()
            y_loss = torch.where(use_original, y_loss_original, y_loss_permuted).mean()
            z_loss = torch.where(use_original, z_loss_original, z_loss_permuted).mean()
            
            # Calculate true 3D Euclidean distance (unweighted)
            # Calculate per-point 3D distances
            point1_dist_orig = torch.norm(diff_original[:, :3], dim=1)  # distance for point 1
            point2_dist_orig = torch.norm(diff_original[:, 3:], dim=1)  # distance for point 2
            point1_dist_perm = torch.norm(diff_permuted[:, :3], dim=1)
            point2_dist_perm = torch.norm(diff_permuted[:, 3:], dim=1)
            
            # Average distance for each assignment
            avg_dist_original = (point1_dist_orig + point2_dist_orig) / 2
            avg_dist_permuted = (point1_dist_perm + point2_dist_perm) / 2
            
            # Choose the better assignment's distance
            euclidean_3d = torch.where(use_original, avg_dist_original, avg_dist_permuted).mean()
            
        else:
            # Original behavior - no permutation
            diff = pred_cartesian_original - target
            
            # Calculate unweighted component losses
            x_loss = torch.abs(diff[:, [0, 3]]).mean()  # x1, x2
            y_loss = torch.abs(diff[:, [1, 4]]).mean()  # y1, y2
            z_loss = torch.abs(diff[:, [2, 5]]).mean()  # z1, z2
            
            # Calculate true 3D Euclidean distance (unweighted)
            point1_dist = torch.norm(diff[:, :3], dim=1)  # distance for point 1
            point2_dist = torch.norm(diff[:, 3:], dim=1)  # distance for point 2
            euclidean_3d = ((point1_dist + point2_dist) / 2).mean()  # average distance per sample, then across batch
            
            weighted_diff = diff * self.coord_weights
            euclidean_dist = torch.norm(weighted_diff, dim=1)
            loss = euclidean_dist.mean()
        
        if return_components:
            return loss, {'x_loss': x_loss.item(), 'y_loss': y_loss.item(), 'z_loss': z_loss.item(), 'euclidean_3d': euclidean_3d.item()}
        return loss


class CartesianPredictionLoss(nn.Module):
    def __init__(self, x_weight=1.0, y_weight=1.0, z_weight=1.0):
        super().__init__()
        self.x_weight = x_weight
        self.y_weight = y_weight
        self.z_weight = z_weight
        
        # Create weight tensor for broadcasting: [x1, y1, z1, x2, y2, z2]
        self.register_buffer('coord_weights', torch.tensor([
            x_weight, y_weight, z_weight, x_weight, y_weight, z_weight
        ], dtype=torch.float32))
    
    def forward(self, pred, target, return_components=False):
        # pred is (B, 6): [x1, y1, z1, x2, y2, z2] - raw network outputs
        # target is (B, 6): [x1, y1, z1, x2, y2, z2] - ground truth coordinates
        B = pred.shape[0]
        
        # Optionally apply tanh activation to normalize predictions
        if TANH_ACTIVATION:
            pred = torch.tanh(pred)
            
            # If using tanh, we need to scale from [-1,1] to the actual coordinate ranges
            pred_scaled = pred.clone()
            
            # Scale x and y coordinates
            pred_scaled[:, [0, 3]] = pred[:, [0, 3]] * R_OUTER_MM  # x1, x2
            pred_scaled[:, [1, 4]] = pred[:, [1, 4]] * R_OUTER_MM  # y1, y2
            
            # Scale z coordinates
            pred_scaled[:, [2, 5]] = pred[:, [2, 5]] * Z_HALF_MM   # z1, z2
        else:
            # Use predictions as-is if no tanh activation
            pred_scaled = pred
        
        if LOSS_PERMUTATION:
            # Original assignment: pred_point1 = [x1, y1, z1], pred_point2 = [x2, y2, z2]
            pred_original = pred_scaled
            
            # Permuted assignment: pred_point1 = [x2, y2, z2], pred_point2 = [x1, y1, z1]
            pred_permuted = torch.stack([
                pred_scaled[:, 3], pred_scaled[:, 4], pred_scaled[:, 5],  # x2, y2, z2
                pred_scaled[:, 0], pred_scaled[:, 1], pred_scaled[:, 2]   # x1, y1, z1
            ], dim=1)
            
            # Calculate differences (unweighted for component analysis)
            diff_original = pred_original - target
            diff_permuted = pred_permuted - target
            
            # Calculate component losses (unweighted, absolute differences)
            x_loss_original = torch.abs(diff_original[:, [0, 3]]).mean()  # x1, x2
            y_loss_original = torch.abs(diff_original[:, [1, 4]]).mean()  # y1, y2
            z_loss_original = torch.abs(diff_original[:, [2, 5]]).mean()  # z1, z2
            
            x_loss_permuted = torch.abs(diff_permuted[:, [0, 3]]).mean()
            y_loss_permuted = torch.abs(diff_permuted[:, [1, 4]]).mean()
            z_loss_permuted = torch.abs(diff_permuted[:, [2, 5]]).mean()
            
            # Apply coordinate weights for loss calculation
            weighted_diff_original = diff_original * self.coord_weights
            weighted_diff_permuted = diff_permuted * self.coord_weights
            
            # Calculate weighted Euclidean distances
            euclidean_dist_original = torch.norm(weighted_diff_original, dim=1)
            euclidean_dist_permuted = torch.norm(weighted_diff_permuted, dim=1)
            
            # Determine which assignment is better for each sample
            use_original = euclidean_dist_original <= euclidean_dist_permuted
            
            # Take minimum loss for each sample (handles point order ambiguity)
            min_euclidean_dist = torch.min(euclidean_dist_original, euclidean_dist_permuted)
            loss = min_euclidean_dist.mean()
            
            # Calculate component losses based on the better assignment
            x_loss = torch.where(use_original, x_loss_original, x_loss_permuted).mean()
            y_loss = torch.where(use_original, y_loss_original, y_loss_permuted).mean()
            z_loss = torch.where(use_original, z_loss_original, z_loss_permuted).mean()
            
            # Calculate true 3D Euclidean distance (unweighted)
            # Calculate per-point 3D distances
            point1_dist_orig = torch.norm(diff_original[:, :3], dim=1)  # distance for point 1
            point2_dist_orig = torch.norm(diff_original[:, 3:], dim=1)  # distance for point 2
            point1_dist_perm = torch.norm(diff_permuted[:, :3], dim=1)
            point2_dist_perm = torch.norm(diff_permuted[:, 3:], dim=1)
            
            # Average distance for each assignment
            avg_dist_original = (point1_dist_orig + point2_dist_orig) / 2
            avg_dist_permuted = (point1_dist_perm + point2_dist_perm) / 2
            
            # Choose the better assignment's distance
            euclidean_3d = torch.where(use_original, avg_dist_original, avg_dist_permuted).mean()
            
        else:
            # Original behavior - no permutation, direct comparison
            diff = pred_scaled - target
            
            # Calculate unweighted component losses
            x_loss = torch.abs(diff[:, [0, 3]]).mean()  # x1, x2
            y_loss = torch.abs(diff[:, [1, 4]]).mean()  # y1, y2
            z_loss = torch.abs(diff[:, [2, 5]]).mean()  # z1, z2
            
            # Calculate true 3D Euclidean distance (unweighted)
            point1_dist = torch.norm(diff[:, :3], dim=1)  # distance for point 1
            point2_dist = torch.norm(diff[:, 3:], dim=1)  # distance for point 2
            euclidean_3d = ((point1_dist + point2_dist) / 2).mean()  # average distance per sample, then across batch
            
            weighted_diff = diff * self.coord_weights
            euclidean_dist = torch.norm(weighted_diff, dim=1)
            loss = euclidean_dist.mean()
        
        if return_components:
            return loss, {'x_loss': x_loss.item(), 'y_loss': y_loss.item(), 'z_loss': z_loss.item(), 'euclidean_3d': euclidean_3d.item()}
        return loss


if CLASSES == 8:
    criterion = CylPredictionLoss(x_weight=X_WEIGHT, y_weight=Y_WEIGHT, z_weight=Z_WEIGHT)
elif CLASSES == 6:
    criterion = CartesianPredictionLoss(x_weight=X_WEIGHT, y_weight=Y_WEIGHT, z_weight=Z_WEIGHT)
else:
    raise ValueError("CLASSES must be 6 or 8")

criterion = criterion.to(device)
model = model.float().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

train_losses = []
val_losses = []

# Store component losses
train_x_losses = []
train_y_losses = []
train_z_losses = []
val_x_losses = []
val_y_losses = []
val_z_losses = []

print(f"Permutation Loss: {LOSS_PERMUTATION}, Tanh Activation: {TANH_ACTIVATION}, Singles Filtered: {SINGLES_FILTERED}")
print(f"Using Loss: {criterion.__class__.__name__} (x_weight={X_WEIGHT}, y_weight={Y_WEIGHT}, z_weight={Z_WEIGHT})")

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    running_x_loss = 0.0
    running_y_loss = 0.0
    running_z_loss = 0.0
    running_euclidean_3d = 0.0
    
    for inputs, labels in train_loader:
        inputs = inputs.float().to(device)
        labels = labels.float().to(device)
        if POSITIONAL_FEATURES:
            inputs = add_positional_features(inputs, normalize=True)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss, components = criterion(outputs, labels, return_components=True)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        running_x_loss += components['x_loss']
        running_y_loss += components['y_loss']
        running_z_loss += components['z_loss']
        running_euclidean_3d += components['euclidean_3d']
    
    avg_train_loss = running_loss / len(train_loader)
    avg_train_x_loss = running_x_loss / len(train_loader)
    avg_train_y_loss = running_y_loss / len(train_loader)
    avg_train_z_loss = running_z_loss / len(train_loader)
    avg_train_euclidean_3d = running_euclidean_3d / len(train_loader)
    
    train_losses.append(avg_train_loss)
    train_x_losses.append(avg_train_x_loss)
    train_y_losses.append(avg_train_y_loss)
    train_z_losses.append(avg_train_z_loss)
    
    model.eval()
    test_loss = 0.0
    test_x_loss = 0.0
    test_y_loss = 0.0
    test_z_loss = 0.0
    test_euclidean_3d = 0.0
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.float().to(device)
            labels = labels.float().to(device)
            if POSITIONAL_FEATURES:
                inputs = add_positional_features(inputs, normalize=True)
            outputs = model(inputs)
            loss, components = criterion(outputs, labels, return_components=True)
            
            test_loss += loss.item()
            test_x_loss += components['x_loss']
            test_y_loss += components['y_loss']
            test_z_loss += components['z_loss']
            test_euclidean_3d += components['euclidean_3d']

    avg_test_loss = test_loss / len(test_loader)
    avg_test_x_loss = test_x_loss / len(test_loader)
    avg_test_y_loss = test_y_loss / len(test_loader)
    avg_test_z_loss = test_z_loss / len(test_loader)
    avg_test_euclidean_3d = test_euclidean_3d / len(test_loader)
    
    val_losses.append(avg_test_loss)
    val_x_losses.append(avg_test_x_loss)
    val_y_losses.append(avg_test_y_loss)
    val_z_losses.append(avg_test_z_loss)
    
    # Now use the true 3D Euclidean distance calculated by the loss function
    print(f"Epoch [{epoch+1:2d}/{NUM_EPOCHS}]")
    print(f"  TRAINING   - Total (weighted): {avg_train_loss:.4f}mm  ||  X: {avg_train_x_loss:.4f}mm  Y: {avg_train_y_loss:.4f}mm  Z: {avg_train_z_loss:.4f}mm  ||  3D Euclidean: {avg_train_euclidean_3d:.4f}mm")
    print(f"  VALIDATION - Total (weighted): {avg_test_loss:.4f}mm  ||  X: {avg_test_x_loss:.4f}mm  Y: {avg_test_y_loss:.4f}mm  Z: {avg_test_z_loss:.4f}mm  ||  3D Euclidean: {avg_test_euclidean_3d:.4f}mm")

Using device: cuda
Permutation Loss: False, Tanh Activation: True, Singles Filtered: True
Using Loss: CartesianPredictionLoss (x_weight=10, y_weight=10, z_weight=1)
Epoch [ 1/40]
  TRAINING   - Total (weighted): 4311.6465mm  ||  X: 184.7821mm  Y: 184.9105mm  Z: 91.6451mm  ||  3D Euclidean: 315.6482mm
  VALIDATION - Total (weighted): 3758.4186mm  ||  X: 168.8166mm  Y: 159.2204mm  Z: 78.3960mm  ||  3D Euclidean: 275.1300mm
Epoch [ 2/40]
  TRAINING   - Total (weighted): 3790.8655mm  ||  X: 164.0039mm  Y: 155.1764mm  Z: 97.1961mm  ||  3D Euclidean: 283.0523mm
  VALIDATION - Total (weighted): 3683.2670mm  ||  X: 156.2642mm  Y: 161.0075mm  Z: 81.4829mm  ||  3D Euclidean: 276.1003mm
Epoch [ 3/40]
  TRAINING   - Total (weighted): 2976.6656mm  ||  X: 130.1121mm  Y: 120.1613mm  Z: 89.9199mm  ||  3D Euclidean: 231.0686mm
  VALIDATION - Total (weighted): 3460.8574mm  ||  X: 152.1357mm  Y: 151.4585mm  Z: 80.1484mm  ||  3D Euclidean: 259.5754mm
Epoch [ 4/40]
  TRAINING   - Total (weighted): 2849.504