# Data Extraction and Training with DEGIS Package

This notebook demonstrates how to use the DEGIS package to:
1. Extract CLIP embeddings from images
2. Generate color histograms and edge maps
3. Train color disentanglement models

Based on the logic from `main.py` but using the new package structure.


In [None]:
# =============================================================================
# AUTOMATIC SETUP - Run this cell first!
# =============================================================================
# This cell will automatically set up the DEGIS package and dependencies
# You only need to run this once per session

from env_setup import setup_training_environment

# Run the setup for training notebooks
setup_training_environment()


## 1. Setup and Imports


In [None]:
# =============================================================================
# IMPORTS - Run this after the setup cell above
# =============================================================================

import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

# Import DEGIS package components
import degis
from degis.core.embeddings import generate_clip_embeddings, generate_xl_embeddings
from degis.core.features import generate_color_histograms, generate_edge_maps
from degis.core.training import train_color_model, train_edge_model
from degis.data.dataset import UnifiedImageDataset
from degis.config import CSV_PATH, BATCH_SIZE, EMBEDDINGS_TARGET_PATH

print("✅ All imports successful!")
print(f"📊 Using CSV_PATH: {CSV_PATH}")
print(f"📦 Using BATCH_SIZE: {BATCH_SIZE}")
from multiprocessing import cpu_count
import os
import time
import platform
import psutil
import shutil

# Import the DEGIS package
import degis
from degis.data.dataset import UnifiedImageDataset

: 

## 2. System Profiling


In [None]:
def print_system_profile():
    print("=== SYSTEM PROFILE ===")
    print("Python:", platform.python_version())
    print("PyTorch:", torch.__version__)
    print("CPU cores:", psutil.cpu_count(logical=True))
    vm = psutil.virtual_memory()
    print(f"RAM: {vm.total/1e9:.1f} GB, free {vm.available/1e9:.1f} GB")
    
    # Check if /data exists, otherwise check current directory
    if os.path.exists("/data"):
        du = shutil.disk_usage("/data")
        print(f"/data disk: total {du.total/1e9:.1f} GB, free {du.free/1e9:.1f} GB")
    else:
        du = shutil.disk_usage(".")
        print(f"Current disk: total {du.total/1e9:.1f} GB, free {du.free/1e9:.1f} GB")
    
    print("CUDA available:", torch.cuda.is_available())
    if torch.cuda.is_available():
        i = torch.cuda.current_device()
        print("GPU:", torch.cuda.get_device_name(i))
        print(f"VRAM total: {torch.cuda.get_device_properties(i).total_memory/1e9:.1f} GB")
    print("======================")

print_system_profile()


=== SYSTEM PROFILE ===
Python: 3.12.3
PyTorch: 2.8.0+cu128
CPU cores: 256
RAM: 540.8 GB, free 407.9 GB
/data disk: total 1099.5 GB, free 249.4 GB
CUDA available: True
GPU: NVIDIA GeForce RTX 5090
VRAM total: 33.7 GB


## 3. Load Dataset and Create Data Loader


In [None]:
batch_size = 512
embeddings_size = "xl"

In [16]:
csv_path = "/data/thesis/coco_manifest.csv"
embeddings_path = "/data/thesis/models/hf_xl_coco_embeddings.npy"
colour_hist_path = "/data/thesis/data/adimagenet_color_histograms_hcl_514.npy" 
edge_maps_path = "/data/thesis/data/adimagenet_edge_maps.npy"

In [17]:
# Load the dataset
df = pd.read_csv(csv_path)
print(f"Dataset loaded: {len(df)} images")
print(f"Columns: {df.columns.tolist()}")
print(f"\nFirst few rows:")
print(df.head())

# Create dataset
dataset = UnifiedImageDataset(
    df.rename(columns={"local_path": "file_path"}), 
    mode="file_df",
    size=(224, 224)
)

print(f"\nDataset created with {len(dataset)} samples")

# Create data loader with optimal settings
num_cpu = cpu_count()
num_workers = min(32, max(8, num_cpu // 8))

loader = DataLoader(
    dataset,
    batch_size=4096,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6,
    pin_memory_device="cuda" if torch.cuda.is_available() else None,
)

print(f"\nDataLoader created with {len(loader)} batches")
print(f"Batch size: {batch_size}")
print(f"Number of workers: {num_workers}")


Dataset loaded: 616767 images
Columns: ['split', 'image_id', 'file_name', 'local_path', 'caption']

First few rows:
   split  image_id         file_name  \
0  train    203564  000000203564.jpg   
1  train    322141  000000322141.jpg   
2  train     16977  000000016977.jpg   
3  train    106140  000000106140.jpg   
4  train    106140  000000106140.jpg   

                                          local_path  \
0  /data/thesis/coco/images/train2017/00000020356...   
1  /data/thesis/coco/images/train2017/00000032214...   
2  /data/thesis/coco/images/train2017/00000001697...   
3  /data/thesis/coco/images/train2017/00000010614...   
4  /data/thesis/coco/images/train2017/00000010614...   

                                             caption  
0  A bicycle replica with a clock as the front wh...  
1  A room with blue walls and a white sink and door.  
2  A car that seems to be parked illegally behind...  
3  A large passenger airplane flying through the ...  
4  There is a GOL plane taking 

## 4. Generate Features and Train Model


In [18]:
# Generate CLIP embeddings using the package
print("Generating CLIP embeddings...")
if embeddings_size == "xl":  
    embeddings = degis.generate_xl_embeddings(
        csv_path=csv_path,
        output_path=embeddings_path,
        batch_size=batch_size,
        num_workers=num_workers,
        force_recompute=True
    )
else:
    embeddings = degis.generate_clip_embeddings(
        csv_path=csv_path,
        output_path=embeddings_path,
        batch_size=batch_size,
        num_workers=num_workers,
        force_recompute=True
    )
print(f"✓ Generated embeddings with shape: {embeddings.shape}")

# Generate color histograms
print("\nGenerating color histograms...")
histograms = degis.generate_color_histograms(
    loader=loader,
    hist_path=colour_hist_path,
    hist_bins=8,
    force_recompute=False,
    color_space="hcl"
)
print(f"✓ Generated HCL histograms with shape: {histograms.shape}")

# Generate edge maps
print("\nGenerating edge maps...")
edge_maps = degis.generate_edge_maps(
    loader=loader,
    edge_maps_path=edge_maps_path,
    method="canny",
    force_recompute=False
)
print(f"✓ Generated edge maps with shape: {edge_maps.shape}")

# Train the color disentanglement model
print("\nTraining color disentanglement model...")
results = degis.train_color_model(
    embeddings_path=embeddings_path,
    histograms_path=colour_hist_path,
    hist_kind="hcl514",
    epochs=200,
    batch_size=4096,
    val_batch_size=8192,
    lr=1e-3,
    weight_decay=1e-2,
    blur=0.05,
    lambda_ortho=0.1,
    top_k=100,
    weighting=True,
)

print(f"\n✓ Training complete!")
print(f"Output directory: {results['output_dir']}")


Generating CLIP embeddings...


  with torch.cuda.amp.autocast(dtype=torch.float16):
HF laion/CLIP-ViT-bigG-14-laion2B-39B-b160k batched encode (global): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1205/1205 [1:34:33<00:00,  4.71s/it]


→ Saved embeddings to /data/thesis/models/hf_xl_coco_embeddings.npy (shape=(616767, 1280), dim=1280)
Saved: /data/thesis/models/hf_xl_coco_embeddings.npy, shape: (616767, 1280)
✓ Generated embeddings with shape: (616767, 1280)

Generating color histograms...
### Color Histogram Generation [FAST, HCL] ###
Total images: 616767, Bins: 8, Dimensions: 514
Loaded from /data/thesis/data/adimagenet_color_histograms_hcl_514.npy
✓ Generated HCL histograms with shape: (2080, 514)

Generating edge maps...
[✓] Loaded cached edge maps from /data/thesis/data/adimagenet_edge_maps.npy
✓ Generated edge maps with shape: (2080, 50176)

Training color disentanglement model...
Run dir: /data/degis/runs/color_hcl514_tk100_b4096-20250911-101203
Rest dir: /data/degis/runs/rest_hcl514_tk100_b4096-20250911-101203


AssertionError: Embeddings and histograms must share N