In [10]:
import torch
from torch.utils.data import Dataset
from typing import List, Dict, Any
import os
import sys

project_root = os.path.abspath('..')
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    

class InstructFineTuningDataset(Dataset):
    def __init__(self, examples: List[Dict[str, Any]], tokenizer, max_length: int, stride: int = 0):
        self.input_sequences: List[torch.Tensor] = []
        self.target_sequences: List[torch.Tensor] = []
        self.loss_masks: List[torch.Tensor] = []
        
        print(f"Initializing InstructFineTuningDataset with {len(examples)} examples")
        
        for example in examples:
            prompt = f"Instruction:\n{example['instruction']}\n\nInput:\n{example['input']}\n\nResponse:\n"
            full_sequence = prompt + example["output"]
            prompt_tokens = tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
            full_tokens = tokenizer.encode(full_sequence, allowed_special={"<|endoftext|>"})
            
            if len(full_tokens) > max_length:
                print(f"Sequence too long ({len(full_tokens)} > {max_length}), skipping")
                continue
            
            loss_mask = [0] * len(prompt_tokens) + [1] * (len(full_tokens) - len(prompt_tokens))
            
            self.input_sequences.append(torch.tensor(prompt_tokens, dtype=torch.long))
            self.target_sequences.append(torch.tensor(full_tokens, dtype=torch.long))
            self.loss_masks.append(torch.tensor(loss_mask, dtype=torch.bool))
        
        print(f"Created {len(self.input_sequences)} training examples")
    
    def __len__(self) -> int:
        return len(self.input_sequences)
    
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return (
            self.input_sequences[idx],
            self.target_sequences[idx], 
            self.loss_masks[idx]
        )

In [11]:
import json
from src.data.tokenizer import get_tokenizer

tokenizer = get_tokenizer("gpt2")
alpaca_data = json.load(open("../data/alpaca_data.json"))
instruct_finetuning_dataset = InstructFineTuningDataset(alpaca_data, tokenizer, max_length=256, stride=0)

Initializing InstructFineTuningDataset with 52002 examples
Sequence too long (290 > 256), skipping
Sequence too long (322 > 256), skipping
Sequence too long (257 > 256), skipping
Sequence too long (264 > 256), skipping
Sequence too long (263 > 256), skipping
Sequence too long (262 > 256), skipping
Sequence too long (409 > 256), skipping
Sequence too long (275 > 256), skipping
Sequence too long (410 > 256), skipping
Sequence too long (320 > 256), skipping
Sequence too long (293 > 256), skipping
Sequence too long (309 > 256), skipping
Sequence too long (284 > 256), skipping
Sequence too long (292 > 256), skipping
Sequence too long (264 > 256), skipping
Sequence too long (461 > 256), skipping
Sequence too long (361 > 256), skipping
Sequence too long (264 > 256), skipping
Sequence too long (292 > 256), skipping
Sequence too long (276 > 256), skipping
Sequence too long (350 > 256), skipping
Sequence too long (295 > 256), skipping
Sequence too long (262 > 256), skipping
Sequence too long (29

In [None]:
from src.data.tokenizer import token_ids_to_text

token_ids_to_text(instruct_finetuning_dataset.input_sequences[0], tokenizer)
token_ids_to_text(instruct_finetuning_dataset.target_sequences[0], tokenizer)





'Instruction:\nGive three tips for staying healthy.\n\nInput:\n\n\nResponse:\n1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.'