In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
class TrackNetV2Encoder(nn.Module):
    def __init__(self):
        super(TrackNetV2Encoder, self).__init__()
        # Encoder layers (VGG-16 inspired)
        self.conv1 = nn.Sequential(
            nn.Conv2d(9, 64, 3, padding=1),  # 3 frames * 3 channels = 9 input channels
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)  # Downsample to H/2, W/2
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)  # Downsample to H/4, W/4
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)  # Downsample to H/8, W/8
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)  # Downsample to H/16, W/16
        )

    def forward(self, x):
        # x: (batch, 9, H, W) - 3 frames stacked
        conv1_out = self.conv1(x)
        conv2_out = self.conv2(conv1_out)
        conv3_out = self.conv3(conv2_out)
        conv4_out = self.conv4(conv3_out)
        return conv4_out, [conv1_out, conv2_out, conv3_out]  # Return bottleneck and skip connections

class TrackNetV2Decoder(nn.Module):
    def __init__(self):
        super(TrackNetV2Decoder, self).__init__()
        # Decoder layers with upsampling and skip connections
        self.up4 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.ReLU()
        )
        self.conv_up4 = nn.Sequential(
            nn.Conv2d(512, 256, 3, padding=1),  # 256 from up4 + 256 from conv3
            nn.ReLU()
        )
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.ReLU()
        )
        self.conv_up3 = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1),  # 128 from up3 + 128 from conv2
            nn.ReLU()
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU()
        )
        self.conv_up2 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1),  # 64 from up2 + 64 from conv1
            nn.ReLU()
        )
        self.final = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 3, 1)  # Output 3 heatmaps (one per frame)
        )

    def forward(self, x, skip_connections):
        # x: (batch, 512, H/16, W/16), skip_connections: [conv1, conv2, conv3]
        x = self.up4(x)
        x = torch.cat([x, skip_connections[2]], dim=1)  # Concat with conv3
        x = self.conv_up4(x)
        
        x = self.up3(x)
        x = torch.cat([x, skip_connections[1]], dim=1)  # Concat with conv2
        x = self.conv_up3(x)
        
        x = self.up2(x)
        x = torch.cat([x, skip_connections[0]], dim=1)  # Concat with conv1
        x = self.conv_up2(x)
        
        x = self.final(x)  # (batch, 3, H, W) - 3 heatmaps
        return x

In [3]:
class TrackNetV4(nn.Module):
    def __init__(self, theta_init=0.5):
        super(TrackNetV4, self).__init__()
        self.encoder = TrackNetV2Encoder()
        self.decoder = TrackNetV2Decoder()
        # Learnable parameter for power normalization
        self.theta = nn.Parameter(torch.tensor(theta_init, dtype=torch.float32))

    def generate_motion_attention_maps(self, frames):
        # frames: (batch, 9, H, W) -> split into 3 frames: (batch, 3, H, W) each
        batch_size, _, h, w = frames.shape
        frame1 = frames[:, 0:3, :, :]
        frame2 = frames[:, 3:6, :, :]
        frame3 = frames[:, 6:9, :, :]
        
        # Convert to grayscale (mean across channels)
        gray1 = frame1.mean(dim=1, keepdim=True)  # (batch, 1, H, W)
        gray2 = frame2.mean(dim=1, keepdim=True)
        gray3 = frame3.mean(dim=1, keepdim=True)
        
        # Compute absolute frame differences
        diff1 = torch.abs(gray2 - gray1)  # (batch, 1, H, W)
        diff2 = torch.abs(gray3 - gray2)
        
        # Power normalization with learnable theta
        motion_maps = torch.cat([diff1, diff2], dim=1)  # (batch, 2, H, W)
        motion_maps = torch.pow(motion_maps, self.theta)
        motion_maps = (motion_maps - motion_maps.min()) / (motion_maps.max() - motion_maps.min() + 1e-6)
        
        return motion_maps  # (batch, 2, H, W) - 2 motion maps

    def forward(self, x):
        # x: (batch, 9, H, W) - 3 frames stacked
        # Step 1: Generate motion attention maps
        motion_maps = self.generate_motion_attention_maps(x)  # (batch, 2, H, W)
        
        # Step 2: Encoder-Decoder to get visual features
        bottleneck, skip_connections = self.encoder(x)
        visual_features = self.decoder(bottleneck, skip_connections)  # (batch, 3, H, W)
        
        # Step 3: Motion-aware fusion
        fused_features = torch.zeros_like(visual_features)
        fused_features[:, 0, :, :] = visual_features[:, 0, :, :]  # First frame unchanged
        fused_features[:, 1, :, :] = visual_features[:, 1, :, :] * motion_maps[:, 0, :, :]  # Second frame
        fused_features[:, 2, :, :] = visual_features[:, 2, :, :] * motion_maps[:, 1, :, :]  # Third frame
        
        # Step 4: Apply Sigmoid to get heatmaps
        heatmaps = torch.sigmoid(fused_features)  # (batch, 3, H, W)
        return heatmaps

# Instantiate the model
model = TrackNetV4().to(device)

In [4]:
def generate_heatmap(center_x, center_y, h, w, sigma=5):
    x, y = np.meshgrid(np.arange(w), np.arange(h))
    heatmap = np.exp(-((x - center_x) ** 2 + (y - center_y) ** 2) / (2 * sigma ** 2))
    return heatmap / heatmap.max()

def generate_heatmap(center_x, center_y, h, w, sigma=5):
    x, y = np.meshgrid(np.arange(w), np.arange(h))
    heatmap = np.exp(-((x - center_x) ** 2 + (y - center_y) ** 2) / (2 * sigma ** 2))
    return heatmap / heatmap.max()

class TrackNetDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.frame_folders = sorted([f for f in os.listdir(root_dir) if f.startswith('frame')])

    def __len__(self):
        return len(self.frame_folders)

    def __getitem__(self, idx):
        frame_folder = os.path.join(self.root_dir, self.frame_folders[idx])
        images = []
        heatmaps = []
        
        for i in range(3):  # Load 3 consecutive frames (image-0 to image-2)
            subfolder = os.path.join(frame_folder, f'image-{i}')
            img_path = os.path.join(subfolder, 'image.jpg')
            label_path = os.path.join(subfolder, 'label.txt')
            
            # Load image with fallback for missing files
            try:
                img = Image.open(img_path).convert('RGB')
            except FileNotFoundError:
                print(f"Image file {img_path} not found, using blank image")
                img = Image.fromarray(np.zeros((288, 512, 3), dtype=np.uint8))  # Blank RGB image
            if self.transform:
                img = self.transform(img)
            images.append(img)
            
            # Load and process label
            if os.path.exists(label_path):
                with open(label_path, 'r') as f:
                    label = f.readline().strip().split()
                    if not label:  # Empty file or only whitespace
                        heatmap = np.zeros((288, 512))
                    elif len(label) >= 3:  # Valid YOLO format
                        try:
                            x_center = float(label[1]) * 512  # Denormalize to 512x288
                            y_center = float(label[2]) * 288
                            heatmap = generate_heatmap(x_center, y_center, 288, 512)
                        except (ValueError, IndexError) as e:
                            print(f"Error parsing label file {label_path}: {label}, Error: {e}")
                            heatmap = np.zeros((288, 512))
                    else:
                        print(f"Label file {label_path} has insufficient data: {label}")
                        heatmap = np.zeros((288, 512))
            else:
                print(f"No label file at {label_path}, generating zero heatmap")
                heatmap = np.zeros((288, 512))
            heatmaps.append(torch.tensor(heatmap, dtype=torch.float32))
        
        images = torch.cat(images, dim=0)  # (9, H, W)
        heatmaps = torch.stack(heatmaps, dim=0)  # (3, H, W)
        return images, heatmaps

transform = transforms.Compose([
    transforms.Resize((288, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = TrackNetDataset(root_dir='dataset', transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [5]:
def weighted_bce_loss(pred, target, pos_weight=100):
    bce = F.binary_cross_entropy(pred, target, reduction='none')
    weights = target * pos_weight + (1 - target)
    return (bce * weights).mean()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_idx, (images, heatmaps) in enumerate(dataloader):
        images, heatmaps = images.to(device), heatmaps.to(device)  # (batch, 9, H, W), (batch, 3, H, W)
        
        pred_heatmaps = model(images)  # (batch, 3, H, W)
        
        loss = weighted_bce_loss(pred_heatmaps, heatmaps)
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}')
    
    avg_loss = total_loss / len(dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}] Average Loss: {avg_loss:.4f}')

torch.save(model.state_dict(), 'tracknetv4.pth')

Epoch [1/50], Batch [0/247], Loss: 0.7840
Image file dataset/frame20/image-0/image.jpg not found, using blank image
No label file at dataset/frame20/image-0/label.txt, generating zero heatmap
Image file dataset/frame20/image-1/image.jpg not found, using blank image
No label file at dataset/frame20/image-1/label.txt, generating zero heatmap
Image file dataset/frame20/image-2/image.jpg not found, using blank image
No label file at dataset/frame20/image-2/label.txt, generating zero heatmap
Epoch [1/50], Batch [10/247], Loss: 0.6421
Epoch [1/50], Batch [20/247], Loss: 0.5427
Epoch [1/50], Batch [30/247], Loss: 0.5408
Epoch [1/50], Batch [40/247], Loss: 0.5513
Epoch [1/50], Batch [50/247], Loss: 0.5721
Epoch [1/50], Batch [60/247], Loss: 0.5024
Epoch [1/50], Batch [70/247], Loss: 0.4577
Epoch [1/50], Batch [80/247], Loss: 0.4441
Epoch [1/50], Batch [90/247], Loss: 0.4971
Epoch [1/50], Batch [100/247], Loss: 0.4680
Epoch [1/50], Batch [110/247], Loss: 0.4799
Epoch [1/50], Batch [120/247], Lo

../aten/src/ATen/native/cuda/Loss.cu:94: operator(): block: [347,0,0], thread: [34,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:94: operator(): block: [347,0,0], thread: [36,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:94: operator(): block: [347,0,0], thread: [37,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:94: operator(): block: [347,0,0], thread: [39,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:94: operator(): block: [347,0,0], thread: [43,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:94: operator(): block: [347,0,0], thread: [44,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:94: operator(): block: [347,0,0], thread: [47,0,0] Assertion `input_val >= zero && inpu

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
model.eval()
with torch.no_grad():
    sample_images, sample_heatmaps = next(iter(dataloader))
    sample_images = sample_images.to(device)
    pred_heatmaps = model(sample_images)
    
    for i in range(3):
        plt.subplot(2, 3, i+1)
        plt.imshow(sample_images[0, i*3:(i+1)*3].mean(dim=0).cpu(), cmap='gray')
        plt.title(f'Frame {i+1}')
        plt.subplot(2, 3, i+4)
        plt.imshow(pred_heatmaps[0, i].cpu(), cmap='hot')
        plt.title(f'Heatmap {i+1}')
    plt.show()