In [1]:
import numpy as np

# Load the data
inner_data = np.loadtxt(r'C:\Users\h\Desktop\PET_Recons\Data\ML_Data\inner_training_data_test.csv', delimiter=',')
outer_data = np.loadtxt(r'C:\Users\h\Desktop\PET_Recons\Data\ML_Data\outer_training_data_test.csv', delimiter=',')

print("Inner shape:", inner_data.shape)
print("Outer shape:", outer_data.shape)

# Separate features and labels
inner_features, inner_labels = inner_data[:, :-6], inner_data[:, -6:]
outer_features, outer_labels = outer_data[:, :-6], outer_data[:, -6:]

print("Inner labels shape:", inner_labels.shape)
print("Outer labels shape:", outer_labels.shape)

# Validate that labels are identical
labels_equal = np.array_equal(inner_labels, outer_labels)

print("Labels are identical:" if labels_equal else "⚠️ Labels are NOT identical!")


Inner shape: (5000, 8495)
Outer shape: (5000, 10135)
Inner labels shape: (5000, 6)
Outer labels shape: (5000, 6)
⚠️ Labels are NOT identical!


In [2]:
import numpy as np

# Compare labels row-wise
row_equal = np.all(inner_labels == outer_labels, axis=1)

# Get indices of mismatched rows
mismatched_indices = np.where(~row_equal)[0]

print(f"Number of mismatched rows: {len(mismatched_indices)}")
print("First 10 indices with mismatched labels:", mismatched_indices[:10])

# If you want to inspect the actual labels for a few mismatches:
for idx in mismatched_indices[:5]:  # show first 5 mismatches
    print(f"\nRow {idx}:")
    print("Inner:", inner_labels[idx])
    print("Outer:", outer_labels[idx])


Number of mismatched rows: 4844
First 10 indices with mismatched labels: [13 16 30 33 37 40 42 45 47 48]

Row 13:
Inner: [ 233.3549     42.308315 -140.34695     0.          0.          0.      ]
Outer: [   0.         0.         0.      -199.61256 -141.07738 -142.0763 ]

Row 16:
Inner: [   0.         0.         0.      -199.61256 -141.07738 -142.0763 ]
Outer: [ 89.508736 252.85379   84.34045    0.         0.         0.      ]

Row 30:
Inner: [   0.          0.          0.       -231.473     -84.782135 -147.94    ]
Outer: [   0.          0.          0.       -273.9889    -22.72473    53.701527]

Row 33:
Inner: [   0.          0.          0.       -273.9889    -22.72473    53.701527]
Outer: [270.10522    -0.7890214 -31.816393    0.          0.          0.       ]

Row 37:
Inner: [270.10522    -0.7890214 -31.816393    0.          0.          0.       ]
Outer: [-243.86166  -77.32867 -135.17099    0.         0.         0.     ]


In [3]:
import numpy as np

# Assuming data is already loaded:
# inner_data, outer_data

# Split features/labels
inner_features, inner_labels = inner_data[:, :-6], inner_data[:, -6:]
outer_features, outer_labels = outer_data[:, :-6], outer_data[:, -6:]

# The "7th last column" is -7
inner_keys = inner_data[:, -7]
outer_keys = outer_data[:, -7]

# Storage for filtered results
filtered_inner_features = []
filtered_outer_features = []
filtered_inner_labels = []
filtered_outer_labels = []

i, j = 0, 0
while i < len(inner_keys) and j < len(outer_keys):
    if inner_keys[i] == outer_keys[j]:
        # Keys match → keep both
        filtered_inner_features.append(inner_features[i])
        filtered_outer_features.append(outer_features[j])
        filtered_inner_labels.append(inner_labels[i])
        filtered_outer_labels.append(outer_labels[j])
        i += 1
        j += 1
    elif inner_keys[i] < outer_keys[j]:
        # Advance inner pointer
        i += 1
    else:
        # Advance outer pointer
        j += 1

# Convert lists back into numpy arrays
filtered_inner_features = np.array(filtered_inner_features)
filtered_outer_features = np.array(filtered_outer_features)
filtered_inner_labels = np.array(filtered_inner_labels)
filtered_outer_labels = np.array(filtered_outer_labels)

print("Filtered inner features shape:", filtered_inner_features.shape)
print("Filtered outer features shape:", filtered_outer_features.shape)
print("Filtered inner labels shape:", filtered_inner_labels.shape)
print("Filtered outer labels shape:", filtered_outer_labels.shape)

# Final validation: check labels after alignment
labels_equal = np.array_equal(filtered_inner_labels, filtered_outer_labels)
print("Labels are identical after filtering?", labels_equal)

if labels_equal is True:
    labels = filtered_inner_labels


Filtered inner features shape: (4316, 8489)
Filtered outer features shape: (4316, 10129)
Filtered inner labels shape: (4316, 6)
Filtered outer labels shape: (4316, 6)
Labels are identical after filtering? True


In [4]:
# Remove first and last column from feature arrays
filtered_inner_features = filtered_inner_features[:, 1:-1]
filtered_outer_features = filtered_outer_features[:, 1:-1]

print("Inner features shape after dropping cols:", filtered_inner_features.shape)
print("Outer features shape after dropping cols:", filtered_outer_features.shape)


Inner features shape after dropping cols: (4316, 8487)
Outer features shape after dropping cols: (4316, 10127)


In [5]:
import numpy as np
from collections import Counter

run_lengths = []
current_len = 1

for i in range(1, len(labels)):
    if np.array_equal(labels[i], labels[i-1]):
        current_len += 1
    else:
        run_lengths.append(current_len)
        current_len = 1

# Append last run
run_lengths.append(current_len)

# Count occurrences of each run length
run_length_counts = Counter(run_lengths)

print("Contiguous label run stats:")
for length, count in sorted(run_length_counts.items()):
    print(f"Length {length}: {count} times")

print("\nTotal runs found:", len(run_lengths))


Contiguous label run stats:
Length 1: 36 times
Length 2: 318 times
Length 3: 833 times
Length 4: 144 times
Length 5: 53 times
Length 6: 33 times
Length 7: 9 times
Length 8: 2 times
Length 9: 3 times

Total runs found: 1431


In [6]:
import numpy as np

def make_contiguous_examples(inner_features, outer_features, labels, run_length=3):
    """
    Create tuples of (inner_window, outer_window, label) for contiguous sequences of 'run_length'.
    Labels must remain identical across the window.
    """
    examples = []
    n = len(labels)
    start = 0
    
    while start < n:
        # Find length of current contiguous block
        curr_label = labels[start]
        end = start + 1
        while end < n and np.array_equal(labels[end], curr_label):
            end += 1
        block_len = end - start
        
        # If block is long enough, generate sliding windows of size run_length
        if block_len >= run_length:
            for i in range(start, end - run_length + 1):
                inner_window = inner_features[i:i+run_length]
                outer_window = outer_features[i:i+run_length]
                # label is same across whole block, so we just take one
                examples.append((inner_window, outer_window, curr_label))
        
        # Move to next block
        start = end
    
    return examples

data = make_contiguous_examples(inner_features=filtered_inner_features,
                                outer_features=filtered_outer_features,
                                labels=labels,
                                run_length=3)

print(f"len(data) = {len(data)} examples made")
print(f"len(data[0]) = {len(data[0])} which are the inner, outer, and labels")
print(f"Inner shape: data[0][0].shape = {data[0][0].shape}")
print(f"Outer shape: data[0][1].shape = {data[0][1].shape}")
print(f"Label shape: data[0][2].shape = {data[0][2].shape}")

len(data) = 1490 examples made
len(data[0]) = 3 which are the inner, outer, and labels
Inner shape: data[0][0].shape = (3, 8487)
Outer shape: data[0][1].shape = (3, 10127)
Label shape: data[0][2].shape = (6,)


In [7]:
def reshape_examples(data):
    reshaped = []
    for inner, outer, label in data:
        # reshape each contiguous example
        inner_reshaped = inner.reshape(inner.shape[0], 207, 41)
        outer_reshaped = outer.reshape(outer.shape[0], 247, 41)
        reshaped.append((inner_reshaped, outer_reshaped, label))
    return reshaped

reshaped_data = reshape_examples(data)

print("Original:", data[0][0].shape, data[0][0].shape, data[0][0].shape)
print("Original:", data[0][1].shape, data[0][1].shape, data[0][1].shape)
print("\nReshaped:", reshaped_data[0][0].shape, reshaped_data[0][0].shape, reshaped_data[0][0].shape)
print("Reshaped:", reshaped_data[0][1].shape, reshaped_data[0][1].shape, reshaped_data[0][1].shape)


Original: (3, 8487) (3, 8487) (3, 8487)
Original: (3, 10127) (3, 10127) (3, 10127)

Reshaped: (3, 207, 41) (3, 207, 41) (3, 207, 41)
Reshaped: (3, 247, 41) (3, 247, 41) (3, 247, 41)


In [8]:
import numpy as np
from scipy.ndimage import zoom

def resize_features(arr, target_len):
    """
    Resizes arr from shape (frames, length, channels) 
    to (frames, target_len, channels).
    Uses linear interpolation along the 'length' axis.
    """
    frames, length, channels = arr.shape
    zoom_factor = target_len / length
    # zoom only along axis=1 (the "length" dimension), keep others unchanged
    resized = zoom(arr, (1, zoom_factor, 1), order=1)
    return resized

def resize_dataset(data, target_len=207):
    """
    Applies resize_features to the entire dataset of (inner, outer, label) triples.
    - Inner is kept as-is (already target_len).
    - Outer is resized to target_len.
    """
    resized_data = []
    for inner, outer, label in data:
        inner_resized = inner  # already target_len
        outer_resized = resize_features(outer, target_len)
        resized_data.append((inner_resized, outer_resized, label))
    return resized_data

# Apply to your dataset
resized_data = resize_dataset(reshaped_data, target_len=207)

len(resized_data)
len(resized_data[0])
resized_data[0][0].shape
resized_data[0][1].shape
resized_data[0][2].shape


(6,)

In [9]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_inner_outer_slider(inner, outer):
    """
    Visualize inner and outer arrays side-by-side with a slider to control timestep.
    inner: np array of shape (3, 207, 41)
    outer: np array of shape (3, 207, 41)
    """
    frames = []
    for t in range(inner.shape[0]):
        frames.append(go.Frame(
            data=[
                go.Heatmap(z=inner[t], colorscale='Cividis', colorbar=dict(title='Inner'), zmin=inner.min(), zmax=inner.max()),
                go.Heatmap(z=outer[t], colorscale='Cividis', colorbar=dict(title='Outer'), zmin=outer.min(), zmax=outer.max())
            ],
            name=str(t)
        ))

    fig = make_subplots(rows=1, cols=2,
                        subplot_titles=("Inner", "Outer"),
                        horizontal_spacing=0.1,
                        specs=[[{"type": "heatmap"}, {"type": "heatmap"}]])

    fig.add_trace(frames[0].data[1], row=1, col=1)  # inner heatmap (first trace of frame)
    fig.add_trace(frames[0].data[1], row=1, col=2)  # outer heatmap (second trace of frame)

    # Slider steps
    steps = []
    for i in range(len(frames)):
        steps.append(dict(
            method="animate",
            args=[[str(i)], 
                  dict(mode="immediate", frame=dict(duration=500, redraw=True), transition=dict(duration=0))],
            label=str(i)
        ))

    sliders = [dict(
        active=0,
        pad={"t": 50},
        steps=steps
    )]

    fig.update_layout(
        title="Inner and Outer Feature Maps Over Time",
        sliders=sliders,
        height=600,
        width=900
    )

    fig.frames = frames

    fig.show()

example = 15

plot_inner_outer_slider(resized_data[example][0],
                        resized_data[example][1])

In [10]:
import torch

# Prepare lists for inner and outer tensors for all elements
inner_tensors = []
outer_tensors = []
labels_tensors = []

for data in resized_data:
    inner_tensors.append(torch.tensor(data[0]))  # (timesteps, height, width)
    outer_tensors.append(torch.tensor(data[1]))  # (timesteps, height, width)
    labels_tensors.append(torch.tensor(data[2]))  # (label)

# Stack each list into a tensor along dimension 0 (num_tensors)
inner_tensor = torch.stack(inner_tensors, dim=0)  # (num_tensors, timesteps, height, width)
outer_tensor = torch.stack(outer_tensors, dim=0)  # (num_tensors, timesteps, height, width)
labels_tensor = torch.stack(labels_tensors, dim=0)  # (num_tensors, timesteps, height, width)

# Stack inner and outer along new dim=2 (channels)
combined_tensor = torch.stack((inner_tensor, outer_tensor), dim=1)  # (num_tensors, timesteps, 2, height, width)

print(combined_tensor.shape)
print(labels_tensor.shape)


torch.Size([1490, 2, 3, 207, 41])
torch.Size([1490, 6])


In [11]:
import torch

# Set your split ratio
train_ratio = 0.8
num_samples = combined_tensor.shape[0]
train_size = int(num_samples * train_ratio)
test_size = num_samples - train_size

# Optionally shuffle indices
indices = torch.randperm(num_samples)   # shuffle all sample indices
train_idx = indices[:train_size]
test_idx = indices[train_size:]

# Index into tensors
train_data = combined_tensor[train_idx]
train_labels = labels_tensor[train_idx]

test_data = combined_tensor[test_idx]
test_labels = labels_tensor[test_idx]

In [12]:
print(f"Train tensor: {train_data.shape}, Labels:{train_labels.shape}")
print(f"Test tensor: {test_data.shape}, Labels:{test_labels.shape}")

Train tensor: torch.Size([1192, 2, 3, 207, 41]), Labels:torch.Size([1192, 6])
Test tensor: torch.Size([298, 2, 3, 207, 41]), Labels:torch.Size([298, 6])


In [13]:
import torch

def filter_examples_with_zero_in_label(data_tensor, labels_tensor):
    has_zeros = (labels_tensor == 0)

    has_any_zero_in_row = torch.any(has_zeros, dim=1)
    indices_to_keep = ~has_any_zero_in_row
    filtered_data = data_tensor[indices_to_keep]
    filtered_labels = labels_tensor[indices_to_keep]

    return filtered_data, filtered_labels

# --- Apply the Function ---
train_data, train_labels = filter_examples_with_zero_in_label(train_data, train_labels)
test_data, test_labels = filter_examples_with_zero_in_label(test_data, test_labels)

print(train_data.shape)
print(train_labels.shape)
print(test_data.shape)
print(test_labels.shape)


torch.Size([152, 2, 3, 207, 41])
torch.Size([152, 6])
torch.Size([44, 2, 3, 207, 41])
torch.Size([44, 6])


In [18]:
# takes attention_petnet.py and places the global attention layer after layer 1
# for higher dimensional G.A. Employs point-wise seperation convolutions over
# full fat classical convolutions for an 8-10x reduction in paramater count. 
# Also introduces a residual connection after the global attention, and layer norm
# for the global attention (Apparently batch norm works better for Convolutions, 
# and layer norm works better for Transformers...we'll see about that)

import torch
import torch.nn as nn

###############################################################################
# Global Attention Module
###############################################################################
class GlobalAttention3D(nn.Module):
    """
    TorchScript-compatible global attention module for 3D feature maps.
    Uses manual attention implementation instead of nn.MultiheadAttention.
    """
    
    def __init__(self, in_channels=64, embed_dim=128, output_dim=64, num_heads=2):
        super(GlobalAttention3D, self).__init__()
        
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.output_dim = output_dim
        
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        # Map from input channels to embedding dimension
        self.channel_proj = nn.Linear(in_channels, embed_dim)
        
        # Manual attention projections (instead of nn.MultiheadAttention)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        # Map from embedding dimension to output channels
        self.output_proj = nn.Linear(embed_dim, output_dim)
        
        # Learnable positional encoding will be initialized based on input size
        self.pos_encoding = None
        self.scale = self.head_dim ** -0.5
        self.layernorm = nn.LayerNorm(self.embed_dim)

        
    def _init_positional_encoding(self, D, H, W):
        """Initialize learnable positional encoding for DxHxW spatial positions"""
        num_positions = D * H * W
        self.pos_encoding = nn.Parameter(torch.randn(1, num_positions, self.embed_dim))
        nn.init.normal_(self.pos_encoding, std=0.02)
        
    def forward(self, x):
        batch_size, channels, D, H, W = x.shape
        seq_len = D * H * W

        if self.pos_encoding is None:
            self._init_positional_encoding(D, H, W)
            self.pos_encoding = self.pos_encoding.to(x.device)

        x_flat = x.permute(0, 2, 3, 4, 1).contiguous().view(batch_size, seq_len, channels)
        x_proj = self.channel_proj(x_flat)  # (batch, seq_len, embed_dim)
        residual = x_proj

        x_flat = x_proj + self.pos_encoding

        q = self.q_proj(x_flat)
        k = self.k_proj(x_flat)
        v = self.v_proj(x_flat)

        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn_weights = torch.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        attn_output = self.out_proj(attn_output)

        # Residual + LayerNorm
        attn_output = attn_output + residual  # Residual
        attn_output = self.layernorm(attn_output)   # LayerNorm

        output = self.output_proj(attn_output)  # (batch, seq_len, output_dim)
        output = output.permute(0, 2, 1).contiguous().view(batch_size, self.output_dim, D, H, W)
        return output


###############################################################################
# 3D Residual Block
###############################################################################
class ResidualBlock3D(nn.Module):
    """
    Residual block with 3D depthwise + pointwise convolution.
    Uses circular padding on width, constant elsewhere.
    Includes post skip-connection processing layer.
    """
    def __init__(self, in_channels, out_channels, stride=(1, 2, 2)):
        super(ResidualBlock3D, self).__init__()

        # Depthwise 3D convolution (groups=in_channels)
        self.dw_conv1 = nn.Conv3d(
            in_channels, in_channels,
            kernel_size=3, stride=stride, padding=0,
            groups=in_channels, bias=False
        )
        self.pw_conv1 = nn.Conv3d(
            in_channels, out_channels,
            kernel_size=1, stride=1, padding=0, bias=False
        )
        self.bn1 = nn.BatchNorm3d(out_channels)

        # Depthwise 3D convolution (stride=1)
        self.dw_conv2 = nn.Conv3d(
            out_channels, out_channels,
            kernel_size=3, stride=1, padding=0,
            groups=out_channels, bias=False
        )
        self.pw_conv2 = nn.Conv3d(
            out_channels, out_channels,
            kernel_size=1, stride=1, padding=0, bias=False
        )
        self.bn2 = nn.BatchNorm3d(out_channels)

        # Post skip-connection processing layer (depthwise-pointwise)
        self.dw_conv_post = nn.Conv3d(
            out_channels, out_channels,
            kernel_size=3, stride=1, padding=0,
            groups=out_channels, bias=False
        )
        self.pw_conv_post = nn.Conv3d(
            out_channels, out_channels,
            kernel_size=1, stride=1, padding=0, bias=False
        )
        self.bn_post = nn.BatchNorm3d(out_channels)

        self.activation = nn.GELU()

        # Shortcut for matching dimensions
        self.shortcut = None
        if stride != (1, 1, 1) or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv3d(in_channels, out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm3d(out_channels)
            )

    def _apply_circular_padding(self, x):
        x = torch.nn.functional.pad(x, (1, 1, 0, 0, 0, 0), mode='circular')
        x = torch.nn.functional.pad(x, (0, 0, 1, 1, 1, 1), mode='constant', value=0)
        return x

    def forward(self, x):
        # Main path
        out = self._apply_circular_padding(x)
        out = self.dw_conv1(out)
        out = self.pw_conv1(out)
        out = self.bn1(out)
        out = self.activation(out)

        out = self._apply_circular_padding(out)
        out = self.dw_conv2(out)
        out = self.pw_conv2(out)
        out = self.bn2(out)

        # Shortcut
        if self.shortcut is not None:
            x = self.shortcut(x)

        out += x

        out = self._apply_circular_padding(out)
        out = self.dw_conv_post(out)
        out = self.pw_conv_post(out)
        out = self.bn_post(out)
        out = self.activation(out)

        return out



###############################################################################
# PetNetImproved3D with Early Global Attention
###############################################################################
class PetNetImproved3D(nn.Module):
    def __init__(self, num_classes=6):
        print("Loading PetnetImproved3D Model with Early Global Attention...")
        super(PetNetImproved3D, self).__init__()

        self.conv_in = nn.Sequential(
            nn.Conv3d(2, 2, kernel_size=3, stride=1, padding=1, groups=2, bias=False),    # depthwise
            nn.Conv3d(2, 16, kernel_size=1, stride=1, padding=0, bias=False)              # pointwise
        )
        self.bn_in = nn.BatchNorm3d(16)
        self.activation = nn.GELU()

        self.layer1 = ResidualBlock3D(16, 32, stride=(1, 2, 2))      # downsample H,W

        # Global attention after layer1 (32 channels)
        self.global_attention = GlobalAttention3D(
            in_channels=32,
            embed_dim=128,
            output_dim=32,     # must match "out_channels"
            num_heads=8
        )

        self.layer2 = ResidualBlock3D(32, 64, stride=(1, 2, 2))
        self.layer3 = ResidualBlock3D(64, 128, stride=(1, 2, 2))
        self.layer4 = ResidualBlock3D(128, 256, stride=(1, 2, 2))
        self.layer5 = ResidualBlock3D(256, 512, stride=(1, 2, 2))

        self.dropout = nn.Dropout(0.3)

        fc_in_features = self._compute_fc_input_size()
        self.fc1 = nn.Linear(fc_in_features, 1024, bias=True)
        self.fc2 = nn.Linear(1024, num_classes, bias=True)

        self._initialize_weights()

    def _compute_fc_input_size(self, C=2, T=3, H=207, W=41):
        with torch.no_grad():
            dummy = torch.zeros(1, C, T, H, W)
            out = self.conv_in(dummy)
            out = self.bn_in(out)
            out = self.activation(out)
            out = self.layer1(out)
            # Global attention after layer1
            out = self.global_attention(out)
            out = self.layer2(out)
            out = self.layer3(out)
            out = self.layer4(out)
            out = self.layer5(out)
            out = out.view(out.size(0), -1)
            return out.shape[1]

    def forward(self, x, debug=False):
        if debug: print(f"{x.shape} Input shape")
        x = self.conv_in(x)
        if debug: print(f"{x.shape} After conv_in")
        x = self.bn_in(x)
        if debug: print(f"{x.shape} After bn_input")
        x = self.activation(x)
        if debug: print(f"{x.shape} After activation")

        x = self.layer1(x)
        if debug: print(f"{x.shape} After layer 1")

        # Move attention here
        x = self.global_attention(x)
        if debug: print(f"{x.shape} After global attention (now after layer 1)")

        x = self.layer2(x)
        if debug: print(f"{x.shape} After layer 2")
        x = self.layer3(x)
        if debug: print(f"{x.shape} After layer 3")
        x = self.layer4(x)
        if debug: print(f"{x.shape} After layer 4")
        x = self.layer5(x)
        if debug: print(f"{x.shape} After layer 5")

        x = x.view(x.size(0), -1)  # Flatten all features (B, all_channels)
        if debug: print(f"{x.shape} After flattening all channels/voxels")
        x = self.fc1(x)
        if debug: print(f"{x.shape} After fc layer 1")
        x = self.activation(x)
        x = self.dropout(x)
        if debug: print(f"{x.shape} After activation and dropout")
        x = self.fc2(x)
        if debug: print(f"{x.shape} After fc layer 2 (output)")

        return x

    def _initialize_weights(self):
        """
        Kaiming (He) Initialization for Conv3d and Linear layers.
        """
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)




if __name__ == "__main__":

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    B = 1
    # B = 128

    C = 2
    T = 3 

    H = 207
    # H = 496

    W = 41
    # W = 84
    
    CLASSES = 6

    # Model instantiation
    model = PetNetImproved3D(num_classes=CLASSES).to(device)
    
    # Print parameter count
    param_count = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {param_count:,}")
    
    # Test a dummy pass
    dummy_input = torch.randn(B, C, T, H, W).to(device)
    dummy_target = torch.randn(B, CLASSES).to(device)
    model.forward(dummy_input, debug=True)

    # Dummy training loop to observe loss reduction
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    criterion = nn.MSELoss()

    import time

    total_time = 0.0
    epochs = 5

    for epoch in range(epochs):
        optimizer.zero_grad()
        start_time = time.time()
        output = model(dummy_input)
        end_time = time.time()
        forward_time = end_time - start_time
        total_time += forward_time

        loss = criterion(output, dummy_target)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Forward Time: {forward_time:.6f}s")

    average_time = total_time / epochs
    print(f"Average Forward Pass Time: {average_time:.6f}s")

Using device: cuda
Loading PetnetImproved3D Model with Early Global Attention...
Total parameters: 24,062,764
torch.Size([1, 2, 3, 207, 41]) Input shape
torch.Size([1, 16, 3, 207, 41]) After conv_in
torch.Size([1, 16, 3, 207, 41]) After bn_input
torch.Size([1, 16, 3, 207, 41]) After activation
torch.Size([1, 32, 3, 104, 21]) After layer 1
torch.Size([1, 32, 3, 104, 21]) After global attention (now after layer 1)
torch.Size([1, 64, 3, 52, 11]) After layer 2
torch.Size([1, 128, 3, 26, 6]) After layer 3
torch.Size([1, 256, 3, 13, 3]) After layer 4
torch.Size([1, 512, 3, 7, 2]) After layer 5
torch.Size([1, 21504]) After flattening all channels/voxels
torch.Size([1, 1024]) After fc layer 1
torch.Size([1, 1024]) After activation and dropout
torch.Size([1, 6]) After fc layer 2 (output)
Epoch 1, Loss: 4.8781, Forward Time: 0.006000s
Epoch 2, Loss: 9.6512, Forward Time: 0.004997s
Epoch 3, Loss: 4.8002, Forward Time: 0.004997s
Epoch 4, Loss: 1.5426, Forward Time: 0.005998s
Epoch 5, Loss: 2.2470,

In [19]:
import torch
from torch.utils.data import DataLoader, TensorDataset

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

BATCH_SIZE = 1
NUM_EPOCHS = 40

train_dataset = TensorDataset(train_data, train_labels)
test_dataset = TensorDataset(test_data, test_labels)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# Placeholder for your PyTorch model
model = model.float().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Custom Euclidean distance loss function
def euclidean_distance_loss(pred, target):
    return torch.mean(torch.norm(pred - target, dim=1))

criterion = euclidean_distance_loss

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in train_loader:
        inputs = inputs.float().to(device)
        labels = labels.float().to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        print(loss.item())
        optimizer.step()
        running_loss += loss.item()
    
    avg_train_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {avg_train_loss:.4f}")
    
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.float().to(device)
            labels = labels.float().to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

    avg_test_loss = test_loss / len(test_loader)
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}")


Using device: cuda
392.87811279296875
388.1553955078125
391.2451171875
390.71649169921875
387.8299255371094
380.07855224609375
397.1094665527344
403.48541259765625
383.3761291503906
397.1708679199219
397.7418212890625
413.7375793457031
380.2054443359375
386.1660461425781
394.78875732421875
382.75341796875
372.4220886230469
372.4580993652344
398.74090576171875
381.4772644042969
387.0618896484375
397.1957092285156
384.0345764160156
395.29425048828125
391.630615234375
377.2464294433594
395.1893005371094
377.087158203125
385.6849365234375
396.0188903808594
371.4977722167969
388.5263977050781
391.9915771484375
400.9273681640625
422.7112121582031
395.7725830078125
354.96051025390625
422.9128723144531
385.2550964355469
421.4732666015625
408.0708923339844
366.5907897949219
354.9255065917969
402.3260803222656
366.0848083496094
371.2934875488281
357.21942138671875
387.1688232421875
392.9109191894531
367.9459533691406
414.4554443359375
342.4571838378906
353.006103515625
351.49725341796875
370.566

KeyboardInterrupt: 