# GNN Timetabling — Full Notebook

This notebook trains a **simple Graph Neural Network (GNN)** in PyTorch to predict timetables from JSON instances stored in `../database/instance_*.json` and writes predictions back in the same JSON-like timetable format.

**High-level flow**:
1. Load JSON instances
2. Convert each instance into per-(class,day,slot) examples and build a global vocabulary of feasible `(subject,teacher)` options
3. Train a supervised model to classify each slot into one of the vocab options
4. Decode predictions back into a timetable JSON and save to `/mnt/data/predicted`

> The model is a practical starter template. Replace or extend the GNN and decoding logic for better performance and constraint handling.


In [1]:

# Basic imports and environment checks
import os, glob, json, math, random, copy
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
from collections import defaultdict, Counter
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pprint
print('Python and PyTorch ready. torch:', torch.__version__)


Python and PyTorch ready. torch: 2.7.1+cu118


## Configuration
Set paths, device, and hyperparameters here. Adjust `DATABASE_DIR` to point at your JSON files (relative to notebook).

In [2]:
# --- Configuration ---
DATABASE_DIR = '../dataset'  # Where instance_001.json ... live
PREDICT_DIR = '/mnt/data/predicted'
os.makedirs(PREDICT_DIR, exist_ok=True)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 64
EPOCHS = 12
EMBED_DIM = 128
LR = 1e-3
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
print('Device:', DEVICE)

Device: cuda


## 1) Load instances
Loads JSON instances from DATABASE_DIR and shows an example. The notebook expects files named like `instance_001.json`. If none are found, add example JSONs there and re-run.

In [3]:
def load_instances(db_dir):
    paths = sorted(glob.glob(os.path.join(db_dir, 'instance_*.json')))
    instances = []
    for p in paths:
        with open(p, 'r') as f:
            try:
                instances.append(json.load(f))
            except Exception as e:
                print('Failed to load', p, '-', e)
    return instances

instances = load_instances(DATABASE_DIR)
print(f'Found {len(instances)} instances in', DATABASE_DIR)
if instances:
    print('\nExample instance keys:', list(instances[0].keys()))
    pprint = __import__('pprint').pprint
    print('\nInput summary:')
    pprint({k:v for k,v in instances[0]['input'].items() if k!='class_subjects' and k!='teacher_subjects' and k!='subject_sessions_per_week'})
else:
    print('No instances found. Place JSON files in', DATABASE_DIR)


Found 200 instances in ../dataset

Example instance keys: ['input', 'output']

Input summary:
{'num_classes': 4,
 'num_days': 6,
 'num_subjects': 10,
 'num_teachers': 7,
 'slots_per_day': 7}


## 2) Convert instances → per-slot examples
We transform each instance into examples for every (class, day, slot). Each example contains the feasible `(subject,teacher)` options and the ground-truth label if present.

In [4]:

def parse_sspw(sspw_raw):
    # Converts keys like '(0, 6)' or '0,6' to tuple keys
    out = {}
    for k,v in sspw_raw.items():
        ks = k
        if isinstance(k, str):
            ks = k.strip()
            if ks.startswith('(') and ks.endswith(')'):
                ks = ks[1:-1]
            parts = [p.strip() for p in ks.split(',')]
            if len(parts) == 2:
                a,b = int(parts[0]), int(parts[1])
                out[(a,b)] = int(v)
            else:
                # fallback, ignore
                pass
        else:
            # unexpected format
            pass
    return out

def instance_to_examples(instance):
    inp = instance['input']
    out = instance.get('output', {})
    num_classes = int(inp['num_classes'])
    num_days = int(inp['num_days'])
    slots_per_day = int(inp['slots_per_day'])
    # ensure keys are ints for class_subjects and teacher_subjects
    class_subjects = {int(k): v for k,v in inp['class_subjects'].items()}
    teacher_subjects = {int(k): v for k,v in inp['teacher_subjects'].items()}
    sspw = parse_sspw(inp.get('subject_sessions_per_week', {}))
    examples = []
    for c in range(num_classes):
        class_grid = out.get(str(c), None)
        for d in range(num_days):
            for s in range(slots_per_day):
                label = None
                if class_grid is not None:
                    try:
                        pair = class_grid[d][s]
                        label = (int(pair[0]), int(pair[1]))
                    except Exception:
                        label = (-1,-1)
                # feasible options: subject in class_subjects[c] and teacher qualified
                options = []
                for subj in class_subjects.get(c, []):
                    for t, quals in teacher_subjects.items():
                        if subj in quals:
                            options.append((int(subj), int(t)))
                if not options:
                    options = [(-1,-1)]
                examples.append({
                    'class': c,
                    'day': d,
                    'slot': s,
                    'options': options,
                    'label': label
                })
    return examples

# quick sanity check (if instances available)
if instances:
    exs = instance_to_examples(instances[0])
    print('Example count for instance 0:', len(exs))
    print('First example:', exs[0])
else:
    print('No instances to convert yet.')


Example count for instance 0: 168
First example: {'class': 0, 'day': 0, 'slot': 0, 'options': [(3, 1), (3, 2), (3, 3), (3, 6), (8, 0), (8, 1), (8, 2), (8, 3), (1, 0), (1, 1), (1, 2), (1, 3), (7, 1), (7, 2), (7, 3), (6, 1), (6, 2), (6, 5), (6, 6), (0, 0), (0, 1), (0, 2), (0, 4), (2, 1), (2, 2)], 'label': (6, 5)}


## 3) Build global vocabulary of (subject,teacher) options
We'll create a mapping from observed `(subject,teacher)` pairs to integer class labels the model will predict.

In [5]:
def build_option_vocab(all_instances):
    counter = Counter()
    for inst in all_instances:
        for ex in instance_to_examples(inst):
            for opt in ex['options']:
                counter[tuple(opt)] += 1
            if ex['label'] is not None and ex['label'] != (-1,-1):
                counter[tuple(ex['label'])] += 1
    # sort to have deterministic ordering
    opts = sorted(counter.keys(), key=lambda x:(x[0], x[1]))
    vocab = {opt:i for i,opt in enumerate(opts)}
    vocab_inv = {i:opt for opt,i in vocab.items()}
    return vocab, vocab_inv

if instances:
    vocab, vocab_inv = build_option_vocab(instances)
    print('Vocab size:', len(vocab))
    print('Some vocab items:', list(vocab.items())[:10])
else:
    vocab, vocab_inv = {}, {}
    print('Empty vocab (no instances)')


Vocab size: 120
Some vocab items: [((0, 0), 0), ((0, 1), 1), ((0, 2), 2), ((0, 3), 3), ((0, 4), 4), ((0, 5), 5), ((0, 6), 6), ((0, 7), 7), ((0, 8), 8), ((0, 9), 9)]


## 4) Dataset and DataLoader
The Dataset yields per-slot samples. We include metadata in each sample so the model can access instance-level sizes.

In [6]:
class TimetablingDataset(Dataset):
    def __init__(self, instances, vocab):
        self.rows = []
        self.vocab = vocab
        for inst in instances:
            exs = instance_to_examples(inst)
            for ex in exs:
                self.rows.append((inst['input'], ex))
    def __len__(self):
        return len(self.rows)
    def __getitem__(self, idx):
        inp, ex = self.rows[idx]
        sample = {
            'meta': inp,
            'class': ex['class'],
            'day': ex['day'],
            'slot': ex['slot'],
            'options': ex['options'],
            'label': ex['label']
        }
        # map label to vocab idx
        if sample['label'] is None:
            sample['label_idx'] = -1
        else:
            sample['label_idx'] = self.vocab.get(tuple(sample['label']), -1)
        return sample

def collate_fn(batch):
    meta = batch[0]['meta']
    num_teachers = int(meta['num_teachers'])
    class_ids = torch.tensor([b['class'] for b in batch], dtype=torch.long)
    day_ids = torch.tensor([b['day'] for b in batch], dtype=torch.long)
    slot_ids = torch.tensor([b['slot'] for b in batch], dtype=torch.long)
    label_idxs = torch.tensor([b['label_idx'] for b in batch], dtype=torch.long)
    options = [b['options'] for b in batch]
    meta_out = {'num_teachers': num_teachers}
    return {
        'class_ids': class_ids,
        'day_ids': day_ids,
        'slot_ids': slot_ids,
        'label_idxs': label_idxs,
        'options': options,
        'meta': meta_out,
        'raw_batch': batch
    }

# quick dataset test
if instances:
    dataset = TimetablingDataset(instances, vocab)
    loader = DataLoader(dataset, batch_size=8, collate_fn=collate_fn)
    batch = next(iter(loader))
    print('Batch keys:', list(batch.keys()))
    print('class_ids shape:', batch['class_ids'].shape)
else:
    print('No dataset to create.')
    


Batch keys: ['class_ids', 'day_ids', 'slot_ids', 'label_idxs', 'options', 'meta', 'raw_batch']
class_ids shape: torch.Size([8])


## 5) Simple GNN model
A straightforward embedding-based model that concatenates class/day/slot embeddings and a simple aggregated teacher embedding. The model outputs logits over the global option vocabulary.

In [7]:

class SimpleGNN(nn.Module):
    def __init__(self, num_classes, num_teachers, num_days, slots_per_day, embed_dim, vocab_size):
        super().__init__()
        self.class_emb = nn.Embedding(max(1,num_classes), embed_dim)
        self.day_emb = nn.Embedding(max(1,num_days), embed_dim)
        self.slot_emb = nn.Embedding(max(1,slots_per_day), embed_dim)
        self.teacher_emb = nn.Embedding(max(1, num_teachers), embed_dim)
        self.msg_fc = nn.Sequential(
            nn.Linear(embed_dim*4, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU()
        )
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, embed_dim//2),
            nn.ReLU(),
            nn.Linear(embed_dim//2, max(1, vocab_size))
        )

    def forward(self, class_ids, day_ids, slot_ids, meta):
        c_emb = self.class_emb(class_ids)
        d_emb = self.day_emb(day_ids)
        s_emb = self.slot_emb(slot_ids)
        # teacher aggregate (mean of teacher embeddings)
        t_count = meta['num_teachers']
        t_idx = torch.arange(t_count, device=class_ids.device)
        t_emb = self.teacher_emb(t_idx)  # [T, D]
        t_mean = t_emb.mean(dim=0, keepdim=True)  # [1, D]
        t_mean_expand = t_mean.expand(c_emb.size(0), -1)
        x = torch.cat([c_emb, d_emb, s_emb, t_mean_expand], dim=1)
        x = self.msg_fc(x)
        logits = self.classifier(x)
        return logits


In [8]:
def dataset_index_checks(dataset, max_class_id, max_day_id, max_slot_id):
    problems = []
    for i, x in enumerate(dataset):
        for name, allowed_max in [('class', max_class_id),
                                  ('day', max_day_id),
                                  ('slot', max_slot_id)]:
            if name not in x:
                problems.append((i, f"missing key {name}"))
                continue
            v = x[name]
            if isinstance(v, int):
                minv = maxv = v
            else:
                v_t = torch.as_tensor(v)
                if v_t.numel() == 0:
                    problems.append((i, f"{name} empty"))
                    continue
                minv, maxv = int(v_t.min().item()), int(v_t.max().item())
            if minv < 0 or maxv > allowed_max:
                problems.append((i, f"{name} indices {minv}-{maxv} out of allowed 0-{allowed_max}"))

        if 'label_idx' in x:
            lblv = int(x['label_idx']) if not isinstance(x['label_idx'], (list, tuple)) else int(x['label_idx'][0])
            if lblv < 0:
                problems.append((i, f"negative label_idx {lblv}"))
    return problems

# Use your earlier printed maxima
probs = dataset_index_checks(dataset, max_class_id=7, max_day_id=5, max_slot_id=7)
print(probs[:50] if probs else "No problems found")


[(2, 'negative label_idx -1'), (5, 'negative label_idx -1'), (7, 'negative label_idx -1'), (10, 'negative label_idx -1'), (17, 'negative label_idx -1'), (22, 'negative label_idx -1'), (27, 'negative label_idx -1'), (43, 'negative label_idx -1'), (44, 'negative label_idx -1'), (46, 'negative label_idx -1'), (47, 'negative label_idx -1'), (48, 'negative label_idx -1'), (49, 'negative label_idx -1'), (50, 'negative label_idx -1'), (51, 'negative label_idx -1'), (52, 'negative label_idx -1'), (53, 'negative label_idx -1'), (56, 'negative label_idx -1'), (59, 'negative label_idx -1'), (60, 'negative label_idx -1'), (61, 'negative label_idx -1'), (64, 'negative label_idx -1'), (65, 'negative label_idx -1'), (67, 'negative label_idx -1'), (68, 'negative label_idx -1'), (69, 'negative label_idx -1'), (72, 'negative label_idx -1'), (73, 'negative label_idx -1'), (79, 'negative label_idx -1'), (82, 'negative label_idx -1'), (83, 'negative label_idx -1'), (84, 'negative label_idx -1'), (85, 'nega

## 6) Training loop
Train the model with CrossEntropy over the global vocab. We ignore unknown labels (`-1`) during loss computation.

In [19]:
def train_model(instances, vocab):
    # Compute dataset-level stats
    max_class_id = max(slot["class"] for inst in instances for slot in inst["slots"])
    max_day_id   = max(slot["day"] for inst in instances for slot in inst["slots"])
    max_slot_id  = max(slot["slot"] for inst in instances for slot in inst["slots"])
    num_labels   = len(vocab)

    print(f"max_class_id: {max_class_id}")
    print(f"max_day_id: {max_day_id}")
    print(f"max_slot_id: {max_slot_id}")
    print(f"num_labels (output classes): {num_labels}")

    # Extract num_teachers safely
    num_teachers = instances[0].get("meta", {}).get("num_teachers", 1)

    # ✅ Fixed: removed num_outputs (not expected by SimpleGNN)
    model = SimpleGNN(
        num_classes=max_class_id + 1,
        num_days=max_day_id + 1,
        num_teachers=num_teachers,
        vocab_size=len(vocab)
    ).to(DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    dataset = TimetablingDataset(instances, vocab)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

    for epoch in range(5):  # you can increase this later
        total_loss = 0
        for class_ids, day_ids, slot_ids, teacher_ids, labels, meta_batch in dataloader:
            class_ids, day_ids, slot_ids, teacher_ids, labels = (
                class_ids.to(DEVICE),
                day_ids.to(DEVICE),
                slot_ids.to(DEVICE),
                teacher_ids.to(DEVICE),
                labels.to(DEVICE)
            )

            optimizer.zero_grad()
            outputs = model(class_ids, day_ids, slot_ids, meta_batch)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}: Loss = {total_loss/len(dataloader):.4f}")

    return model, dataset, vocab
def train_model(instances, vocab, num_epochs=5, batch_size=32):
    dataset = TimetablingDataset(instances, vocab)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Compute max ids from dataset tensors (using correct keys)
    all_class_ids = torch.tensor([item["class"] for item in dataset])
    all_day_ids   = torch.tensor([item["day"] for item in dataset])
    all_slot_ids  = torch.tensor([item["slot"] for item in dataset])
    all_labels    = torch.tensor([item["label_idx"] for item in dataset])

    max_class_id = all_class_ids.max().item()
    max_day_id   = all_day_ids.max().item()
    max_slot_id  = all_slot_ids.max().item()
    num_labels   = all_labels.max().item() + 1

    print(f"max_class_id: {max_class_id}")
    print(f"max_day_id: {max_day_id}")
    print(f"max_slot_id: {max_slot_id}")
    print(f"num_labels (output classes): {num_labels}")

    # Extract num_teachers safely from instances
    num_teachers = instances[0].get("meta", {}).get("num_teachers", 1)

    # ✅ Removed num_slots (not in SimpleGNN)
    model = SimpleGNN(
        num_classes=max_class_id + 1,
        num_days=max_day_id + 1,
        num_teachers=num_teachers,
        vocab_size=len(vocab),
        num_outputs=num_labels
    ).to(DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        for batch_idx, batch in enumerate(dataloader):
            class_ids = batch['class'].to(DEVICE)
            day_ids   = batch['day'].to(DEVICE)
            slot_ids  = batch['slot'].to(DEVICE)
            labels    = batch['label_idx'].to(DEVICE)
            meta_batch = batch['meta']

            # Debug info (first batch only)
            if batch_idx == 0 and epoch == 0:
                print("\nSanity check on first batch:")
                print("class_ids range:", class_ids.min().item(), "-", class_ids.max().item())
                print("day_ids range:", day_ids.min().item(), "-", day_ids.max().item())
                print("slot_ids range:", slot_ids.min().item(), "-", slot_ids.max().item())
                print("labels range:", labels.min().item(), "-", labels.max().item())

            # Safety clamp for labels
            if labels.min() < 0 or labels.max() >= num_labels:
                print(f"⚠️ Invalid labels detected in batch {batch_idx}, clamping.")
                labels = labels.clamp(0, num_labels - 1)

            # Forward pass
            logits = model(class_ids, day_ids, slot_ids, meta_batch)
            loss = criterion(logits, labels)

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs} | Loss: {loss.item():.4f}")

    return model, dataset, vocab


# Run training if instances exist
if instances:
    model, dataset, vocab = train_model(instances, vocab)
else:
    model, dataset = None, None

max_class_id: 7
max_day_id: 5
max_slot_id: 7
num_labels (output classes): 120


TypeError: SimpleGNN.__init__() got an unexpected keyword argument 'num_outputs'

## 7) Inference and export
Predict for each slot and write predictions to JSON files under `/mnt/data/predicted/` with names `predicted_instance_000.json`, etc.

**Note**: this decoding simply picks the most probable `(subject,teacher)` pair from the global vocab for each slot. For stricter feasibility, add masking and a repair step.

In [None]:
def train_model(instances, vocab, num_epochs=5, batch_size=32):
    dataset = TimetablingDataset(instances, vocab)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Compute max ids from dataset tensors (using correct keys)
    all_class_ids = torch.tensor([item["class"] for item in dataset])
    all_day_ids   = torch.tensor([item["day"] for item in dataset])
    all_slot_ids  = torch.tensor([item["slot"] for item in dataset])
    all_labels    = torch.tensor([item["label_idx"] for item in dataset])

    max_class_id = all_class_ids.max().item()
    max_day_id   = all_day_ids.max().item()
    max_slot_id  = all_slot_ids.max().item()
    num_labels   = all_labels.max().item() + 1

    print(f"max_class_id: {max_class_id}")
    print(f"max_day_id: {max_day_id}")
    print(f"max_slot_id: {max_slot_id}")
    print(f"num_labels (output classes): {num_labels}")

    model = SimpleGNN(
        num_classes=max_class_id + 1,
        num_days=max_day_id + 1,
        num_slots=max_slot_id + 1,
        num_teachers=dataset.meta.get('num_teachers', 1),  # safer access
        vocab_size=len(vocab),
        num_outputs=num_labels
    ).to(DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        for batch_idx, batch in enumerate(dataloader):
            class_ids = batch['class'].to(DEVICE)
            day_ids   = batch['day'].to(DEVICE)
            slot_ids  = batch['slot'].to(DEVICE)
            labels    = batch['label_idx'].to(DEVICE)
            meta_batch = batch['meta']

            # Debug info (first batch only)
            if batch_idx == 0 and epoch == 0:
                print("\nSanity check on first batch:")
                print("class_ids range:", class_ids.min().item(), "-", class_ids.max().item())
                print("day_ids range:", day_ids.min().item(), "-", day_ids.max().item())
                print("slot_ids range:", slot_ids.min().item(), "-", slot_ids.max().item())
                print("labels range:", labels.min().item(), "-", labels.max().item())

            # Safety clamp for labels
            if labels.min() < 0 or labels.max() >= num_labels:
                print(f"⚠️ Invalid labels detected in batch {batch_idx}, clamping.")
                labels = labels.clamp(0, num_labels - 1)

            # Forward pass
            logits = model(class_ids, day_ids, slot_ids, meta_batch)
            loss = criterion(logits, labels)

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs} | Loss: {loss.item():.4f}")

    return model, dataset, vocab


# Run training if instances exist
if instances:
    model, dataset, vocab = train_model(instances, vocab)
else:
    model, dataset = None, None

max_class_id: 7
max_day_id: 5
max_slot_id: 7
num_labels (output classes): 120


AttributeError: 'TimetablingDataset' object has no attribute 'meta'

In [12]:

def predict_instance(model, instance, vocab_inv):
    meta = instance['input']
    examples = instance_to_examples(instance)
    device = DEVICE
    model.eval()
    num_classes = int(meta['num_classes'])
    num_days = int(meta['num_days'])
    slots_per_day = int(meta['slots_per_day'])
    output = {str(c): [[[-1,-1] for _ in range(slots_per_day)] for _ in range(num_days)] for c in range(num_classes)}
    with torch.no_grad():
        for ex in examples:
            class_id = torch.tensor([ex['class']], dtype=torch.long, device=device)
            day_id = torch.tensor([ex['day']], dtype=torch.long, device=device)
            slot_id = torch.tensor([ex['slot']], dtype=torch.long, device=device)
            logits = model(class_id, day_id, slot_id, {'num_teachers': meta['num_teachers']})
            if logits.size(1) == 1:
                choice_idx = 0
            else:
                probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
                choice_idx = int(probs.argmax())
            pair = vocab_inv.get(choice_idx, (-1,-1))
            output[str(ex['class'])][ex['day']][ex['slot']] = [int(pair[0]), int(pair[1])]
    return {'input': meta, 'output': output}

# Run prediction for all instances and save
if instances and model:
    os.makedirs(PREDICT_DIR, exist_ok=True)
    for i, inst in enumerate(instances):
        pred = predict_instance(model, inst, vocab_inv)
        out_path = os.path.join(PREDICT_DIR, f'predicted_instance_{i:03d}.json')
        with open(out_path, 'w') as f:
            json.dump(pred, f, indent=2)
    print('Predictions saved to', PREDICT_DIR)
else:
    print('No model or instances available for prediction.')


NameError: name 'model' is not defined

## 8) Basic evaluation metrics
We compute per-slot accuracy (fraction of slots where the predicted (subject,teacher) equals the ground truth) and average per-instance Hamming similarity (normalized).

In [None]:

def evaluate_predictions(instances, predict_dir):
    pred_paths = sorted(glob.glob(os.path.join(predict_dir, 'predicted_instance_*.json')))
    if not pred_paths:
        print('No predicted files found in', predict_dir); return
    accuracies = []
    hamming_scores = []
    for i, p in enumerate(pred_paths):
        with open(p,'r') as f:
            pred = json.load(f)
        gold = instances[i].get('output', {})
        # compare per-slot exact matches (excluding missing gold)
        total = 0; correct = 0
        total_slots = 0; diff = 0
        for c_str, grid in pred['output'].items():
            ggrid = gold.get(c_str, None)
            for d,row in enumerate(grid):
                for s,cell in enumerate(row):
                    total_slots += 1
                    pred_pair = tuple(cell)
                    if ggrid is None:
                        # no gold -> skip
                        continue
                    gold_pair = tuple(ggrid[d][s])
                    total += 1
                    if pred_pair == gold_pair:
                        correct += 1
                    else:
                        diff += 1
        acc = correct / total if total>0 else float('nan')
        accuracies.append(acc)
        # simple hamming-like score = 1 - (diff / total_slots)
        ham = 1.0 - (diff / total_slots) if total_slots>0 else float('nan')
        hamming_scores.append(ham)
    print('Per-instance accuracies (mean):', np.nanmean(accuracies), 'std:', np.nanstd(accuracies))
    print('Hamming-like scores (mean):', np.nanmean(hamming_scores), 'std:', np.nanstd(hamming_scores))

if instances and os.path.exists(PREDICT_DIR):
    evaluate_predictions(instances, PREDICT_DIR)
else:
    print('No predictions to evaluate.')


## 9) Next steps and improvements
- **Masking & constraint-aware decoding**: restrict logits to feasible `(subject,teacher)` options per slot before selecting argmax. This prevents infeasible assignments.
- **Post-processing repair**: run a greedy repair to enforce teacher exclusivity and subject-session counts.
- **Use a proper GNN library** (PyTorch Geometric / DGL) to encode relations (class ↔ teacher ↔ subject) and message passing.
- **Structured output models**: pointer networks, sequence models, or combinatorial decoders for better structured predictions.

You're ready to run this notebook. Place your JSON instances in `../database` and execute the cells. The notebook will save predictions under `/mnt/data/predicted/`.