# CNN Fusion Model

In [None]:
    # take 3-5 heatmaps of a single frame
    # tile each into 16*9 tiles and label each tile with single index going per column from left to right, per row from top to bottom
    # convert into 3d matrix (each depth is a heatmap)
    # size of input matrix (16*pixWidthPerTile)*(9*pixHeightPerTile)*(3|4|5)
    # assume 4k consumer/customer market standard: 4k(3840x1920)
    # also take the head direction center and the corresponding tile's index for data per frame as input
    # we should be taking the 3d matrix of heatmaps and the tile index where the user is looking per frame as out data
    # predict the tile that the user will look at based on the heatmaps (weighted by each heatmap then fused) to find the tile with highest saliency
    # check the predicted tile index vs the actual inputted tile index for user view
    # loss/error function will be the difference in tile index (using the standard loss func) to optimize by minimizing the loss
    # do this per frame and get an accuracy as well as runtime calculation

In [None]:
!pip install torch torchvision tqdm tabulate --quiet

import time, os, torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
from tabulate import tabulate
from tqdm import tqdm
import gc

# Clear GPU memory
torch.cuda.empty_cache() if torch.cuda.is_available() else None

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\n")

# Configuration - 25% of 4K resolution
FRAME_WIDTH = 960    # 25% of 3840
FRAME_HEIGHT = 480   # 25% of 1920
TILES_X = 16
TILES_Y = 9
NUM_TILES = TILES_X * TILES_Y  # 144 tiles
TILE_WIDTH = FRAME_WIDTH // TILES_X   # 60 pixels
TILE_HEIGHT = FRAME_HEIGHT // TILES_Y  # ~53 pixels
NUM_HEATMAPS = 9  # 7 audio + 2 video

# Video configuration
FPS = 60  # Frames per second (full framerate)
VIDEO_DURATION = 3  # seconds
FRAME_SAMPLE_RATE = 5  # Take every 5th frame
TOTAL_FRAMES = FPS * VIDEO_DURATION  # 180 frames total for 3 seconds
FRAMES_PER_VIDEO = TOTAL_FRAMES // FRAME_SAMPLE_RATE  # 36 sampled frames

class SaliencyTileDataset(Dataset):
    """Memory-efficient dataset for saliency heatmaps"""

    def __init__(self, heatmaps, tile_indices):
        """
        Args:
            heatmaps: numpy array of shape (N_frames, NUM_HEATMAPS, H, W)
            tile_indices: numpy array of shape (N_frames,) with tile indices [0-143]
        """
        # Convert to float32 for memory efficiency
        self.heatmaps = heatmaps.astype(np.float32)
        self.tile_indices = tile_indices.astype(np.int64)

    def __len__(self):
        return len(self.tile_indices)

    def __getitem__(self, idx):
        heatmap = torch.from_numpy(self.heatmaps[idx])
        tile_idx = torch.tensor(self.tile_indices[idx], dtype=torch.long)
        return heatmap, tile_idx


class HeatmapFusionCNN(nn.Module):
    """Lightweight CNN for fusing saliency heatmaps"""

    def __init__(self, num_heatmaps=9, num_tiles=144, dropout=0.3):
        super(HeatmapFusionCNN, self).__init__()

        # Heatmap fusion with 1x1 conv
        self.fusion = nn.Conv2d(num_heatmaps, 8, kernel_size=1)

        # Lightweight feature extraction
        self.conv1 = nn.Conv2d(8, 32, kernel_size=5, stride=4, padding=2)
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)

        # Adaptive pooling
        self.pool = nn.AdaptiveAvgPool2d((4, 4))

        # Classifier
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(256, num_tiles)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # Fuse heatmaps
        x = self.fusion(x)
        x = self.relu(x)

        # Feature extraction
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)

        # Pool and classify
        x = self.pool(x)
        x = x.view(x.size(0), -1)

        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)

        return x


def tile_coords_to_index(x, y, tiles_x=TILES_X):
    """Convert tile coordinates to linear index"""
    return y * tiles_x + x


def tile_index_to_coords(idx, tiles_x=TILES_X):
    """Convert linear index to tile coordinates"""
    y = idx // tiles_x
    x = idx % tiles_x
    return x, y


def tile_distance(pred_idx, true_idx, tiles_x=TILES_X):
    """Calculate tile distance"""
    px, py = tile_index_to_coords(pred_idx, tiles_x)
    tx, ty = tile_index_to_coords(true_idx, tiles_x)
    return np.sqrt((px - tx)**2 + (py - ty)**2)


def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for heatmaps, tile_indices in dataloader:
        heatmaps = heatmaps.to(device)
        tile_indices = tile_indices.to(device)

        optimizer.zero_grad()
        outputs = model(heatmaps)
        loss = criterion(outputs, tile_indices)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += tile_indices.size(0)
        correct += predicted.eq(tile_indices).sum().item()

    return total_loss / len(dataloader), 100. * correct / total


def validate(model, dataloader, criterion, device):
    """Validate the model"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    tile_distances = []

    with torch.no_grad():
        for heatmaps, tile_indices in dataloader:
            heatmaps = heatmaps.to(device)
            tile_indices = tile_indices.to(device)

            outputs = model(heatmaps)
            loss = criterion(outputs, tile_indices)

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += tile_indices.size(0)
            correct += predicted.eq(tile_indices).sum().item()

            for pred, true in zip(predicted.cpu().numpy(), tile_indices.cpu().numpy()):
                tile_distances.append(tile_distance(pred, true))

    avg_distance = np.mean(tile_distances)
    return total_loss / len(dataloader), 100. * correct / total, avg_distance


def generate_synthetic_video_data(num_videos=1):
    """
    Generate synthetic video data at 60 FPS, sampling every 5th frame

    Returns:
        heatmaps_4d: shape (num_videos, FRAMES_PER_VIDEO, NUM_HEATMAPS, H, W)
        tile_indices_1d: shape (num_videos, FRAMES_PER_VIDEO)
    """
    print(f"\n{'='*60}")
    print(f"Generating Synthetic Video Data")
    print(f"{'='*60}\n")
    print(f"Videos: {num_videos}")
    print(f"Duration: {VIDEO_DURATION}s @ {FPS} FPS = {TOTAL_FRAMES} total frames")
    print(f"Sampling: Every {FRAME_SAMPLE_RATE}th frame = {FRAMES_PER_VIDEO} sampled frames")
    print(f"Heatmap resolution: {FRAME_WIDTH}x{FRAME_HEIGHT} (25% of 4K)")
    print(f"Total sampled frames: {num_videos * FRAMES_PER_VIDEO}\n")

    # Initialize 4D matrix for heatmaps: (videos, sampled_frames, heatmaps, height, width)
    heatmaps_4d = np.zeros((num_videos, FRAMES_PER_VIDEO, NUM_HEATMAPS,
                            FRAME_HEIGHT, FRAME_WIDTH), dtype=np.float32)

    # Initialize 1D matrix for tile indices: (videos, sampled_frames)
    tile_indices_1d = np.zeros((num_videos, FRAMES_PER_VIDEO), dtype=np.int64)

    for video_idx in range(num_videos):
        print(f"Generating video {video_idx + 1}/{num_videos}...")

        # Create a smooth trajectory for the salient region across ALL frames
        start_x = np.random.randint(TILE_WIDTH * 2, FRAME_WIDTH - TILE_WIDTH * 2)
        start_y = np.random.randint(TILE_HEIGHT * 2, FRAME_HEIGHT - TILE_HEIGHT * 2)

        end_x = np.random.randint(TILE_WIDTH * 2, FRAME_WIDTH - TILE_WIDTH * 2)
        end_y = np.random.randint(TILE_HEIGHT * 2, FRAME_HEIGHT - TILE_HEIGHT * 2)

        # Create smooth trajectory for ALL 180 frames (60 FPS * 3 sec)
        trajectory_x = np.linspace(start_x, end_x, TOTAL_FRAMES)
        trajectory_y = np.linspace(start_y, end_y, TOTAL_FRAMES)

        # Add jitter
        jitter_x = np.random.randn(TOTAL_FRAMES) * (TILE_WIDTH / 4)
        jitter_y = np.random.randn(TOTAL_FRAMES) * (TILE_HEIGHT / 4)

        trajectory_x = np.clip(trajectory_x + jitter_x, 0, FRAME_WIDTH - 1)
        trajectory_y = np.clip(trajectory_y + jitter_y, 0, FRAME_HEIGHT - 1)

        # Sample every 5th frame
        sampled_frame_idx = 0
        for full_frame_idx in range(0, TOTAL_FRAMES, FRAME_SAMPLE_RATE):
            # Current center of attention at this frame
            cx = int(trajectory_x[full_frame_idx])
            cy = int(trajectory_y[full_frame_idx])

            # Generate heatmaps for this sampled frame
            fused_saliency = np.zeros((FRAME_HEIGHT, FRAME_WIDTH), dtype=np.float32)

            for heatmap_idx in range(NUM_HEATMAPS):
                # Random sigma for variety in saliency spread
                sigma = FRAME_WIDTH / (6 + np.random.rand() * 4)  # Varies between /6 and /10

                # Add small random offset for each heatmap
                offset_x = np.random.randint(-TILE_WIDTH//2, TILE_WIDTH//2)
                offset_y = np.random.randint(-TILE_HEIGHT//2, TILE_HEIGHT//2)

                center_x = np.clip(cx + offset_x, 0, FRAME_WIDTH - 1)
                center_y = np.clip(cy + offset_y, 0, FRAME_HEIGHT - 1)

                # Create Gaussian saliency map
                y, x = np.ogrid[:FRAME_HEIGHT, :FRAME_WIDTH]
                heatmap = np.exp(-((x - center_x)**2 + (y - center_y)**2) / (2 * sigma**2))

                # Store heatmap
                heatmaps_4d[video_idx, sampled_frame_idx, heatmap_idx] = heatmap

                # Accumulate saliency (equal weight, let model learn optimal fusion)
                fused_saliency += heatmap

            # Find the tile with maximum saliency
            max_y, max_x = np.unravel_index(fused_saliency.argmax(), fused_saliency.shape)
            tile_x = min(max_x // TILE_WIDTH, TILES_X - 1)
            tile_y = min(max_y // TILE_HEIGHT, TILES_Y - 1)
            tile_idx = tile_coords_to_index(tile_x, tile_y)

            # Store tile index for this sampled frame
            tile_indices_1d[video_idx, sampled_frame_idx] = tile_idx

            sampled_frame_idx += 1

    print(f"\nData generation complete!")
    print(f"  Heatmaps 4D shape: {heatmaps_4d.shape}")
    print(f"  Tile indices shape: {tile_indices_1d.shape}")
    print(f"  Memory usage: {heatmaps_4d.nbytes / 1e6:.1f} MB\n")

    return heatmaps_4d, tile_indices_1d


def flatten_video_data(heatmaps_4d, tile_indices_1d):
    """
    Flatten video data from (videos, frames, ...) to (total_frames, ...)

    Args:
        heatmaps_4d: shape (num_videos, frames_per_video, num_heatmaps, H, W)
        tile_indices_1d: shape (num_videos, frames_per_video)

    Returns:
        heatmaps: shape (total_frames, num_heatmaps, H, W)
        tile_indices: shape (total_frames,)
    """
    num_videos, frames_per_video = tile_indices_1d.shape
    total_frames = num_videos * frames_per_video

    # Reshape heatmaps: (videos, frames, heatmaps, H, W) -> (total_frames, heatmaps, H, W)
    heatmaps = heatmaps_4d.reshape(total_frames, NUM_HEATMAPS, FRAME_HEIGHT, FRAME_WIDTH)

    # Flatten tile indices: (videos, frames) -> (total_frames,)
    tile_indices = tile_indices_1d.reshape(total_frames)

    return heatmaps, tile_indices


# Main execution
if __name__ == "__main__":
    print(f"\n{'='*60}")
    print("360° Video Saliency-Based Tile Predictor")
    print(f"{'='*60}\n")

    print(f"Configuration:")
    print(f"  Frame size: {FRAME_WIDTH}x{FRAME_HEIGHT} (25% of 4K)")
    print(f"  Grid: {TILES_X}x{TILES_Y} = {NUM_TILES} tiles")
    print(f"  Tile size: {TILE_WIDTH}x{TILE_HEIGHT}")
    print(f"  Heatmaps per frame: {NUM_HEATMAPS}")
    print(f"  Video: {VIDEO_DURATION}s @ {FPS}fps = {TOTAL_FRAMES} total frames")
    print(f"  Sampling: Every {FRAME_SAMPLE_RATE}th frame = {FRAMES_PER_VIDEO} sampled frames\n")

    # Generate synthetic video data - just 1 video for memory efficiency
    num_videos = 1  # 1 video of 3 seconds
    heatmaps_4d, tile_indices_1d = generate_synthetic_video_data(num_videos=num_videos)

    print(f"{'='*60}")
    print("Data Structure:")
    print(f"{'='*60}")
    print(f"4D Heatmaps Matrix: {heatmaps_4d.shape}")
    print(f"  [videos, sampled_frames, heatmaps, height, width]")
    print(f"  [1, 36, 9, 480, 960]")
    print(f"  - 1 video")
    print(f"  - 36 sampled frames (every 5th frame from 180 total @ 60 FPS)")
    print(f"  - 9 heatmaps per frame (7 audio + 2 video)")
    print(f"  - 480x960 resolution (25% of 4K)")
    print(f"\n1D Tile Index Array: {tile_indices_1d.shape}")
    print(f"  [videos, sampled_frames]")
    print(f"  [1, 36]")
    print(f"  - 1 user's viewing data")
    print(f"  - 36 tile indices (sampled every 5th frame, range 0-143)")
    print(f"  - Each index represents which tile user is looking at")
    print(f"{'='*60}\n")

    # Flatten data for training
    heatmaps_flat, tile_indices_flat = flatten_video_data(heatmaps_4d, tile_indices_1d)

    print(f"Flattened for Training:")
    print(f"  Heatmaps: {heatmaps_flat.shape}")
    print(f"  Tile indices: {tile_indices_flat.shape}\n")

    # Create dataset
    dataset = SaliencyTileDataset(heatmaps_flat, tile_indices_flat)

    # Split
    train_size = int(0.7 * len(dataset))
    val_size = int(0.15 * len(dataset))
    test_size = len(dataset) - train_size - val_size

    train_dataset, val_dataset, test_dataset = random_split(
        dataset, [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )

    # Dataloaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=0)

    print(f"Dataset splits:")
    print(f"  Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}\n")

    # Initialize model
    model = HeatmapFusionCNN(num_heatmaps=NUM_HEATMAPS, num_tiles=NUM_TILES).to(device)

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {total_params:,}\n")

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

    # Training
    num_epochs = 15
    best_val_acc = 0
    results = []

    print("Training started...\n")
    start_time = time.time()

    for epoch in range(num_epochs):
        epoch_start = time.time()

        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, avg_tile_dist = validate(model, val_loader, criterion, device)

        scheduler.step(val_loss)
        epoch_time = time.time() - epoch_start

        results.append([
            epoch + 1,
            f"{train_loss:.4f}",
            f"{train_acc:.1f}%",
            f"{val_loss:.4f}",
            f"{val_acc:.1f}%",
            f"{avg_tile_dist:.2f}",
            f"{epoch_time:.1f}s"
        ])

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')

        print(f"Epoch {epoch+1}/{num_epochs}: "
              f"Train {train_acc:.1f}% | Val {val_acc:.1f}% | "
              f"Dist {avg_tile_dist:.2f} | {epoch_time:.1f}s")

        # Clear cache periodically
        if epoch % 3 == 0 and device == "cuda":
            torch.cuda.empty_cache()

    total_time = time.time() - start_time

    # Test evaluation
    print("\nEvaluating on test set...")
    model.load_state_dict(torch.load('best_model.pth'))
    test_loss, test_acc, test_tile_dist = validate(model, test_loader, criterion, device)

    print(f"\n{'='*60}")
    print("Training Complete!")
    print(f"{'='*60}\n")

    print(f"Best Val Accuracy: {best_val_acc:.2f}%")
    print(f"Test Accuracy: {test_acc:.2f}%")
    print(f"Test Avg Tile Distance: {test_tile_dist:.2f} tiles")
    print(f"Total Training Time: {total_time:.1f}s ({total_time/60:.1f} min)\n")

    # Training history table
    headers = ['Epoch', 'Train Loss', 'Train Acc', 'Val Loss', 'Val Acc', 'Tile Dist', 'Time']
    print(tabulate(results, headers=headers, tablefmt='simple'))

    # Inference speed test
    print(f"\n{'='*60}")
    print("Inference Speed Test")
    print(f"{'='*60}\n")

    model.eval()
    test_input = torch.randn(1, NUM_HEATMAPS, FRAME_HEIGHT, FRAME_WIDTH).to(device)

    # Warmup
    with torch.no_grad():
        for _ in range(10):
            _ = model(test_input)

    # Measure
    if device == "cuda":
        torch.cuda.synchronize()

    inference_times = []
    with torch.no_grad():
        for _ in range(100):
            start = time.time()
            output = model(test_input)
            if device == "cuda":
                torch.cuda.synchronize()
            inference_times.append(time.time() - start)

    avg_inf_ms = np.mean(inference_times) * 1000
    fps = 1000 / avg_inf_ms

    print(f"Average inference: {avg_inf_ms:.2f}ms")
    print(f"Throughput: {fps:.1f} FPS")
    print(f"Can process: {fps/FPS:.1f}x realtime")

    pred_tile = output.argmax(1).item()
    x, y = tile_index_to_coords(pred_tile)
    print(f"Sample prediction: Tile {pred_tile} at ({x}, {y})")

    print(f"\n{'='*60}")
    print("Model saved as 'best_model.pth'")
    print(f"{'='*60}\n")

    # Sample video analysis - show all sampled frames
    print(f"{'='*60}")
    print("User Viewing Data - Video 0 (3 seconds, every 5th frame)")
    print(f"{'='*60}")
    print(f"Sample# | Frame# | Time(s) | Tile Index | Tile Grid (x,y)")
    print(f"{'-'*60}")
    for i in range(FRAMES_PER_VIDEO):
        tile_idx = tile_indices_1d[0, i]
        tx, ty = tile_index_to_coords(tile_idx)
        actual_frame = i * FRAME_SAMPLE_RATE
        time_sec = actual_frame / FPS
        if i % max(1, FRAMES_PER_VIDEO//12) == 0:  # Show ~12 samples
            print(f"{i:7d} | {actual_frame:6d} | {time_sec:6.2f}  | {tile_idx:10d} | ({tx:2d}, {ty:2d})")
    print(f"{'='*60}\n")

    # Memory cleanup
    del model, train_loader, val_loader, test_loader
    gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()
        print(f"GPU memory freed: {torch.cuda.memory_allocated()/1e9:.2f} GB in use")

Using device: cpu

360° Video Saliency-Based Tile Predictor

Configuration:
  Frame size: 960x480 (25% of 4K)
  Grid: 16x9 = 144 tiles
  Tile size: 60x53
  Heatmaps per frame: 9
  Video: 3s @ 60fps = 180 total frames
  Sampling: Every 5th frame = 36 sampled frames


Generating Synthetic Video Data

Videos: 1
Duration: 3s @ 60 FPS = 180 total frames
Sampling: Every 5th frame = 36 sampled frames
Heatmap resolution: 960x480 (25% of 4K)
Total sampled frames: 36

Generating video 1/1...

Data generation complete!
  Heatmaps 4D shape: (1, 36, 9, 480, 960)
  Tile indices shape: (1, 36)
  Memory usage: 597.2 MB

Data Structure:
4D Heatmaps Matrix: (1, 36, 9, 480, 960)
  [videos, sampled_frames, heatmaps, height, width]
  [1, 36, 9, 480, 960]
  - 1 video
  - 36 sampled frames (every 5th frame from 180 total @ 60 FPS)
  - 9 heatmaps per frame (7 audio + 2 video)
  - 480x960 resolution (25% of 4K)

1D Tile Index Array: (1, 36)
  [videos, sampled_frames]
  [1, 36]
  - 1 user's viewing data
  - 36 