<p align="center">
  <img src="docs/training_sequence_diagram.png" alt="training diagram" width="1000"/>
</p>

# Data Extraction and Training with DEGIS Package

This notebook demonstrates how to use the DEGIS package to:
1. Extract CLIP embeddings from images
2. Generate color histograms and edge maps
3. Train color disentanglement models

In [None]:
# =============================================================================
# ENVIRONMENT CHECK - Run this cell first!
# =============================================================================
# This cell verifies that the degis environment is properly set up
# Make sure you've run ./setup_server_fixed.sh first!

import sys
import os

# Check if we're in the right environment
if 'degis-env' in sys.executable:
    print("DEGIS environment is active")
    print(f"Python: {sys.executable}")
else:
    print("Warning: DEGIS environment not detected")
    print("Please run: ./setup_server_fixed.sh")
    print("Then activate: source degis-env/bin/activate")

# Check if DEGIS package is available
try:
    import degis
    print("DEGIS package is available")
except ImportError:
    print("DEGIS package not found")
    print("   Please run: ./setup_server_fixed.sh")

print("\n Ready to start training!")


## 1. Setup and Imports


In [None]:
# =============================================================================
# IMPORTS - Run this after the setup cell above
# =============================================================================

import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from multiprocessing import cpu_count
import os
import time
import platform
import psutil
import shutil

# Import DEGIS package components
import degis
from degis import generate_clip_embeddings, generate_xl_embeddings
from degis import generate_color_histograms, generate_edge_maps
from degis.training import train_color_model
from degis.data.dataset import UnifiedImageDataset
from degis.shared.config import CSV_PATH, BATCH_SIZE, EMBEDDINGS_TARGET_PATH

## 2. System Profiling


In [None]:
def print_system_profile():
    print("=== SYSTEM PROFILE ===")
    print("Python:", platform.python_version())
    print("PyTorch:", torch.__version__)
    print("CPU cores:", psutil.cpu_count(logical=True))
    vm = psutil.virtual_memory()
    print(f"RAM: {vm.total/1e9:.1f} GB, free {vm.available/1e9:.1f} GB")
    
    # Check if /data exists, otherwise check current directory
    if os.path.exists("/data"):
        du = shutil.disk_usage("/data")
        print(f"/data disk: total {du.total/1e9:.1f} GB, free {du.free/1e9:.1f} GB")
    else:
        du = shutil.disk_usage(".")
        print(f"Current disk: total {du.total/1e9:.1f} GB, free {du.free/1e9:.1f} GB")
    
    print("CUDA available:", torch.cuda.is_available())
    if torch.cuda.is_available():
        i = torch.cuda.current_device()
        print("GPU:", torch.cuda.get_device_name(i))
        print(f"VRAM total: {torch.cuda.get_device_properties(i).total_memory/1e9:.1f} GB")
    print("======================")

print_system_profile()


## 3. Load Dataset and Create Data Loader


In [None]:
csv_path = "/data/thesis/adimagenet_manifest.csv"
embeddings_path = "/data/degis/runs/adimagenet_embeddings.npy"
colour_hist_path = "/data/degis/runs/adimagenet_color_histograms.npy" 
edge_maps_path = "/data/degis/runs/adimagenet_edge_maps.npy"

In [None]:
# Load the dataset
df = pd.read_csv(csv_path)
print(f"Dataset loaded: {len(df)} images")
print(f"Columns: {df.columns.tolist()}")

# Create dataset
dataset = UnifiedImageDataset(
    df, 
    # file_df → uses local file_path column  
    # url_df  → uses image_url column (cloud storage)  
    mode="file_df",
    size=(224, 224)
)

print(f"\nDataset created with {len(dataset)} samples")

# Create data loader with optimal settings
num_cpu = cpu_count()
num_workers = min(32, max(8, num_cpu // 8))

loader = DataLoader(
    dataset,
    batch_size=4096,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6,
    pin_memory_device="cuda" if torch.cuda.is_available() else None,
)

print(f"\nDataLoader created")

# 4. Select the target encoder/diffusion family

In [None]:
# Select the target encoder/diffusion family for extraction and training. 
# Other modules adapt automatically by detecting the embedding size. 
#
# embeddings_size == "base" # 1024-D embeddings (CLIP-ViT-H/14, Stable Diffusion 1.5) 
embeddings_size = "xl"   # 1280-D embeddings (CLIP-ViT-bigG/14, Stable Diffusion XL)

## 4. Generate Features and Train Model


In [None]:
# Generate CLIP embeddings using the package
print("Generating CLIP embeddings...")
if embeddings_size == "xl":  
    embeddings = degis.generate_xl_embeddings(
        csv_path=csv_path,
        output_path=embeddings_path,
        batch_size=512,
        num_workers=num_workers,
        force_recompute=False
    )
else:
    embeddings = degis.generate_clip_embeddings(
        csv_path=csv_path,
        output_path=embeddings_path,
        batch_size=512,
        num_workers=num_workers,
        force_recompute=False
    )
print(f"✓ Generated embeddings with shape: {embeddings.shape}")

# Generate color histograms
print("\nGenerating color histograms...")
histograms = degis.generate_color_histograms(
    loader=loader,
    hist_path=colour_hist_path,
    hist_bins=8,
    force_recompute=False,
    color_space="hcl"
)
print(f"✓ Generated HCL histograms with shape: {histograms.shape}")

# Generate edge maps
print("\nGenerating edge maps...")
edge_maps = degis.generate_edge_maps(
    loader=loader,
    edge_maps_path=edge_maps_path,
    method="canny",
    force_recompute=False
)
print(f"✓ Generated edge maps with shape: {edge_maps.shape}")

# Train the color disentanglement model
print("\nTraining color disentanglement model...")
results = degis.train_color_model(
    embeddings_path=embeddings_path,
    histograms_path=colour_hist_path,
    hist_kind="hcl514",
    epochs=2,
    batch_size=4096,
    val_batch_size=8192,
    lr=1e-3,
    weight_decay=1e-2,
    blur=0.05,
    lambda_ortho=0.1,
    top_k=100,
    weighting=True,
)

print(f"\n✓ Training complete!")
print(f"Output directory: {results['output_dir']}")
