# Dataset

In [1]:
import os
import json
from torch.utils.data import Dataset

class BoolQDataset(Dataset):
    def __init__(self, base_dir, split):
        """
        Args:
            base_dir (str): Path to the base folder containing dataset splits.
            split (str): Dataset split to use ('train', 'test', or 'dev').
        """
        self.data_path = os.path.join(base_dir, f"{split}.jsonl")
        if not os.path.exists(self.data_path):
            raise FileNotFoundError(f"Dataset split file not found: {self.data_path}")
        self.data = self._load_data()

    def _load_data(self):
        """Loads data from the JSONL file."""
        data = []
        with open(self.data_path, 'r') as f:
            for idx, line in enumerate(f):  # Add idx while reading
                sample = json.loads(line)
                question = sample.get('question', None)
                passage = sample.get('passage', None)
                label = sample.get('answer', None)
                data.append({
                    'idx': idx,        # Add idx directly in the data
                    'question': question,
                    'passage': passage,
                    'label': label
                })
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return [self._get_item(i) for i in range(*idx.indices(len(self)))]
        
        if isinstance(idx, int):
            if idx < 0 or idx >= len(self.data):
                raise IndexError(f"Index {idx} is out of range.")
        
        return self._get_item(idx)

    def _get_item(self, idx):
        sample = self.data[idx]
        return sample  # No need to extract 'idx' here since it's already part of the data

import random

class DatasetWrapper:

    def __init__(self, dataset_tag, base_dir, split):
        if dataset_tag == "boolq":
            self.dataset = BoolQDataset(
                base_dir=base_dir,
                split=split
            )

        elif dataset_tag == "gsm8k":
            self.dataset = GSM8KDataset(
                base_dir=base_dir,
                split=split
            )
            
        else:
            raise ValueError(f"Unsupported dataset_tag: {dataset_tag}")

    def __len__(self):
        return len(self.dataset)

    def get_dataset(self):
        return self.dataset
    
    def get_random_samples(self, num_samples, seed=None):
        """Get a list of random samples from the dataset."""
        if seed is not None:
            random.seed(seed)
            
        num_samples = min(num_samples, len(self.dataset))
        indices = random.sample(range(len(self.dataset)), num_samples)
        return [self.dataset[i] for i in indices]


# Model

In [2]:
import torch
from transformers import pipeline

models = {
    "qwen2.5_1.5b": "Qwen/Qwen2.5-1.5B-Instruct",
    "llama3.2_3b": "meta-llama/Llama-3.2-3B-Instruct",
}


class ModelWrapper:
    
    def __init__(self, model_name):

        # Hyperparams
        self.max_new_tokens = 20
        self.temperature = 0.001

        self.model_name = model_name
        self.pipe = pipeline(
            task="text-generation",
            model=model_name,
            device_map="auto"
        )
    
    def generate(self, messages):
        """
        messages format:
        [
            {"role": "user", "content": "Who are you?"},
            ...
        ]
        """
        output = self.pipe(
            messages,
            max_new_tokens=self.max_new_tokens,
            temperature=self.temperature
        )

        return output


# Inference

In [3]:
OUTPUT_FILENAME = "rename"
dataset = DatasetWrapper("boolq", "/kaggle/input/boolq-dataset", "dev").get_dataset()
model = ModelWrapper(models["qwen2.5_1.5b"])

config.json:   0%|          | 0.00/660 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/7.30k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

In [None]:
import csv
from tqdm import tqdm

csvfile = open(f"{OUTPUT_FILE}.csv", "w", newline='')
csv_writer = csv.writer(csvfile, delimiter=",")

columns = ["idx", "output"]
csv_writer.writerow(columns)

for data in tqdm(dataset[:]):

    question = data["question"]
    passage = data["passage"]

    prompt = (
        "You are given the following context:\n"
        f"\n{passage}\n"
        "Answer the given question as only 'true' or 'false':\n"
        f"{question}\n"
    )

    messages = [
        {"role": "user", "content": prompt}
    ]

    output = model.generate(messages)
    extracted_output = output[0]["generated_text"][-1]["content"]
    
    csv_writer.writerow([data["idx"], extracted_output])
    if idx%100 == 0:
        print(extracted_output)

    