# PointNet for Radar Data
Adapted for 5D points: `[x, y, z, v, snr]`

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import json
import numpy as np

In [None]:
def load_radar_lines(file_path):
    """Reads the raw JSON lines from the file."""
    valid_frames = []
    with open(file_path, 'r') as f:
        for line in f:
            line = line.strip()
            if line:
                try:
                    valid_frames.append(json.loads(line))
                except json.JSONDecodeError:
                    pass
    print(f"Loaded {len(valid_frames)} frames.")
    return valid_frames

In [None]:
def process_frame(frame, num_points=50):
    """Extracts features and pads/samples to fixed point count."""
    # Features: x, y, z, v, snr
    feats = [frame.get(k, []) for k in ['x', 'y', 'z', 'v', 'snr']]
    points = np.array(feats).T # Shape (N, 5)
    
    if len(points) == 0:
        return np.zeros((num_points, 5))
        
    # Sampling or Padding
    if len(points) >= num_points:
        choice = np.random.choice(len(points), num_points, replace=False)
        points = points[choice, :]
    else:
        choice = np.random.choice(len(points), num_points, replace=True)
        points = points[choice, :]
        
    return points # (50, 5)

In [None]:
class RadarDataset(Dataset):
    def __init__(self, frames, num_points=50):
        self.frames = frames
        self.num_points = num_points

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        points = process_frame(self.frames[idx], self.num_points)
        # Transpose for PyTorch Conv1d: (Channels, Points)
        points = torch.from_numpy(points).float().transpose(1, 0)
        return points, 0 # Dummy label 0

In [None]:
class TNet(nn.Module):
    """Transformation Network to align input points."""
    def __init__(self, k=5):
        super().__init__()
        self.conv1 = nn.Conv1d(k, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)
        self.k = k

    def forward(self, x):
        batch_size = x.size(0)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        # Initialize as identity
        # identity = torch.eye(self.k, requires_grad=True).repeat(batch_size, 1, 1)
        # x = x.view(-1, self.k, self.k) + identity.to(x.device)
        return x.view(-1, self.k, self.k)

In [None]:
class PointNetBackbone(nn.Module):
    """Extracts global features from point cloud."""
    def __init__(self, input_dim=5):
        super().__init__()
        self.tnet = TNet(k=input_dim)
        self.conv1 = nn.Conv1d(input_dim, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)

    def forward(self, x):
        # x shape: (Batch, 5, NumPoints)
        trans = self.tnet(x)
        x = x.transpose(2, 1) # (B, N, 5)
        x = torch.bmm(x, trans) # Align
        x = x.transpose(2, 1) # (B, 5, N)
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        
        # Max Pooling (Symmetric function)
        x = torch.max(x, 2, keepdim=True)[0]
        return x.view(-1, 1024)

In [None]:
class PointNetClassifier(nn.Module):
    """Standard classification head."""
    def __init__(self, num_classes=2):
        super().__init__()
        self.feat = PointNetBackbone(input_dim=5)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.feat(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return F.log_softmax(self.fc3(x), dim=1)

In [None]:
"""Execution and Testing Block"""
# 1. Load Data
frames = load_radar_lines('data/first.txt')

# 2. Dataset & Loader
dataset = RadarDataset(frames, num_points=50)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 3. Initialize Model
model = PointNetClassifier(num_classes=2)
print("Model Initialized.")

# 4. Dry Run (Forward Pass)
if len(dataset) > 0:
    sample_batch, _ = next(iter(dataloader))
    print(f"Input shape: {sample_batch.shape}") # (32, 5, 50)
    
    output = model(sample_batch)
    print(f"Output shape: {output.shape}") # (32, 2)
    print("Dry run successful!")