# SwellSight - Multi-Task Model Architecture

Demonstrates the multi-task model with three specialized prediction heads.

---

## 1. Setup

In [None]:
import sys
from pathlib import Path
import json
import warnings
warnings.filterwarnings('ignore')

sys.path.insert(0, str(Path.cwd()))

from src.swellsight.core.wave_analyzer import DINOv2WaveAnalyzer
from src.swellsight.models.heads import WaveHeightHead, DirectionHead, BreakingTypeHead
from src.swellsight.models.losses import MultiTaskLoss
from src.swellsight.utils.hardware import HardwareManager
from src.swellsight.utils.config import load_config
import torch
import numpy as np
import matplotlib.pyplot as plt

print("✅ Modules loaded")

In [None]:
config = load_config('config.json')
OUTPUT_DIR = Path(config['paths']['output_dir'])
MODEL_DIR = OUTPUT_DIR / 'multi_task_model'
MODEL_DIR.mkdir(parents=True, exist_ok=True)
print(f'Output: {MODEL_DIR}')

## 2. Hardware

In [None]:
hw_mgr = HardwareManager()
hw = hw_mgr.hardware_info
device = hw.device_type
print(f'Device: {device}')
print(f'Memory: {hw.memory_total_gb:.1f} GB')

## 3. Sub-task 8.1: Multi-Task Model with Three Prediction Heads

In [None]:
print('🧠 Sub-task 8.1: Multi-task model architecture')

analyzer = DINOv2WaveAnalyzer(
    backbone_model='dinov2_vitl14',
    freeze_backbone=True,
    device=device,
    enable_optimization=False
)

print('\n✅ Model initialized')
print(f'Input: 4 channels, {analyzer.input_resolution}')
print(f'Backbone: {analyzer.backbone_model}')
print(f'Feature dim: {analyzer.backbone.get_feature_dim()}')

bp = sum(p.numel() for p in analyzer.backbone.parameters())
hp = sum(p.numel() for p in analyzer.height_head.parameters())
dp = sum(p.numel() for p in analyzer.direction_head.parameters())
brp = sum(p.numel() for p in analyzer.breaking_head.parameters())
total = sum(p.numel() for p in analyzer.parameters())

print(f'\nParameters:')
print(f'  Backbone: {bp:,}')
print(f'  Height: {hp:,}')
print(f'  Direction: {dp:,}')
print(f'  Breaking: {brp:,}')
print(f'  Total: {total:,}')

trainable = sum(p.numel() for p in analyzer.parameters() if p.requires_grad)
print(f'\nTrainable: {trainable:,} ({100*trainable/total:.1f}%)')
print(f'Frozen: {total-trainable:,} ({100*(total-trainable)/total:.1f}%)')

In [None]:
print('\n🔍 Prediction Heads:')

print(f'\n📏 Height Head:')
print(f'  Input: {analyzer.height_head.input_dim}-dim')
print(f'  Hidden: {analyzer.height_head.hidden_dim}-dim')
print(f'  Output: 1 (height in meters, 0.5-8.0m)')
print(f'  Components: regressor, confidence, dominance, wave_count')

print(f'\n🧭 Direction Head:')
print(f'  Input: {analyzer.direction_head.input_dim}-dim')
print(f'  Hidden: {analyzer.direction_head.hidden_dim}-dim')
print(f'  Classes: {analyzer.direction_head.num_classes}')
print(f'  Names: {analyzer.direction_head.class_names}')
print(f'  Components: classifier, mixed_condition, strength, wave_train')

print(f'\n💥 Breaking Head:')
print(f'  Input: {analyzer.breaking_head.input_dim}-dim')
print(f'  Hidden: {analyzer.breaking_head.hidden_dim}-dim')
print(f'  Classes: {analyzer.breaking_head.num_classes}')
print(f'  Names: {analyzer.breaking_head.class_names}')
print(f'  Components: classifier, intensity, mixed, clarity, no_breaking')

## 4. Sub-task 8.2: Task-Specific Projection Layers

In [None]:
print('🏗️ Sub-task 8.2: Task-specific projections')

test_input = torch.randn(2, 4, 518, 518).to(device)
print(f'Test input: {test_input.shape}')

analyzer.eval()
with torch.no_grad():
    # Backbone features
    features = analyzer.backbone(test_input)
    print(f'\nBackbone output: {features.shape}')
    
    # Height projections
    h_pred = analyzer.height_head(features)
    print(f'\nHeight Head:')
    print(f'  height_meters: {h_pred["height_meters"].shape}')
    print(f'  confidence: {h_pred["height_confidence"].shape}')
    print(f'  dominance: {h_pred["dominance_score"].shape}')
    
    # Direction projections
    d_pred = analyzer.direction_head(features)
    print(f'\nDirection Head:')
    print(f'  logits: {d_pred["direction_logits"].shape}')
    print(f'  probabilities: {d_pred["direction_probabilities"].shape}')
    print(f'  predicted: {d_pred["direction_predicted"].shape}')
    
    # Breaking projections
    b_pred = analyzer.breaking_head(features)
    print(f'\nBreaking Head:')
    print(f'  logits: {b_pred["breaking_logits"].shape}')
    print(f'  probabilities: {b_pred["breaking_probabilities"].shape}')
    print(f'  predicted: {b_pred["breaking_predicted"].shape}')
    
    # Complete forward pass
    all_pred = analyzer.forward(test_input)
    print(f'\nComplete forward pass: {len(all_pred)} outputs')
    print(f'Keys: {list(all_pred.keys())[:8]}...')

print('\n✅ Task-specific projections demonstrated')

## 5. Sub-task 8.3: Weighted Loss Computation

In [None]:
print('⚖️ Sub-task 8.3: Weighted loss computation')

# Initialize loss function
loss_fn = MultiTaskLoss(
    height_weight=1.0,
    direction_weight=1.0,
    breaking_weight=1.0,
    adaptive_weighting=True
)

print(f'\nMultiTaskLoss initialized:')
print(f'  Height weight: {loss_fn.height_weight}')
print(f'  Direction weight: {loss_fn.direction_weight}')
print(f'  Breaking weight: {loss_fn.breaking_weight}')
print(f'  Adaptive weighting: {loss_fn.adaptive_weighting}')

if loss_fn.adaptive_weighting:
    print(f'  Log vars (learnable): {loss_fn.log_vars.shape}')

# Create mock predictions and targets
with torch.no_grad():
    predictions = analyzer.forward(test_input)
    
targets = {
    'height_meters': torch.rand(2, 1).to(device) * 7.5 + 0.5,
    'direction_labels': torch.randint(0, 3, (2,)).to(device),
    'breaking_labels': torch.randint(0, 4, (2,)).to(device)
}

print(f'\nTarget shapes:')
print(f'  height_meters: {targets["height_meters"].shape}')
print(f'  direction_labels: {targets["direction_labels"].shape}')
print(f'  breaking_labels: {targets["breaking_labels"].shape}')

# Compute losses
losses = loss_fn(predictions, targets)

print(f'\nComputed losses:')
print(f'  Height loss: {losses["height_loss"].item():.4f}')
print(f'  Direction loss: {losses["direction_loss"].item():.4f}')
print(f'  Breaking loss: {losses["breaking_loss"].item():.4f}')
print(f'  Total loss: {losses["total_loss"].item():.4f}')

if loss_fn.adaptive_weighting:
    print(f'\nAdaptive weights:')
    print(f'  Height: {losses["height_weight"].item():.4f}')
    print(f'  Direction: {losses["direction_weight"].item():.4f}')
    print(f'  Breaking: {losses["breaking_weight"].item():.4f}')

print('\n✅ Weighted loss computation demonstrated')

## 6. Architecture Visualization

In [None]:
print('📊 Creating architecture visualization')

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Parameter distribution
ax = axes[0, 0]
components = ['Backbone', 'Height\nHead', 'Direction\nHead', 'Breaking\nHead']
params = [bp, hp, dp, brp]
colors = ['#3498db', '#e74c3c', '#2ecc71', '#f39c12']
ax.bar(components, params, color=colors, alpha=0.7, edgecolor='black')
ax.set_ylabel('Parameters')
ax.set_title('Parameter Distribution')
ax.grid(axis='y', alpha=0.3)
for i, v in enumerate(params):
    ax.text(i, v, f'{v:,}', ha='center', va='bottom', fontsize=9)

# Plot 2: Trainable vs Frozen
ax = axes[0, 1]
sizes = [trainable, total-trainable]
labels = [f'Trainable\n{trainable:,}', f'Frozen\n{total-trainable:,}']
ax.pie(sizes, labels=labels, colors=['#2ecc71', '#95a5a6'], autopct='%1.1f%%', startangle=90)
ax.set_title('Trainable vs Frozen Parameters')

# Plot 3: Loss components
ax = axes[1, 0]
loss_names = ['Height', 'Direction', 'Breaking']
loss_values = [
    losses['height_loss'].item(),
    losses['direction_loss'].item(),
    losses['breaking_loss'].item()
]
ax.bar(loss_names, loss_values, color=['#e74c3c', '#2ecc71', '#f39c12'], alpha=0.7, edgecolor='black')
ax.set_ylabel('Loss Value')
ax.set_title('Multi-Task Loss Components')
ax.grid(axis='y', alpha=0.3)

# Plot 4: Architecture flow
ax = axes[1, 1]
ax.text(0.5, 0.9, 'Multi-Task Architecture', ha='center', fontsize=14, weight='bold')
ax.text(0.5, 0.75, 'Input: RGB + Depth (4 channels)', ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='lightblue'))
ax.text(0.5, 0.6, '↓', ha='center', fontsize=16)
ax.text(0.5, 0.5, 'DINOv2 Backbone (Frozen)', ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='lightgray'))
ax.text(0.5, 0.35, '↓', ha='center', fontsize=16)
ax.text(0.5, 0.25, '1024-dim Features', ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='lightyellow'))
ax.text(0.15, 0.1, 'Height\nHead', ha='center', fontsize=9, bbox=dict(boxstyle='round', facecolor='#e74c3c', alpha=0.5))
ax.text(0.5, 0.1, 'Direction\nHead', ha='center', fontsize=9, bbox=dict(boxstyle='round', facecolor='#2ecc71', alpha=0.5))
ax.text(0.85, 0.1, 'Breaking\nHead', ha='center', fontsize=9, bbox=dict(boxstyle='round', facecolor='#f39c12', alpha=0.5))
ax.axis('off')

plt.tight_layout()
plt.savefig(MODEL_DIR / 'architecture_visualization.png', dpi=150, bbox_inches='tight')
plt.show()

print(f'✅ Visualization saved to {MODEL_DIR / "architecture_visualization.png"}')

## 7. Summary and Metadata

In [None]:
print('💾 Saving results and metadata')

metadata = {
    'notebook': '10_Multi_Task_Model_Architecture',
    'model': {
        'backbone': analyzer.backbone_model,
        'feature_dim': analyzer.backbone.get_feature_dim(),
        'input_channels': 4,
        'input_resolution': analyzer.input_resolution,
        'frozen_backbone': analyzer.freeze_backbone
    },
    'heads': {
        'height': {'params': hp, 'output': '1 (regression)'},
        'direction': {'params': dp, 'classes': 3, 'names': analyzer.direction_head.class_names},
        'breaking': {'params': brp, 'classes': 4, 'names': analyzer.breaking_head.class_names}
    },
    'parameters': {
        'total': total,
        'trainable': trainable,
        'frozen': total - trainable,
        'trainable_pct': 100 * trainable / total
    },
    'loss': {
        'type': 'MultiTaskLoss',
        'adaptive_weighting': loss_fn.adaptive_weighting,
        'height_loss': losses['height_loss'].item(),
        'direction_loss': losses['direction_loss'].item(),
        'breaking_loss': losses['breaking_loss'].item(),
        'total_loss': losses['total_loss'].item()
    }
}

with open(MODEL_DIR / 'model_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

print(f'✅ Metadata saved')

print(f'\n{"="*60}')
print('MULTI-TASK MODEL ARCHITECTURE SUMMARY')
print(f'{"="*60}')
print(f'Model: {analyzer.backbone_model}')
print(f'Total Parameters: {total:,}')
print(f'Trainable: {trainable:,} ({100*trainable/total:.1f}%)')
print(f'\nPrediction Heads:')
print(f'  ✅ Height: {hp:,} params')
print(f'  ✅ Direction: {dp:,} params')
print(f'  ✅ Breaking: {brp:,} params')
print(f'\nLoss Function:')
print(f'  ✅ Multi-task with adaptive weighting')
print(f'  ✅ Total loss: {losses["total_loss"].item():.4f}')
print(f'\n✅ All sub-tasks completed:')
print(f'  ✅ 8.1: Multi-task model demonstrated')
print(f'  ✅ 8.2: Task-specific projections shown')
print(f'  ✅ 8.3: Weighted loss computation verified')
print(f'{"="*60}')