In [2]:
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset

from transformers import AutoTokenizer, AutoModelForCausalLM

In [3]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token=True)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token=True)

Loading checkpoint shards: 100%|███████████████████████████████████████████████| 2/2 [00:02<00:00,  1.30s/it]


In [None]:
class ChatDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        conversation = self.data[idx]['input']
        # We need to split the conversation into individual messages
        messages = conversation.split('\n')
        # Remove the '[|Human|]' and '[|AI|]' tags and split into source and target
        source_messages = [m.replace('[|Human|]', '').strip() for m in messages if '[|Human|]' in m]
        target_messages = [m.replace('[|AI|]', '').strip() for m in messages if '[|AI|]' in m]
        # Take the last human message as source and the corresponding AI response as target
        source = source_messages[-1]  # Last human message
        target = target_messages[-1] if target_messages else ""  # Last AI message. If no AI message, use empty string
        # Tokenize the source and target
        source = self.tokenizer.encode_plus(source, return_tensors='pt', max_length=512, truncation=True, padding='max_length')
        target = self.tokenizer.encode_plus(target, return_tensors='pt', max_length=512, truncation=True, padding='max_length')
        return source, target

# Create the dataset and dataloader
dataset = ChatDataset(data, tokenizer)
dataloader = DataLoader(dataset, batch_size=32)