In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import init_empty_weights
import torch

device = "cuda"


model_name = "HuggingFaceH4/zephyr-7b-beta"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:00<00:00, 189.89it/s]


In [2]:
import json
import torch
from torch.utils.data import Dataset
import os
from transformers import AutoTokenizer
from abc import ABC, abstractmethod
from datasets import load_dataset
from typing import List, Dict, Optional, Iterator, Union
from pathlib import Path
import logging

logger = logging.getLogger(__name__)


class BaseDataset(Dataset, ABC):

    def __init__(self, 
                 tokenizer: AutoTokenizer, 
                 tokenizer_max_length: int = 768,
                 batch_size: int = 4,
                 min_len: int = 50
    ) -> None:
        self.tokenizer = tokenizer
        self.tokenizer_max_length = tokenizer_max_length
        self.batch_size = batch_size
        self.min_len = min_len
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.data = self._load_data()


    @abstractmethod
    def _load_data(self):
        pass


    def __len__(self) -> int:
        return len(self.data)


    def __getitem__(self, idx) -> Dict:
        item = self.data[idx]
        inputs = self.tokenizer(item, 
                                return_tensors="pt",
                                padding=True,
                                truncation=True,
                                max_length=self.tokenizer_max_length)
        inputs = {key: value.to(self.device) for key, value in inputs.items()}
        
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
        }


class JSONLDataset(BaseDataset):
    """
    Dataset class for loading and processing JSONL format data.

    Attributes:
        dataset_name (str): Name of the dataset from the data folder
        dataset_folder (str): Name of the folder where the dataset is located
    """

    def __init__(self, 
                 dataset_name: str, 
                 dataset_folder: str = 'data/',
                 **kwargs
    ) -> None:
        """Initialize the JSONL dataset."""
        self.dataset_name = dataset_name
        self.dataset_folder = dataset_folder
        super().__init__(**kwargs)


    def _load_data(self) -> List[Dict]:
        """
        Load and parse the JSONL file.

        Returns:
            List of dictionaries containing parsed data entries.
        """
        file_path = os.path.join(self.dataset_folder, self.dataset_name)
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"The file {file_path} does not exist.")
        data = []
        for line in open(f"{file_path}", "r"):
            if "bio-forget-corpus" in self.dataset_name:
                raw_text = json.loads(line)['text']
            else:
                raw_text = line
            if len(raw_text) > self.min_len:
                data.append(str(raw_text))
        return [data[i:i + self.batch_size] for i in range(0, len(data), self.batch_size)]


class WikitextDataset(BaseDataset):
    """
    Dataset class for handling the Wikitext dataset from HuggingFace.

    Attributes:
        dataset_version (str): Name of the version of wikitext to use
    """

    def __init__(self, 
                 dataset_version: str = 'wikitext-2-raw-v1', 
                 **kwargs
    ) -> None:
        self.dataset_version = dataset_version
        super().__init__(**kwargs)


    def _load_data(self) -> List[str]:
        """Load data from Wikitext dataset."""

        dataset = load_dataset("wikitext", self.dataset_version, split="test")
        if dataset is None:
            raise DatasetError("Failed to load Wikitext dataset")
        data = [item['text'] for item in dataset if len(item['text']) > self.min_len]
        return [data[i:i + self.batch_size] for i in range(0, len(data), self.batch_size)]

In [3]:
# We want a Dataset class that loads in a jsonl file, tokenizes the dataset as expected and returns the input ids and attention mask.
import os
import json

class JsonlDataset():
  def __init__(self, tokenizer, tokenizer_max_length, batch_size, min_len, dataset_name, dataset_folder, device):
    self.tokenizer = tokenizer
    if tokenizer.pad_token is None:
      tokenizer.pad_token = tokenizer.eos_token

    self.tokenizer_max_length = tokenizer_max_length
    self.batch_size = batch_size
    self.min_len = min_len
    self.dataset_name = dataset_name
    self.dataset_folder = dataset_folder
    self.data = []
    self.device = device

  def __getitem__(self, idx):
    item = self.data[idx]
    input_ids = self.tokenizer(item["text"], return_tensors="pt", padding=True, truncation=True, max_length=self.tokenizer_max_length)
    inputs = {key: value.to(self.device) for key, value in input_ids.items()}
    return {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"]}

  def _load_dataset(self):
    dataset_path = os.path.join(self.dataset_folder, self.dataset_name)
    if not os.path.exists(dataset_path):
      raise FileNotFoundError(f"Dataset file not found at {dataset_path}")
    
    data_list = []
    with open(dataset_path, "r") as f:
      for line in f:
        data = json.loads(line)
        if len(data["text"]) > self.min_len:
          data_list.append(data)

    self.data = data_list
    
  def __len__(self):
    return len(self.data)



In [4]:
diogo_forget = JSONLDataset("cyber-forget-corpus.jsonl", tokenizer=tokenizer)


In [5]:
print(diogo_forget[0]["input_ids"].shape)

torch.Size([4, 768])


In [6]:
model(diogo_forget[0]["input_ids"])

CausalLMOutputWithPast(loss=None, logits=tensor([[[ -5.8438,  -5.7812,  -0.2119,  ...,  -4.1250,  -3.6406,  -3.8750],
         [ -8.7500,  -8.7500,   5.5625,  ...,  -5.1875,  -6.3438,  -5.8438],
         [ -6.4688,  -6.4375,   2.3750,  ...,  -5.8750,  -5.6250,  -4.8438],
         ...,
         [ -7.8750,  -8.0000,   5.8438,  ...,  -5.1250,  -7.4688,  -4.1250],
         [ -6.7188,  -6.3750,  11.1875,  ...,  -6.2812,  -8.0625,  -5.0312],
         [ -6.9688,  -5.6562,  10.5625,  ...,  -3.1094,  -5.1875,  -3.9688]],

        [[ -5.8438,  -5.7812,  -0.2119,  ...,  -4.1250,  -3.6406,  -3.8750],
         [ -6.0000,  -5.4062,  -0.7070,  ...,  -1.1484,   0.0430,   1.7266],
         [-10.1875, -10.3125,  -1.7578,  ...,  -9.3750,  -6.3750,  -5.8125],
         ...,
         [ -7.5938,  -6.5938,   8.0000,  ...,  -5.0000,  -2.0938,  -2.7500],
         [ -5.1562,  -4.8750,   9.6250,  ...,  -5.9375,  -5.9375,  -4.1250],
         [ -7.3438,  -6.0000,  10.4375,  ...,  -3.5625,  -4.6250,  -3.6094]],

   

In [10]:
cyber_forget = JsonlDataset(
      tokenizer=tokenizer, tokenizer_max_length=1024, batch_size=1,
      min_len=30, dataset_name="cyber-forget-corpus.jsonl", dataset_folder="data/",
      device=device
    )
cyber_forget._load_dataset()

cyber_retain = JsonlDataset(
      tokenizer=tokenizer, tokenizer_max_length=1024, batch_size=1,
      min_len=30, dataset_name="cyber-retain-corpus.jsonl", dataset_folder="data/",
      device=device
    )
cyber_retain._load_dataset()


In [12]:
print(len(cyber_forget.data))
print(len(cyber_retain.data))

1000
4385


In [8]:
my_input = cyber_forget[0]["input_ids"]
print(my_input.shape)

# What they do in my fork:
inputs = cyber_forget.data[0]["text"]

tokenized_inputs = tokenizer(inputs, max_length=1024, return_tensors="pt").to(device)
print(tokenized_inputs["input_ids"].shape)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


torch.Size([1, 1024])
torch.Size([1, 1024])


In [9]:

model(cyber_forget[0]["input_ids"])

CausalLMOutputWithPast(loss=None, logits=tensor([[[-5.8438, -5.7812, -0.2109,  ..., -4.1250, -3.6406, -3.8750],
         [-7.4688, -7.7500, -1.0547,  ..., -6.7500, -3.8438, -4.9375],
         [-7.6875, -7.9375,  0.2891,  ..., -6.5938, -6.4375, -7.3125],
         ...,
         [-8.6875, -7.7500,  6.8125,  ..., -7.1562, -6.6562, -3.8906],
         [-4.4375, -4.1875,  2.1094,  ..., -4.3125, -1.1172, -2.7344],
         [-9.2500, -8.7500,  9.6875,  ..., -5.9688, -7.8750, -4.8750]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>), past_key_values=<transformers.cache_utils.DynamicCache object at 0x72ec12b01f70>, hidden_states=None, attentions=None)