# VQ Embedding Analysis

Pretrained VQ-AENB-Conditional encoder를 사용하여 전체 adata에 VQ code 및 embedding을 추출하고 분석합니다.

## 출력물
- `adata.obs['vq_code']`: 각 세포의 codebook index
- `adata.obsm['X_vq']`: 각 세포의 quantized latent embedding
- `adata.uns['codebook']`: codebook embedding matrix
- `adata.uns['codebook_stats']`: code별 통계 (DataFrame → dict)

## 1. Setup & Load

In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
sys.path.insert(0, str(PROJECT_ROOT))

import json
import numpy as np
import pandas as pd
import torch
import scanpy as sc
from tqdm.auto import tqdm

print(f"Project root: {PROJECT_ROOT}")

In [None]:
# ============================================================
# Configuration - 경로 수정 필요
# ============================================================

# Paths
ADATA_PATH = "/home/bmi-user/workspace/data/HSvsCD/data/Whole_SCP_PCD_Skin_805k_6k.h5ad"
ENCODER_PATH = "/home/bmi-user/workspace/data/HSvsCD/scMILDQ_Cond/results/pretrained/vq_aenb_conditional_whole.pth"
STUDY_MAPPING_PATH = "/home/bmi-user/workspace/data/HSvsCD/scMILDQ_Cond/results/pretrained/study_mapping.json"
OUTPUT_DIR = Path("/home/bmi-user/workspace/data/HSvsCD/scMILDQ_Cond/results/vq_analysis")

# Processing
BATCH_SIZE = 4096
DEVICE = "cuda:0"  # or "cpu"

# Columns
STUDY_COL = "study"
STATUS_COL = "Status"
ORGAN_COL = "Organ"
DISEASE_COL = "disease_numeric"
SAMPLE_COL = "sample"

# Subset definitions
SUBSETS = {
    'whole': None,
    'skin3': ['GSE175990', 'GSE220116'],
    'scp1884': ['SCP1884'],
    'skin_all': ['GSE154775', 'GSE155850', 'GSE175990', 'GSE212721', 'GSE220116'],
    'colon_all': ['SCP1884', 'GSE225199', 'GSE260842', 'GSE277387', 'GSE114374', 'GSE116222'],
}

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")

In [None]:
# Load adata
print(f"Loading adata from: {ADATA_PATH}")
adata = sc.read_h5ad(ADATA_PATH)
print(f"Shape: {adata.n_obs} cells × {adata.n_vars} genes")
print(f"Columns: {list(adata.obs.columns)}")

In [None]:
# Load study mapping
with open(STUDY_MAPPING_PATH, 'r') as f:
    id_to_name = json.load(f)
name_to_id = {v: int(k) for k, v in id_to_name.items()}
print(f"Study mapping: {name_to_id}")

In [None]:
# Load pretrained encoder
from src.models.autoencoder import VQ_AENB_Conditional

device = torch.device(DEVICE if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

checkpoint = torch.load(ENCODER_PATH, map_location=device)
model_config = checkpoint.get('config', {})
print(f"Model config: {model_config}")

encoder = VQ_AENB_Conditional(
    input_dim=model_config['input_dim'],
    latent_dim=model_config['latent_dim'],
    device=device,
    hidden_layers=model_config['hidden_layers'],
    n_studies=model_config['n_studies'],
    study_emb_dim=model_config.get('study_emb_dim', 16),
    num_codes=model_config.get('num_codes', 1024),
)
encoder.load_state_dict(checkpoint['model_state_dict'])
encoder.to(device)
encoder.eval()

print(f"Encoder loaded: {model_config['num_codes']} codes, {model_config['latent_dim']} latent dim")

## 2. Cell-level VQ Extraction

In [None]:
def extract_vq_embeddings(
    adata,
    encoder,
    name_to_id: dict,
    study_col: str = "study",
    batch_size: int = 4096,
    device: torch.device = None
):
    """
    adata에 VQ code와 embedding을 추가합니다 (in-place).
    
    Args:
        adata: AnnData object
        encoder: Pretrained VQ-AENB-Conditional model
        name_to_id: study name → study id mapping
        study_col: adata.obs에서 study 정보가 있는 컬럼
        batch_size: 배치 크기
        device: torch device
    
    Adds:
        adata.obs['vq_code']: codebook indices (int)
        adata.obsm['X_vq']: quantized embeddings (float32)
    """
    if device is None:
        device = next(encoder.parameters()).device
    
    # Map study names to IDs
    study_ids = adata.obs[study_col].map(name_to_id).values.astype(int)
    
    n_cells = adata.n_obs
    latent_dim = encoder.latent_dim
    
    all_codes = np.zeros(n_cells, dtype=np.int32)
    all_embeddings = np.zeros((n_cells, latent_dim), dtype=np.float32)
    
    encoder.eval()
    with torch.no_grad():
        for start_idx in tqdm(range(0, n_cells, batch_size), desc="Extracting VQ embeddings"):
            end_idx = min(start_idx + batch_size, n_cells)
            
            # Get batch data
            if hasattr(adata.X, 'toarray'):
                batch_x = torch.tensor(
                    adata.X[start_idx:end_idx].toarray(), 
                    dtype=torch.float32, device=device
                )
            else:
                batch_x = torch.tensor(
                    adata.X[start_idx:end_idx], 
                    dtype=torch.float32, device=device
                )
            
            batch_study_ids = torch.tensor(
                study_ids[start_idx:end_idx], 
                dtype=torch.long, device=device
            )
            
            # Get VQ codes
            codes = encoder.get_codebook_indices(batch_x, batch_study_ids)
            all_codes[start_idx:end_idx] = codes.cpu().numpy()
            
            # Get quantized embeddings
            embeddings = encoder.features(batch_x, batch_study_ids)
            all_embeddings[start_idx:end_idx] = embeddings.cpu().numpy()
    
    # Add to adata
    adata.obs['vq_code'] = all_codes
    adata.obsm['X_vq'] = all_embeddings
    
    print(f"Added to adata:")
    print(f"  - adata.obs['vq_code']: {all_codes.shape}, unique codes: {len(np.unique(all_codes))}")
    print(f"  - adata.obsm['X_vq']: {all_embeddings.shape}")

In [None]:
# Extract VQ embeddings for whole adata
extract_vq_embeddings(
    adata=adata,
    encoder=encoder,
    name_to_id=name_to_id,
    study_col=STUDY_COL,
    batch_size=BATCH_SIZE,
    device=device
)

In [None]:
# Save codebook to adata.uns
codebook = encoder.quantizer.get_codebook().cpu().numpy()
adata.uns['codebook'] = codebook
print(f"Codebook shape: {codebook.shape}")

## 3. Codebook-level Statistics

In [None]:
def compute_codebook_stats(
    adata,
    num_codes: int,
    status_col: str = "Status",
    organ_col: str = "Organ",
    study_col: str = "study",
    sample_col: str = "sample",
    subset_mask: np.ndarray = None,
    subset_name: str = "whole"
) -> pd.DataFrame:
    """
    각 VQ code별 통계를 계산합니다.
    
    Args:
        adata: AnnData with 'vq_code' in obs
        num_codes: Total number of codes in codebook
        status_col: Status column name
        organ_col: Organ column name
        study_col: Study column name
        sample_col: Sample column name
        subset_mask: Boolean mask for subset (None = whole)
        subset_name: Name of subset for column prefix
    
    Returns:
        DataFrame with codebook statistics
    """
    if subset_mask is None:
        obs = adata.obs
    else:
        obs = adata.obs[subset_mask]
    
    # Initialize stats dict
    stats = {
        'code_idx': list(range(num_codes)),
        'n_cells': [0] * num_codes,
        'n_samples': [0] * num_codes,  # sample diversity
        'is_single_sample': [False] * num_codes,  # single sample flag
    }
    
    # Get unique values for categorical columns
    statuses = obs[status_col].unique()
    organs = obs[organ_col].unique()
    studies = obs[study_col].unique()
    
    # Initialize ratio columns
    for status in statuses:
        stats[f'status_{status}'] = [0.0] * num_codes
    for organ in organs:
        stats[f'organ_{organ}'] = [0.0] * num_codes
    for study in studies:
        stats[f'study_{study}'] = [0.0] * num_codes
    
    # Organ-specific disease ratios
    stats['disease_ratio_skin'] = [np.nan] * num_codes  # HS / (HS + ctrl_skin)
    stats['disease_ratio_colon'] = [np.nan] * num_codes  # CD / (CD + ctrl_colon)
    stats['disease_ratio_overall'] = [0.0] * num_codes  # (CD + HS) / total
    
    # Group by vq_code
    grouped = obs.groupby('vq_code')
    
    for code_idx, group in grouped:
        if code_idx >= num_codes:
            continue
            
        n = len(group)
        stats['n_cells'][code_idx] = n
        
        # Sample diversity
        unique_samples = group[sample_col].nunique()
        stats['n_samples'][code_idx] = unique_samples
        stats['is_single_sample'][code_idx] = (unique_samples == 1)
        
        # Status ratios
        status_counts = group[status_col].value_counts()
        for status in statuses:
            stats[f'status_{status}'][code_idx] = status_counts.get(status, 0) / n
        
        # Organ ratios
        organ_counts = group[organ_col].value_counts()
        for organ in organs:
            stats[f'organ_{organ}'][code_idx] = organ_counts.get(organ, 0) / n
        
        # Study ratios
        study_counts = group[study_col].value_counts()
        for study in studies:
            stats[f'study_{study}'][code_idx] = study_counts.get(study, 0) / n
        
        # Disease ratios by organ
        # Skin: HS / (HS + ctrl_skin)
        n_hs = status_counts.get('HS', 0)
        n_ctrl_skin = status_counts.get('ctrl_skin', 0)
        if n_hs + n_ctrl_skin > 0:
            stats['disease_ratio_skin'][code_idx] = n_hs / (n_hs + n_ctrl_skin)
        
        # Colon: CD / (CD + ctrl_colon)
        n_cd = status_counts.get('CD', 0)
        n_ctrl_colon = status_counts.get('ctrl_colon', 0)
        if n_cd + n_ctrl_colon > 0:
            stats['disease_ratio_colon'][code_idx] = n_cd / (n_cd + n_ctrl_colon)
        
        # Overall: (CD + HS) / total
        stats['disease_ratio_overall'][code_idx] = (n_cd + n_hs) / n
    
    df = pd.DataFrame(stats)
    df['subset'] = subset_name
    
    return df

In [None]:
# Compute stats for whole data
num_codes = model_config.get('num_codes', 1024)

codebook_stats_whole = compute_codebook_stats(
    adata=adata,
    num_codes=num_codes,
    status_col=STATUS_COL,
    organ_col=ORGAN_COL,
    study_col=STUDY_COL,
    sample_col=SAMPLE_COL,
    subset_mask=None,
    subset_name="whole"
)

print(f"Codebook stats shape: {codebook_stats_whole.shape}")
print(f"Active codes (n_cells > 0): {(codebook_stats_whole['n_cells'] > 0).sum()}")
print(f"Single-sample codes: {codebook_stats_whole['is_single_sample'].sum()}")
codebook_stats_whole.head(10)

In [None]:
# Quick summary of codebook usage
active_codes = codebook_stats_whole[codebook_stats_whole['n_cells'] > 0]
print(f"\n=== Codebook Usage Summary ===")
print(f"Total codes: {num_codes}")
print(f"Active codes: {len(active_codes)} ({len(active_codes)/num_codes*100:.1f}%)")
print(f"Single-sample codes: {active_codes['is_single_sample'].sum()} ({active_codes['is_single_sample'].mean()*100:.1f}%)")
print(f"\nCells per code:")
print(f"  Mean: {active_codes['n_cells'].mean():.1f}")
print(f"  Median: {active_codes['n_cells'].median():.1f}")
print(f"  Max: {active_codes['n_cells'].max()}")
print(f"\nSamples per code:")
print(f"  Mean: {active_codes['n_samples'].mean():.1f}")
print(f"  Median: {active_codes['n_samples'].median():.1f}")
print(f"  Max: {active_codes['n_samples'].max()}")

## 4. Subset Analysis

In [None]:
# Compute stats for each subset
all_subset_stats = [codebook_stats_whole]

for subset_name, study_list in SUBSETS.items():
    if subset_name == 'whole' or study_list is None:
        continue
    
    print(f"\nProcessing subset: {subset_name}")
    
    # Create mask
    mask = adata.obs[STUDY_COL].isin(study_list)
    n_cells = mask.sum()
    print(f"  Cells: {n_cells}")
    
    if n_cells == 0:
        print(f"  Skipping (no cells)")
        continue
    
    subset_stats = compute_codebook_stats(
        adata=adata,
        num_codes=num_codes,
        status_col=STATUS_COL,
        organ_col=ORGAN_COL,
        study_col=STUDY_COL,
        sample_col=SAMPLE_COL,
        subset_mask=mask,
        subset_name=subset_name
    )
    
    active = (subset_stats['n_cells'] > 0).sum()
    print(f"  Active codes: {active}")
    
    all_subset_stats.append(subset_stats)

# Combine all stats
codebook_stats_all = pd.concat(all_subset_stats, ignore_index=True)
print(f"\nCombined stats shape: {codebook_stats_all.shape}")

In [None]:
# Summary by subset
summary = codebook_stats_all.groupby('subset').agg({
    'n_cells': 'sum',
    'code_idx': lambda x: (codebook_stats_all.loc[x.index, 'n_cells'] > 0).sum(),
    'is_single_sample': lambda x: x[codebook_stats_all.loc[x.index, 'n_cells'] > 0].sum()
}).rename(columns={'code_idx': 'active_codes', 'is_single_sample': 'single_sample_codes'})

print("=== Subset Summary ===")
summary

## 5. Save Results

In [None]:
# Save codebook stats to adata.uns (whole stats as dict)
# Convert DataFrame to dict for h5ad compatibility
adata.uns['codebook_stats'] = codebook_stats_whole.to_dict(orient='list')
print("Saved codebook_stats to adata.uns")

In [None]:
# Save adata with VQ embeddings
output_adata_path = OUTPUT_DIR / "adata_with_vq.h5ad"
print(f"Saving adata to: {output_adata_path}")
adata.write_h5ad(output_adata_path)
print("Done!")

In [None]:
# Save codebook stats as CSV (all subsets)
output_stats_path = OUTPUT_DIR / "codebook_stats_all_subsets.csv"
codebook_stats_all.to_csv(output_stats_path, index=False)
print(f"Saved codebook stats to: {output_stats_path}")

# Save whole stats separately
output_stats_whole_path = OUTPUT_DIR / "codebook_stats_whole.csv"
codebook_stats_whole.to_csv(output_stats_whole_path, index=False)
print(f"Saved whole stats to: {output_stats_whole_path}")

In [None]:
# Save codebook embeddings
output_codebook_path = OUTPUT_DIR / "codebook_embeddings.npy"
np.save(output_codebook_path, codebook)
print(f"Saved codebook embeddings to: {output_codebook_path}")

## 6. Quick Sanity Check

In [None]:
# Reload and verify
print("Verifying saved adata...")
adata_reload = sc.read_h5ad(output_adata_path)

print(f"\nadata.obs columns: {list(adata_reload.obs.columns)}")
print(f"adata.obsm keys: {list(adata_reload.obsm.keys())}")
print(f"adata.uns keys: {list(adata_reload.uns.keys())}")

print(f"\nvq_code unique values: {adata_reload.obs['vq_code'].nunique()}")
print(f"X_vq shape: {adata_reload.obsm['X_vq'].shape}")
print(f"codebook shape: {adata_reload.uns['codebook'].shape}")

In [None]:
# Reload codebook stats from uns
codebook_stats_reloaded = pd.DataFrame(adata_reload.uns['codebook_stats'])
print(f"Reloaded codebook_stats shape: {codebook_stats_reloaded.shape}")
codebook_stats_reloaded.head()

---

## Summary

### 저장된 파일
1. `adata_with_vq.h5ad`: VQ code, embedding, codebook, stats가 추가된 adata
2. `codebook_stats_all_subsets.csv`: 모든 subset의 codebook 통계
3. `codebook_stats_whole.csv`: whole data의 codebook 통계
4. `codebook_embeddings.npy`: codebook embedding matrix

### adata 구조
- `adata.obs['vq_code']`: 각 세포의 VQ code index
- `adata.obsm['X_vq']`: 각 세포의 quantized latent embedding
- `adata.uns['codebook']`: codebook embedding matrix (num_codes × latent_dim)
- `adata.uns['codebook_stats']`: code별 통계 (dict format)