# CKAN-SpecNet Training Demo

This notebook demonstrates training CKAN-SpecNet for functional group classification from MIR spectra.

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import polars as pl
import torch
from sklearn.metrics import classification_report, confusion_matrix

from ckan_specnet import (
    CKANSpecNet,
    stratified_split,
    create_dataloaders,
    Trainer,
    set_seed,
    get_device,
    count_parameters,
    SMARTS_PATTERNS,
)
from ckan_specnet.utils import plot_confusion_matrix

In [None]:
# Configuration
TARGET = "alcohols"
NUM_CLASSES = 4  # 0, 1, 2, >=3
UPPER_BOUND = NUM_CLASSES - 1
RANDOM_STATE = 42
BATCH_SIZE = 1024
EPOCHS = 500

set_seed(RANDOM_STATE)
device = get_device()
print(f"Device: {device}")
print(f"Target: {TARGET}")
print(f"SMARTS: {SMARTS_PATTERNS[TARGET]}")

In [None]:
# Load data
df = pl.read_parquet("../data/data.parquet").filter(pl.col("components") == 1)
X = np.vstack(df["X"].list.slice(1).to_list()).astype(np.float32)[:, 76:-78]
Y = df[TARGET].clip(upper_bound=UPPER_BOUND).to_numpy().astype(np.int64)

print(f"Dataset: {len(X)} samples, {X.shape[1]} features")
print(f"Class distribution: {np.bincount(Y)}")

In [None]:
# Split data
X_train, X_test, Y_train, Y_test, class_weights = stratified_split(
    X, Y, NUM_CLASSES, test_size=0.2, random_state=RANDOM_STATE
)
print(f"Train: {len(X_train)}, Test: {len(X_test)}")
print(f"Class weights: {class_weights.round(3)}")

In [None]:
# Create dataloaders
train_loader, test_loader = create_dataloaders(
    X_train, Y_train, X_test, Y_test, batch_size=BATCH_SIZE
)

In [None]:
# Initialize model
model = CKANSpecNet(
    input_size=X_train.shape[1],
    num_classes=NUM_CLASSES,
    conv_channels=[32, 64, 128, 256],
    conv_kernels=[15, 13, 11, 5],
    pool_sizes=[3, 2, None, None],
    eca_positions=[2, 3],
    adaptive_pool_size=64,
    fc_hidden=1024,
    kan_hidden=256,
    dropout_fc=0.7,
    dropout_head=0.3,
)

print(f"Model parameters: {count_parameters(model):,}")
print(model)

In [None]:
# Initialize trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    device=device,
    num_classes=NUM_CLASSES,
    learning_rate=0.001,
    weight_decay=0.05,
    class_weights=class_weights,
)

In [None]:
# Train
best_acc = trainer.train(
    epochs=EPOCHS,
    patience=50,
    use_mixup=True,
    mixup_alpha=1.0,
    mixup_p=0.5,
    use_smoothness=True,
    lambda_smooth=0.01,
    verbose=True,
)
print(f"\nBest accuracy: {best_acc:.2f}%")

In [None]:
# Evaluate
metrics, y_true, y_pred, y_probs = trainer.evaluate()

print("\nMetrics:")
for k, v in metrics.items():
    print(f"  {k}: {v:.4f}")

print("\nClassification Report:")
print(classification_report(y_true, y_pred, digits=4))

In [None]:
# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
class_names = [f"Class {i}" for i in range(NUM_CLASSES)]
plot_confusion_matrix(cm, class_names, title=f"{TARGET} Confusion Matrix")

In [None]:
# Save model
trainer.save(f"../results/{TARGET}_model.pth")
print("Model saved!")