# Data Extraction and Training with DEGIS Package

This notebook demonstrates how to use the DEGIS package to:
1. Extract CLIP embeddings from images
2. Generate color histograms and edge maps
3. Train color disentanglement models

Based on the logic from `main.py` but using the new package structure.


## 1. Setup and Imports


In [None]:
# Install the package in development mode if needed
# !pip install -e .

import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from multiprocessing import cpu_count
import os
import time
import platform
import psutil
import shutil

# Import the DEGIS package
import degis
from degis.data.dataset import UnifiedImageDataset
from degis.config import CSV_PATH, BATCH_SIZE, HF_XL_EMBEDDINGS_TARGET_PATH, COLOR_HIST_PATH_HCL_514, EDGE_MAPS_PATH


## 2. System Profiling


In [None]:
def print_system_profile():
    print("=== SYSTEM PROFILE ===")
    print("Python:", platform.python_version())
    print("PyTorch:", torch.__version__)
    print("CPU cores:", psutil.cpu_count(logical=True))
    vm = psutil.virtual_memory()
    print(f"RAM: {vm.total/1e9:.1f} GB, free {vm.available/1e9:.1f} GB")
    
    # Check if /data exists, otherwise check current directory
    if os.path.exists("/data"):
        du = shutil.disk_usage("/data")
        print(f"/data disk: total {du.total/1e9:.1f} GB, free {du.free/1e9:.1f} GB")
    else:
        du = shutil.disk_usage(".")
        print(f"Current disk: total {du.total/1e9:.1f} GB, free {du.free/1e9:.1f} GB")
    
    print("CUDA available:", torch.cuda.is_available())
    if torch.cuda.is_available():
        i = torch.cuda.current_device()
        print("GPU:", torch.cuda.get_device_name(i))
        print(f"VRAM total: {torch.cuda.get_device_properties(i).total_memory/1e9:.1f} GB")
    print("======================")

print_system_profile()


## 3. Load Dataset and Create Data Loader


In [None]:
# Load the dataset
df = pd.read_csv(CSV_PATH)
print(f"Dataset loaded: {len(df)} images")
print(f"Columns: {df.columns.tolist()}")
print(f"\nFirst few rows:")
print(df.head())

# Create dataset
dataset = UnifiedImageDataset(
    df.rename(columns={"local_path": "file_path"}), 
    mode="file_df",
    size=(224, 224)
)

print(f"\nDataset created with {len(dataset)} samples")

# Create data loader with optimal settings
num_cpu = cpu_count()
num_workers = min(32, max(8, num_cpu // 8))

loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6,
    pin_memory_device="cuda" if torch.cuda.is_available() else None,
)

print(f"\nDataLoader created with {len(loader)} batches")
print(f"Batch size: {BATCH_SIZE}")
print(f"Number of workers: {num_workers}")


## 4. Generate Features and Train Model


In [None]:
# Generate CLIP embeddings using the package
print("Generating CLIP embeddings...")
embeddings = degis.generate_xl_embeddings(
    csv_path=CSV_PATH,
    output_path=HF_XL_EMBEDDINGS_TARGET_PATH,
    batch_size=BATCH_SIZE,
    num_workers=num_workers,
    force_recompute=True  # Set to False to skip if already exists
)
print(f"✓ Generated embeddings with shape: {embeddings.shape}")

# Generate color histograms
print("\nGenerating color histograms...")
histograms = degis.generate_hcl_histograms(
    loader=loader,
    output_path=COLOR_HIST_PATH_HCL_514,
    bins=8,
    force_recompute=True
)
print(f"✓ Generated HCL histograms with shape: {histograms.shape}")

# Generate edge maps
print("\nGenerating edge maps...")
edge_maps = degis.generate_edge_maps(
    loader=loader,
    output_path=EDGE_MAPS_PATH,
    method="canny",
    force_recompute=True
)
print(f"✓ Generated edge maps with shape: {edge_maps.shape}")

# Train the color disentanglement model
print("\nTraining color disentanglement model...")
results = degis.train_color_model(
    embeddings_path=HF_XL_EMBEDDINGS_TARGET_PATH,
    histograms_path=COLOR_HIST_PATH_HCL_514,
    hist_kind="hcl514",
    epochs=200,
    batch_size=128,
    val_batch_size=256,
    lr=1e-3,
    weight_decay=1e-2,
    blur=0.05,
    lambda_ortho=0.1,
    top_k=None,
    weighting=False,
)

print(f"\n✓ Training complete!")
print(f"Output directory: {results['output_dir']}")
