In [None]:
# ============================================================
# LASER TRIANGULATION POSE ESTIMATION - TRAINING (Google Colab)
# ============================================================

# Cell 1: Setup and Installation
!pip install h5py pyyaml scipy transforms3d tqdm

# Cell 2: Mount Google Drive (if dataset stored there)
from google.colab import drive
drive.mount('/content/drive')

# Cell 3: Clone Repository or Upload Files
# Option A: Clone from GitHub
# !git clone https://github.com/YOUR_USERNAME/laser_pose_estimation.git
# %cd laser_pose_estimation

# Option B: Upload files manually
# Use Colab's file upload or copy from Drive

# Cell 4: Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
import yaml
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt

# Cell 5: Dataset Class
class PoseDataset(Dataset):
    """Dataset for pose estimation"""
    
    def __init__(self, h5_path, split='train', splits_path=None):
        self.h5_path = h5_path
        
        # Load splits
        if splits_path and Path(splits_path).exists():
            splits = np.load(splits_path)
            self.indices = splits[split]
        else:
            # Use all data
            with h5py.File(h5_path, 'r') as f:
                self.indices = np.arange(len(f['point_clouds']))
        
        print(f"{split} set: {len(self.indices)} samples")
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        
        with h5py.File(self.h5_path, 'r') as f:
            pc = f['point_clouds'][real_idx]  # (2048, 3)
            pose = f['poses'][real_idx]  # (7,)
            label = f['labels'][real_idx]
        
        return {
            'point_cloud': torch.FloatTensor(pc),
            'pose': torch.FloatTensor(pose),
            'label': torch.LongTensor([label])
        }

# Cell 6: Load Model (paste pointnet2_model.py content here or import)
# ... (PointNet2PoseEstimation and PoseLoss classes)

# Cell 7: Training Configuration
config = {
    'data': {
        'h5_path': 'data/synthetic/dataset.h5',
        'splits_path': 'data/synthetic/splits.npz',
        'batch_size': 32,
        'num_workers': 2
    },
    'model': {
        'input_channels': 3
    },
    'training': {
        'lr': 0.001,
        'weight_decay': 0.0001,
        'epochs': 200,
        'early_stopping': 20,
        'pos_loss_weight': 1.0,
        'rot_loss_weight': 0.1
    },
    'paths': {
        'checkpoint_dir': 'checkpoints',
        'log_dir': 'logs'
    }
}

# Create directories
Path(config['paths']['checkpoint_dir']).mkdir(exist_ok=True)
Path(config['paths']['log_dir']).mkdir(exist_ok=True)

# Cell 8: Initialize Data Loaders
train_dataset = PoseDataset(
    config['data']['h5_path'],
    split='train',
    splits_path=config['data']['splits_path']
)

val_dataset = PoseDataset(
    config['data']['h5_path'],
    split='val',
    splits_path=config['data']['splits_path']
)

train_loader = DataLoader(
    train_dataset,
    batch_size=config['data']['batch_size'],
    shuffle=True,
    num_workers=config['data']['num_workers']
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['data']['batch_size'],
    shuffle=False,
    num_workers=config['data']['num_workers']
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

# Cell 9: Initialize Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = PointNet2PoseEstimation().to(device)
criterion = PoseLoss(
    pos_weight=config['training']['pos_loss_weight'],
    rot_weight=config['training']['rot_loss_weight']
)
optimizer = optim.AdamW(
    model.parameters(),
    lr=config['training']['lr'],
    weight_decay=config['training']['weight_decay']
)

# Cosine annealing scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config['training']['epochs']
)

# Cell 10: Training Loop
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    total_pos_loss = 0
    total_rot_loss = 0
    
    for batch in tqdm(loader, desc="Training"):
        pc = batch['point_cloud'].to(device)
        gt_pose = batch['pose'].to(device)
        
        optimizer.zero_grad()
        
        pred_pose = model(pc)
        loss, pos_loss, rot_loss = criterion(pred_pose, gt_pose)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_pos_loss += pos_loss.item()
        total_rot_loss += rot_loss.item()
    
    n = len(loader)
    return total_loss/n, total_pos_loss/n, total_rot_loss/n

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    total_pos_loss = 0
    total_rot_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation"):
            pc = batch['point_cloud'].to(device)
            gt_pose = batch['pose'].to(device)
            
            pred_pose = model(pc)
            loss, pos_loss, rot_loss = criterion(pred_pose, gt_pose)
            
            total_loss += loss.item()
            total_pos_loss += pos_loss.item()
            total_rot_loss += rot_loss.item()
    
    n = len(loader)
    return total_loss/n, total_pos_loss/n, total_rot_loss/n

# Cell 11: Main Training Loop
history = {
    'train_loss': [],
    'val_loss': [],
    'train_pos_loss': [],
    'val_pos_loss': [],
    'train_rot_loss': [],
    'val_rot_loss': []
}

best_val_loss = float('inf')
patience_counter = 0

for epoch in range(config['training']['epochs']):
    print(f"\nEpoch {epoch+1}/{config['training']['epochs']}")
    
    # Train
    train_loss, train_pos, train_rot = train_epoch(
        model, train_loader, criterion, optimizer, device
    )
    
    # Validate
    val_loss, val_pos, val_rot = validate(
        model, val_loader, criterion, device
    )
    
    # Update scheduler
    scheduler.step()
    
    # Log
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_pos_loss'].append(train_pos)
    history['val_pos_loss'].append(val_pos)
    history['train_rot_loss'].append(train_rot)
    history['val_rot_loss'].append(val_rot)
    
    print(f"Train Loss: {train_loss:.4f} (Pos: {train_pos:.4f}, Rot: {train_rot:.4f})")
    print(f"Val Loss: {val_loss:.4f} (Pos: {val_pos:.4f}, Rot: {val_rot:.4f})")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, f"{config['paths']['checkpoint_dir']}/best_model.pth")
        print("✓ Saved best model")
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= config['training']['early_stopping']:
        print(f"\nEarly stopping after {epoch+1} epochs")
        break

# Cell 12: Plot Training Curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(history['train_loss'], label='Train')
plt.plot(history['val_loss'], label='Val')
plt.xlabel('Epoch')
plt.ylabel('Total Loss')
plt.legend()
plt.title('Total Loss')

plt.subplot(1, 3, 2)
plt.plot(history['train_pos_loss'], label='Train')
plt.plot(history['val_pos_loss'], label='Val')
plt.xlabel('Epoch')
plt.ylabel('Position Loss')
plt.legend()
plt.title('Position Loss')

plt.subplot(1, 3, 3)
plt.plot(history['train_rot_loss'], label='Train')
plt.plot(history['val_rot_loss'], label='Val')
plt.xlabel('Epoch')
plt.ylabel('Rotation Loss')
plt.legend()
plt.title('Rotation Loss')

plt.tight_layout()
plt.savefig(f"{config['paths']['log_dir']}/training_curves.png")
plt.show()

print("\n✅ Training complete!")
print(f"Best validation loss: {best_val_loss:.4f}")