<a href="https://colab.research.google.com/github/ParthivRB/Deeptrack_Colab/blob/main/DeepTrack_Cloud_Trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================================================
# üéØ DeepTrack2 Cloud Training System - OPTIMIZED VERSION
# ============================================================================
# Version: 2.0.0 | Incremental Training + Performance Optimizations
# ============================================================================

print("=" * 70)
print("üöÄ DEEPTRACK CLOUD TRAINER")
print("=" * 70)
print("\nInitializing training environment.. .\n")

# -----------------------------------------------------------------------------
# STEP 1: Install Dependencies
# -----------------------------------------------------------------------------
print("üì¶ Installing dependencies...")
import subprocess
import sys

packages = [
    'deeptrack', 'deeplay', 'torch', 'torchvision',
    'tqdm', 'ipywidgets', 'matplotlib', 'scikit-image',
    'pandas', 'scipy'
]

for package in packages:
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package])

print("‚úÖ All dependencies installed!\n")

# -----------------------------------------------------------------------------
# STEP 2: Import Libraries
# -----------------------------------------------------------------------------
print("üìö Loading libraries...")

import os
import json
import warnings
from pathlib import Path
from datetime import datetime
import shutil
import hashlib

import numpy as np
import pandas as pd
from scipy. ndimage import label, center_of_mass
from skimage import io as skio
from tqdm. auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import deeplay as dl
import deeptrack as dt

import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML

warnings.filterwarnings('ignore')
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"‚úÖ Libraries loaded!")
print(f"üñ•Ô∏è  Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
print(f"   PyTorch:  {torch.__version__}")
print(f"   DeepTrack: installed\n")

# -----------------------------------------------------------------------------
# STEP 3: Mount Google Drive
# -----------------------------------------------------------------------------
print("üìÇ Mounting Google Drive...")
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

BASE_PATH = Path('/content/drive/MyDrive/DeepTrack_Studio')
DATA_PATH = BASE_PATH / 'training_data'
MODEL_PATH = BASE_PATH / 'models'
LOG_PATH = BASE_PATH / 'logs'
CACHE_PATH = BASE_PATH / 'cache'

for path in [BASE_PATH, DATA_PATH, MODEL_PATH, LOG_PATH, CACHE_PATH]:
    path.mkdir(parents=True, exist_ok=True)

(DATA_PATH / 'videos').mkdir(exist_ok=True)
(DATA_PATH / 'annotations').mkdir(exist_ok=True)

print(f"‚úÖ Google Drive ready!")
print(f"   Base:  {BASE_PATH}\n")

# -----------------------------------------------------------------------------
# STEP 4: Training Tracker (NEW - for incremental training)
# -----------------------------------------------------------------------------
class TrainingTracker:
    def __init__(self, cache_path):
        self.cache_path = Path(cache_path)
        self.tracker_file = self.cache_path / 'training_tracker.json'
        self.trained_videos = self.load_tracker()

    def load_tracker(self):
        if self.tracker_file.exists():
            with open(self.tracker_file, 'r') as f:
                return json.load(f)
        return {}

    def save_tracker(self):
        with open(self.tracker_file, 'w') as f:
            json. dump(self.trained_videos, f, indent=2)

    def get_file_hash(self, file_path):
        """Generate hash of file to detect changes"""
        with open(file_path, 'rb') as f:
            return hashlib.md5(f.read()).hexdigest()

    def is_video_trained(self, video_path):
        video_name = video_path.name
        if video_name not in self.trained_videos:
            return False

        current_hash = self.get_file_hash(video_path)
        return self.trained_videos[video_name]. get('hash') == current_hash

    def mark_video_trained(self, video_path, model_version):
        video_name = video_path.name
        self.trained_videos[video_name] = {
            'hash':  self.get_file_hash(video_path),
            'trained_date': datetime.now().isoformat(),
            'model_version': model_version
        }
        self.save_tracker()

    def get_untrained_videos(self, video_files):
        return [v for v in video_files if not self.is_video_trained(v)]

# -----------------------------------------------------------------------------
# STEP 5: Data Loader (OPTIMIZED with caching)
# -----------------------------------------------------------------------------
class TrainingDataLoader:
    def __init__(self, data_path, cache_path):
        self.data_path = Path(data_path)
        self.cache_path = Path(cache_path)
        self.videos_path = self.data_path / 'videos'
        self.annotations_path = self.data_path / 'annotations'
        self.video_files = []
        self.annotation_files = {}

    def scan_data(self):
        print("üîç Scanning for training data...")
        video_extensions = ['.tif', '.tiff', '.png', '.jpg']
        self.video_files = []
        for ext in video_extensions:
            self.video_files.extend(list(self.videos_path. glob(f"*{ext}")))

        if not self.video_files:
            print(f"‚ùå No videos found!")
            print(f"\nüìù UPLOAD INSTRUCTIONS:")
            print(f"   1. Open Google Drive in new tab")
            print(f"   2. Go to:  {self.videos_path}")
            print(f"   3. Upload your . tif video files")
            print(f"   4. Return here and re-run this cell")
            return False

        print(f"‚úÖ Found {len(self.video_files)} video(s)")

        self.annotation_files = {}
        for video_path in self.video_files:
            annotation_path = self.annotations_path / f"{video_path.stem}_particles.csv"
            if annotation_path. exists():
                self.annotation_files[video_path.stem] = annotation_path

        print(f"üìã Found {len(self.annotation_files)} annotation(s)")
        for i, video_path in enumerate(self.video_files, 1):
            status = "‚úÖ" if video_path. stem in self.annotation_files else "‚ö†Ô∏è (no annotation)"
            print(f"   {i}. {video_path.name} {status}")

        return True

    def load_video_cached(self, video_path):
        """Load video with caching for faster repeated access"""
        cache_file = self.cache_path / f"{video_path.stem}_processed.npy"

        if cache_file.exists():
            return np.load(cache_file)

        video = skio.imread(str(video_path))
        if video.ndim == 2:
            video = video[np.newaxis, ...]
        elif video.ndim == 4:
            video = video[:, 0, :, :]
        if video.max() > 0:
            video = video. astype(np.float32) / video.max()

        # Save to cache
        np.save(cache_file, video)
        return video

    def load_annotations(self, video_stem):
        if video_stem not in self.annotation_files:
            return None
        df = pd.read_csv(self.annotation_files[video_stem])
        return df

    def create_ground_truth_masks(self, annotations, shape, radius=3):
        num_frames, height, width = shape
        masks = np.zeros(shape, dtype=np.float32)
        yy, xx = np.ogrid[: height, :width]
        for frame_idx in range(num_frames):
            frame_particles = annotations[annotations['frame'] == frame_idx]
            for _, particle in frame_particles.iterrows():
                x, y = int(particle['x']), int(particle['y'])
                distance = (xx - x)**2 + (yy - y)**2
                masks[frame_idx][distance <= radius**2] = 1.0
        return masks

    def preview_data(self, video_idx=0, frame_idx=0):
        if not self.video_files:
            return
        video_path = self.video_files[video_idx]
        video = self.load_video_cached(video_path)
        annotations = self.load_annotations(video_path.stem)

        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        axes[0].imshow(video[frame_idx], cmap='gray')
        axes[0].set_title(f"{video_path.name} - Frame {frame_idx}")
        axes[0].axis('off')

        axes[1].imshow(video[frame_idx], cmap='gray')
        if annotations is not None:
            frame_particles = annotations[annotations['frame'] == frame_idx]
            if not frame_particles.empty:
                axes[1].scatter(frame_particles['x'], frame_particles['y'],
                              c='red', s=50, marker='o', facecolors='none', linewidths=2)
            axes[1].set_title(f"Annotations ({len(frame_particles)} particles)")
        else:
            axes[1].set_title("No annotations")
        axes[1].axis('off')
        plt.tight_layout()
        plt.show()

# Initialize tracker and data loader
tracker = TrainingTracker(CACHE_PATH)
data_loader = TrainingDataLoader(DATA_PATH, CACHE_PATH)
data_available = data_loader.scan_data()

if data_available:
    print("\nüì∏ Preview:")
    data_loader.preview_data(video_idx=0, frame_idx=0)

# -----------------------------------------------------------------------------
# STEP 6: Configuration
# -----------------------------------------------------------------------------
class TrainingConfig:
    def __init__(self):
        self.widgets = {}
        self.widgets['model_name'] = widgets.Text(value='particle_detector', description='Model Name: ')
        self.widgets['architecture'] = widgets. Dropdown(options=['UNet'], value='UNet', description='Architecture:')
        self.widgets['unet_channels'] = widgets.Text(value='16,32,64', description='Channels:')
        self.widgets['epochs'] = widgets.IntSlider(value=30, min=10, max=100, description='Epochs:')
        self.widgets['batch_size'] = widgets. Dropdown(options=[2,4,8,16], value=8, description='Batch Size:')
        self.widgets['learning_rate'] = widgets.FloatLogSlider(value=1e-4, base=10, min=-6, max=-2, description='Learning Rate:')
        self.widgets['validation_split'] = widgets.FloatSlider(value=0.2, min=0.1, max=0.4, description='Val Split:')
        self.widgets['augmentation'] = widgets.Checkbox(value=True, description='Augmentation')
        self.widgets['particle_radius'] = widgets.IntSlider(value=3, min=1, max=10, description='Particle Radius:')
        self.widgets['incremental_training'] = widgets. Checkbox(value=True, description='Incremental Training')

    def display(self):
        display(HTML("<h3>‚öôÔ∏è Training Configuration</h3>"))
        display(widgets.VBox([
            self.widgets['model_name'],
            self.widgets['architecture'],
            self.widgets['unet_channels'],
            self. widgets['epochs'],
            self. widgets['batch_size'],
            self.widgets['learning_rate'],
            self.widgets['validation_split'],
            self.widgets['augmentation'],
            self.widgets['particle_radius'],
            self.widgets['incremental_training']
        ]))

    def get_config(self):
        return {
            'model':  {
                'name': self.widgets['model_name'].value,
                'architecture': self.widgets['architecture'].value. lower(),
                'unet_channels': [int(x. strip()) for x in self.widgets['unet_channels'].value. split(',')]
            },
            'training':  {
                'epochs': self. widgets['epochs'].value,
                'batch_size': self.widgets['batch_size'].value,
                'learning_rate':  self.widgets['learning_rate']. value,
                'validation_split': self.widgets['validation_split'].value,
                'incremental':  self.widgets['incremental_training'].value
            },
            'augmentation': {
                'enabled': self.widgets['augmentation'].value,
                'flip_lr': True,
                'flip_ud': True,
                'rotate': True,
                'brightness': True
            },
            'data': {
                'particle_radius': self.widgets['particle_radius'].value
            }
        }

config_manager = TrainingConfig()
config_manager.display()

# -----------------------------------------------------------------------------
# STEP 7: Dataset (OPTIMIZED - fixed stride issue)
# -----------------------------------------------------------------------------
class ParticleDataset(Dataset):
    def __init__(self, frames, masks, augmentation_config=None):
        self.frames = frames
        self.masks = masks
        self.aug_config = augmentation_config or {}

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        frame = self.frames[idx]. copy()
        mask = self. masks[idx].copy()

        if self.aug_config. get('enabled', False):
            if self.aug_config. get('flip_lr') and np.random.rand() > 0.5:
                frame = np.fliplr(frame)
                mask = np.fliplr(mask)
            if self.aug_config.get('flip_ud') and np.random.rand() > 0.5:
                frame = np.flipud(frame)
                mask = np.flipud(mask)
            if self.aug_config.get('rotate') and np.random.rand() > 0.5:
                k = np.random.randint(1, 4)
                frame = np.rot90(frame, k)
                mask = np.rot90(mask, k)
            if self.aug_config.get('brightness') and np.random.rand() > 0.5:
                frame = np.clip(frame * np.random.uniform(0.8, 1.2), 0, 1)

        # FIX: Ensure contiguous arrays before converting to tensors
        frame = np.ascontiguousarray(frame)
        mask = np.ascontiguousarray(mask)

        frame = torch.from_numpy(frame).float().unsqueeze(0)
        mask = torch.from_numpy(mask).float().unsqueeze(0)
        return frame, mask

def prepare_datasets(data_loader, config, video_files_to_train=None):
    """Prepare datasets - optionally only for specific videos (incremental training)"""
    print("\nüì¶ Preparing datasets...")
    all_frames, all_masks = [], []

    videos_to_process = video_files_to_train if video_files_to_train else data_loader.video_files

    for video_path in tqdm(videos_to_process, desc="Loading"):
        video = data_loader.load_video_cached(video_path)
        annotations = data_loader.load_annotations(video_path.stem)

        if annotations is not None:
            masks = data_loader.create_ground_truth_masks(annotations, video.shape, config['data']['particle_radius'])
        else:
            masks = np. zeros_like(video)

        all_frames. append(video)
        all_masks.append(masks)

    all_frames = np.concatenate(all_frames, axis=0)
    all_masks = np. concatenate(all_masks, axis=0)

    val_split = config['training']['validation_split']
    n_val = int(len(all_frames) * val_split)
    indices = np.random.permutation(len(all_frames))

    train_dataset = ParticleDataset(all_frames[indices[n_val:]], all_masks[indices[n_val:]], config['augmentation'])
    val_dataset = ParticleDataset(all_frames[indices[:n_val]], all_masks[indices[:n_val]])

    # OPTIMIZATION: Use pin_memory for faster GPU transfer, increase num_workers
    num_workers = 2 if device.type == 'cuda' else 0
    train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'],
                             shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=config['training']['batch_size'],
                           shuffle=False, num_workers=num_workers, pin_memory=True)

    print(f"‚úÖ Train:  {len(train_dataset)}, Val: {len(val_dataset)}")
    return train_loader, val_loader

# -----------------------------------------------------------------------------
# STEP 8: Model (OPTIMIZED PyTorch UNet)
# -----------------------------------------------------------------------------
def create_model(config):
    print(f"\nüèóÔ∏è Building {config['model']['architecture']. upper()} model...")

    class UNet(nn.Module):
        def __init__(self, in_channels=1, out_channels=1, features=[16, 32, 64]):
            super(UNet, self).__init__()
            self.encoder = nn.ModuleList()
            self.decoder = nn.ModuleList()
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

            # Encoder
            for feature in features:
                self.encoder.append(
                    nn.Sequential(
                        nn.Conv2d(in_channels, feature, kernel_size=3, padding=1),
                        nn.BatchNorm2d(feature),
                        nn.ReLU(inplace=True),
                        nn.Conv2d(feature, feature, kernel_size=3, padding=1),
                        nn.BatchNorm2d(feature),
                        nn.ReLU(inplace=True)
                    )
                )
                in_channels = feature

            # Decoder
            for feature in reversed(features):
                self.decoder.append(
                    nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
                )
                self.decoder.append(
                    nn.Sequential(
                        nn.Conv2d(feature * 2, feature, kernel_size=3, padding=1),
                        nn.BatchNorm2d(feature),
                        nn.ReLU(inplace=True),
                        nn.Conv2d(feature, feature, kernel_size=3, padding=1),
                        nn.BatchNorm2d(feature),
                        nn.ReLU(inplace=True)
                    )
                )

            self.bottleneck = nn.Sequential(
                nn.Conv2d(features[-1], features[-1] * 2, kernel_size=3, padding=1),
                nn.BatchNorm2d(features[-1] * 2),
                nn.ReLU(inplace=True),
                nn.Conv2d(features[-1] * 2, features[-1] * 2, kernel_size=3, padding=1),
                nn.BatchNorm2d(features[-1] * 2),
                nn.ReLU(inplace=True)
            )

            self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

        def forward(self, x):
            skip_connections = []

            for encode in self.encoder:
                x = encode(x)
                skip_connections.append(x)
                x = self.pool(x)

            x = self.bottleneck(x)
            skip_connections = skip_connections[::-1]

            for idx in range(0, len(self.decoder), 2):
                x = self.decoder[idx](x)
                skip_connection = skip_connections[idx // 2]

                if x.shape != skip_connection.shape:
                    x = nn. functional.interpolate(x, size=skip_connection.shape[2:])

                concat_skip = torch.cat((skip_connection, x), dim=1)
                x = self.decoder[idx + 1](concat_skip)

            return self.final_conv(x)

    model = UNet(
        in_channels=1,
        out_channels=1,
        features=config['model']['unet_channels']
    )

    total_params = sum(p.numel() for p in model.parameters())
    print(f"‚úÖ Model created!  Parameters: {total_params:,}")

    if total_params == 0:
        raise ValueError("Model has 0 parameters!  Check UNet initialization.")

    return model

# -----------------------------------------------------------------------------
# STEP 9: Trainer (OPTIMIZED + Incremental Training Support)
# -----------------------------------------------------------------------------
class Trainer:
    def __init__(self, model, config, save_dir):
        self.model = model. to(device)
        self.config = config
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)

        self.optimizer = torch.optim.Adam(model.parameters(), lr=config['training']['learning_rate'])
        self.scheduler = torch.optim. lr_scheduler. ReduceLROnPlateau(self.optimizer, mode='min', patience=5, verbose=True)
        self.criterion = nn.BCEWithLogitsLoss()
        self.history = {'train_loss': [], 'val_loss': [], 'train_iou': [], 'val_iou': []}
        self. best_val_loss = float('inf')
        self.start_epoch = 0

    def _calculate_iou(self, pred, target, threshold=0.5):
        pred_binary = (pred > threshold).float()
        target_binary = (target > threshold).float()
        intersection = (pred_binary * target_binary).sum()
        union = pred_binary.sum() + target_binary.sum() - intersection
        return (intersection / union).item() if union > 0 else 1.0

    def train_epoch(self, train_loader):
        self.model.train()
        total_loss = 0
        total_iou = 0

        pbar = tqdm(train_loader, desc="Training")
        for frames, masks in pbar:
            frames, masks = frames.to(device), masks.to(device)

            self.optimizer.zero_grad()
            outputs = self. model(frames)
            loss = self.criterion(outputs, masks)
            loss.backward()
            self.optimizer.step()

            with torch.no_grad():
                iou = self._calculate_iou(torch.sigmoid(outputs), masks)

            total_loss += loss. item()
            total_iou += iou
            pbar.set_postfix({'loss': f"{loss.item():.4f}", 'iou': f"{iou:.4f}"})

        return total_loss / len(train_loader), total_iou / len(train_loader)

    def validate(self, val_loader):
        self.model.eval()
        total_loss = 0
        total_iou = 0

        with torch.no_grad():
            for frames, masks in tqdm(val_loader, desc="Validation"):
                frames, masks = frames.to(device), masks.to(device)
                outputs = self.model(frames)
                loss = self.criterion(outputs, masks)
                iou = self._calculate_iou(torch.sigmoid(outputs), masks)
                total_loss += loss.item()
                total_iou += iou

        return total_loss / len(val_loader), total_iou / len(val_loader)

    def load_checkpoint(self, checkpoint_path):
        """Load existing model for incremental training"""
        if Path(checkpoint_path).exists():
            print(f"üì• Loading existing model from {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.history = checkpoint. get('history', self.history)
            self.start_epoch = checkpoint.get('epoch', 0)
            self.best_val_loss = min(self.history. get('val_loss', [float('inf')]))
            print(f"‚úÖ Loaded model from epoch {self.start_epoch}")
            return True
        return False

    def save_checkpoint(self, epoch, is_best=False):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'history':  self.history,
            'config': self.config
        }

        # Always save latest
        torch.save(checkpoint, self.save_dir / "latest_model.pth")

        if epoch % 10 == 0:
            torch.save(checkpoint, self.save_dir / f"checkpoint_epoch{epoch}.pth")

        if is_best:
            torch.save(checkpoint, self.save_dir / "best_model.pth")
            print(f"   üíæ Best model saved (loss: {self.best_val_loss:.4f})")

    def plot_progress(self):
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))

        axes[0].plot(self.history['train_loss'], label='Train', marker='o')
        axes[0].plot(self.history['val_loss'], label='Val', marker='s')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Loss')
        axes[0].legend()
        axes[0].grid(True)

        axes[1].plot(self.history['train_iou'], label='Train', marker='o')
        axes[1].plot(self.history['val_iou'], label='Val', marker='s')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('IoU')
        axes[1].set_title('IoU')
        axes[1].legend()
        axes[1].grid(True)

        plt. tight_layout()
        plt.savefig(self.save_dir / 'training_progress. png', dpi=150)
        plt.show()

    def train(self, train_loader, val_loader, epochs):
        print(f"\nüöÄ Starting training for {epochs} epochs.. .\n")

        for epoch in range(self.start_epoch + 1, self.start_epoch + epochs + 1):
            print(f"\n{'='*60}")
            print(f"Epoch {epoch}/{self.start_epoch + epochs}")
            print(f"{'='*60}")

            train_loss, train_iou = self.train_epoch(train_loader)
            val_loss, val_iou = self.validate(val_loader)

            self.history['train_loss']. append(train_loss)
            self.history['val_loss']. append(val_loss)
            self.history['train_iou'].append(train_iou)
            self.history['val_iou'].append(val_iou)

            # FIX: Removed space before . 4f
            print(f"\nüìä Summary:")
            print(f"   Train - Loss: {train_loss:.4f}, IoU: {train_iou:. 4f}")
            print(f"   Val   - Loss: {val_loss:. 4f}, IoU: {val_iou:.4f}")

            self.scheduler.step(val_loss)

            is_best = val_loss < self.best_val_loss
            if is_best:
                self. best_val_loss = val_loss

            self.save_checkpoint(epoch, is_best)

            if epoch % 5 == 0:
                clear_output(wait=True)
                self.plot_progress()

        print(f"\n‚úÖ Training complete! Best val loss: {self.best_val_loss:.4f}")
        self.plot_progress()

# -----------------------------------------------------------------------------
# STEP 10: Model Exporter
# -----------------------------------------------------------------------------
class ModelExporter:
    def __init__(self, model_path):
        self.model_path = Path(model_path)

    def generate_version(self):
        return f"v_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

    def export_model(self, trainer, config):
        version = self.generate_version()
        export_dir = self.model_path / version
        export_dir.mkdir(parents=True, exist_ok=True)

        print(f"\nüì¶ Exporting model:  {version}")

        torch.save(trainer.model. state_dict(), export_dir / "weights.pth")

        metadata = {
            "model_name": config['model']['name'],
            "version": version,
            "created_at": datetime.now().isoformat(),
            "architecture": {
                "type": config['model']['architecture'],
                "input_shape": [1, 512, 512],
                "unet_channels": config['model']['unet_channels'],
                "out_channels": 1
            },
            "training":  config['training'],
            "performance": {
                "final_train_loss": trainer.history['train_loss'][-1],
                "final_val_loss": trainer.history['val_loss'][-1],
                "best_val_loss": trainer.best_val_loss,
                "final_train_iou": trainer. history['train_iou'][-1],
                "final_val_iou": trainer.history['val_iou'][-1],
                "best_val_iou": max(trainer.history['val_iou'])
            },
            "data_info": {
                "num_videos": len(data_loader.video_files),
                "augmentation":  config['augmentation']['enabled']
            },
            "compatibility": {
                "deeptrack_version": "installed",
                "torch_version":  torch.__version__,
                "python_version": f"{sys.version_info.major}.{sys.version_info. minor}"
            }
        }

        with open(export_dir / "metadata.json", 'w') as f:
            json.dump(metadata, f, indent=2)

        with open(export_dir / "config.json", 'w') as f:
            json.dump(config, f, indent=2)

        shutil.copy(trainer.save_dir / 'training_progress.png', export_dir / 'training_progress.png')

        card = f"""# Model:  {config['model']['name']}

**Version:** {version}
**Created:** {metadata['created_at']}
**Architecture:** {metadata['architecture']['type']. upper()}

## Performance

| Metric | Train | Validation | Best |
|--------|-------|------------|------|
| Loss | {metadata['performance']['final_train_loss']:.4f} | {metadata['performance']['final_val_loss']:.4f} | {metadata['performance']['best_val_loss']:.4f} |
| IoU | {metadata['performance']['final_train_iou']:. 4f} | {metadata['performance']['final_val_iou']:.4f} | {metadata['performance']['best_val_iou']:.4f} |

## Architecture

- Type: {metadata['architecture']['type']. upper()}
- Channels: {metadata['architecture']['unet_channels']}
- Input:  {metadata['architecture']['input_shape']}

## Training

- Epochs: {config['training']['epochs']}
- Batch Size: {config['training']['batch_size']}
- Learning Rate: {config['training']['learning_rate']}

## Usage

Download this model to your local DeepTrack MPT Studio app!

```python
from src.engines.ai_engine import DeepTrackEngine
engine = DeepTrackEngine()
engine.load_model("weights.pth", "metadata.json")
"""

    with open(export_dir / "model_card.md", 'w') as f:
        f. write(card)

    print(f"‚úÖ Model exported to: {export_dir}")
    print(f"\nüì• DOWNLOAD INSTRUCTIONS:")
    print(f"   1. Open Google Drive")
    print(f"   2. Navigate to: {export_dir}")
    print(f"   3. Download the entire '{version}' folder")
    print(f"   4. Place it in your local app's 'models' directory")

    return export_dir

In [None]:
# ============================================================================
# üéØ EXECUTE SMART TRAINING PIPELINE
# ============================================================================
# Incremental training:  Only trains on new/untrained videos!
# ============================================================================

config = config_manager.get_config()

print("=" * 70)
print("üöÄ STARTING SMART TRAINING PIPELINE")
print("=" * 70)
print(f"\n‚öôÔ∏è  Configuration:")
print(f"   Model: {config['model']['name']}")
print(f"   Architecture: {config['model']['architecture']. upper()}")
print(f"   Epochs: {config['training']['epochs']}")
print(f"   Batch Size: {config['training']['batch_size']}")
print(f"   Learning Rate: {config['training']['learning_rate']}")
print(f"   Incremental Training: {config['training']['incremental']}")
print("=" * 70)

# Determine which videos to train
videos_to_train = None
should_load_existing = False

if config['training']['incremental']:
    untrained_videos = tracker.get_untrained_videos(data_loader.video_files)

    if not untrained_videos:
        print("\n‚úÖ All videos already trained!")
        print("üí° Set 'Incremental Training' to False to retrain all videos.")
        videos_to_train = []
    else:
        print(f"\nüÜï Found {len(untrained_videos)} new/untrained video(s):")
        for v in untrained_videos:
            print(f"   - {v.name}")

        videos_to_train = untrained_videos

        # Check if we have an existing model to continue from
        latest_model = LOG_PATH / 'current_training' / 'latest_model.pth'
        if latest_model.exists():
            should_load_existing = True
            print(f"\nüì• Existing model found - will continue training from checkpoint")
        else:
            print(f"\n‚ÑπÔ∏è  No existing model found - starting fresh training")
else:
    print("\nüîÑ Full training mode - training on all videos")
    videos_to_train = data_loader.video_files
    should_load_existing = False

# Exit if nothing to train
if not videos_to_train:
    print("\n‚úã Nothing to train.  Exiting.")
else:
    # Prepare datasets
    train_loader, val_loader = prepare_datasets(data_loader, config, videos_to_train)

    # Create model
    model = create_model(config)
    trainer = Trainer(model, config, LOG_PATH / 'current_training')

    # Load existing model if incremental training is enabled AND model exists
    if should_load_existing:
        latest_model = LOG_PATH / 'current_training' / 'latest_model.pth'
        try:
            trainer.load_checkpoint(latest_model)
            print(f"‚úÖ Successfully loaded existing model for incremental training")
        except Exception as e:
            print(f"‚ö†Ô∏è  Could not load existing model: {e}")
            print(f"   Starting fresh training instead...")

    # Start training
    print(f"\n{'='*70}")
    print(f"üéØ Training on {len(videos_to_train)} video(s)")
    print(f"{'='*70}\n")

    trainer.train(train_loader, val_loader, config['training']['epochs'])

    # Mark videos as trained
    if config['training']['incremental']:
        version = ModelExporter(MODEL_PATH).generate_version()
        for video_path in videos_to_train:
            tracker.mark_video_trained(video_path, version)
        print(f"\n‚úÖ Marked {len(videos_to_train)} video(s) as trained")

    # Export model
    exporter = ModelExporter(MODEL_PATH)
    export_dir = exporter.export_model(trainer, config)

    print("\n" + "=" * 70)
    print("üéâ TRAINING COMPLETE!")
    print("=" * 70)
    print(f"\nüì¶ Model saved to: {export_dir}")
    print(f"\nüì• Download your trained model from Google Drive!")
    print(f"\nüí° Next time you upload new videos, only those will be trained!")
    print("=" * 70)