<a href="https://colab.research.google.com/github/AI01010/FoV_ML_spatial_audio_prediction/blob/main/FoV_ML_spatial_audio_prediction_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision joblib pandas tqdm tabulate --quiet

import time, os, torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import numpy as np, pandas as pd
from tabulate import tabulate
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# Configuration
FRAME_WIDTH = 3840  # 4K width
FRAME_HEIGHT = 1920  # 4K height (ERP format)
TILES_X = 16
TILES_Y = 9
NUM_TILES = TILES_X * TILES_Y  # 144 tiles
TILE_WIDTH = FRAME_WIDTH // TILES_X  # 240px
TILE_HEIGHT = FRAME_HEIGHT // TILES_Y  # ~213px
NUM_HEATMAPS = 5  # 3 audio + 2 video

class SaliencyTileDataset(Dataset):
    """Dataset for saliency heatmaps and corresponding tile indices"""
    
    def __init__(self, heatmaps, tile_indices, transform=None):
        """
        Args:
            heatmaps: numpy array of shape (N, NUM_HEATMAPS, H, W)
            tile_indices: numpy array of shape (N,) with tile indices [0-143]
            transform: optional transform to apply
        """
        self.heatmaps = torch.FloatTensor(heatmaps)
        self.tile_indices = torch.LongTensor(tile_indices)
        self.transform = transform
        
    def __len__(self):
        return len(self.tile_indices)
    
    def __getitem__(self, idx):
        heatmap = self.heatmaps[idx]
        tile_idx = self.tile_indices[idx]
        
        if self.transform:
            heatmap = self.transform(heatmap)
            
        return heatmap, tile_idx


class HeatmapFusionCNN(nn.Module):
    """CNN for fusing multiple saliency heatmaps and predicting tile indices"""
    
    def __init__(self, num_heatmaps=5, num_tiles=144, dropout=0.3):
        super(HeatmapFusionCNN, self).__init__()
        
        # Initial fusion layer to weight heatmaps
        self.heatmap_weights = nn.Conv2d(num_heatmaps, 16, kernel_size=1, padding=0)
        
        # Feature extraction layers
        self.conv1 = nn.Conv2d(16, 32, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        
        # Adaptive pooling to handle variable input sizes
        self.adaptive_pool = nn.AdaptiveAvgPool2d((6, 3))
        
        # Fully connected layers for classification
        self.fc1 = nn.Linear(256 * 6 * 3, 512)
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(512, 256)
        self.dropout2 = nn.Dropout(dropout)
        self.fc3 = nn.Linear(256, num_tiles)
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        # Fuse heatmaps with learnable weights
        x = self.heatmap_weights(x)
        x = self.relu(x)
        
        # Feature extraction
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.pool2(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.pool3(x)
        
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.relu(x)
        
        # Adaptive pooling
        x = self.adaptive_pool(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Classification
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout1(x)
        
        x = self.fc2(x)
        x = self.relu(x)
        x = self.dropout2(x)
        
        x = self.fc3(x)
        
        return x


def tile_coords_to_index(x, y, tiles_x=TILES_X):
    """Convert tile coordinates to linear index"""
    return y * tiles_x + x


def tile_index_to_coords(idx, tiles_x=TILES_X):
    """Convert linear index to tile coordinates"""
    y = idx // tiles_x
    x = idx % tiles_x
    return x, y


def tile_distance(pred_idx, true_idx, tiles_x=TILES_X, tiles_y=TILES_Y):
    """Calculate Euclidean distance between predicted and true tile"""
    px, py = tile_index_to_coords(pred_idx, tiles_x)
    tx, ty = tile_index_to_coords(true_idx, tiles_x)
    return np.sqrt((px - tx)**2 + (py - ty)**2)


def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for heatmaps, tile_indices in tqdm(dataloader, desc="Training", leave=False):
        heatmaps = heatmaps.to(device)
        tile_indices = tile_indices.to(device)
        
        optimizer.zero_grad()
        outputs = model(heatmaps)
        loss = criterion(outputs, tile_indices)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += tile_indices.size(0)
        correct += predicted.eq(tile_indices).sum().item()
    
    return total_loss / len(dataloader), 100. * correct / total


def validate(model, dataloader, criterion, device):
    """Validate the model"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    tile_distances = []
    
    with torch.no_grad():
        for heatmaps, tile_indices in tqdm(dataloader, desc="Validating", leave=False):
            heatmaps = heatmaps.to(device)
            tile_indices = tile_indices.to(device)
            
            outputs = model(heatmaps)
            loss = criterion(outputs, tile_indices)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += tile_indices.size(0)
            correct += predicted.eq(tile_indices).sum().item()
            
            # Calculate tile distances
            for pred, true in zip(predicted.cpu().numpy(), tile_indices.cpu().numpy()):
                tile_distances.append(tile_distance(pred, true))
    
    avg_distance = np.mean(tile_distances)
    return total_loss / len(dataloader), 100. * correct / total, avg_distance


def generate_synthetic_data(num_samples=1000):
    """Generate synthetic heatmap data for demonstration"""
    print("Generating synthetic data...")
    
    heatmaps = []
    tile_indices = []
    
    for _ in range(num_samples):
        # Generate 5 heatmaps with different patterns
        sample_heatmaps = []
        
        # Create random focal points for each heatmap
        for h in range(NUM_HEATMAPS):
            heatmap = np.zeros((FRAME_HEIGHT, FRAME_WIDTH))
            
            # Random focal point
            cx = np.random.randint(0, FRAME_WIDTH)
            cy = np.random.randint(0, FRAME_HEIGHT)
            
            # Create Gaussian-like saliency
            y, x = np.ogrid[:FRAME_HEIGHT, :FRAME_WIDTH]
            distance = np.sqrt((x - cx)**2 + (y - cy)**2)
            heatmap = np.exp(-distance**2 / (2 * (FRAME_WIDTH/6)**2))
            
            sample_heatmaps.append(heatmap)
        
        # Stack heatmaps
        stacked = np.stack(sample_heatmaps, axis=0)
        heatmaps.append(stacked)
        
        # Find most salient tile (where max value is)
        fused = np.mean(sample_heatmaps, axis=0)
        max_y, max_x = np.unravel_index(fused.argmax(), fused.shape)
        tile_x = min(max_x // TILE_WIDTH, TILES_X - 1)
        tile_y = min(max_y // TILE_HEIGHT, TILES_Y - 1)
        tile_idx = tile_coords_to_index(tile_x, tile_y)
        
        tile_indices.append(tile_idx)
    
    return np.array(heatmaps), np.array(tile_indices)


# Main execution
if __name__ == "__main__":
    print(f"\n{'='*60}")
    print("360Â° Video Saliency-Based Tile Predictor")
    print(f"{'='*60}\n")
    
    print(f"Configuration:")
    print(f"  Frame size: {FRAME_WIDTH}x{FRAME_HEIGHT}")
    print(f"  Grid: {TILES_X}x{TILES_Y} = {NUM_TILES} tiles")
    print(f"  Tile size: {TILE_WIDTH}x{TILE_HEIGHT}")
    print(f"  Number of heatmaps: {NUM_HEATMAPS}")
    print(f"  Device: {device}\n")
    
    # Generate synthetic data
    heatmaps, tile_indices = generate_synthetic_data(num_samples=2000)
    
    # Create dataset
    dataset = SaliencyTileDataset(heatmaps, tile_indices)
    
    # Split into train/val/test
    train_size = int(0.7 * len(dataset))
    val_size = int(0.15 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = random_split(
        dataset, [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2)
    
    print(f"Dataset splits:")
    print(f"  Training: {len(train_dataset)} samples")
    print(f"  Validation: {len(val_dataset)} samples")
    print(f"  Test: {len(test_dataset)} samples\n")
    
    # Initialize model
    model = HeatmapFusionCNN(num_heatmaps=NUM_HEATMAPS, num_tiles=NUM_TILES).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)
    
    # Training
    num_epochs = 20
    best_val_acc = 0
    results = []
    
    print("Starting training...\n")
    start_time = time.time()
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, avg_tile_dist = validate(model, val_loader, criterion, device)
        
        scheduler.step(val_loss)
        
        epoch_time = time.time() - epoch_start
        
        results.append({
            'Epoch': epoch + 1,
            'Train Loss': f"{train_loss:.4f}",
            'Train Acc': f"{train_acc:.2f}%",
            'Val Loss': f"{val_loss:.4f}",
            'Val Acc': f"{val_acc:.2f}%",
            'Avg Tile Dist': f"{avg_tile_dist:.2f}",
            'Time': f"{epoch_time:.2f}s"
        })
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')
        
        if (epoch + 1) % 5 == 0 or epoch == 0:
            print(f"Epoch {epoch+1}/{num_epochs}:")
            print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
            print(f"  Avg Tile Distance: {avg_tile_dist:.2f} tiles")
            print(f"  Time: {epoch_time:.2f}s\n")
    
    total_time = time.time() - start_time
    
    # Final evaluation on test set
    model.load_state_dict(torch.load('best_model.pth'))
    test_loss, test_acc, test_tile_dist = validate(model, test_loader, criterion, device)
    
    # Results summary
    print(f"\n{'='*60}")
    print("Training Complete!")
    print(f"{'='*60}\n")
    
    print("Final Results:")
    print(f"  Best Validation Accuracy: {best_val_acc:.2f}%")
    print(f"  Test Accuracy: {test_acc:.2f}%")
    print(f"  Test Avg Tile Distance: {test_tile_dist:.2f} tiles")
    print(f"  Total Training Time: {total_time/60:.2f} minutes")
    print(f"  Avg Time per Epoch: {total_time/num_epochs:.2f}s\n")
    
    # Display results table
    print("\nTraining History:")
    print(tabulate(results, headers='keys', tablefmt='grid'))
    
    # Inference speed test
    print("\nInference Speed Test:")
    model.eval()
    test_input = torch.randn(1, NUM_HEATMAPS, FRAME_HEIGHT, FRAME_WIDTH).to(device)
    
    # Warmup
    with torch.no_grad():
        for _ in range(10):
            _ = model(test_input)
    
    # Measure
    torch.cuda.synchronize() if device == "cuda" else None
    inference_times = []
    with torch.no_grad():
        for _ in range(100):
            start = time.time()
            output = model(test_input)
            torch.cuda.synchronize() if device == "cuda" else None
            inference_times.append(time.time() - start)
    
    avg_inference = np.mean(inference_times) * 1000  # Convert to ms
    fps = 1000 / avg_inference
    
    # take 3-5 heatmaps of a single frame
    # tile each into 16*9 tiles and label each tile with single index going per column from left to right, per row from top to bottom
    # convert into 3d matrix (each depth is a heatmap)  
    # size of input matrix (16*pixWidthPerTile)*(9*pixHeightPerTile)*(3|4|5)
    # assume 4k consumer/customer market standard: 4k(3840x1920)
    # also take the head direction center and the corresponding tile's index for data per frame as input
    # we should be taking the 3d matrix of heatmaps and the tile index where the user is looking per frame as out data
    # predict the tile that the user will look at based on the heatmaps (weighted by each heatmap then fused) to find the tile with highest saliency
    # check the predicted tile index vs the actual inputted tile index for user view
    # loss/error function will be the difference in tile index (using the standard loss func) to optimize by minimizing the loss
    # do this per frame and get an accuracy as well as runtime calculation

    print(f"  Average inference time: {avg_inference:.2f}ms")
    print(f"  Throughput: {fps:.1f} FPS")
    print(f"  Predicted tile index: {output.argmax(1).item()}")
    
    x, y = tile_index_to_coords(output.argmax(1).item())
    print(f"  Predicted tile coordinates: ({x}, {y})\n")
    
    print("Model saved as 'best_model.pth'")
    print(f"{'='*60}")
