# COFNet Training on SkyWatch Dataset

**Continuous-scale Object Field Network** - A novel architecture combining:
- Mamba-SSM backbone for efficient sequence modeling
- Continuous Scale Field (CSF) for scale-equivariant features
- Diffusion-based box refinement
- Scale-Diffusion Self-Supervision (SDSS)

This notebook trains COFNet on the SkyWatch dataset (planes, wildlife, meteorites in night sky images).

## 1. Setup Environment

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Clone COFNet repository
!git clone https://github.com/Deibler/COFNet.git
%cd COFNet

In [None]:
# Install dependencies
!pip install torch torchvision --upgrade
!pip install mamba-ssm causal-conv1d
!pip install huggingface_hub wandb
!pip install pycocotools
!pip install einops timm

## 2. Download SkyWatch Dataset from HuggingFace

In [None]:
from huggingface_hub import hf_hub_download, snapshot_download
import zipfile
import os

# Download the processed dataset
print("Downloading SkyWatch dataset from HuggingFace...")

# Try to download the full dataset
try:
    dataset_path = snapshot_download(
        repo_id="Deibler/skywatch-dataset",
        repo_type="dataset",
        local_dir="./data"
    )
    print(f"Dataset downloaded to: {dataset_path}")
except Exception as e:
    print(f"Error downloading: {e}")
    print("Trying alternative download...")
    
    # Alternative: download zip file
    zip_path = hf_hub_download(
        repo_id="Deibler/skywatch-dataset",
        filename="skywatch_processed.zip",
        repo_type="dataset"
    )
    
    # Extract
    os.makedirs("./data", exist_ok=True)
    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall("./data")
    print("Dataset extracted to ./data")

In [None]:
# Verify dataset structure
import os

def count_files(path):
    if os.path.exists(path):
        return len([f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))])
    return 0

# Check for different possible structures
possible_paths = [
    "./data/processed",
    "./data",
    "./data/skywatch"
]

DATA_ROOT = None
for path in possible_paths:
    train_path = os.path.join(path, "train", "images")
    if os.path.exists(train_path):
        DATA_ROOT = path
        break

if DATA_ROOT:
    print(f"Dataset root: {DATA_ROOT}")
    print(f"Train images: {count_files(os.path.join(DATA_ROOT, 'train', 'images'))}")
    print(f"Valid images: {count_files(os.path.join(DATA_ROOT, 'valid', 'images'))}")
    print(f"Test images: {count_files(os.path.join(DATA_ROOT, 'test', 'images'))}")
else:
    print("Dataset structure:")
    !find ./data -type f | head -20

## 3. Import COFNet Components

In [None]:
import sys
sys.path.insert(0, './src')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

from models.cofnet import COFNet
from data.coco_dataset import COCODataset, collate_fn
from training.ssl import ScaleContrastiveLearning, CrossScaleReconstruction

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 4. Configuration

In [None]:
# Training Configuration
CONFIG = {
    # Data
    'data_root': DATA_ROOT or './data/processed',
    'image_size': (512, 512),  # Resize images
    'batch_size': 4,  # Adjust based on GPU memory
    'num_workers': 2,
    
    # Model
    'num_classes': 3,  # Plane, WildLife, meteorite
    'backbone_dims': [64, 128, 256, 512],  # Feature dimensions
    'csf_dim': 128,  # CSF output dimension
    'num_queries': 100,  # Detection queries
    'diffusion_steps_train': 100,
    'diffusion_steps_infer': 10,
    
    # Training
    'epochs': 50,
    'lr': 1e-4,
    'weight_decay': 0.01,
    'use_ssl': True,  # Enable self-supervised learning
    'ssl_weight': 0.1,  # Weight for SSL losses
    
    # Logging
    'log_interval': 10,
    'save_interval': 5,
    'use_wandb': False,  # Set to True to enable W&B logging
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 5. Load Dataset

In [None]:
from pathlib import Path

data_root = Path(CONFIG['data_root'])

# Create datasets
train_dataset = COCODataset(
    img_folder=str(data_root / "train" / "images"),
    ann_file=str(data_root / "train_coco.json"),
    image_size=CONFIG['image_size'],
)

val_dataset = COCODataset(
    img_folder=str(data_root / "valid" / "images"),
    ann_file=str(data_root / "valid_coco.json"),
    image_size=CONFIG['image_size'],
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=CONFIG['num_workers'],
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=CONFIG['num_workers'],
    pin_memory=True,
)

In [None]:
# Visualize a sample
import matplotlib.pyplot as plt
import matplotlib.patches as patches

sample = train_dataset[0]
img = sample['image'].permute(1, 2, 0).numpy()
boxes = sample['boxes']
labels = sample['labels']

CLASS_NAMES = ['Plane', 'WildLife', 'meteorite']
COLORS = ['red', 'green', 'blue']

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(img)

H, W = img.shape[:2]
for box, label in zip(boxes, labels):
    # Convert from [cx, cy, w, h] normalized to [x, y, w, h] pixels
    cx, cy, w, h = box.numpy()
    x = (cx - w/2) * W
    y = (cy - h/2) * H
    w_px = w * W
    h_px = h * H
    
    rect = patches.Rectangle(
        (x, y), w_px, h_px,
        linewidth=2,
        edgecolor=COLORS[label],
        facecolor='none'
    )
    ax.add_patch(rect)
    ax.text(x, y-5, CLASS_NAMES[label], color=COLORS[label], fontsize=12)

ax.set_title('Sample Training Image')
plt.show()

## 6. Create Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create COFNet model
model = COFNet(
    num_classes=CONFIG['num_classes'],
    backbone_dims=CONFIG['backbone_dims'],
    csf_dim=CONFIG['csf_dim'],
    num_queries=CONFIG['num_queries'],
    diffusion_steps_train=CONFIG['diffusion_steps_train'],
    diffusion_steps_infer=CONFIG['diffusion_steps_infer'],
).to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {num_params:,}")
print(f"Trainable parameters: {num_trainable:,}")

In [None]:
# Create SSL modules (optional but recommended)
ssl_modules = None
if CONFIG['use_ssl']:
    print("Creating SSL modules...")
    scl = ScaleContrastiveLearning(
        feature_dim=CONFIG['csf_dim'],
        num_scales=8
    ).to(device)
    
    csr = CrossScaleReconstruction(
        feature_dim=CONFIG['csf_dim'],
        num_scales=8
    ).to(device)
    
    ssl_modules = (scl, csr)
    print("SSL modules created: ScaleContrastiveLearning, CrossScaleReconstruction")

## 7. Training Functions

In [None]:
def compute_detection_losses(outputs, targets, device):
    """Compute detection losses (classification + box regression)."""
    pred_boxes = outputs['pred_boxes']
    pred_logits = outputs['pred_logits']
    
    B = pred_boxes.shape[0]
    total_cls_loss = torch.tensor(0.0, device=device)
    total_box_loss = torch.tensor(0.0, device=device)
    num_targets = 0
    
    for b in range(B):
        gt_boxes = targets[b]['boxes'].to(device)
        gt_labels = targets[b]['labels'].to(device)
        
        if len(gt_boxes) == 0:
            continue
        
        pred_b = pred_boxes[b]  # [num_queries, 4]
        logits_b = pred_logits[b]  # [num_queries, num_classes]
        
        # Simple matching: for each GT, find closest prediction
        for gt_box, gt_label in zip(gt_boxes, gt_labels):
            # Find closest prediction by L1 distance
            dists = (pred_b - gt_box.unsqueeze(0)).abs().sum(dim=-1)
            closest_idx = dists.argmin()
            
            # Box loss (L1 + GIoU)
            box_loss = nn.functional.l1_loss(pred_b[closest_idx], gt_box)
            total_box_loss = total_box_loss + box_loss
            
            # Classification loss
            cls_loss = nn.functional.cross_entropy(
                logits_b[closest_idx].unsqueeze(0),
                gt_label.unsqueeze(0)
            )
            total_cls_loss = total_cls_loss + cls_loss
            num_targets += 1
    
    # Normalize by number of targets
    if num_targets > 0:
        total_cls_loss = total_cls_loss / num_targets
        total_box_loss = total_box_loss / num_targets
    
    return total_cls_loss, total_box_loss


def train_epoch(model, loader, optimizer, ssl_modules, device, epoch, config):
    """Train for one epoch."""
    model.train()
    
    total_loss = 0.0
    total_cls_loss = 0.0
    total_box_loss = 0.0
    total_diff_loss = 0.0
    total_ssl_loss = 0.0
    num_batches = 0
    
    for batch_idx, batch in enumerate(loader):
        images = batch['images'].to(device)
        targets = batch['targets']
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images, targets)
        
        # Detection losses
        diffusion_loss = outputs.get('loss_diffusion', torch.tensor(0.0, device=device))
        cls_loss, box_loss = compute_detection_losses(outputs, targets, device)
        
        # SSL losses (every other batch to save compute)
        ssl_loss = torch.tensor(0.0, device=device)
        if ssl_modules is not None and batch_idx % 2 == 0:
            scl, csr = ssl_modules
            scl_loss = scl(model, images)['total']
            csr_loss = csr(model, images)['total']
            ssl_loss = (scl_loss + csr_loss) * config['ssl_weight']
        
        # Total loss
        loss = diffusion_loss + cls_loss + 5.0 * box_loss + ssl_loss
        
        # Backward
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        # Update
        optimizer.step()
        
        # Track losses
        total_loss += loss.item()
        total_cls_loss += cls_loss.item()
        total_box_loss += box_loss.item()
        total_diff_loss += diffusion_loss.item()
        total_ssl_loss += ssl_loss.item()
        num_batches += 1
        
        # Log progress
        if (batch_idx + 1) % config['log_interval'] == 0:
            avg_loss = total_loss / num_batches
            print(f"  [{batch_idx + 1}/{len(loader)}] "
                  f"loss={avg_loss:.4f} "
                  f"(cls={cls_loss.item():.3f}, "
                  f"box={box_loss.item():.3f}, "
                  f"diff={diffusion_loss.item():.3f})")
    
    return {
        'loss': total_loss / max(num_batches, 1),
        'cls_loss': total_cls_loss / max(num_batches, 1),
        'box_loss': total_box_loss / max(num_batches, 1),
        'diff_loss': total_diff_loss / max(num_batches, 1),
        'ssl_loss': total_ssl_loss / max(num_batches, 1),
    }


@torch.no_grad()
def validate(model, loader, device):
    """Validate the model."""
    model.eval()
    
    total_box_error = 0.0
    total_cls_correct = 0
    num_boxes = 0
    
    for batch in loader:
        images = batch['images'].to(device)
        targets = batch['targets']
        
        outputs = model(images)
        pred_boxes = outputs['pred_boxes']
        pred_logits = outputs['pred_logits']
        
        for b, target in enumerate(targets):
            gt_boxes = target['boxes'].to(device)
            gt_labels = target['labels'].to(device)
            
            if len(gt_boxes) == 0:
                continue
            
            for gt_box, gt_label in zip(gt_boxes, gt_labels):
                # Find closest prediction
                dists = (pred_boxes[b] - gt_box.unsqueeze(0)).abs().sum(dim=-1)
                closest_idx = dists.argmin()
                min_dist = dists.min()
                
                total_box_error += min_dist.item()
                
                # Check classification
                pred_label = pred_logits[b, closest_idx].argmax()
                if pred_label == gt_label:
                    total_cls_correct += 1
                
                num_boxes += 1
    
    return {
        'box_error': total_box_error / max(num_boxes, 1),
        'cls_accuracy': total_cls_correct / max(num_boxes, 1),
    }

## 8. Training Loop

In [None]:
import time
from pathlib import Path

# Create output directory
output_dir = Path('./output')
output_dir.mkdir(exist_ok=True)

# Optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=CONFIG['lr'],
    weight_decay=CONFIG['weight_decay']
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=CONFIG['epochs'],
    eta_min=1e-6
)

# Training history
history = {
    'train_loss': [],
    'val_box_error': [],
    'val_cls_accuracy': [],
    'lr': [],
}

best_val_error = float('inf')

print("\n" + "="*60)
print("Starting COFNet Training")
print("="*60)
print(f"Epochs: {CONFIG['epochs']}")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Learning rate: {CONFIG['lr']}")
print(f"SSL enabled: {CONFIG['use_ssl']}")
print("="*60 + "\n")

In [None]:
# Main training loop
for epoch in range(CONFIG['epochs']):
    start_time = time.time()
    
    print(f"\nEpoch {epoch + 1}/{CONFIG['epochs']}")
    print("-" * 40)
    
    # Train
    train_metrics = train_epoch(
        model, train_loader, optimizer, ssl_modules, device, epoch, CONFIG
    )
    
    # Validate
    val_metrics = validate(model, val_loader, device)
    
    # Update scheduler
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    
    elapsed = time.time() - start_time
    
    # Log metrics
    print(f"\n  Train Loss: {train_metrics['loss']:.4f}")
    print(f"    - Classification: {train_metrics['cls_loss']:.4f}")
    print(f"    - Box Regression: {train_metrics['box_loss']:.4f}")
    print(f"    - Diffusion: {train_metrics['diff_loss']:.4f}")
    if CONFIG['use_ssl']:
        print(f"    - SSL: {train_metrics['ssl_loss']:.4f}")
    print(f"  Val Box Error: {val_metrics['box_error']:.4f}")
    print(f"  Val Cls Accuracy: {val_metrics['cls_accuracy']:.2%}")
    print(f"  Learning Rate: {current_lr:.6f}")
    print(f"  Time: {elapsed:.1f}s")
    
    # Track history
    history['train_loss'].append(train_metrics['loss'])
    history['val_box_error'].append(val_metrics['box_error'])
    history['val_cls_accuracy'].append(val_metrics['cls_accuracy'])
    history['lr'].append(current_lr)
    
    # Save best model
    if val_metrics['box_error'] < best_val_error:
        best_val_error = val_metrics['box_error']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_box_error': best_val_error,
            'config': CONFIG,
        }, output_dir / 'best_model.pth')
        print(f"  -> Saved best model (box_error: {best_val_error:.4f})")
    
    # Periodic checkpoint
    if (epoch + 1) % CONFIG['save_interval'] == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'history': history,
            'config': CONFIG,
        }, output_dir / f'checkpoint_epoch_{epoch + 1}.pth')

print("\n" + "="*60)
print("Training Complete!")
print(f"Best validation box error: {best_val_error:.4f}")
print("="*60)

## 9. Training Visualization

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Training loss
axes[0, 0].plot(history['train_loss'])
axes[0, 0].set_title('Training Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].grid(True)

# Validation box error
axes[0, 1].plot(history['val_box_error'])
axes[0, 1].set_title('Validation Box Error')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Box Error')
axes[0, 1].grid(True)

# Classification accuracy
axes[1, 0].plot(history['val_cls_accuracy'])
axes[1, 0].set_title('Validation Classification Accuracy')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Accuracy')
axes[1, 0].grid(True)

# Learning rate
axes[1, 1].plot(history['lr'])
axes[1, 1].set_title('Learning Rate')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('LR')
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig(output_dir / 'training_curves.png', dpi=150)
plt.show()

## 10. Inference & Visualization

In [None]:
# Load best model
checkpoint = torch.load(output_dir / 'best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Loaded best model from epoch {checkpoint['epoch']}")

In [None]:
@torch.no_grad()
def visualize_predictions(model, dataset, device, num_samples=4, conf_threshold=0.3):
    """Visualize model predictions on sample images."""
    CLASS_NAMES = ['Plane', 'WildLife', 'meteorite']
    COLORS = ['red', 'green', 'blue']
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 16))
    axes = axes.flatten()
    
    indices = torch.randperm(len(dataset))[:num_samples]
    
    for ax_idx, sample_idx in enumerate(indices):
        sample = dataset[sample_idx.item()]
        img = sample['image'].unsqueeze(0).to(device)
        
        # Get predictions
        outputs = model(img)
        pred_boxes = outputs['pred_boxes'][0]  # [num_queries, 4]
        pred_logits = outputs['pred_logits'][0]  # [num_queries, num_classes]
        
        # Get confidences
        pred_probs = torch.softmax(pred_logits, dim=-1)
        pred_scores, pred_labels = pred_probs.max(dim=-1)
        
        # Filter by confidence
        mask = pred_scores > conf_threshold
        
        # Plot image
        img_np = sample['image'].permute(1, 2, 0).numpy()
        axes[ax_idx].imshow(img_np)
        
        H, W = img_np.shape[:2]
        
        # Plot predictions
        for box, label, score in zip(pred_boxes[mask], pred_labels[mask], pred_scores[mask]):
            box = box.cpu().numpy()
            label = label.cpu().item()
            score = score.cpu().item()
            
            cx, cy, w, h = box
            x = (cx - w/2) * W
            y = (cy - h/2) * H
            w_px = w * W
            h_px = h * H
            
            rect = patches.Rectangle(
                (x, y), w_px, h_px,
                linewidth=2,
                edgecolor=COLORS[label],
                facecolor='none'
            )
            axes[ax_idx].add_patch(rect)
            axes[ax_idx].text(
                x, y-5,
                f"{CLASS_NAMES[label]}: {score:.2f}",
                color=COLORS[label],
                fontsize=10,
                fontweight='bold'
            )
        
        # Plot ground truth (dashed)
        gt_boxes = sample['boxes']
        gt_labels = sample['labels']
        for box, label in zip(gt_boxes, gt_labels):
            box = box.numpy()
            label = label.item()
            
            cx, cy, w, h = box
            x = (cx - w/2) * W
            y = (cy - h/2) * H
            w_px = w * W
            h_px = h * H
            
            rect = patches.Rectangle(
                (x, y), w_px, h_px,
                linewidth=2,
                edgecolor=COLORS[label],
                facecolor='none',
                linestyle='--'
            )
            axes[ax_idx].add_patch(rect)
        
        axes[ax_idx].set_title(f'Sample {sample_idx.item()}')
        axes[ax_idx].axis('off')
    
    plt.suptitle('Predictions (solid) vs Ground Truth (dashed)', fontsize=14)
    plt.tight_layout()
    plt.savefig(output_dir / 'predictions.png', dpi=150)
    plt.show()

visualize_predictions(model, val_dataset, device)

## 11. Save Final Model

In [None]:
# Save final model with all components
final_checkpoint = {
    'model_state_dict': model.state_dict(),
    'config': CONFIG,
    'history': history,
    'best_val_error': best_val_error,
    'num_epochs': CONFIG['epochs'],
}

if ssl_modules is not None:
    scl, csr = ssl_modules
    final_checkpoint['scl_state_dict'] = scl.state_dict()
    final_checkpoint['csr_state_dict'] = csr.state_dict()

torch.save(final_checkpoint, output_dir / 'cofnet_skywatch_final.pth')
print(f"Final model saved to: {output_dir / 'cofnet_skywatch_final.pth'}")

In [None]:
# Download the trained model (for Colab)
from google.colab import files
files.download(str(output_dir / 'cofnet_skywatch_final.pth'))
files.download(str(output_dir / 'training_curves.png'))
files.download(str(output_dir / 'predictions.png'))

## Summary

This notebook trained COFNet on the SkyWatch dataset for detecting:
- **Planes** - Aircraft in night sky images
- **Wildlife** - Birds and other animals
- **Meteorites** - Rare meteor streaks

### Key Components:
1. **Mamba-SSM Backbone** - Efficient state-space modeling for visual features
2. **Continuous Scale Field** - Scale-equivariant feature representations
3. **Diffusion Box Refiner** - Iterative refinement of bounding boxes
4. **SDSS** - Self-supervised learning exploiting the architecture

### Next Steps:
- Fine-tune hyperparameters (learning rate, batch size, epochs)
- Experiment with different backbone sizes
- Add data augmentation
- Use full SDSS pretraining before supervised training