<h1 align="center" style="font-family: 'Braniella', cursive; font-style: italic; font-size: 96px;">
  Continuous Diffusion Language Model
</h1>

## Step-by-Step Guide: Training a Diffusion-Based Model on HellaSwag

---

### 1. **Imports and Setup**

Start by importing all necessary Python packages:

* `torch`, `nn`, `optim`: for deep learning.
* `datasets`: to load the HellaSwag dataset.
* `Counter`: to build a vocabulary.
* `pad_sequence`, `DataLoader`, `Dataset`: for batching and token handling.
* `math`, `os`: for numerical ops and file handling.

---

### 2. **Positional Embedding with Sinusoidal Encoding**

We define a `SinusoidalPosEmb` class that takes in a timestep and produces a fixed positional embedding vector (like in transformers). This helps the model know "where in the diffusion process" it is.

---

### 3. **Diffusion Schedule**

A `DiffusionSchedule` class is defined to:

* Generate a sequence of `beta` values linearly spaced over time.
* Derive `alpha` and `alpha_bar` values for the forward diffusion.
* Implement `q_sample()`, which adds Gaussian noise to input embeddings according to the timestep — simulating the forward diffusion process.

---

### 4. **Dataset Preparation (HellaSwag)**

A custom dataset class `HellaSwagDataset` is defined to:

* Tokenize both the context and each of the four multiple-choice endings.
* Convert words into token IDs using a vocabulary.
* Return all four options (tokenized) and the correct label for each example.

---

### 5. **Collate Function for Padding**

The `collate_fn()` function ensures:

* Each batch is grouped by the answer position (i.e., all A choices together, all B choices together, etc.).
* Sequences are padded so they can be stacked into uniform tensors.

---

### 6. **The Model: ContinuousDiffusionModel**

This model:

* Embeds tokens using `nn.Embedding`.
* Projects sinusoidal timestep embeddings and adds them to token embeddings.
* Feeds the result into an `LSTM`.
* Outputs predicted denoised embeddings using a final linear layer.

---

### 7. **Training Logic**

Inside the `train_model()` function:

#### a. **Load HellaSwag Dataset**

* Load 50,000 examples from the training set.
* Load 100 examples from the validation set.

#### b. **Build Vocabulary**

* Go through all text in the dataset (context and endings).
* Count word frequency.
* Assign each word a unique token ID.

#### c. **Create DataLoaders**

* Use the `HellaSwagDataset` and `collate_fn()` to batch and pad sequences.
* Shuffle training data.

#### d. **Initialize Components**

* Instantiate the model, optimizer (`Adam`), loss function (`MSE`), and diffusion schedule.

---

### 8. **Training Loop (50 Epochs)**

For each epoch:

#### a. **Training Phase**

* For each batch:

  * Sample random diffusion timesteps.
  * For each answer option:

    * Embed the tokens.
    * Add noise (simulate forward process).
    * Predict the noise using the model.
    * Compute loss between predicted and real noise (MSE).
    * Collect scores for all options.
  * Compute accuracy based on which option the model thinks is easiest to denoise.

#### b. **Validation Phase**

* Repeat the above steps, but without gradient updates.

#### c. **Print Metrics**

* Log training and validation loss and accuracy for each epoch.

---

### 9. **Saving the Model**

At the end of training:

* Save the model's weights as `final_model.pth` under `./saved_models/`.

---

### 10. **Run Training**

When the script is run (`__main__`), the `train_model()` function is called, kicking off the full training process.

---

## Summary of Key Features

* This project applies **diffusion principles** to language modeling.
* It **compares multiple-choice answers** by seeing which option's noise it can predict most confidently.
* It uses **raw token embeddings** and applies **Gaussian noise + LSTM** to denoise them.
* Accuracy is computed based on how well the model chooses the correct ending among four.

---

In [44]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from datasets import load_dataset
from collections import Counter
import os
import math

# Positional Embedding
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, timestep):
        device = timestep.device
        half_dim = self.dim // 2
        emb_scale = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb_scale)
        emb = timestep[:, None].float() * emb[None, :]
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

# Diffusion Schedule
class DiffusionSchedule:
    def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02):
        self.timesteps = timesteps
        self.beta = torch.linspace(beta_start, beta_end, timesteps)
        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        a_bar = self.alpha_bar[t].to(x_start.device).unsqueeze(-1).unsqueeze(-1)
        return torch.sqrt(a_bar) * x_start + torch.sqrt(1 - a_bar) * noise

# Dataset
class HellaSwagDataset(Dataset):
    def __init__(self, hf_dataset_split, vocab):
        self.data = hf_dataset_split
        self.vocab = vocab

    def tokenize(self, text):
        return [self.vocab.get(word, self.vocab['[UNK]']) for word in text.lower().split()]

    def __getitem__(self, idx):
        item = self.data[idx]
        context = item['ctx']
        endings = item['endings']
        label = int(item['label'])
        tokenized_choices = [torch.tensor(self.tokenize(context + " " + end), dtype=torch.long) for end in endings]
        return tokenized_choices, label

    def __len__(self):
        return len(self.data)

# Collate Function
def collate_fn(batch):
    all_choices, labels = [], []
    for choices, label in batch:
        all_choices.append(choices)
        labels.append(label)

    choices_per_pos = list(zip(*all_choices))
    padded_choices = [pad_sequence(choice_tensors, batch_first=True, padding_value=0)
                      for choice_tensors in choices_per_pos]

    return padded_choices, torch.tensor(labels, dtype=torch.long)

# Model
class ContinuousDiffusionModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, timestep_emb_dim=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.timestep_emb = SinusoidalPosEmb(timestep_emb_dim)
        self.timestep_proj = nn.Linear(timestep_emb_dim, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc_out = nn.Linear(hidden_dim, embed_dim)

    def forward(self, x_tokens, t):
        emb = self.embedding(x_tokens)
        t_emb = self.timestep_emb(t)
        t_proj = self.timestep_proj(t_emb).unsqueeze(1)
        emb = emb + t_proj
        lstm_out, _ = self.lstm(emb)
        return self.fc_out(lstm_out)

# Training
def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset_size = 50000
    hf_dataset_train = load_dataset("hellaswag", split=f"train[:{dataset_size}]", cache_dir="./cache", trust_remote_code=True)
    hf_dataset_val = load_dataset("hellaswag", split="validation[:100]", cache_dir="./cache", trust_remote_code=True)

    counter = Counter()
    for item in hf_dataset_train:
        counter.update(item['ctx'].lower().split())
        for end in item['endings']:
            counter.update(end.lower().split())

    vocab = {word: idx + 2 for idx, (word, _) in enumerate(counter.most_common())}
    vocab['[PAD]'] = 0
    vocab['[UNK]'] = 1

    train_dataset = HellaSwagDataset(hf_dataset_train, vocab)
    val_dataset = HellaSwagDataset(hf_dataset_val, vocab)

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

    model = ContinuousDiffusionModel(len(vocab), 128, 256).to(device)
    diffusion = DiffusionSchedule()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()

    for epoch in range(50):
        model.train()
        total_loss, correct, samples = 0.0, 0, 0

        for padded_choices, labels in train_loader:
            batch_size = labels.size(0)
            labels = labels.to(device)
            optimizer.zero_grad()
            t = torch.randint(0, diffusion.timesteps, (batch_size,), device=device)

            losses, scores = [], []
            for choice in padded_choices:
                choice = choice.to(device)
                emb = model.embedding(choice)
                noise = torch.randn_like(emb)
                noisy_emb = diffusion.q_sample(emb, t, noise)
                noise_pred = model(choice, t)
                loss = criterion(noise_pred, noise)
                losses.append(loss)
                scores.append(-((noise_pred - noise) ** 2).mean(dim=[1, 2]))

            total = torch.stack(losses).sum()
            total.backward()
            optimizer.step()

            total_loss += total.item()
            pred = torch.stack(scores, dim=1).argmax(dim=1)
            correct += (pred == labels).sum().item()
            samples += batch_size

        print(f"Epoch {epoch+1} (Train): Train Loss = {total_loss/samples:.4f}, Accuracy = {correct/samples:.4f}")

        model.eval()
        val_loss, val_correct, val_samples = 0.0, 0, 0
        with torch.no_grad():
            for padded_choices, labels in val_loader:
                batch_size = labels.size(0)
                labels = labels.to(device)
                t = torch.randint(0, diffusion.timesteps, (batch_size,), device=device)

                losses, scores = [], []
                for choice in padded_choices:
                    choice = choice.to(device)
                    emb = model.embedding(choice)
                    noise = torch.randn_like(emb)
                    noisy_emb = diffusion.q_sample(emb, t, noise)
                    noise_pred = model(choice, t)
                    loss = criterion(noise_pred, noise)
                    losses.append(loss)
                    scores.append(-((noise_pred - noise) ** 2).mean(dim=[1, 2]))

                if losses:
                    val_loss += torch.stack(losses).sum().item()
                    pred = torch.stack(scores, dim=1).argmax(dim=1)
                    val_correct += (pred == labels).sum().item()
                    val_samples += batch_size

        print(f"Epoch {epoch+1} (Test): Val Loss = {val_loss/val_samples:.4f}, Accuracy = {val_correct/val_samples:.4f}")

    # Save the final trained model
    save_path = "./saved_models/final_model.pth"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(model.state_dict(), save_path)
    print(f"\n✅ Model saved to {save_path}")

# Entry point
if __name__ == "__main__":
    train_model()


Epoch 1 (Train): Train Loss = 0.2501, Accuracy = 0.2477
Epoch 1 (Test): Val Loss = 0.2802, Accuracy = 0.2400
Epoch 2 (Train): Train Loss = 0.2501, Accuracy = 0.2504
Epoch 2 (Test): Val Loss = 0.2797, Accuracy = 0.3200
Epoch 3 (Train): Train Loss = 0.2501, Accuracy = 0.2472
Epoch 3 (Test): Val Loss = 0.2796, Accuracy = 0.2200
Epoch 4 (Train): Train Loss = 0.2501, Accuracy = 0.2522
Epoch 4 (Test): Val Loss = 0.2797, Accuracy = 0.2800
Epoch 5 (Train): Train Loss = 0.2501, Accuracy = 0.2476
Epoch 5 (Test): Val Loss = 0.2803, Accuracy = 0.2400
Epoch 6 (Train): Train Loss = 0.2501, Accuracy = 0.2542
Epoch 6 (Test): Val Loss = 0.2799, Accuracy = 0.3200
Epoch 7 (Train): Train Loss = 0.2501, Accuracy = 0.2505
Epoch 7 (Test): Val Loss = 0.2802, Accuracy = 0.2700
Epoch 8 (Train): Train Loss = 0.2501, Accuracy = 0.2487
Epoch 8 (Test): Val Loss = 0.2805, Accuracy = 0.3000
Epoch 9 (Train): Train Loss = 0.2501, Accuracy = 0.2481
Epoch 9 (Test): Val Loss = 0.2800, Accuracy = 0.2100
Epoch 10 (Train): T

---

### Step-by-Step Guide: Saving Vocabulary from the HellaSwag Dataset

---

**1. Load Training Data**
You start by loading a portion of the HellaSwag training dataset (e.g., 50,000 examples). This ensures you're using the same data that the model was trained on, which is important for consistency.

---

**2. Count Word Frequencies**
Next, for each example:

* You extract the main context text and the list of possible endings.
* You lowercase all words and split the text into individual tokens.
* You use a counter to tally how often each word appears across the dataset.

---

**3. Create the Vocabulary Dictionary**
Once all words are counted:

* You assign an index to each word, starting from 2 (to leave room for special tokens).
* You explicitly add two special tokens:

  * `[PAD]` with index `0` for padding.
  * `[UNK]` with index `1` for unknown words not in the vocab.

---

**4. Save the Vocabulary File**
You then:

* Create a directory if it doesn't exist (`./saved_models`).
* Save the vocabulary dictionary as a file (`vocab.pkl`) using Python's `pickle` module.

---

**5. Confirm Successful Save**
Finally, you print a message showing how many entries are in your vocab and confirming that it was saved to the correct file path.

---

### Purpose

This script prepares and preserves the vocabulary used during training so you can **re-use the exact same word-to-index mapping** later when evaluating the model or generating text.

---

In [63]:
import os
import pickle
from collections import Counter
from datasets import load_dataset

# Load the same data split you used for training
dataset_size = 50000
hf_dataset_train = load_dataset("hellaswag", split=f"train[:{dataset_size}]", cache_dir="./cache", trust_remote_code=True)

# Build vocab from training text
counter = Counter()
for item in hf_dataset_train:
    counter.update(item['ctx'].lower().split())
    for end in item['endings']:
        counter.update(end.lower().split())

# Create vocab dictionary
vocab = {word: idx + 2 for idx, (word, _) in enumerate(counter.most_common())}
vocab['[PAD]'] = 0
vocab['[UNK]'] = 1

# Save to file
os.makedirs("./saved_models", exist_ok=True)
with open("./saved_models/vocab.pkl", "wb") as f:
    pickle.dump(vocab, f)

print(f"✅ Vocab saved with {len(vocab)} entries to ./saved_models/vocab.pkl")


✅ Vocab saved with 106630 entries to ./saved_models/vocab.pkl


---

### Step-by-Step Guide: Loading a Saved Vocabulary

---

**1. Open the Saved File**
The script begins by locating and opening the file named `vocab.pkl` that was previously saved in the `./saved_models/` directory.

---

**2. Load the Vocabulary into Memory**
Using Python’s `pickle` module (which handles object serialization), the file is read and the saved vocabulary dictionary is loaded back into the program.

---

**3. Result**
After running this script, you'll have a `vocab` dictionary in memory that maps words to their corresponding integer IDs — exactly as they were during training.

---

### Purpose

This allows you to **reuse the same vocabulary mapping** during model evaluation, inference, or text preprocessing — ensuring consistency with what the model was trained on.

---

In [69]:
import pickle

with open("./saved_models/vocab.pkl", "rb") as f:
    vocab = pickle.load(f)


---

### Step-by-Step: Loading a Trained Model with Vocabulary

---

#### 1. **Define the Sinusoidal Positional Embedding**

* A custom neural network layer is created to generate embeddings based on a time step.
* This is common in diffusion models and transformers, allowing the model to understand "when" a token appears.

---

#### 2. **Build the Model Architecture**

* A class called `ContinuousDiffusionModel` is defined:

  * It uses a word embedding layer to turn token IDs into vectors.
  * It adds a learned projection of the time step embedding to each word.
  * Then it passes everything through an LSTM (a type of RNN).
  * Finally, it uses a linear layer to map the output to the same dimensional space as the embeddings.

---

#### 3. **Rebuild the Vocabulary**

* The HellaSwag dataset (subset of 50,000 examples) is loaded.
* A `Counter` goes through all words in the context (`ctx`) and endings to count word frequencies.
* A vocabulary dictionary is created:

  * Most common words get unique IDs starting from 2.
  * Special tokens `[PAD]` and `[UNK]` are added with IDs 0 and 1.

---

#### 4. **Initialize the Model**

* A new instance of the model is created using the size of the reconstructed vocabulary.
* The model is moved to GPU if available, otherwise it uses CPU.

---

#### 5. **Load the Trained Weights**

* The script attempts to load pre-trained weights from `./saved_models/final_model.pth`.
* If successful, the model is put into evaluation mode (inference-ready).
* If it fails (e.g., due to vocabulary mismatch), an error message is printed.

---

### Result:

By the end of this script:

* You have a fully defined model architecture.
* The vocabulary used during training has been rebuilt.
* The trained model weights are loaded, and the model is ready for use.

---

In [70]:
import torch
import torch.nn as nn
import math
from collections import Counter
from datasets import load_dataset

# Sinusoidal Positional Embedding
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, timestep):
        device = timestep.device
        half_dim = self.dim // 2
        emb_scale = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb_scale)
        emb = timestep[:, None].float() * emb[None, :]
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

# Model Definition
class ContinuousDiffusionModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, timestep_emb_dim=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.timestep_emb = SinusoidalPosEmb(timestep_emb_dim)
        self.timestep_proj = nn.Linear(timestep_emb_dim, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc_out = nn.Linear(hidden_dim, embed_dim)

    def forward(self, x_tokens, t):
        emb = self.embedding(x_tokens)
        t_emb = self.timestep_emb(t)
        t_proj = self.timestep_proj(t_emb).unsqueeze(1)
        emb = emb + t_proj
        lstm_out, _ = self.lstm(emb)
        return self.fc_out(lstm_out)

# Vocab Reconstruction
hf_dataset_train = load_dataset("hellaswag", split="train[:50000]", cache_dir="./cache", trust_remote_code=True)

counter = Counter()
for item in hf_dataset_train:
    counter.update(item['ctx'].lower().split())
    for end in item['endings']:
        counter.update(end.lower().split())

vocab = {word: idx + 2 for idx, (word, _) in enumerate(counter.most_common())}
vocab['[PAD]'] = 0
vocab['[UNK]'] = 1

# Load the Model Weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ContinuousDiffusionModel(len(vocab), 128, 256).to(device)

try:
    model.load_state_dict(torch.load("./saved_models/final_model.pth", map_location=device))
    model.eval()
    print("✅ Model loaded and ready to use.")
except Exception as e:
    print("❌ Failed to load model:", e)


✅ Model loaded and ready to use.


---

### Step-by-Step: Evaluating a Trained Diffusion-Based Model on HellaSwag

---

#### 1. **Import Required Libraries**

The script imports PyTorch, dataset handling (`datasets`), `pickle` (for loading the saved vocabulary), and other essentials for model evaluation.

---

#### 2. **Redefine the Necessary Classes**

These classes are reused from training, so they are redefined here to match the structure used when training the model:

* **`SinusoidalPosEmb`**: Creates a positional embedding based on the timestep.
* **`ContinuousDiffusionModel`**: Neural model with embeddings, an LSTM layer, and a linear output layer.
* **`DiffusionSchedule`**: Defines how noise is added over time using a beta schedule.
* **`HellaSwagDataset`**: Prepares the dataset for use with your model (including tokenization).
* **`collate_fn`**: Handles padding and batching during loading.

---

#### 3. **Load the Dataset**

* Loads 100 validation samples from the HellaSwag dataset using Hugging Face’s `load_dataset` function.
* This dataset will be used to test the model’s multiple-choice prediction ability.

---

#### 4. **Load the Vocabulary**

* Reads the previously saved vocabulary (`vocab.pkl`) from disk.
* This ensures token IDs match what the model was trained with.

---

#### 5. **Rebuild the Model**

* Reconstructs the trained model with the correct vocabulary size and architecture.
* Loads pre-trained weights from `final_model.pth`.

---

#### 6. **Prepare the Dataset and DataLoader**

* The validation set is processed into a custom PyTorch `Dataset`.
* A `DataLoader` is created to iterate over each example (1 at a time).

---

#### 7. **Run the Evaluation Loop**

* For each question in the validation set:

  * All four answer choices are embedded.
  * Noise is added using the diffusion schedule.
  * The model predicts the noise for each choice.
  * The mean squared error (MSE) between predicted and actual noise is calculated.
  * The choice with the lowest MSE is selected as the model’s prediction.
* The model's prediction is compared to the ground-truth label.

---

#### 8. **Print the Results**

* For each question:

  * The context and all answer choices are printed.
  * The model’s prediction is shown, along with the correct answer.
* At the end:

  * The script prints the total accuracy over the 100 validation examples.

---

### Result

You get a detailed evaluation report showing:

* Which questions the model got right or wrong.
* How confident it was (based on the MSE scores).
* The overall accuracy across all tested samples.

---

In [74]:
import torch
import torch.nn.functional as F
from datasets import load_dataset
from collections import Counter
import math
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import pickle

# Re-define or import your classes here:
# If you saved your training script as `your_diffusion_script.py`, you can do:
# from your_diffusion_script import ContinuousDiffusionModel, DiffusionSchedule, HellaSwagDataset, collate_fn


# Redefine SinusoidalPosEmb (needed by model)
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, timestep):
        device = timestep.device
        half_dim = self.dim // 2
        emb_scale = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb_scale)
        emb = timestep[:, None].float() * emb[None, :]
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

# Redefine ContinuousDiffusionModel
class ContinuousDiffusionModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, timestep_emb_dim=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.timestep_emb = SinusoidalPosEmb(timestep_emb_dim)
        self.timestep_proj = nn.Linear(timestep_emb_dim, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc_out = nn.Linear(hidden_dim, embed_dim)

    def forward(self, x_tokens, t):
        emb = self.embedding(x_tokens)
        t_emb = self.timestep_emb(t)
        t_proj = self.timestep_proj(t_emb).unsqueeze(1)
        emb = emb + t_proj
        lstm_out, _ = self.lstm(emb)
        return self.fc_out(lstm_out)

# Redefine DiffusionSchedule
class DiffusionSchedule:
    def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02):
        self.timesteps = timesteps
        self.beta = torch.linspace(beta_start, beta_end, timesteps)
        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        a_bar = self.alpha_bar[t].to(x_start.device).unsqueeze(-1).unsqueeze(-1)
        return torch.sqrt(a_bar) * x_start + torch.sqrt(1 - a_bar) * noise

# Redefine Dataset and collate_fn

class HellaSwagDataset(Dataset):
    def __init__(self, hf_dataset_split, vocab):
        self.data = hf_dataset_split
        self.vocab = vocab

    def tokenize(self, text):
        return [self.vocab.get(word, self.vocab['[UNK]']) for word in text.lower().split()]

    def __getitem__(self, idx):
        item = self.data[idx]
        context = item['ctx']
        endings = item['endings']
        label = int(item['label'])
        tokenized_choices = [torch.tensor(self.tokenize(context + " " + end), dtype=torch.long) for end in endings]
        return tokenized_choices, label

    def __len__(self):
        return len(self.data)

def collate_fn(batch):
    all_choices = []
    labels = []
    for choices, label in batch:
        all_choices.append(choices)
        labels.append(label)

    choices_per_pos = list(zip(*all_choices))
    padded_choices = [pad_sequence(choice_tensors, batch_first=True, padding_value=0)
                      for choice_tensors in choices_per_pos]

    return padded_choices, torch.tensor(labels, dtype=torch.long)

# Load dataset, vocab, model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

hf_val = load_dataset("hellaswag", split="validation[:100]", cache_dir="./cache", trust_remote_code=True)
with open('./saved_models/vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)

model = ContinuousDiffusionModel(len(vocab), 128, 256).to(device)
model.load_state_dict(torch.load("./saved_models/final_model.pth", map_location=device))
model.eval()

diffusion = DiffusionSchedule(timesteps=1000)  # Use the same timesteps as in your training

val_dataset = HellaSwagDataset(hf_val, vocab)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

# Evaluation loop
total, correct = 0, 0

for padded_choices, labels in val_loader:
    labels = labels.to(device)
    t = torch.randint(0, diffusion.timesteps, (1,), device=device)

    scores = []
    for choice_seq in padded_choices:
        choice_seq = choice_seq.to(device)
        emb = model.embedding(choice_seq)
        noise = torch.randn_like(emb)
        noisy_emb = diffusion.q_sample(emb, t, noise)
        noise_pred = model(choice_seq, t)
        mse = F.mse_loss(noise_pred, noise, reduction='none').mean(dim=[1, 2])
        scores.append(-mse)

    scores = torch.stack(scores, dim=1)  # [batch, num_choices]
    pred = scores.argmax(dim=1)
    total += 1
    correct += (pred == labels).item()

    ctx = hf_val[int(total) - 1]['ctx']
    endings = hf_val[int(total) - 1]['endings']
    print(f"\nQ{total}: {ctx}")
    for idx, end in enumerate(endings):
        mark = "✅" if idx == labels.item() else ""
        print(f"  [{idx}] {end} (score={scores[0, idx].item():.2f}) {mark}")
    print(f"→ Predicted: {pred.item()} | Ground Truth: {labels.item()} {'✅' if pred==labels else '❌'}")

accuracy = correct / total * 100
print(f"\nFinal Accuracy on {total} samples: {correct}/{total} = {accuracy:.2f}%")



Q1: A man is sitting on a roof. he
  [0] is using wrap to wrap a pair of skis. (score=-1.01) 
  [1] is ripping level tiles off. (score=-1.05) 
  [2] is holding a rubik's cube. (score=-1.02) 
  [3] starts pulling up roofing on a roof. (score=-0.98) ✅
→ Predicted: 3 | Ground Truth: 3 ✅

Q2: A lady walks to a barbell. She bends down and grabs the pole. the lady
  [0] swings and lands in her arms. (score=-0.97) 
  [1] pulls the barbell forward. (score=-0.99) 
  [2] pulls a rope attached to the barbell. (score=-0.98) 
  [3] stands and lifts the weight over her head. (score=-1.00) ✅
→ Predicted: 0 | Ground Truth: 3 ❌

Q3: Two women in a child are shown in a canoe while a man pulls the canoe while standing in the water, with other individuals visible in the background. the child and a different man
  [0] are then shown paddling down a river in a boat while a woman talks. (score=-1.01) 
  [1] are driving the canoe, they go down the river flowing side to side. (score=-1.02) 
  [2] sit in a can

<h1 align="center" style="font-family: 'Braniella', cursive; font-style: italic; font-size: 96px;">
  Thank You
</h1>