# CRAIG

We observe how a coreset of the data curated by CRAIG for a tiny transformer on sentiment classification.

## Benchmarks

- [stanfordnlp/imdb](https://huggingface.co/datasets/stanfordnlp/imdb)
- [stanfordnlp/sst2](https://huggingface.co/datasets/stanfordnlp/sst2)

In [None]:
!git clone https://github.com/SamhithKakarla/Data-Selection-for-LLM-Training.git

import sys
sys.path.append('/content/Data-Selection-for-LLM-Training')

In [None]:
from datasets import load_dataset

sst2_dataset = load_dataset("stanfordnlp/sst2")
imdb_dataset = load_dataset("stanfordnlp/imdb")

In [None]:
import os
import math
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import load_dataset

from model import TinyGPT
from train import make_dataset, collate_fn, compute_accuracy

## Train on Full Dataset

In [None]:
config = {
    'tokenizer': 'gpt2',
    'output_dir': './tiny_gpt_runs',
    'max_len': 64,
    'batch_size': 32,
    'd_model': 128,
    'n_layers': 4,
    'n_heads': 4,
    'epochs': 15,
    'lr': 3e-4,
}

os.makedirs(config['output_dir'], exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = AutoTokenizer.from_pretrained(config['tokenizer'])
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token

# Data
def extract(d): 
    return [ex['sentence'] for ex in d], [ex['label'] for ex in d]

train_texts, train_labels = extract(sst2_dataset['train'])
val_texts, val_labels = extract(sst2_dataset['validation'])

train_loader = DataLoader(
    make_dataset(tokenizer, train_texts, train_labels, config['max_len']),
    batch_size=config['batch_size'],
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    make_dataset(tokenizer, val_texts, val_labels, config['max_len']),
    batch_size=config['batch_size'],
    shuffle=False,
    collate_fn=collate_fn
)

num_classes = max(train_labels + val_labels) + 1

# Model
model = TinyGPT(
    vocab_size=tokenizer.vocab_size,
    max_len=config['max_len'],
    d_model=config['d_model'],
    n_layers=config['n_layers'],
    n_heads=config['n_heads'],
    num_classes=num_classes
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'])
criterion = nn.CrossEntropyLoss()

print("Params:", sum(p.numel() for p in model.parameters()))

# Training Loop
full_epoches = []

for epoch in range(1, config['epochs'] + 1):
    model.train()
    train_loss = train_acc = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch}"):
        input_ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        logits = model(input_ids, attention_mask=mask)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += compute_accuracy(torch.argmax(logits, dim=-1), labels)

    # Validation
    model.eval()
    val_loss = val_acc = 0

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            logits = model(input_ids, attention_mask=mask)
            val_loss += criterion(logits, labels).item()
            val_acc += compute_accuracy(torch.argmax(logits, dim=-1), labels)

    print(
        f"Epoch {epoch}: "
        f"train_loss={train_loss/len(train_loader):.4f}, "
        f"train_acc={train_acc/len(train_loader):.4f}, "
        f"val_loss={val_loss/len(val_loader):.4f}, "
        f"val_acc={val_acc/len(val_loader):.4f}"
    )

    full_epoches.append({
        "train_loss": train_loss / len(train_loader),
        "train_acc": train_acc / len(train_loader),
        "val_loss": val_loss / len(val_loader),
        "val_acc": val_acc / len(val_loader),
    })

    torch.save(model.state_dict(), f"{config['output_dir']}/epoch{epoch}.pt")

print("Done.")

## Train Random Selection

In [None]:
config = {
    'tokenizer': 'gpt2',
    'output_dir': './tiny_gpt_runs',
    'max_len': 64,
    'batch_size': 32,
    'd_model': 128,
    'n_layers': 4,
    'n_heads': 4,
    'epochs': 15,
    'lr': 3e-4,
}

os.makedirs(config['output_dir'], exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = AutoTokenizer.from_pretrained(config['tokenizer'])
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token

# Data
def extract(d): 
    return [ex['sentence'] for ex in d], [ex['label'] for ex in d]

train_texts, train_labels = extract(sst2_dataset['train'])
val_texts, val_labels = extract(sst2_dataset['validation'])

full_dataset = make_dataset(tokenizer, train_texts, train_labels, config['max_len'])

# ---- Random subset selection ----
budget = int(0.30 * len(full_dataset))   # keep 30%
indices = torch.randperm(len(full_dataset))[:budget].tolist()

subset_dataset = torch.utils.data.Subset(full_dataset, indices)

train_loader = DataLoader(
    subset_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    make_dataset(tokenizer, val_texts, val_labels, config['max_len']),
    batch_size=config['batch_size'],
    shuffle=False,
    collate_fn=collate_fn
)

num_classes = max(train_labels + val_labels) + 1

# Model
model = TinyGPT(
    vocab_size=tokenizer.vocab_size,
    max_len=config['max_len'],
    d_model=config['d_model'],
    n_layers=config['n_layers'],
    n_heads=config['n_heads'],
    num_classes=num_classes
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'])
criterion = nn.CrossEntropyLoss()

print("Params:", sum(p.numel() for p in model.parameters()))
print(f"Random subset size: {len(subset_dataset)} / {len(full_dataset)}")

# Training Loop
random_epoches = []

for epoch in range(1, config['epochs'] + 1):
    model.train()
    train_loss = train_acc = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch}"):
        input_ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        logits = model(input_ids, attention_mask=mask)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += compute_accuracy(torch.argmax(logits, dim=-1), labels)

    # Validation
    model.eval()
    val_loss = val_acc = 0

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            logits = model(input_ids, attention_mask=mask)
            val_loss += criterion(logits, labels).item()
            val_acc += compute_accuracy(torch.argmax(logits, dim=-1), labels)

    print(
        f"Epoch {epoch}: "
        f"train_loss={train_loss/len(train_loader):.4f}, "
        f"train_acc={train_acc/len(train_loader):.4f}, "
        f"val_loss={val_loss/len(val_loader):.4f}, "
        f"val_acc={val_acc/len(val_loader):.4f}"
    )

    random_epoches.append({
        "train_loss": train_loss / len(train_loader),
        "train_acc": train_acc / len(train_loader),
        "val_loss": val_loss / len(val_loader),
        "val_acc": val_acc / len(val_loader),
    })

    torch.save(model.state_dict(), f"{config['output_dir']}/epoch{epoch}.pt")

print("Done.")

## Train CRAIG

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

def compute_representations(model, dataloader, device):
    """
    Lightweight CRAIG approximation:
    use model logits as embeddings for each example.
    You can swap logits for better reps later.
    """
    model.eval()
    reps = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Computing representations"):
            input_ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)

            logits = model(input_ids, attention_mask=mask)  # [B, num_classes]
            reps.append(logits.detach().cpu())

    reps = torch.cat(reps, dim=0)  # [N, num_classes]
    return reps


def kcenter_coreset(reps, budget):
    """
    K-center greedy selection in embedding space.
    reps: [N, D] tensor on CPU
    budget: number of points to select
    """
    reps = reps.to(torch.float32)
    N, D = reps.shape

    # start with a random point
    first = torch.randint(0, N, (1,)).item()
    selected = [first]

    # distances to nearest selected center
    dist = torch.cdist(reps, reps[first:first+1]).squeeze(1)  # [N]

    for _ in range(1, budget):
        # pick the point farthest from any selected center
        idx = torch.argmax(dist).item()
        selected.append(idx)

        # update distance: min(current_dist, dist_to_new_center)
        new_dist = torch.cdist(reps, reps[idx:idx+1]).squeeze(1)
        dist = torch.minimum(dist, new_dist)

    return selected


In [None]:
config = {
    'tokenizer': 'gpt2',
    'output_dir': './tiny_gpt_runs',
    'max_len': 64,
    'batch_size': 32,
    'd_model': 128,
    'n_layers': 4,
    'n_heads': 4,
    'epochs': 15,
    'lr': 3e-4,
}

os.makedirs(config['output_dir'], exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = AutoTokenizer.from_pretrained(config['tokenizer'])
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token

# Data
def extract(d): 
    return [ex['sentence'] for ex in d], [ex['label'] for ex in d]

train_texts, train_labels = extract(sst2_dataset['train'])
val_texts, val_labels = extract(sst2_dataset['validation'])

full_dataset = make_dataset(tokenizer, train_texts, train_labels, config['max_len'])

# loader to compute embeddings for coreset selection
rep_loader = DataLoader(
    full_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    collate_fn=collate_fn
)

# normal val loader
val_loader = DataLoader(
    make_dataset(tokenizer, val_texts, val_labels, config['max_len']),
    batch_size=config['batch_size'],
    shuffle=False,
    collate_fn=collate_fn
)

num_classes = max(train_labels + val_labels) + 1

# Model
model = TinyGPT(
    vocab_size=tokenizer.vocab_size,
    max_len=config['max_len'],
    d_model=config['d_model'],
    n_layers=config['n_layers'],
    n_heads=config['n_heads'],
    num_classes=num_classes
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'])
criterion = nn.CrossEntropyLoss()

print("Params:", sum(p.numel() for p in model.parameters()))

# Lightweight CRAIG (k-center coreset)
print("Computing representations for lightweight CRAIG...")
reps = compute_representations(model, rep_loader, device)

budget = int(0.20 * len(full_dataset))   # keep 20% of SST-2
print(f"Selecting {budget} samples with k-center coreset...")

subset_indices = kcenter_coreset(reps, budget)

subset_dataset = torch.utils.data.Subset(full_dataset, subset_indices)

train_loader = DataLoader(
    subset_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    collate_fn=collate_fn
)

print(f"Subset size: {len(subset_dataset)} / {len(full_dataset)}")

# Training Loop
craig_epoches = []
for epoch in range(1, config['epochs'] + 1):
    model.train()
    train_loss = train_acc = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch}"):
        input_ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        logits = model(input_ids, attention_mask=mask)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += compute_accuracy(torch.argmax(logits, dim=-1), labels)

    model.eval()
    val_loss = val_acc = 0

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            logits = model(input_ids, attention_mask=mask)
            val_loss += criterion(logits, labels).item()
            val_acc += compute_accuracy(torch.argmax(logits, dim=-1), labels)

    print(f"Epoch {epoch}: "
          f"train_loss={train_loss/len(train_loader):.4f}, "
          f"train_acc={train_acc/len(train_loader):.4f}, "
          f"val_loss={val_loss/len(val_loader):.4f}, "
          f"val_acc={val_acc/len(val_loader):.4f}")

    craig_epoches.append({
        "train_loss": train_loss/len(train_loader),
        "train_acc": train_acc/len(train_loader),
        "val_loss": val_loss/len(val_loader),
        "val_acc": val_acc/len(val_loader),
    })

    torch.save(model.state_dict(), f"{config['output_dir']}/epoch{epoch}.pt")

print("Done.")

## Visualize Results

In [None]:
import matplotlib.pyplot as plt

full_train_acc = [e["train_acc"] for e in full_epoches]
full_val_acc   = [e["val_acc"] for e in full_epoches]
craig_train_acc = [e["train_acc"] for e in craig_epoches]
craig_val_acc   = [e["val_acc"] for e in craig_epoches]
epochs    = list(range(1, len(full_epoches) + 1))

plt.figure(figsize=(7, 4))
plt.plot(epochs, full_train_acc, marker='o', label='Full Train Accuracy')
plt.plot(epochs, full_val_acc, marker='o', label='Full Val Accuracy')
plt.plot(epochs, craig_train_acc, marker='o', label='CRAIG Train Accuracy')
plt.plot(epochs, craig_val_acc, marker='o', label='CRAIG Val Accuracy')

plt.xticks(epochs)
plt.ylim(0.4, 1.0)

plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Train vs Validation Accuracy")
plt.grid(True, linestyle='--', alpha=0.5)
plt.legend()
plt.tight_layout()
plt.show()

## Visualize Random vs CRAIG

In [None]:
random_30_epoches = random_epoches
craig_30_epoches = craig_epoches
print(random_5_epoches)
print(random_10_epoches)
print(random_15_epoches)
print(random_20_epoches)
print(random_25_epoches)
print(random_30_epoches)

In [None]:
import matplotlib.pyplot as plt

random_sets = {
    5:  random_5_epoches,
    10: random_10_epoches,
    15: random_15_epoches,
    20: random_20_epoches,
    25: random_25_epoches,
    30: random_30_epoches,
}

craig_sets = {
    5:  craig_5_epoches,
    10: craig_10_epoches,
    15: craig_15_epoches,
    20: craig_20_epoches,
    25: craig_25_epoches,
    30: craig_30_epoches,
}

fig, axes = plt.subplots(2, 3, figsize=(14, 8))
axes = axes.flatten()

Xs = [5, 10, 15, 20, 25, 30]

for i, X in enumerate(Xs):
    ax = axes[i]

    # random accuracy curve
    r_vals = [e["val_acc"] for e in random_sets[X]]
    epochs = list(range(1, len(r_vals) + 1))
    ax.plot(epochs, r_vals, marker='o', label=f"Random {X}%", linewidth=2)

    # craig accuracy curve
    c_vals = [e["val_acc"] for e in craig_sets[X]]
    ax.plot(epochs, c_vals, marker='o', label=f"CRAIG {X}%", linewidth=2)

    # subplot formatting
    ax.set_title(f"{X}% Subset")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Val Accuracy")
    ax.set_ylim(0.5, 0.8)
    ax.set_xticks(epochs)
    ax.grid(True, linestyle='--', alpha=0.4)
    ax.legend()

plt.tight_layout()
plt.show()
