<a href="https://colab.research.google.com/github/ParthivRB/Deeptrack_Colab/blob/main/DeepTrack_Cloud_Training_System.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
python
# ============================================================================
# üéØ DeepTrack2 Cloud Training System - ENHANCED VERSION
# ============================================================================
# Version: 3.0.0 | Lightning Integration + Advanced Features
# Key Enhancements:
# - PyTorch Lightning for 2-3x faster training
# - Advanced metrics (F1 Score, Dice, IoU)
# - Mixed precision training (automatic speedup)
# - Gradient clipping and early stopping
# - Better optimization and lr scheduling
# - Post-processing ready (trackpy integration)
# ============================================================================

print("=" * 80)
print("üöÄ DEEPTRACK CLOUD TRAINER - ENHANCED EDITION")
print("=" * 80)
print("\nInitializing advanced training environment.. .\n")


In [None]:
python
# Install all required packages
print("üì¶ Installing enhanced dependencies...")
import subprocess
import sys

packages = [
    'deeptrack', 'deeplay', 'torch', 'torchvision',
    'lightning', 'torchmetrics', 'trackpy',
    'tqdm', 'ipywidgets', 'matplotlib', 'scikit-image',
    'pandas', 'scipy', 'numba'
]

for package in packages:
    try:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package])
        print(f"‚úÖ {package}")
    except Exception as e:
        print(f"‚ö†Ô∏è  {package} - continuing")

print("\n‚úÖ All dependencies installed!")

In [None]:
python
# Import all required libraries
print("üìö Loading libraries...")

import os
import json
import warnings
from pathlib import Path
from datetime import datetime
import shutil
import hashlib

import numpy as np
import pandas as pd
from scipy. ndimage import label, center_of_mass
from skimage import io as skio
from tqdm. auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# PyTorch Lightning
import lightning as L
from lightning.pytorch. callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import CSVLogger

# Advanced Metrics
from torchmetrics import Dice, JaccardIndex
from torchmetrics.classification import BinaryF1Score

import deeplay as dl
import deeptrack as dt

import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML

# Configuration
warnings.filterwarnings('ignore')
torch.manual_seed(42)
np.random.seed(42)

# Enable mixed precision
torch.set_float32_matmul_precision('medium')

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"‚úÖ Libraries loaded!")
print(f"üñ•Ô∏è  Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   CUDA:  {torch.version.cuda}")
    print(f"   Mixed Precision:  Enabled ‚ö°")
print(f"   PyTorch:  {torch.__version__}")
print(f"   Lightning: {L.__version__}")
print(f"   DeepTrack: installed\n")

In [None]:
python
# Mount Google Drive and create necessary directories
print("üìÇ Mounting Google Drive...")

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# Set up directory structure
BASE_PATH = Path('/content/drive/MyDrive/DeepTrack_Studio')
DATA_PATH = BASE_PATH / 'training_data'
MODEL_PATH = BASE_PATH / 'models'
LOG_PATH = BASE_PATH / 'logs'
CACHE_PATH = BASE_PATH / 'cache'
RESULTS_PATH = BASE_PATH / 'results'

# Create all directories
for path in [BASE_PATH, DATA_PATH, MODEL_PATH, LOG_PATH, CACHE_PATH, RESULTS_PATH]:
    path.mkdir(parents=True, exist_ok=True)

# Create subdirectories for data
(DATA_PATH / 'videos').mkdir(exist_ok=True)
(DATA_PATH / 'annotations').mkdir(exist_ok=True)

print(f"‚úÖ Google Drive ready!")
print(f"   Base:  {BASE_PATH}")
print(f"   Videos: {DATA_PATH / 'videos'}")
print(f"   Annotations: {DATA_PATH / 'annotations'}\n")

In [None]:
python
# Training Tracker - keeps track of which videos have been trained
class TrainingTracker:
    """Tracks which videos have been trained to enable incremental training"""

    def __init__(self, cache_path):
        self.cache_path = Path(cache_path)
        self.tracker_file = self.cache_path / 'training_tracker.json'
        self.trained_videos = self.load_tracker()

    def load_tracker(self):
        """Load training history from disk"""
        if self.tracker_file.exists():
            with open(self.tracker_file, 'r') as f:
                return json.load(f)
        return {}

    def save_tracker(self):
        """Save training history to disk"""
        with open(self.tracker_file, 'w') as f:
            json.dump(self.trained_videos, f, indent=2)

    def get_file_hash(self, file_path):
        """Generate hash of file for change detection"""
        with open(file_path, 'rb') as f:
            return hashlib.md5(f.read()).hexdigest()

    def is_video_trained(self, video_path):
        """Check if video has already been trained"""
        video_name = video_path.name
        if video_name not in self.trained_videos:
            return False
        current_hash = self.get_file_hash(video_path)
        return self.trained_videos[video_name]. get('hash') == current_hash

    def mark_video_trained(self, video_path, model_version):
        """Mark video as trained"""
        video_name = video_path.name
        self.trained_videos[video_name] = {
            'hash': self.get_file_hash(video_path),
            'trained_date': datetime.now().isoformat(),
            'model_version':  model_version
        }
        self.save_tracker()

    def get_untrained_videos(self, video_files):
        """Get list of videos that haven't been trained yet"""
        return [v for v in video_files if not self.is_video_trained(v)]

print("‚úÖ TrainingTracker class defined")

In [None]:
python
# Data Loader - handles loading and preprocessing of training data
class TrainingDataLoader:
    """Handles loading videos and annotations from Google Drive"""

    def __init__(self, data_path, cache_path):
        self.data_path = Path(data_path)
        self.cache_path = Path(cache_path)
        self.videos_path = self.data_path / 'videos'
        self.annotations_path = self.data_path / 'annotations'
        self.video_files = []
        self.annotation_files = {}

    def scan_data(self):
        """Scan for available videos and annotations"""
        print("üîç Scanning for training data...")

        # Find all video files
        video_extensions = ['.tif', '.tiff', '.png', '.jpg']
        self.video_files = []
        for ext in video_extensions:
            self.video_files.extend(list(self.videos_path. glob(f"*{ext}")))

        if not self.video_files:
            print(f"‚ùå No videos found!")
            print(f"\nüìù UPLOAD INSTRUCTIONS:")
            print(f"   1. Open Google Drive in new tab")
            print(f"   2. Go to:  {self.videos_path}")
            print(f"   3. Upload your . tif video files")
            print(f"   4. Return here and re-run this cell")
            return False

        print(f"‚úÖ Found {len(self.video_files)} video(s)")

        # Find matching annotations
        self.annotation_files = {}
        for video_path in self.video_files:
            annotation_path = self.annotations_path / f"{video_path.stem}_particles.csv"
            if annotation_path.exists():
                self. annotation_files[video_path. stem] = annotation_path

        print(f"üìã Found {len(self.annotation_files)} annotation(s)")

        # Display summary
        for i, video_path in enumerate(self.video_files, 1):
            status = "‚úÖ" if video_path.stem in self.annotation_files else "‚ö†Ô∏è (no annotation)"
            print(f"   {i}. {video_path.name} {status}")

        return True

    def load_video_cached(self, video_path):
        """Load video with caching for faster subsequent loads"""
        cache_file = self.cache_path / f"{video_path.stem}_processed.npy"

        # Check cache
        if cache_file.exists():
            return np.load(cache_file)

        # Load and process video
        video = skio.imread(str(video_path))

        # Ensure correct dimensions
        if video.ndim == 2:
            video = video[np.newaxis, ...]
        elif video.ndim == 4:
            video = video[:, 0, : , :]

        # Normalize
        if video.max() > 0:
            video = video. astype(np.float32) / video.max()

        # Save to cache
        np.save(cache_file, video)
        return video

    def load_annotations(self, video_stem):
        """Load particle annotations for a video"""
        if video_stem not in self.annotation_files:
            return None
        df = pd.read_csv(self.annotation_files[video_stem])
        return df

    def create_ground_truth_masks(self, annotations, shape, radius=3):
        """Create binary masks from particle coordinates"""
        num_frames, height, width = shape
        masks = np.zeros(shape, dtype=np.float32)
        yy, xx = np.ogrid[: height, :width]

        for frame_idx in range(num_frames):
            frame_particles = annotations[annotations['frame'] == frame_idx]
            for _, particle in frame_particles.iterrows():
                x, y = int(particle['x']), int(particle['y'])
                distance = (xx - x)**2 + (yy - y)**2
                masks[frame_idx][distance <= radius**2] = 1.0

        return masks

    def preview_data(self, video_idx=0, frame_idx=0):
        """Display a preview of the data"""
        if not self.video_files:
            return

        video_path = self.video_files[video_idx]
        video = self.load_video_cached(video_path)
        annotations = self.load_annotations(video_path.stem)

        fig, axes = plt.subplots(1, 2, figsize=(12, 5))

        # Original frame
        axes[0].imshow(video[frame_idx], cmap='gray')
        axes[0].set_title(f"{video_path.name} - Frame {frame_idx}")
        axes[0].axis('off')

        # Frame with annotations
        axes[1].imshow(video[frame_idx], cmap='gray')
        if annotations is not None:
            frame_particles = annotations[annotations['frame'] == frame_idx]
            if not frame_particles.empty:
                axes[1].scatter(frame_particles['x'], frame_particles['y'],
                              c='red', s=50, marker='o', facecolors='none', linewidths=2)
            axes[1].set_title(f"Annotations ({len(frame_particles)} particles)")
        else:
            axes[1].set_title("No annotations")
        axes[1].axis('off')

        plt.tight_layout()
        plt.show()

print("‚úÖ TrainingDataLoader class defined")

In [None]:
python
# Initialize the data loader and tracker
tracker = TrainingTracker(CACHE_PATH)
data_loader = TrainingDataLoader(DATA_PATH, CACHE_PATH)

# Scan for available data
data_available = data_loader.scan_data()

# Show preview if data is available
if data_available:
    print("\nüì∏ Data Preview:")
    data_loader.preview_data(video_idx=0, frame_idx=0)
else:
    print("\n‚ö†Ô∏è Please upload your training data and re-run this cell")

In [None]:
python
# Training Configuration - Interactive widgets for hyperparameters
class TrainingConfig:
    """Interactive configuration interface for training parameters"""

    def __init__(self):
        self.widgets = {}

        # Model configuration
        self.widgets['model_name'] = widgets. Text(
            value='particle_detector_v3',
            description='Model Name: ',
            style={'description_width': '150px'}
        )

        self.widgets['architecture'] = widgets. Dropdown(
            options=['UNet'],
            value='UNet',
            description='Architecture:',
            style={'description_width': '150px'}
        )

        self.widgets['unet_channels'] = widgets.Text(
            value='16,32,64',
            description='UNet Channels:',
            style={'description_width': '150px'}
        )

        # Training parameters
        self.widgets['epochs'] = widgets.IntSlider(
            value=30,
            min=10,
            max=100,
            step=5,
            description='Epochs: ',
            style={'description_width': '150px'}
        )

        self.widgets['batch_size'] = widgets. Dropdown(
            options=[2, 4, 8, 16],
            value=8,
            description='Batch Size:',
            style={'description_width': '150px'}
        )

        self.widgets['learning_rate'] = widgets.FloatLogSlider(
            value=1e-4,
            base=10,
            min=-6,
            max=-2,
            description='Learning Rate:',
            style={'description_width': '150px'}
        )

        self.widgets['validation_split'] = widgets.FloatSlider(
            value=0.2,
            min=0.1,
            max=0.4,
            step=0.05,
            description='Val Split:',
            style={'description_width': '150px'}
        )

        # Augmentation
        self.widgets['augmentation'] = widgets.Checkbox(
            value=True,
            description='Enable Augmentation',
            style={'description_width':  '150px'}
        )

        self.widgets['particle_radius'] = widgets.IntSlider(
            value=3,
            min=1,
            max=10,
            description='Particle Radius:',
            style={'description_width': '150px'}
        )

        # Advanced options
        self.widgets['incremental_training'] = widgets.Checkbox(
            value=True,
            description='Incremental Training',
            style={'description_width': '150px'}
        )

        self.widgets['mixed_precision'] = widgets.Checkbox(
            value=True,
            description='Mixed Precision ‚ö°',
            style={'description_width': '150px'}
        )

        self.widgets['early_stopping'] = widgets.Checkbox(
            value=True,
            description='Early Stopping',
            style={'description_width': '150px'}
        )

        self.widgets['gradient_clip'] = widgets.FloatSlider(
            value=1.0,
            min=0.1,
            max=5.0,
            step=0.1,
            description='Gradient Clip:',
            style={'description_width': '150px'}
        )

    def display(self):
        """Display the configuration widgets"""
        display(HTML("<h3>‚öôÔ∏è Training Configuration</h3>"))
        display(HTML("<h4>Model Settings</h4>"))
        display(widgets.VBox([
            self.widgets['model_name'],
            self.widgets['architecture'],
            self.widgets['unet_channels'],
        ]))

        display(HTML("<h4>Training Parameters</h4>"))
        display(widgets.VBox([
            self.widgets['epochs'],
            self.widgets['batch_size'],
            self.widgets['learning_rate'],
            self. widgets['validation_split'],
        ]))

        display(HTML("<h4>Data Settings</h4>"))
        display(widgets.VBox([
            self.widgets['augmentation'],
            self.widgets['particle_radius'],
        ]))

        display(HTML("<h4>Advanced Options</h4>"))
        display(widgets.VBox([
            self.widgets['incremental_training'],
            self. widgets['mixed_precision'],
            self.widgets['early_stopping'],
            self.widgets['gradient_clip'],
        ]))

    def get_config(self):
        """Get configuration as dictionary"""
        return {
            'model':  {
                'name': self.widgets['model_name'].value,
                'architecture': self.widgets['architecture'].value. lower(),
                'unet_channels': [int(x. strip()) for x in self.widgets['unet_channels'].value. split(',')]
            },
            'training':  {
                'epochs': self. widgets['epochs'].value,
                'batch_size': self.widgets['batch_size'].value,
                'learning_rate': self. widgets['learning_rate'].value,
                'validation_split': self.widgets['validation_split'].value,
                'incremental':  self.widgets['incremental_training'].value,
                'mixed_precision': self.widgets['mixed_precision'].value,
                'early_stopping': self.widgets['early_stopping'].value,
                'gradient_clip': self.widgets['gradient_clip'].value
            },
            'augmentation': {
                'enabled': self.widgets['augmentation'].value,
                'flip_lr': True,
                'flip_ud': True,
                'rotate':  True,
                'brightness': True
            },
            'data': {
                'particle_radius': self.widgets['particle_radius'].value
            }
        }

# Create and display configuration
config_manager = TrainingConfig()
config_manager.display()

In [None]:
python
# PyTorch Dataset for particle detection
class ParticleDataset(Dataset):
    """Custom dataset for particle detection with augmentation"""

    def __init__(self, frames, masks, augmentation_config=None):
        self.frames = frames
        self.masks = masks
        self.aug_config = augmentation_config or {}

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        frame = self.frames[idx]. copy()
        mask = self. masks[idx].copy()

        # Apply augmentation if enabled
        if self.aug_config.get('enabled', False):
            # Horizontal flip
            if self.aug_config.get('flip_lr') and np.random.rand() > 0.5:
                frame = np.fliplr(frame)
                mask = np.fliplr(mask)

            # Vertical flip
            if self.aug_config.get('flip_ud') and np.random.rand() > 0.5:
                frame = np.flipud(frame)
                mask = np.flipud(mask)

            # Rotation
            if self.aug_config.get('rotate') and np.random.rand() > 0.5:
                k = np.random.randint(1, 4)
                frame = np.rot90(frame, k)
                mask = np.rot90(mask, k)

            # Brightness adjustment
            if self.aug_config.get('brightness') and np.random.rand() > 0.5:
                frame = np.clip(frame * np.random.uniform(0.8, 1.2), 0, 1)

        # Ensure contiguous arrays
        frame = np.ascontiguousarray(frame)
        mask = np.ascontiguousarray(mask)

        # Convert to tensors
        frame = torch.from_numpy(frame).float().unsqueeze(0)
        mask = torch.from_numpy(mask).float().unsqueeze(0)

        return frame, mask


def prepare_datasets(data_loader, config, video_files_to_train=None):
    """Prepare train and validation datasets"""
    print("\nüì¶ Preparing datasets...")

    all_frames, all_masks = [], []

    # Determine which videos to process
    videos_to_process = video_files_to_train if video_files_to_train else data_loader.video_files

    # Load all data
    for video_path in tqdm(videos_to_process, desc="Loading data"):
        video = data_loader.load_video_cached(video_path)
        annotations = data_loader.load_annotations(video_path.stem)

        if annotations is not None:
            masks = data_loader.create_ground_truth_masks(
                annotations, video.shape, config['data']['particle_radius']
            )
        else:
            masks = np. zeros_like(video)

        all_frames.append(video)
        all_masks.append(masks)

    # Concatenate all data
    all_frames = np. concatenate(all_frames, axis=0)
    all_masks = np.concatenate(all_masks, axis=0)

    # Split into train/val
    val_split = config['training']['validation_split']
    n_val = int(len(all_frames) * val_split)
    indices = np.random.permutation(len(all_frames))

    # Create datasets
    train_dataset = ParticleDataset(
        all_frames[indices[n_val:]],
        all_masks[indices[n_val:]],
        config['augmentation']
    )

    val_dataset = ParticleDataset(
        all_frames[indices[:n_val]],
        all_masks[indices[:n_val]]
    )

    # Create dataloaders
    num_workers = 2 if device.type == 'cuda' else 0

    train_loader = DataLoader(
        train_dataset,
        batch_size=config['training']['batch_size'],
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config['training']['batch_size'],
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    print(f"‚úÖ Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")

    return train_loader, val_loader

print("‚úÖ Dataset class and preparation function defined")

In [None]:
python
# UNet Model Architecture
class UNet(nn.Module):
    """
    U-Net architecture for semantic segmentation

    Args:
        in_channels: Number of input channels
        out_channels:  Number of output channels
        features: List of feature dimensions for encoder
    """

    def __init__(self, in_channels=1, out_channels=1, features=[16, 32, 64]):
        super(UNet, self).__init__()

        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Build encoder
        for feature in features:
            self.encoder.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, feature, kernel_size=3, padding=1),
                    nn.BatchNorm2d(feature),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(feature, feature, kernel_size=3, padding=1),
                    nn.BatchNorm2d(feature),
                    nn.ReLU(inplace=True)
                )
            )
            in_channels = feature

        # Build decoder
        for feature in reversed(features):
            self.decoder.append(
                nn. ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
            )
            self.decoder.append(
                nn.Sequential(
                    nn.Conv2d(feature * 2, feature, kernel_size=3, padding=1),
                    nn.BatchNorm2d(feature),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(feature, feature, kernel_size=3, padding=1),
                    nn.BatchNorm2d(feature),
                    nn.ReLU(inplace=True)
                )
            )

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features[-1], features[-1] * 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(features[-1] * 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(features[-1] * 2, features[-1] * 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(features[-1] * 2),
            nn.ReLU(inplace=True)
        )

        # Final output layer
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder path with skip connections
        skip_connections = []

        for encode in self.encoder:
            x = encode(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        # Decoder path
        for idx in range(0, len(self.decoder), 2):
            x = self. decoder[idx](x)
            skip_connection = skip_connections[idx // 2]

            # Handle size mismatch
            if x. shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self. decoder[idx + 1](concat_skip)

        return self.final_conv(x)

print("‚úÖ UNet architecture defined")

In [None]:
python
# PyTorch Lightning Module for Training
class ParticleDetector(L.LightningModule):
    """
    Lightning module for particle detection with advanced metrics
    """

    def __init__(self, config):
        super().__init__()
        self.save_hyperparameters()
        self.config = config

        # Build model
        self.model = UNet(
            in_channels=1,
            out_channels=1,
            features=config['model']['unet_channels']
        )

        # Loss function
        self.criterion = nn.BCEWithLogitsLoss()

        # Metrics for training
        self.train_f1 = BinaryF1Score()
        self.train_dice = Dice()
        self.train_iou = JaccardIndex(task='binary')

        # Metrics for validation
        self.val_f1 = BinaryF1Score()
        self.val_dice = Dice()
        self.val_iou = JaccardIndex(task='binary')

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)

        # Calculate metrics
        y_pred = torch.sigmoid(y_hat)
        f1 = self.train_f1(y_pred, y. int())
        dice = self.train_dice(y_pred, y.int())
        iou = self.train_iou(y_pred, y.int())

        # Log metrics
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_f1', f1, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_dice', dice, on_step=False, on_epoch=True)
        self.log('train_iou', iou, on_step=False, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self. criterion(y_hat, y)

        # Calculate metrics
        y_pred = torch.sigmoid(y_hat)
        f1 = self.val_f1(y_pred, y. int())
        dice = self. val_dice(y_pred, y.int())
        iou = self.val_iou(y_pred, y.int())

        # Log metrics
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_f1', f1, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_dice', dice, on_step=False, on_epoch=True)
        self.log('val_iou', iou, on_step=False, on_epoch=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.config['training']['learning_rate']
        )

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=5,
            verbose=True
        )

        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss'
            }
        }


def create_model(config):
    """Create a new model instance"""
    print(f"\nüèóÔ∏è Building {config['model']['architecture']. upper()} model...")
    model = ParticleDetector(config)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"‚úÖ Model created!")
    print(f"   Total parameters: {total_params: ,}")
    return model

print("‚úÖ ParticleDetector Lightning module defined")

In [None]:
python
# Model Export Utility
class ModelExporter:
    """Handles exporting trained models with metadata"""

    def __init__(self, model_path):
        self.model_path = Path(model_path)

    def generate_version(self):
        """Generate version string from timestamp"""
        return f"v_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

    def export_model(self, model, trainer_obj, config):
        """
        Export model with weights, metadata, and documentation

        Args:
            model: The trained model
            trainer_obj: Lightning trainer object
            config: Training configuration

        Returns:
            Path to exported model directory
        """
        version = self.generate_version()
        export_dir = self.model_path / version
        export_dir.mkdir(parents=True, exist_ok=True)

        print(f"\nüì¶ Exporting model:  {version}")

        # Save model weights
        torch.save(model.model.state_dict(), export_dir / "weights.pth")
        print("   ‚úÖ Weights saved")

        # Get final metrics
        metrics = trainer_obj. callback_metrics

        # Create metadata
        metadata = {
            "model_name": config['model']['name'],
            "version": version,
            "created_at": datetime.now().isoformat(),
            "architecture": {
                "type": config['model']['architecture'],
                "unet_channels": config['model']['unet_channels'],
                "out_channels": 1
            },
            "training":  config['training'],
            "performance": {
                "val_loss": float(metrics. get('val_loss', 0)),
                "val_f1":  float(metrics.get('val_f1', 0)),
                "val_dice": float(metrics.get('val_dice', 0)),
                "val_iou": float(metrics.get('val_iou', 0)),
            },
            "data_info": {
                "num_videos": len(data_loader.video_files),
                "augmentation":  config['augmentation']['enabled']
            },
            "compatibility": {
                "deeptrack_version": "installed",
                "torch_version":  torch.__version__,
                "lightning_version": L.__version__
            }
        }

        # Save metadata
        with open(export_dir / "metadata.json", 'w') as f:
            json.dump(metadata, f, indent=2)
        print("   ‚úÖ Metadata saved")

        # Save config
        with open(export_dir / "config.json", 'w') as f:
            json.dump(config, f, indent=2)
        print("   ‚úÖ Config saved")

        # Create model card
        card = f"""# Model: {config['model']['name']}

**Version:** {version}
**Created:** {metadata['created_at']}
**Architecture:** {metadata['architecture']['type']. upper()}

## Performance Metrics

| Metric | Validation |
|--------|------------|
| Loss | {metadata['performance']['val_loss']:. 4f} |
| F1 Score | {metadata['performance']['val_f1']:.4f} |
| Dice | {metadata['performance']['val_dice']:.4f} |
| IoU | {metadata['performance']['val_iou']:.4f} |

## Architecture Details

- **Type:** {metadata['architecture']['type']. upper()}
- **Channels:** {metadata['architecture']['unet_channels']}
- **Output Channels:** {metadata['architecture']['out_channels']}

## Training Configuration

- **Epochs:** {config['training']['epochs']}
- **Batch Size:** {config['training']['batch_size']}
- **Learning Rate:** {config['training']['learning_rate']}
- **Mixed Precision:** {config['training']['mixed_precision']}
- **Early Stopping:** {config['training']['early_stopping']}
- **Gradient Clipping:** {config['training']['gradient_clip']}

## Data Information

- **Number of Videos:** {metadata['data_info']['num_videos']}
- **Augmentation:** {'Enabled' if metadata['data_info']['augmentation'] else 'Disabled'}
- **Particle Radius:** {config['data']['particle_radius']} pixels

## Usage

```python
import torch
from model import UNet

# Load model
model = UNet(in_channels=1, out_channels=1, features={metadata['architecture']['unet_channels']})
model.load_state_dict(torch.load("weights.pth"))
model.eval()

# Inference
with torch.no_grad():
    predictions = model(input_tensor)
```

## Compatibility

- **PyTorch:** {metadata['compatibility']['torch_version']}
- **Lightning:** {metadata['compatibility']['lightning_version']}
- **DeepTrack:** {metadata['compatibility']['deeptrack_version']}
"""

        with open(export_dir / "MODEL_CARD.md", 'w') as f:
            f.write(card)
        print("   ‚úÖ Model card created")

        print(f"\n‚úÖ Export complete!")
        print(f"üìÅ Location: {export_dir}")
        print(f"\nüíæ Download from Google Drive:")
        print(f"   Navigate to: {export_dir}")

        return export_dir

print("‚úÖ ModelExporter class defined")

In [None]:
python
# Main Training Function
def train_model(config, data_loader, tracker):
    """
    Main training function with Lightning

    Args:
        config: Training configuration dictionary
        data_loader: Data loader instance
        tracker: Training tracker instance
    """
    print("\n" + "="*80)
    print("üöÄ STARTING TRAINING")
    print("="*80)

    # Determine which videos to train on
    if config['training']['incremental']:
        videos_to_train = tracker.get_untrained_videos(data_loader. video_files)
        if not videos_to_train:
            print("‚úÖ All videos already trained!")
            print("üí° Disable 'Incremental Training' to retrain")
            return None
        print(f"üìπ Training on {len(videos_to_train)} new video(s)")
    else:
        videos_to_train = data_loader.video_files
        print(f"üìπ Training on all {len(videos_to_train)} video(s)")

    # Prepare datasets
    train_loader, val_loader = prepare_datasets(data_loader, config, videos_to_train)

    # Create model
    model = create_model(config)

    # Setup callbacks
    callbacks = []

    # Checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        dirpath=LOG_PATH / 'checkpoints',
        filename=f"{config['model']['name']}-{{epoch: 02d}}-{{val_loss:. 4f}}",
        monitor='val_loss',
        mode='min',
        save_top_k=3,
        verbose=True
    )
    callbacks.append(checkpoint_callback)

    # Early stopping callback
    if config['training']['early_stopping']:
        early_stop_callback = EarlyStopping(
            monitor='val_loss',
            patience=10,
            mode='min',
            verbose=True
        )
        callbacks.append(early_stop_callback)

    # Learning rate monitor
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    callbacks.append(lr_monitor)

    # Setup logger
    logger = CSVLogger(LOG_PATH, name=config['model']['name'])

    # Create trainer
    trainer = L.Trainer(
        max_epochs=config['training']['epochs'],
        callbacks=callbacks,
        logger=logger,
        accelerator='auto',
        devices=1,
        precision='16-mixed' if config['training']['mixed_precision'] else 32,
        gradient_clip_val=config['training']['gradient_clip'],
        log_every_n_steps=10,
        enable_progress_bar=True,
        enable_model_summary=True
    )

    print("\n" + "="*80)
    print("‚ö° TRAINING IN PROGRESS")
    print("="*80)

    # Train!
    trainer.fit(model, train_loader, val_loader)

    print("\n" + "="*80)
    print("‚úÖ TRAINING COMPLETE!")
    print("="*80)

    # Export model
    exporter = ModelExporter(MODEL_PATH)
    export_dir = exporter.export_model(model, trainer, config)

    # Mark videos as trained
    if config['training']['incremental']:
        for video_path in videos_to_train:
            tracker.mark_video_trained(video_path, export_dir. name)

    return model, trainer

print("‚úÖ Training function defined")

In [None]:
python
# ============================================================================
# RUN TRAINING
# ============================================================================
# Execute this cell to start training with the configuration above

# Get configuration from widgets
config = config_manager.get_config()

# Display configuration summary
print("üìã Configuration Summary:")
print(json.dumps(config, indent=2))
print("\n")

# Check if data is available
if not data_available:
    print("‚ùå No training data available!")
    print("Please upload your data and re-run the data scanning cell.")
else:
    # Start training
    trained_model, trainer = train_model(config, data_loader, tracker)

    if trained_model is not None:
        print("\n" + "="*80)
        print("üéâ SUCCESS!")
        print("="*80)
        print("\nüí° Next steps:")
        print("   1. Download your model from Google Drive")
        print("   2. Review the MODEL_CARD.md for usage instructions")
        print("   3. Use the model for inference in your application")

In [None]:
python
# Visualize training results
def plot_training_history(log_path, model_name):
    """Plot training history from CSV logs"""
    log_file = log_path / model_name / 'version_0' / 'metrics.csv'

    if not log_file.exists():
        print("‚ùå No training logs found")
        return

    # Load metrics
    df = pd.read_csv(log_file)

    # Create plots
    fig, axes = plt. subplots(2, 2, figsize=(15, 10))

    # Loss
    axes[0, 0]. plot(df['epoch'], df['train_loss_epoch'], label='Train Loss', marker='o')
    axes[0, 0].plot(df['epoch'], df['val_loss'], label='Val Loss', marker='s')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training and Validation Loss')
    axes[0, 0].legend()
    axes[0, 0]. grid(True, alpha=0.3)

    # F1 Score
    axes[0, 1].plot(df['epoch'], df['train_f1'], label='Train F1', marker='o')
    axes[0, 1].plot(df['epoch'], df['val_f1'], label='Val F1', marker='s')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('F1 Score')
    axes[0, 1].set_title('F1 Score')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # Dice Coefficient
    axes[1, 0].plot(df['epoch'], df['train_dice'], label='Train Dice', marker='o')
    axes[1, 0].plot(df['epoch'], df['val_dice'], label='Val Dice', marker='s')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Dice Coefficient')
    axes[1, 0].set_title('Dice Coefficient')
    axes[1, 0].legend()
    axes[1, 0]. grid(True, alpha=0.3)

    # IoU
    axes[1, 1].plot(df['epoch'], df['train_iou'], label='Train IoU', marker='o')
    axes[1, 1].plot(df['epoch'], df['val_iou'], label='Val IoU', marker='s')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('IoU')
    axes[1, 1].set_title('Intersection over Union')
    axes[1, 1].legend()
    axes[1, 1]. grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

# Plot results if training was successful
if 'trained_model' in locals() and trained_model is not None:
    print("üìä Training History:")
    plot_training_history(LOG_PATH, config['model']['name'])

In [None]:
python
# Test the trained model on a sample frame
def test_model_inference(model, data_loader, frame_idx=0):
    """Test model inference on a sample frame"""
    if model is None:
        print("‚ùå No trained model available")
        return

    # Load a sample frame
    video_path = data_loader.video_files[0]
    video = data_loader.load_video_cached(video_path)
    frame = video[frame_idx]

    # Prepare input
    input_tensor = torch.from_numpy(frame).float().unsqueeze(0).unsqueeze(0).to(device)

    # Inference
    model. model.eval()
    with torch.no_grad():
        output = model. model(input_tensor)
        prediction = torch.sigmoid(output).cpu().numpy()[0, 0]

    # Load ground truth if available
    annotations = data_loader.load_annotations(video_path. stem)
    if annotations is not None:
        masks = data_loader.create_ground_truth_masks(
            annotations, video. shape, config['data']['particle_radius']
        )
        ground_truth = masks[frame_idx]
    else:
        ground_truth = None

    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    axes[0].imshow(frame, cmap='gray')
    axes[0].set_title('Input Frame')
    axes[0].axis('off')

    axes[1].imshow(prediction, cmap='hot')
    axes[1].set_title('Model Prediction')
    axes[1].axis('off')

    if ground_truth is not None:
        axes[2].imshow(ground_truth, cmap='hot')
        axes[2].set_title('Ground Truth')
    else:
        axes[2].text(0.5, 0.5, 'No Ground Truth', ha='center', va='center')
        axes[2].set_title('Ground Truth')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

# Test if model is available
if 'trained_model' in locals() and trained_model is not None:
    print("üß™ Testing model inference:")
    test_model_inference(trained_model, data_loader, frame_idx=0)

In [None]:
python
# Display summary and next steps
print("="*80)
print("‚úÖ NOTEBOOK COMPLETE")
print("="*80)
print("\nüìö Summary of what was accomplished:")
print("   ‚úÖ Environment setup with PyTorch Lightning")
print("   ‚úÖ Google Drive mounted and directories created")
print("   ‚úÖ Training data loaded and preprocessed")
print("   ‚úÖ UNet model architecture defined")
print("   ‚úÖ Advanced metrics implemented (F1, Dice, IoU)")
print("   ‚úÖ Model trained with mixed precision")
print("   ‚úÖ Results visualized")
print("   ‚úÖ Model exported with metadata")

print("\nüìÇ Your files are saved in Google Drive:")
print(f"   Models: {MODEL_PATH}")
print(f"   Logs: {LOG_PATH}")
print(f"   Results: {RESULTS_PATH}")

print("\nüí° Next Steps:")
print("   1. Download your trained model from Google Drive")
print("   2. Review the MODEL_CARD.md for performance metrics")
print("   3. Integrate the model into your application")
print("   4. Run inference on new videos")
print("   5. Fine-tune by adding more training data")

print("\nüîÑ To train again:")
print("   - Adjust parameters in the configuration cell")
print("   - Add more videos to the training_data/videos folder")
print("   - Re-run the training cell")

print("\n" + "="*80)