In [None]:
import time, os, torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader, random_split

from Scripts.finalCode.model import HeatmapFusionCNN, SaliencyTileDataset

# Generate synthetic video data - just 1 video for memory efficiency
num_videos = 1  # 1 video of 3 seconds
heatmaps_4d, tile_indices_1d = generate_synthetic_video_data(num_videos=num_videos)

print(f"{'='*60}")
print("Data Structure:")
print(f"{'='*60}")
print(f"4D Heatmaps Matrix: {heatmaps_4d.shape}")
print(f"  [videos, sampled_frames, heatmaps, height, width]")
print(f"  [1, 36, 9, 480, 960]")
print(f"  - 1 video")
print(f"  - 36 sampled frames (every 5th frame from 180 total @ 60 FPS)")
print(f"  - 9 heatmaps per frame (7 audio + 2 video)")
print(f"  - 480x960 resolution (25% of 4K)")
print(f"\n1D Tile Index Array: {tile_indices_1d.shape}")
print(f"  [videos, sampled_frames]")
print(f"  [1, 36]")
print(f"  - 1 user's viewing data")
print(f"  - 36 tile indices (sampled every 5th frame, range 0-143)")
print(f"  - Each index represents which tile user is looking at")
print(f"{'='*60}\n")

# Flatten data for training
heatmaps_flat, tile_indices_flat = flatten_video_data(heatmaps_4d, tile_indices_1d)

print(f"Flattened for Training:")
print(f"  Heatmaps: {heatmaps_flat.shape}")
print(f"  Tile indices: {tile_indices_flat.shape}\n")

# Create dataset
dataset = SaliencyTileDataset(heatmaps_flat, tile_indices_flat)

# Split
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=0)

print(f"Dataset splits:")
print(f"  Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}\n")

# Initialize model
model = HeatmapFusionCNN(num_heatmaps=NUM_HEATMAPS, num_tiles=NUM_TILES).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}\n")

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

# Training
num_epochs = 15
best_val_acc = 0
results = []

print("Training started...\n")
start_time = time.time()

for epoch in range(num_epochs):
    epoch_start = time.time()

    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, avg_tile_dist = validate(model, val_loader, criterion, device)

    scheduler.step(val_loss)
    epoch_time = time.time() - epoch_start

    results.append([
        epoch + 1,
        f"{train_loss:.4f}",
        f"{train_acc:.1f}%",
        f"{val_loss:.4f}",
        f"{val_acc:.1f}%",
        f"{avg_tile_dist:.2f}",
        f"{epoch_time:.1f}s"
    ])

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pth')

    print(f"Epoch {epoch+1}/{num_epochs}: "
            f"Train {train_acc:.1f}% | Val {val_acc:.1f}% | "
            f"Dist {avg_tile_dist:.2f} | {epoch_time:.1f}s")

    # Clear cache periodically
    if epoch % 3 == 0 and device == "cuda":
        torch.cuda.empty_cache()

total_time = time.time() - start_time

# Test evaluation
print("\nEvaluating on test set...")
model.load_state_dict(torch.load('best_model.pth'))
test_loss, test_acc, test_tile_dist = validate(model, test_loader, criterion, device)

print(f"\n{'='*60}")
print("Training Complete!")
print(f"{'='*60}\n")

print(f"Best Val Accuracy: {best_val_acc:.2f}%")
print(f"Test Accuracy: {test_acc:.2f}%")
print(f"Test Avg Tile Distance: {test_tile_dist:.2f} tiles")
print(f"Total Training Time: {total_time:.1f}s ({total_time/60:.1f} min)\n")

# Training history table
headers = ['Epoch', 'Train Loss', 'Train Acc', 'Val Loss', 'Val Acc', 'Tile Dist', 'Time']
print(tabulate(results, headers=headers, tablefmt='simple'))

# Inference speed test
print(f"\n{'='*60}")
print("Inference Speed Test")
print(f"{'='*60}\n")

model.eval()
test_input = torch.randn(1, NUM_HEATMAPS, FRAME_HEIGHT, FRAME_WIDTH).to(device)

# Warmup
with torch.no_grad():
    for _ in range(10):
        _ = model(test_input)

# Measure
if device == "cuda":
    torch.cuda.synchronize()

inference_times = []
with torch.no_grad():
    for _ in range(100):
        start = time.time()
        output = model(test_input)
        if device == "cuda":
            torch.cuda.synchronize()
        inference_times.append(time.time() - start)

avg_inf_ms = np.mean(inference_times) * 1000
fps = 1000 / avg_inf_ms

print(f"Average inference: {avg_inf_ms:.2f}ms")
print(f"Throughput: {fps:.1f} FPS")
print(f"Can process: {fps/FPS:.1f}x realtime")

pred_tile = output.argmax(1).item()
x, y = tile_index_to_coords(pred_tile)
print(f"Sample prediction: Tile {pred_tile} at ({x}, {y})")

print(f"\n{'='*60}")
print("Model saved as 'best_model.pth'")
print(f"{'='*60}\n")

# Sample video analysis - show all sampled frames
print(f"{'='*60}")
print("User Viewing Data - Video 0 (3 seconds, every 5th frame)")
print(f"{'='*60}")
print(f"Sample# | Frame# | Time(s) | Tile Index | Tile Grid (x,y)")
print(f"{'-'*60}")
for i in range(FRAMES_PER_VIDEO):
    tile_idx = tile_indices_1d[0, i]
    tx, ty = tile_index_to_coords(tile_idx)
    actual_frame = i * FRAME_SAMPLE_RATE
    time_sec = actual_frame / FPS
    if i % max(1, FRAMES_PER_VIDEO//12) == 0:  # Show ~12 samples
        print(f"{i:7d} | {actual_frame:6d} | {time_sec:6.2f}  | {tile_idx:10d} | ({tx:2d}, {ty:2d})")
print(f"{'='*60}\n")

# Memory cleanup
del model, train_loader, val_loader, test_loader
gc.collect()
if device == "cuda":
    torch.cuda.empty_cache()
    print(f"GPU memory freed: {torch.cuda.memory_allocated()/1e9:.2f} GB in use")