### Discrete CoT Model, MNNS Task

In [None]:
import torch
import random
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Config, GPT2LMHeadModel
from torch.optim import AdamW
import torch.nn as nn
import numpy as np

# seed everything for reproducibility
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)  # for multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
def tokenize_line(line, token2id_map):
    tokens = line.strip().split()
    return [token2id_map[t] for t in tokens]

class MathExpressionDataset(Dataset):
    def __init__(self, tokenized_samples, max_len, token2id):
        self.samples = tokenized_samples
        self.max_len = max_len
        self.token2id = token2id

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        token_ids = self.samples[idx]
        # Truncate if longer than max_len
        if len(token_ids) > self.max_len:
            token_ids = token_ids[:self.max_len]

        # Create attention mask
        attention_mask = [1] * len(token_ids)

        # Pad if shorter
        while len(token_ids) < self.max_len:
            token_ids.append(self.token2id["<PAD>"])
            attention_mask.append(0)

        input_ids = torch.tensor(token_ids, dtype=torch.long)
        attention_mask = torch.tensor(attention_mask, dtype=torch.long)

        # For a causal LM, labels are the same as input_ids
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": input_ids.clone()
        }

In [None]:
import itertools

def generate_vocab():
    special_tokens = ["<PAD>", "<BOS>", "<EOS>", "->", "+", "-"]

    # Digits 0..9
    digit_tokens = [f"D{i}" for i in range(10)]

    # Partial sums from -36..36
    sum_tokens = [f"S{i}" for i in range(-36, 37)]

    # Combine them into a single list
    vocab = special_tokens + digit_tokens + sum_tokens

    # Create mapping from token to ID and back
    token2id = {token: idx for idx, token in enumerate(vocab)}
    id2token = {idx: token for token, idx in token2id.items()}

    return vocab, token2id, id2token

def generate_text_dataset(digit_range=range(1, 6), seq_length=4):
    dataset_text = []

    for seq in itertools.product(digit_range, repeat=seq_length):
        best_final_sum = None
        best_partial_sums = None

        # Try all sign patterns for the 4 digits (2^4 = 16),
        # starting partial_sum = 0, then apply +/- for each digit in seq.
        for signs in itertools.product(["+", "-"], repeat=seq_length):
            partial_sum = 0
            partial_sums = [partial_sum]  # [0, x, y, ...]

            for i in range(seq_length):
                if signs[i] == "+":
                    partial_sum += seq[i]
                else:
                    partial_sum -= seq[i]
                partial_sums.append(partial_sum)

            # Check if final sum is >= 0
            if partial_sum >= 0:
                # If this is the first non-negative final sum found
                # or if it's smaller than our current best
                if best_final_sum is None or partial_sum < best_final_sum:
                    best_final_sum = partial_sum
                    best_partial_sums = partial_sums

        # If we found at least one sign pattern that yields a non-negative sum,
        # record the best partial sums in textual form.
        if best_final_sum is not None:
            digit_seq_tokens = [f"D{d}" for d in seq]
            # omit the initial partial sum (index 0) so only 4 sums remain
            # best_partial_sums has length 5 => skip best_partial_sums[0]
            sum_seq_tokens = [f"S{ps}" for ps in best_partial_sums[1:]]

            line_tokens = ["<BOS>"] + digit_seq_tokens + ["->"] + sum_seq_tokens + ["<EOS>"]
            line_text = " ".join(line_tokens)
            dataset_text.append(line_text)
    
    return dataset_text

In [None]:
def encode_prompt(digit_seq, token2id):
    # We'll build a prompt like: <BOS> D5 D3 D2 D4 ->
    tokens = ["<BOS>"] + [f"D{d}" for d in digit_seq] + ["->"]
    return [token2id[t] for t in tokens]

def decode_tokens(token_ids, id2token):
    return [id2token[i] for i in token_ids]

In [None]:
def get_token_loss(outputs, labels, seq_length):
    # Minimal additional code to get token-level losses:
    logits = outputs.logits
    # Shift the logits and labels by one for causal LM:
    shifted_logits = logits[..., :-1, :].contiguous()
    shifted_labels = labels[..., 1:].contiguous()

    loss_fct = nn.CrossEntropyLoss(reduction="none")
    per_token_loss = loss_fct(
        shifted_logits.view(-1, shifted_logits.size(-1)),
        shifted_labels.view(-1)
    )
    # Reshape into (batch_size, sequence_length - 1) for easy interpretation
    per_token_loss = per_token_loss.view(shifted_labels.size())
    return per_token_loss[:, seq_length+1:2*seq_length+1].sum(dim=0).cpu().detach().numpy()

In [None]:
from collections import defaultdict

def permutation_train_val_split(dataset_text, train_ratio=0.8):
    """
    This function takes a dataset of text lines, groups them by the sorted digits
    contained in each line, shuffles the groups, and splits them into training
    and validation sets based on a specified ratio.
    """
    # Group lines by sorted digits
    groups = defaultdict(list)
    for line in dataset_text:
        tokens = line.split()
        arrow_idx = tokens.index("->")
        digit_tokens = tokens[1:arrow_idx]  # ignoring <BOS>
        # parse digits from lines like "D5"
        digits = tuple(sorted(int(dt[1:]) for dt in digit_tokens))
        groups[digits].append(line)

    # Shuffle the group keys
    group_keys = list(groups.keys())
    random.shuffle(group_keys)

    # Split group keys 80/20
    split_idx = int(train_ratio * len(group_keys))
    train_keys = group_keys[:split_idx]
    val_keys = group_keys[split_idx:]

    # Gather lines
    train_lines = []
    val_lines = []
    for k in train_keys:
        train_lines.extend(groups[k])
    for k in val_keys:
        val_lines.extend(groups[k])

    return train_lines, val_lines

In [None]:
def parse_line(line):
    tokens = line.split()  # e.g. ["<BOS>", "D5", "D3", "D2", "D4", "->", "S5", "S2", "S4", "S0", "<EOS>"]
    arrow_idx = tokens.index("->")  # location of '->'

    # digits: everything after <BOS> up to '->'
    digit_tokens = tokens[1:arrow_idx]  # ignore <BOS>
    # partial sums: everything after '->' up to <EOS>
    sum_tokens = tokens[arrow_idx + 1:-1]  # ignore <EOS>

    # Convert "D5" -> 5, "S5" -> 5, etc.
    digits = [int(dt[1:]) for dt in digit_tokens]  # strip the first char 'D'
    partial_sums = [int(st[1:]) for st in sum_tokens]  # strip the first char 'S'
    return digits, partial_sums

def generate_partial_sums_step_by_step(
    model,
    digits,
    token2id,
    id2token,
    device,
    max_sums=4,
    do_sample=False,
    temperature=1.0
):
    # Build the initial prompt (no partial sums yet).
    # Example: "<BOS> D5 D3 D2 D4 ->"
    prompt_tokens = ["<BOS>"] + [f"D{d}" for d in digits] + ["->"]

    # Convert each token string to its ID.
    input_ids = torch.tensor([[token2id[t] for t in prompt_tokens]], dtype=torch.long).to(device)

    predicted_sums = []
    for _ in range(max_sums):
        # Generate exactly 1 token from the model (greedy).
        # pad_token_id is important to avoid warnings if the sequence grows.
        out = model.generate(
            input_ids=input_ids,
            max_new_tokens=1,
            do_sample=do_sample,
            pad_token_id=token2id["<PAD>"],
            temperature=temperature
        )
        # The last generated token is out[0, -1].
        new_token_id = out[0, -1].item()
        new_token_str = id2token[new_token_id]

        # If it looks like "Sxxx", parse out the integer value, else store None.
        if new_token_str.startswith("S"):
            val = int(new_token_str[1:])  # e.g. "S5" -> 5
        else:
            val = None

        predicted_sums.append(val)

        # Update input_ids to include the newly generated token
        input_ids = out

    return predicted_sums


def evaluate_model(model, test_dataset_text, token2id, id2token, device):
    """
    For each line in the test dataset:
      - Parse out the digits and the final 'ground-truth' partial sums.
      - Generate partial sums step by step.
      - Check if the final predicted sum == the final ground-truth sum:
    """
    correct_count = 0

    for line in test_dataset_text:
        digits, gt_sums = parse_line(line)
        # ground_truth_final = minimal non-negative sum (or whatever is in the dataset)
        ground_truth_final = gt_sums[-1]

        # Model's predicted partial sums
        predicted_sums = generate_partial_sums_step_by_step(
            model, digits, token2id, id2token, device, max_sums=len(gt_sums)
        )

        # Check validity of partial sums and final sum
        # if is_valid_path(digits, predicted_sums) and (predicted_sums[-1] == ground_truth_final):
        #     correct_count += 1
        
        if predicted_sums[-1] == ground_truth_final:  # For fair comparison, compare the last token only instead
            correct_count += 1

    accuracy = correct_count / len(test_dataset_text)
    return accuracy, correct_count

In [None]:
SEQ_LENGTH = 4
MAX_SEQ_LEN = 2 * SEQ_LENGTH + 3
EMBEDDING_DIM = 24
DIGIT_RANGE = range(1, 10) # 1..9 digits
BATCH_SIZE = 16
SPLIT_METHOD = "random_permutation"  
NUM_EPOCHS = 1000
OUTPUT_DIR = "moss-test"
NUM_LAYERS = 2
NUM_HEADS = 2

print(f"Max Seq Len: {MAX_SEQ_LEN}, "
          f"Embedding Dim: {EMBEDDING_DIM}, "
          f"Digit Range: {DIGIT_RANGE}, "
          f"Batch Size: {BATCH_SIZE}, "
          f"Seq Length: {SEQ_LENGTH}, "
          f"Split Method: {SPLIT_METHOD}, "
          f"Num Epochs: {NUM_EPOCHS}, "
          f"Output Dir: {OUTPUT_DIR}, "
          f"Num Layers: {NUM_LAYERS}, "
          f"Num Heads: {NUM_HEADS}")

#### Train for 1000 epochs on MNNS task, we'll compare final validation accuracies.

In [None]:
# Generate vocab, dataset
vocab, token2id, id2token = generate_vocab()
dataset_text = generate_text_dataset(DIGIT_RANGE, SEQ_LENGTH)
vocab_size = len(vocab)

print(f"Vocab size = {vocab_size}")
print(f"Number of valid sequences in dataset: {len(dataset_text)}")
print("Sample line:", dataset_text[0])

train_lines, val_lines = permutation_train_val_split(dataset_text, train_ratio=0.8)
train_data = [tokenize_line(line, token2id) for line in train_lines]
val_data = [tokenize_line(line, token2id) for line in val_lines]
random.shuffle(train_data)
random.shuffle(val_data)

# Create the model configuration
config = GPT2Config(
    vocab_size=vocab_size,
    n_positions=MAX_SEQ_LEN,
    n_embd=EMBEDDING_DIM,
    n_layer=NUM_LAYERS,
    n_head=NUM_HEADS
)

# Instantiate the model
model = GPT2LMHeadModel(config)

# Create datasets and loaders
train_dataset = MathExpressionDataset(train_data, max_len=MAX_SEQ_LEN, token2id=token2id)
val_dataset = MathExpressionDataset(val_data, max_len=MAX_SEQ_LEN, token2id=token2id)

# Create DataLoader for training and validation
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

optimizer = AdamW(model.parameters(), lr=1e-4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

train_losses = []
train_losses_tokens = []
val_losses = []
val_losses_tokens = []
val_accuracies = []
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0.
    train_loss_tokens = np.zeros((SEQ_LENGTH, ))
    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        loss = outputs.loss

        #Calculate token-level loss
        train_loss_tokens += get_token_loss(outputs, labels, SEQ_LENGTH)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    train_loss_tokens /= len(train_loader)
    train_loss_tokens /= BATCH_SIZE
    train_losses.append(avg_train_loss)
    train_losses_tokens.append(train_loss_tokens.tolist())

    # Validation
    model.eval()
    val_loss = 0.0
    val_loss_tokens = np.zeros((SEQ_LENGTH, )) 
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            val_loss += outputs.loss.item()

            # Calculate token-level validation loss
            val_loss_tokens += get_token_loss(outputs, labels, SEQ_LENGTH)

    avg_val_loss = val_loss / len(val_loader)
    val_loss_tokens /= len(val_loader)
    val_loss_tokens /= BATCH_SIZE
    val_losses_tokens.append(val_loss_tokens.tolist())
    val_losses.append(avg_val_loss)

    val_accuracy, _ = evaluate_model(model, val_lines, token2id, id2token, device)
    val_accuracies.append(val_accuracy)

    print(f"Epoch {epoch + 1} | "
          f"Train Loss: {avg_train_loss:.4f} | "
          f"Val Loss: {avg_val_loss:.4f} | "
          f"Val Accuracy: {val_accuracy:.2%} | "
          f"Val Loss Tokens: " + "-".join([f"{val_loss_tokens[i]:.4f}" for i in range(len(val_loss_tokens))]) + " | "
          f"Train Loss Tokens: " + "-".join([f"{train_loss_tokens[i]:.4f}" for i in range(len(train_loss_tokens))]) + " | ")

# Plot losses and accuracies
epochs_range = range(1, NUM_EPOCHS+1)
file_name = (
    f"Digit{DIGIT_RANGE.start}-{DIGIT_RANGE.stop}"
    f"_Seq{SEQ_LENGTH}_Emb{EMBEDDING_DIM}_Split{SPLIT_METHOD}"
    f"_Batch{BATCH_SIZE}_Epochs{NUM_EPOCHS}"
)

### CoT2 Model, MNNS Task
We will compare the final performance of this model against the discrete model.

In [None]:
from torch.utils.data import Dataset

class FourDigitsSoftDataset(Dataset):
    """
    For each 4-digit sequence in [1..5], we:
      1) Start from partial_sum=0.
      2) For i in [1..4], we expand partial sums by +/- seq[i],
         resulting in 2^i partial sums, each with uniform probability 1/(2^i).
         Build a dist_vec for step i, setting dist_vec[S{ps}] = 1/(2^i) if "S{ps}" in vocab.
      3) Among those 16 final sums, pick the min non-negative sum as final_hard_label = S{best_sum}.
      4) If no non-negative sum is found, skip the sequence.
    """
    def __init__(self, token2id, digit_range=range(1, 6), seq_length=4):
        super().__init__()
        self.token2id = token2id
        self.vocab_size = len(token2id)
        self.examples = []

        for seq in itertools.product(digit_range, repeat=seq_length):
            # We'll accumulate partial sums at each step,
            # always starting from 0 for step 0.
            partial_sums_at_step = []
            current_partial_sums = [0]  # step 0
            partial_sums_at_step.append(current_partial_sums)

            # Build dist_steps (4 steps, each distribution over S{ps})
            dist_steps = []

            for i in range(seq_length):
                # Expand to 2^i+1 partial sums
                new_sums = []
                digit = seq[i]
                for ps in current_partial_sums:
                    new_sums.append(ps + digit)
                    new_sums.append(ps - digit)
                current_partial_sums = new_sums
                partial_sums_at_step.append(current_partial_sums)

                # Now build a distribution vector: each sum has probability 1/2^(i+1)
                dist_vec = torch.zeros(self.vocab_size)
                prob = 1.0 / len(current_partial_sums)  # = 1/(2^(i+1))
                for ps_val in current_partial_sums:
                    key = f"S{ps_val}"
                    if key in self.token2id:
                        dist_vec[self.token2id[key]] += prob
                # dist_steps.append(torch.sqrt(dist_vec))
                dist_steps.append(dist_vec) 

            # current_partial_sums now has 16 final sums (2^4)
            # pick smallest non-negative final sum
            final_sums = current_partial_sums
            best_sum = None
            for candidate in sorted(final_sums):
                if candidate >= 0:
                    best_sum = candidate
                    break

            if best_sum is None:
                # skip if no non-negative sum
                continue

            # final label => S{best_sum}
            final_label_str = f"S{best_sum}"
            if final_label_str not in self.token2id:
                # skip if not in vocab
                continue
            final_label_id = self.token2id[final_label_str]

            # Build the prompt: <BOS> + digits
            prompt_tokens = ["<BOS>"] + [f"D{d}" for d in seq]
            prompt_ids = []
            for pt in prompt_tokens:
                if pt in self.token2id:
                    prompt_ids.append(self.token2id[pt])

            ex = {
                "prompt_ids": torch.tensor(prompt_ids, dtype=torch.long),
                "dist_steps": dist_steps,  # 4 distributions
                "final_hard_label": final_label_id
            }
            self.examples.append(ex)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]


def collate_fn(batch):
    max_len = max(len(ex["prompt_ids"]) for ex in batch)
    prompt_list = []
    attn_list = []
    dist_steps_list = []
    final_labels_list = []

    for ex in batch:
        p = ex["prompt_ids"]
        pad_len = max_len - len(p)
        padded = torch.cat([p, torch.full((pad_len,), 0, dtype=torch.long)])
        attn = torch.cat([torch.ones(len(p)), torch.zeros(pad_len)])

        prompt_list.append(padded.unsqueeze(0))
        attn_list.append(attn.unsqueeze(0))
        dist_steps_list.append(ex["dist_steps"])
        final_labels_list.append(ex["final_hard_label"])

    prompt_ids = torch.cat(prompt_list, dim=0)     # (B, max_len)
    attention_mask = torch.cat(attn_list, dim=0)  # (B, max_len)

    return {
        "prompt_ids": prompt_ids,
        "attention_mask": attention_mask,
        "dist_steps": dist_steps_list,      # list of lists
        "final_labels": final_labels_list
    }

In [None]:
from torch.utils.data import Subset
from collections import defaultdict

def permutation_train_val_split_continuous(dataset, id2token, seq_length, train_ratio=0.8):
    """
    Ensures that all permutations of the same digit sequence go entirely into
    train or val. We do this by grouping examples based on sorted digits in
    their prompt, then performing a group-level random split.

    Args:
      dataset: A PyTorch Dataset whose __getitem__ returns a dict like:
               {
                 "prompt_ids": Tensor,  # shape [prompt_len]
                 "dist_steps": ...,
                 "final_hard_label": ...
               }
      id2token: mapping from token ID to string token, e.g. "D5"
      seq_length: how many 'D' tokens in each prompt (e.g. 4)
      train_ratio: fraction of groups to go to train (e.g. 0.8 => 80% train)

    Returns:
      train_data, val_data: Subset objects pointing to the train/val samples.
    """
    # Build groups: canonical sorted digits -> list of example indices
    groups = defaultdict(list)

    for idx in range(len(dataset)):
        ex = dataset[idx]     # e.g. {"prompt_ids": ..., "dist_steps":..., "final_hard_label":...}
        prompt_ids = ex["prompt_ids"]

        # The digits are typically in the tokens from index 1..(1+seq_length) ignoring <BOS>.
        digit_ids = prompt_ids[1: 1 + seq_length].tolist()
        digit_strs = [id2token[d] for d in digit_ids]  # e.g. ["D5", "D1", ...]
        digits = [int(s[1:]) for s in digit_strs]      # strip off the "D", e.g. [5, 1, ...]
        canon_digits = tuple(sorted(digits))           # canonical form, e.g. (1,5,5,4)

        groups[canon_digits].append(idx)

    # Shuffle group keys
    group_keys = list(groups.keys())
    random.shuffle(group_keys)

    # Split group keys based on train_ratio
    split_idx = int(train_ratio * len(group_keys))
    train_keys = group_keys[:split_idx]
    val_keys = group_keys[split_idx:]

    # Gather indices
    train_indices = []
    val_indices = []
    for k in train_keys:
        train_indices.extend(groups[k])
    for k in val_keys:
        val_indices.extend(groups[k])

    # Create Subset objects for train and val
    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)

    return train_subset, val_subset

In [None]:
import torch.nn.functional as F

def cross_entropy_distribution_batch(logits, target_dist):
    """
    Batch version
    logits:  (B, vocab_size)
    dist:    (B, vocab_size)  (already on same device as logits)
    Returns (B,) => one scalar CE loss per example, same formula as above.
    """
    EPS = 1e-8
    log_probs = F.log_softmax(logits, dim=-1)  # (B, vocab_size)
    return -(target_dist * log_probs).sum(dim=-1) + (target_dist*torch.log(target_dist + EPS)).sum(dim=-1)

In [None]:
def train_soft_steps_batch(
    model,
    data_loader,
    optimizer,
    device,
    loss_steps_to_include,
    supervision_type,
    seq_length,
    num_soft_steps
):
    model.train()
    total_loss = 0.0
    steps = 0

    ce_loss_fn = nn.CrossEntropyLoss()

    # keep track of partial losses
    train_loss_tokens = np.zeros(seq_length, dtype=np.float32)
    embedding_matrix = model.transformer.wte.weight  # (vocab_size, n_embd)

    for batch in data_loader:
        # Move prompt and attention mask to device
        prompt_ids = batch["prompt_ids"].to(device)        # (B, max_len)
        attention_mask = batch["attention_mask"].to(device)  # (B, max_len)
        final_labels = torch.tensor(batch["final_labels"], device=device)  # (B,)

        # Convert batch["dist_steps"] from list-of-lists to a single Tensor
        # shape => (B, T, vocab_size).
        # batch["dist_steps"] is a list of length B; each item is a list of T vectors
        all_dist_steps = []
        for ex_dists in batch["dist_steps"]:
            # ex_dists is e.g. [Tensor(vocab_size), Tensor(vocab_size), ...] or lists
            ex_tensors = []
            for d in ex_dists:
                # If it's already a Tensor, just .to(device); if it's a list, convert
                if isinstance(d, torch.Tensor):
                    ex_tensors.append(d.to(device))
                else:
                    ex_tensors.append(torch.tensor(d, dtype=torch.float, device=device))
            # Stack them: shape => (T, vocab_size)
            stacked = torch.stack(ex_tensors, dim=0)
            all_dist_steps.append(stacked)
        # Now we have a Python list of length B, each shape (T, vocab_size).
        # If T is the same for every example, we can stack directly:
        dist_steps = torch.stack(all_dist_steps, dim=0)  # (B, T, vocab_size)
        # If T differs across examples, you'd need pad_sequence instead:
        # dist_steps = nn.utils.rnn.pad_sequence(all_dist_steps, batch_first=True)

        batch_size = prompt_ids.size(0)
        batch_loss = torch.zeros((), device=device)

        # One forward pass for the entire batch to get "past_key_values"
        outputs = model(input_ids=prompt_ids, attention_mask=attention_mask, use_cache=True)
        past_key_values = outputs.past_key_values

        # We'll accumulate partial losses in a small tensor
        partial_losses = torch.zeros(num_soft_steps + 1, device=device)

        # Loop over each "soft step" in parallel for the entire batch
        for step_idx in range(num_soft_steps):
            # Get last logits for the batch
            last_logits = outputs.logits[:, -1, :]  # (B, vocab_size)

            # Cross entropy wrt. teacher distribution at this step
            dist_vec = dist_steps[:, step_idx, :]    # (B, vocab_size)
            step_ce_vals = cross_entropy_distribution_batch(last_logits, dist_vec)
            step_loss = step_ce_vals.mean()
            partial_losses[step_idx] = step_loss.detach()

            if str(step_idx) in loss_steps_to_include:
                batch_loss += step_loss

            # Build "soft embedding" for the entire batch
            if supervision_type == "soft_teacher":
                token_dist_vec = F.softmax(last_logits, dim=-1)  # (B, vocab_size)
                e_soft = token_dist_vec @ embedding_matrix       # (B, n_embd)
            else:
                # "hard_teacher"
                e_soft = dist_vec @ embedding_matrix             # (B, n_embd)

            # Feed that embedding as the next token for all B examples
            out2 = model(
                inputs_embeds=e_soft.unsqueeze(1),  # (B, 1, n_embd)
                past_key_values=past_key_values,
                use_cache=True
            )
            outputs = out2
            past_key_values = out2.past_key_values

        # Final "hard" step => CE with final_label
        last_logits = outputs.logits[:, -1, :]  # (B, vocab_size)
        final_loss = ce_loss_fn(last_logits, final_labels)
        partial_losses[-1] = final_loss.detach()

        if "h" in loss_steps_to_include:
            batch_loss += final_loss

        # Backprop once for the entire batch
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

        total_loss += batch_loss.item()
        steps += 1

    avg_loss = total_loss / max(1, steps)
    return avg_loss, train_loss_tokens / max(1, steps)

In [None]:
def cross_entropy_distribution(logits, target_dist):
    """
    logits: shape (vocab_size,)
    target_dist: shape (vocab_size,)
    Returns scalar: cross-entropy = - sum_{k} p(k) log softmax(logits)[k]
    """
    EPS = 10**(-8)
    log_probs = F.log_softmax(logits, dim=-1)
    return - (target_dist * log_probs).sum() + (target_dist*torch.log(target_dist + EPS)).sum()

@torch.no_grad()
def eval_soft_steps_acc(model, data_loader, device, seq_length):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    steps = 0

    ce_loss_fn = nn.CrossEntropyLoss()
    embedding_matrix = model.transformer.wte.weight
    val_loss_tokens = np.zeros((seq_length,))

    def get_last_logits(o):
        return o.logits[:, -1, :]

    for batch in data_loader:
        prompt_ids = batch["prompt_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        dist_steps_list = batch["dist_steps"]
        final_labels = batch["final_labels"]
        batch_size = prompt_ids.size(0)

        batch_loss = 0.0  # python float is fine; each step_loss will track grad separately
        val_loss_tokens_batch = np.zeros((seq_length,))

        for i in range(batch_size):
            # Forward the prompt in discrete form
            pi = prompt_ids[i].unsqueeze(0)
            am = attention_mask[i].unsqueeze(0)

            outputs = model(pi, attention_mask=am, use_cache=True)
            pkv = outputs.past_key_values

            item_loss = 0.0

            # Distribution steps (1..4)
            for count, dist_vec in enumerate(dist_steps_list[i]):
                if count == len(dist_steps_list[i]) - 1:
                    break
                last_logits = get_last_logits(outputs).squeeze(0)  # shape (vocab_size,)
                step_loss = cross_entropy_distribution(last_logits, dist_vec.to(device))
                item_loss += step_loss.item()
                val_loss_tokens_batch[count] += step_loss.item()

                # build "soft" embedding => e_soft = sum_v dist_vec[v]*embedding_matrix[v]
                token_dist_vec = F.softmax(last_logits, dim=-1)
                e_soft = torch.matmul(token_dist_vec, embedding_matrix)  # shape (n_embd,)
                out2 = model(
                    inputs_embeds=e_soft.unsqueeze(0).unsqueeze(1),  # (1,1,n_embd)
                    past_key_values=pkv,
                    use_cache=True
                )
                pkv = out2.past_key_values
                outputs = out2

            # final step => measure CE to final_label, also do discrete "prediction" for accuracy
            last_logits = get_last_logits(outputs).squeeze(0)  # (vocab_size,)
            final_label_id = final_labels[i]

            # we can do cross-entropy with the final label
            final_loss = ce_loss_fn(last_logits.unsqueeze(0), torch.tensor([final_label_id], device=device))
            item_loss += final_loss.item()
            val_loss_tokens_batch[-1] += final_loss.item()

            # for accuracy, pick argmax
            predicted_id = last_logits.argmax(dim=-1).item()
            if predicted_id == final_label_id:
                total_correct += 1

            batch_loss += item_loss

        # average over batch
        batch_loss /= batch_size
        val_loss_tokens_batch /= batch_size
        val_loss_tokens += val_loss_tokens_batch
        total_loss += batch_loss
        total_samples += batch_size
        steps += 1

    avg_loss = total_loss / steps
    val_loss_tokens /= steps
    accuracy = total_correct / total_samples
    return avg_loss, accuracy, val_loss_tokens

#### Train for 1000 epochs on MNNS task, we'll compare final validation accuracies.

In [None]:
LOSS_STEPS = ["0", "1", "2", "h"]  # 3 soft steps + "h" for final hard label
SUPERVISION = "hard_teacher"

print(f"Max Seq Len: {MAX_SEQ_LEN}, "
      f"Embedding Dim: {EMBEDDING_DIM}, "
      f"Digit Range: {DIGIT_RANGE}, "
      f"Batch Size: {BATCH_SIZE}, "
      f"Seq Length: {SEQ_LENGTH}, "
      f"Loss Steps: {LOSS_STEPS}, "
      f"Split Method: {SPLIT_METHOD}, "
      f"Num Epochs: {NUM_EPOCHS}, "
      f"Output Dir: {OUTPUT_DIR}, "
      f"Supervision: {SUPERVISION}, "
      f"Num Layers: {NUM_LAYERS}, "
      f"Num Heads: {NUM_HEADS}")

# Build vocab & model
vocab, token2id, id2token = generate_vocab()
vocab_size = len(vocab)
print("Vocab size =", vocab_size)

config = GPT2Config(
    vocab_size=vocab_size,
    n_positions=MAX_SEQ_LEN,
    n_embd=EMBEDDING_DIM,
    n_layer=NUM_LAYERS,
    n_head=NUM_HEADS
)
model = GPT2LMHeadModel(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Create the dataset => 4 digits as prompt + 4 distribution steps + final label
dataset = FourDigitsSoftDataset(token2id=token2id, digit_range=DIGIT_RANGE, seq_length=SEQ_LENGTH)

train_data, val_data = permutation_train_val_split_continuous(
    dataset=dataset,
    id2token=id2token,
    seq_length=SEQ_LENGTH,
    train_ratio=0.8
)

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Training
train_losses = []
val_losses = []
val_accuracies = []
train_losses_tokens = []
val_losses_tokens = []
for epoch in range(NUM_EPOCHS):
    train_loss, train_loss_tokens = train_soft_steps_batch(
        model,
        train_loader,
        optimizer,
        device,
        LOSS_STEPS,
        SUPERVISION,
        SEQ_LENGTH,
        num_soft_steps=SEQ_LENGTH - 1
    )
    val_loss, val_acc, val_loss_tokens = eval_soft_steps_acc(model, val_loader, device, SEQ_LENGTH)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    train_losses_tokens.append(train_loss_tokens)
    val_losses_tokens.append(val_loss_tokens)
    print(f"Epoch {epoch + 1} | "
          f"Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | "
          f"Val Accuracy: {val_acc:.2%} | "
          f"Val Loss Tokens: " + "-".join([f"{val_loss_tokens[i]:.4f}" for i in range(len(val_loss_tokens))]) + " | ")

# Plot losses & accuracies
epochs_range = range(1, NUM_EPOCHS + 1)
file_name = (
    f"Digit{DIGIT_RANGE.start}-{DIGIT_RANGE.stop}" + "_seq" + str(SEQ_LENGTH) + "_emb" +
    str(EMBEDDING_DIM) + "_steps" + "".join(LOSS_STEPS) + "_split" + SPLIT_METHOD +
    "_batch" + str(BATCH_SIZE) + "_epochs" + str(NUM_EPOCHS)
)

# save the model
torch.save(model.state_dict(), f"models/{OUTPUT_DIR}continuous_model_{file_name}.pt")