<a href="https://colab.research.google.com/github/ParthivRB/Deeptrack_Colab/blob/main/DeepTrack_Cloud_Trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================================================
# üéØ DeepTrack2 Cloud Training System - Complete Notebook
# ============================================================================
# One-stop solution for training particle tracking models in the cloud
# Version: 1.0.0 | Compatible with DeepTrack MPT Studio
# ============================================================================

print("=" * 70)
print("üöÄ DEEPTRACK CLOUD TRAINER")
print("=" * 70)
print("\nInitializing training environment.. .\n")

# -----------------------------------------------------------------------------
# STEP 1: Install Dependencies
# -----------------------------------------------------------------------------
print("üì¶ Installing dependencies...")
import subprocess
import sys

packages = [
    'deeptrack', 'deeplay', 'torch', 'torchvision',
    'tqdm', 'ipywidgets', 'matplotlib', 'scikit-image',
    'pandas', 'scipy'
]

for package in packages:
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package])

print("‚úÖ All dependencies installed!\n")

# -----------------------------------------------------------------------------
# STEP 2: Import Libraries
# -----------------------------------------------------------------------------
print("üìö Loading libraries...")

import os
import json
import warnings
from pathlib import Path
from datetime import datetime
import shutil

import numpy as np
import pandas as pd
from scipy. ndimage import label, center_of_mass
from skimage import io as skio
from tqdm. auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import deeplay as dl
import deeptrack as dt
import importlib.metadata

import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML

warnings.filterwarnings('ignore')
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"‚úÖ Libraries loaded!")
print(f"üñ•Ô∏è  Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
print(f"   PyTorch:  {torch.__version__}")
try:
    dt_version = importlib.metadata.version('deeptrack')
except importlib.metadata.PackageNotFoundError:
    dt_version = "Unknown (Package Not Found)"
print(f"   DeepTrack: {dt_version}\n")

# -----------------------------------------------------------------------------
# STEP 3: Mount Google Drive
# -----------------------------------------------------------------------------
print("üìÇ Mounting Google Drive...")
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

BASE_PATH = Path('/content/drive/MyDrive/DeepTrack_Studio')
DATA_PATH = BASE_PATH / 'training_data'
MODEL_PATH = BASE_PATH / 'models'
LOG_PATH = BASE_PATH / 'logs'

for path in [BASE_PATH, DATA_PATH, MODEL_PATH, LOG_PATH]:
    path.mkdir(parents=True, exist_ok=True)

(DATA_PATH / 'videos').mkdir(exist_ok=True)
(DATA_PATH / 'annotations').mkdir(exist_ok=True)

print(f"‚úÖ Google Drive ready!")
print(f"   Base:  {BASE_PATH}\n")

# -----------------------------------------------------------------------------
# STEP 4: Data Loader
# -----------------------------------------------------------------------------
class TrainingDataLoader:
    def __init__(self, data_path):
        self.data_path = Path(data_path)
        self.videos_path = self.data_path / 'videos'
        self.annotations_path = self.data_path / 'annotations'
        self.video_files = []
        self.annotation_files = {}

    def scan_data(self):
        print("üîç Scanning for training data...")
        video_extensions = ['.tif', '.tiff', '.png', '.jpg']
        self.video_files = []
        for ext in video_extensions:
            self.video_files.extend(list(self.videos_path.glob(f"*{ext}")))

        if not self.video_files:
            print(f"‚ùå No videos found!")
            print(f"\nüìù UPLOAD INSTRUCTIONS:")
            print(f"   1. Open Google Drive in new tab")
            print(f"   2. Go to: {self.videos_path}")
            print(f"   3. Upload your . tif video files")
            print(f"   4. Return here and re-run this cell")
            return False

        print(f"‚úÖ Found {len(self.video_files)} video(s)")

        self.annotation_files = {}
        for video_path in self.video_files:
            annotation_path = self.annotations_path / f"{video_path.stem}_particles.csv"
            if annotation_path. exists():
                self.annotation_files[video_path. stem] = annotation_path

        print(f"üìã Found {len(self.annotation_files)} annotation(s)")
        for i, video_path in enumerate(self.video_files, 1):
            status = "‚úÖ" if video_path.stem in self.annotation_files else "‚ö†Ô∏è (no annotation)"
            print(f"   {i}. {video_path.name} {status}")

        return True

    def load_video(self, video_path):
        video = skio.imread(str(video_path))
        if video.ndim == 2:
            video = video[np.newaxis, ...]
        elif video.ndim == 4:
            video = video[:, 0, : , :]
        if video.max() > 0:
            video = video. astype(np.float32) / video.max()
        return video

    def load_annotations(self, video_stem):
        if video_stem not in self.annotation_files:
            return None
        df = pd.read_csv(self.annotation_files[video_stem])
        return df

    def create_ground_truth_masks(self, annotations, shape, radius=3):
        num_frames, height, width = shape
        masks = np.zeros(shape, dtype=np.float32)
        yy, xx = np.ogrid[: height, :width]
        for frame_idx in range(num_frames):
            frame_particles = annotations[annotations['frame'] == frame_idx]
            for _, particle in frame_particles.iterrows():
                x, y = particle['x'], particle['y']
                distance = (xx - x)**2 + (yy - y)**2
                masks[frame_idx][distance <= radius**2] = 1.0
        return masks

    def preview_data(self, video_idx=0, frame_idx=0):
        if not self.video_files:
            return
        video_path = self.video_files[video_idx]
        video = self.load_video(video_path)
        annotations = self.load_annotations(video_path.stem)

        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        axes[0].imshow(video[frame_idx], cmap='gray')
        axes[0].set_title(f"{video_path.name} - Frame {frame_idx}")
        axes[0].axis('off')

        axes[1].imshow(video[frame_idx], cmap='gray')
        if annotations is not None:
            frame_particles = annotations[annotations['frame'] == frame_idx]
            if not frame_particles.empty:
                axes[1].scatter(frame_particles['x'], frame_particles['y'],
                              c='red', s=50, marker='o', facecolors='none', linewidths=2)
            axes[1].set_title(f"Annotations ({len(frame_particles)} particles)")
        else:
            axes[1].set_title("No annotations")
        axes[1].axis('off')
        plt.tight_layout()
        plt.show()

data_loader = TrainingDataLoader(DATA_PATH)
data_available = data_loader.scan_data()

if data_available:
    print("\nüì∏ Preview:")
    data_loader.preview_data(video_idx=0, frame_idx=0)

# -----------------------------------------------------------------------------
# STEP 5: Configuration
# -----------------------------------------------------------------------------
class TrainingConfig:
    def __init__(self):
        self.widgets = {}
        self.widgets['model_name'] = widgets.Text(value='particle_detector', description='Model Name:')
        self.widgets['architecture'] = widgets. Dropdown(options=['UNet'], value='UNet', description='Architecture:')
        self.widgets['unet_channels'] = widgets.Text(value='16,32,64', description='Channels:')
        self.widgets['epochs'] = widgets.IntSlider(value=30, min=10, max=100, description='Epochs:')
        self.widgets['batch_size'] = widgets. Dropdown(options=[2,4,8], value=4, description='Batch Size:')
        self.widgets['learning_rate'] = widgets.FloatLogSlider(value=1e-4, base=10, min=-6, max=-2, description='Learning Rate:')
        self.widgets['validation_split'] = widgets.FloatSlider(value=0.2, min=0.1, max=0.4, description='Val Split:')
        self.widgets['augmentation'] = widgets.Checkbox(value=True, description='Augmentation')
        self.widgets['particle_radius'] = widgets.IntSlider(value=3, min=1, max=10, description='Particle Radius:')

    def display(self):
        display(HTML("<h3>‚öôÔ∏è Training Configuration</h3>"))
        display(widgets.VBox([
            self.widgets['model_name'],
            self.widgets['architecture'],
            self.widgets['unet_channels'],
            self.widgets['epochs'],
            self.widgets['batch_size'],
            self. widgets['learning_rate'],
            self.widgets['validation_split'],
            self.widgets['augmentation'],
            self.widgets['particle_radius']
        ]))

    def get_config(self):
        return {
            'model':  {
                'name': self.widgets['model_name'].value,
                'architecture': self.widgets['architecture'].value. lower(),
                'unet_channels': [int(x. strip()) for x in self.widgets['unet_channels'].value. split(',')]
            },
            'training':  {
                'epochs': self. widgets['epochs'].value,
                'batch_size': self.widgets['batch_size'].value,
                'learning_rate':  self.widgets['learning_rate']. value,
                'validation_split': self.widgets['validation_split'].value
            },
            'augmentation':  {
                'enabled': self.widgets['augmentation'].value,
                'flip_lr': True,
                'flip_ud':  True,
                'rotate': True,
                'brightness': True
            },
            'data': {
                'particle_radius': self.widgets['particle_radius'].value
            }
        }

config_manager = TrainingConfig()
config_manager.display()

# -----------------------------------------------------------------------------
# STEP 6: Dataset
# -----------------------------------------------------------------------------
class ParticleDataset(Dataset):
    def __init__(self, frames, masks, augmentation_config=None):
        self.frames = frames
        self.masks = masks
        self.aug_config = augmentation_config or {}

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        frame = self.frames[idx]. copy()
        mask = self.masks[idx].copy()

        if self.aug_config. get('enabled', False):
            if self.aug_config. get('flip_lr') and np.random.rand() > 0.5:
                frame = np.fliplr(frame)
                mask = np.fliplr(mask)
            if self.aug_config.get('flip_ud') and np.random.rand() > 0.5:
                frame = np.flipud(frame)
                mask = np.flipud(mask)
            if self.aug_config. get('rotate') and np.random.rand() > 0.5:
                k = np.random.randint(1, 4)
                frame = np.rot90(frame, k)
                mask = np.rot90(mask, k)
            if self.aug_config.get('brightness') and np.random.rand() > 0.5:
                frame = np.clip(frame * np.random.uniform(0.8, 1.2), 0, 1)

        frame = torch.from_numpy(frame).float().unsqueeze(0)
        mask = torch.from_numpy(mask).float().unsqueeze(0)
        return frame, mask

def prepare_datasets(data_loader, config):
    print("\nüì¶ Preparing datasets...")
    all_frames, all_masks = [], []

    for video_path in tqdm(data_loader.video_files, desc="Loading"):
        video = data_loader. load_video(video_path)
        annotations = data_loader.load_annotations(video_path.stem)

        if annotations is not None:
            masks = data_loader.create_ground_truth_masks(annotations, video.shape, config['data']['particle_radius'])
        else:
            masks = np. zeros_like(video)

        all_frames.append(video)
        all_masks.append(masks) # This was the indentation error

    all_frames = np.concatenate(all_frames, axis=0)
    all_masks = np.concatenate(all_masks, axis=0)

    val_split = config['training']['validation_split']
    n_val = int(len(all_frames) * val_split)
    indices = np.random.permutation(len(all_frames))

    train_dataset = ParticleDataset(all_frames[indices[n_val:]], all_masks[indices[n_val:]], config['augmentation'])
    val_dataset = ParticleDataset(all_frames[indices[:n_val]], all_masks[indices[:n_val]])

    train_loader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=config['training']['batch_size'], shuffle=False, num_workers=2)

    print(f"‚úÖ Train:  {len(train_dataset)}, Val: {len(val_dataset)}")
    return train_loader, val_loader

# -----------------------------------------------------------------------------
# STEP 7: Model
# -----------------------------------------------------------------------------
def create_model(config):
    print(f"\nüèóÔ∏è Building {config['model']['architecture']. upper()} model...")
    unet_channels_list = config['model']['unet_channels']
    if not unet_channels_list:
        base_channels = 16
        depth = 3
    else:
        base_channels = unet_channels_list[0]
        # The deeplay.UNet2d constructor expects 'channels' to be a list
        # and infers depth from its length, not a separate 'depth' argument.

    model = dl.UNet2d(
        in_channels=1,
        out_channels=1,
        channels=unet_channels_list # Pass the list directly here
    )
    total_params = sum(p.numel() for p in model.parameters())
    print(f"‚úÖ Model created!  Parameters: {total_params:,}")
    return model

# -----------------------------------------------------------------------------
# STEP 8:  Trainer
# -----------------------------------------------------------------------------
class Trainer:
    def __init__(self, model, config, save_dir):
        self.model = model. to(device)
        self.config = config
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)

        self.optimizer = torch.optim.Adam(model.parameters(), lr=config['training']['learning_rate'])
        self.criterion = nn.BCEWithLogitsLoss()
        self.history = {'train_loss': [], 'val_loss': [], 'train_iou': [], 'val_iou': []}
        self.best_val_loss = float('inf')

    def _calculate_iou(self, pred, target, threshold=0.5):
        pred_binary = (pred > threshold).float()
        target_binary = (target > threshold).float()
        intersection = (pred_binary * target_binary).sum()
        union = pred_binary.sum() + target_binary.sum() - intersection
        return (intersection / union).item() if union > 0 else 1.0

    def train_epoch(self, train_loader):
        self.model.train()
        total_loss = 0
        total_iou = 0

        pbar = tqdm(train_loader, desc="Training")
        for frames, masks in pbar:
            frames, masks = frames.to(device), masks.to(device)

            self.optimizer.zero_grad()
            outputs = self.model(frames)
            loss = self.criterion(outputs, masks)
            loss.backward()
            self.optimizer.step()

            with torch.no_grad():
                iou = self._calculate_iou(torch.sigmoid(outputs), masks)

            total_loss += loss.item()
            total_iou += iou
            pbar.set_postfix({'loss': f"{loss.item():.4f}", 'iou': f"{iou:.4f}"})

        return total_loss / len(train_loader), total_iou / len(train_loader)

    def validate(self, val_loader):
        self.model.eval()
        total_loss = 0
        total_iou = 0

        with torch.no_grad():
            for frames, masks in tqdm(val_loader, desc="Validation"):
                frames, masks = frames.to(device), masks.to(device)
                outputs = self.model(frames)
                loss = self.criterion(outputs, masks)
                iou = self._calculate_iou(torch.sigmoid(outputs), masks)
                total_loss += loss.item()
                total_iou += iou

        return total_loss / len(val_loader), total_iou / len(val_loader)

    def save_checkpoint(self, epoch, is_best=False):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'history': self.history,
            'config': self.config
        }

        if epoch % 10 == 0:
            torch.save(checkpoint, self.save_dir / f"checkpoint_epoch{epoch}.pth")

        if is_best:
            torch.save(checkpoint, self. save_dir / "best_model.pth")
            print(f"   üíæ Best model saved (loss: {self.best_val_loss:.4f})")

    def plot_progress(self):
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))

        axes[0].plot(self.history['train_loss'], label='Train', marker='o')
        axes[0].plot(self.history['val_loss'], label='Val', marker='s')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Loss')
        axes[0].legend()
        axes[0].grid(True)

        axes[1].plot(self.history['train_iou'], label='Train', marker='o')
        axes[1].plot(self.history['val_iou'], label='Val', marker='s')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('IoU')
        axes[1].set_title('IoU')
        axes[1].legend()
        axes[1].grid(True)

        plt.tight_layout()
        plt.savefig(self.save_dir / 'training_progress.png', dpi=150)
        plt.show()

    def train(self, train_loader, val_loader, epochs):
        print(f"\nüöÄ Starting training for {epochs} epochs.. .\n")

        for epoch in range(1, epochs + 1):
            print(f"\n{'='*60}")
            print(f"Epoch {epoch}/{epochs}")
            print(f"{'='*60}")

            train_loss, train_iou = self.train_epoch(train_loader)
            val_loss, val_iou = self.validate(val_loader)

            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['train_iou'].append(train_iou)
            self.history['val_iou'].append(val_iou)

            print(f"\nüìä Summary:")
            print(f"   Train - Loss: {train_loss:.4f}, IoU: {train_iou:.4f}")
            print(f"   Val   - Loss: {val_loss:. 4f}, IoU: {val_iou:.4f}")

            is_best = val_loss < self.best_val_loss
            if is_best:
                self. best_val_loss = val_loss

            self.save_checkpoint(epoch, is_best)

            if epoch % 5 == 0:
                clear_output(wait=True)
                self.plot_progress()

        print(f"\n‚úÖ Training complete! Best val loss: {self.best_val_loss:.4f}")
        self.plot_progress()

# -----------------------------------------------------------------------------
# STEP 9: Model Exporter
# -----------------------------------------------------------------------------
class ModelExporter:
    def __init__(self, model_path):
        self.model_path = Path(model_path)

    def generate_version(self):
        return f"v_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

    def export_model(self, trainer, config):
        version = self.generate_version()
        export_dir = self.model_path / version
        export_dir.mkdir(parents=True, exist_ok=True)

        print(f"\nüì¶ Exporting model:  {version}")

        # Save weights
        torch.save(trainer. model.state_dict(), export_dir / "weights.pth")

        # Create metadata
        metadata = {
            "model_name": config['model']['name'],
            "version": version,
            "created_at": datetime.now().isoformat(),
            "architecture": {
                "type": config['model']['architecture'],
                "input_shape": [1, 512, 512],
                "unet_channels": config['model']['unet_channels'],
                "out_channels": 1
            },
            "training":  config['training'],
            "performance": {
                "final_train_loss": trainer.history['train_loss'][-1],
                "final_val_loss": trainer.history['val_loss'][-1],
                "best_val_loss": trainer.best_val_loss,
                "final_train_iou": trainer.history['train_iou'][-1],
                "final_val_iou": trainer.history['val_iou'][-1],
                "best_val_iou": max(trainer.history['val_iou'])
            },
            "data_info": {
                "num_videos": len(data_loader.video_files),
                "augmentation":  config['augmentation']['enabled']
            },
            "compatibility": {
                "deeptrack_version": dt_version,
                "torch_version": torch.__version__,
                "python_version": f"{sys.version_info.major}.{sys.version_info. minor}"
            }
        }

        with open(export_dir / "metadata.json", 'w') as f:
            json.dump(metadata, f, indent=2)

        # Save config
        with open(export_dir / "config.json", 'w') as f:
            json.dump(config, f, indent=2)

        # Copy plot
        shutil.copy(trainer.save_dir / 'training_progress.png', export_dir / 'training_progress.png')

        # Create model card
        card = f"""# Model:  {config['model']['name']}

**Version:** {version}
**Created:** {metadata['created_at']}
**Architecture:** {metadata['architecture']['type']. upper()}

## Performance

| Metric | Train | Validation | Best |
|--------|-------|------------|------|
| Loss | {metadata['performance']['final_train_loss']:.4f} | {metadata['performance']['final_val_loss']:.4f} | {metadata['performance']['best_val_loss']:.4f} |
| IoU | {metadata['performance']['final_train_iou']:. 4f} | {metadata['performance']['final_val_iou']:.4f} | {metadata['performance']['best_val_iou']:.4f} |

## Architecture

- Type: {metadata['architecture']['type']. upper()}
- Channels: {metadata['architecture']['unet_channels']}
- Input:  {metadata['architecture']['input_shape']}

## Training

- Epochs: {config['training']['epochs']}
- Batch Size: {config['training']['batch_size']}
- Learning Rate:  {config['training']['learning_rate']}

## Usage

Download this model to your local DeepTrack MPT Studio app!

```python
from src.engines.ai_engine import DeepTrackEngine
engine = DeepTrackEngine()
engine.load_model("weights.pth", "metadata.json")
```
"""

        with open(export_dir / "MODEL_CARD.md", 'w') as f:
            f.write(card)

        print(f"‚úÖ Model exported to {export_dir}")
        return export_dir


# At the very end of CELL 1:
print("\n" + "=" * 70)
print("‚úÖ ALL COMPONENTS LOADED SUCCESSFULLY!")
print("=" * 70)
print("\nüìã NEXT STEPS:")
print("   1. Ensure videos are uploaded to Google Drive")
print("   2. Configure training parameters above")
print("   3. Run the training cell below")
print("=" * 70)

In [None]:
# ============================================================================
# ‚ÄÅ EXECUTE TRAINING PIPELINE
# ============================================================================
# Run this cell to start training!
# ============================================================================

# Get configuration from widgets
config = config_manager.get_config()

print("=" * 70)
print("üöÄ STARTING TRAINING PIPELINE")
print("=" * 70)
print(f"\n‚öôÔ∏è  Configuration:")
print(f"   Model: {config['model']['name']}")
print(f"   Architecture: {config['model']['architecture']. upper()}")
print(f"   Epochs: {config['training']['epochs']}")
print(f"   Batch Size: {config['training']['batch_size']}")
print(f"   Learning Rate: {config['training']['learning_rate']}")
print("=" * 70)

# Prepare datasets
train_loader, val_loader = prepare_datasets(data_loader, config)

# Create model
model = create_model(config)

# Create trainer
trainer = Trainer(model, config, LOG_PATH / 'current_training')

# Start training
trainer.train(train_loader, val_loader, config['training']['epochs'])

# Export model
exporter = ModelExporter(MODEL_PATH)
export_dir = exporter.export_model(trainer, config)

print("\n" + "=" * 70)
print("üéâ TRAINING COMPLETE!")
print("=" * 70)
print(f"\nüì¶ Model saved to: {export_dir}")
print(f"\nüì• Download your trained model from Google Drive!")
print("=" * 70)