# TrackNet Architecture

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TrackNet(nn.Module):
    def __init__(self):
        super(TrackNet, self).__init__()
        
        # Encoder
        self.conv1 = nn.Conv2d(9, 64, kernel_size=3, padding=1)  # 3 consecutive frames stacked (3x3 = 9 channels)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        
        # Decoder
        self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
        self.deconv3 = nn.ConvTranspose2d(64, 1, kernel_size=3, padding=1)  # Output heatmap
        
        # Maxpool layer
        self.pool = nn.MaxPool2d(2, 2)
    
    def forward(self, x):
        # Encoder path
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        
        # Decoder path
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = torch.sigmoid(self.deconv3(x))  # Output heatmap between 0 and 1
        
        return x


# Data Loaders

In [None]:
from torchvision import transforms
from torch.utils.data import Dataset
import cv2
import numpy as np

class SquashBallDataset(Dataset):
    def __init__(self, coco_data, transform=None):
        self.data = coco_data  # Load COCO formatted annotations
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Load 3 consecutive frames
        img1 = cv2.imread(self.data[idx]['frame_1_path'])
        img2 = cv2.imread(self.data[idx]['frame_2_path'])
        img3 = cv2.imread(self.data[idx]['frame_3_path'])
        
        # Stack the frames along the channel dimension
        img = np.dstack((img1, img2, img3))
        
        # Load corresponding heatmap (ball location)
        heatmap = self.data[idx]['heatmap']  # Load the heatmap of the ball
        
        if self.transform:
            img = self.transform(img)
            heatmap = self.transform(heatmap)
        
        return img, heatmap

# Data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
])




# Training Loop

In [None]:
import torch.optim as optim

# Initialize model, optimizer, and loss function
model = TrackNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
def train_model(model, data_loader, num_epochs=25):
    model.train()
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        for imgs, heatmaps in data_loader:
            imgs = imgs.float()
            heatmaps = heatmaps.float()

            # Forward pass
            outputs = model(imgs)
            loss = criterion(outputs, heatmaps)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(data_loader):.4f}')

# Example to train:
# train_model(model, data_loader)
