In [1]:
# Required libraries by LLM; Are not installed per default
# ! pip install accelerate
# ! pip install bitsandbytes

In [2]:
import os
import torch
import warnings
import random
import pickle
import json
import numpy as np
import hashlib
import gc
from tqdm import tqdm
from dotenv import load_dotenv
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Optional, Dict
from transformers import (
    AutoConfig, AutoTokenizer, AutoModelForCausalLM, 
    BitsAndBytesConfig, StoppingCriteriaList, StoppingCriteria,
    PreTrainedTokenizer
)
from huggingface_hub import login

os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
warnings.filterwarnings('ignore')

Using device: cpu


In [3]:
# Load environment variables
load_dotenv(override=True)

# General variables
seed = int(os.getenv("SEED"))
batch_size = int(os.getenv("BATCH_SIZE"))
save_every = int(os.getenv("SAVE_EVERY"))

# Embeddings variables
corpus_path = os.getenv("CORPUS_PATH")

# Retrieval variables
queries_path = os.getenv("QUERIES_PATH")

## LLM variables
hf_access_token = os.getenv("HF_ACCESS_TOKEN")
llm_id = os.getenv("LLM_ID")
num_docs = int(os.getenv("TOP_K"))
max_input_length = int(os.getenv("MAX_INPUT_LENGTH"))
max_output_length = int(os.getenv("MAX_OUTPUT_LENGTH"))
normalize_queries = os.getenv("NORMALIZE_QUERIES") == "True"
context_retrieval_dir = os.getenv("CONTEXT_RETRIEVAL_DIR")
llm_response_dir = os.getenv("LLM_RESPONSE_DIR")

In [4]:
login(token=hf_access_token)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /Users/sebastiaanbeekman/.cache/huggingface/token
Login successful


In [5]:
# Utility functions
def set_seeds(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    
def read_pickle(file_path: str):
    with open(file_path, "rb") as reader:
        data = pickle.load(reader)
    return data


def write_pickle(data, file_path: str):
    with open(file_path, "wb") as writer:
        pickle.dump(data, writer)


def read_json(file_path: str):
    with open(file_path, "rb") as reader:
        data = json.load(reader)
    return data

def read_corpus(corpus_path: str):
    new_corpus = []
    corpus = read_json(corpus_path).values()
    for i, record in enumerate(corpus):
        record["full_corpus_idx"] = i
        new_corpus.append(record)
    return new_corpus

In [6]:
# Run seeder before proceeding
set_seeds(seed)
print(f"Seed set to {seed}")

In [7]:
class LLM:
    """
    A class for loading and generating text using a Language Model (LM) with support for quantization
    and custom stopping criteria.
    
    Attributes:
        model_id (str): Identifier for the model to load.
        device (str): Device to run the model on, e.g. 'cuda'.
        quantization_bits (Optional[int]): Number of bits for quantization, supports 4 or 8 bits.
        stop_list (Optional[List[str]]): List of tokens where generation should stop.
        model_max_length (int): Maximum length of the model inputs.
    """
    def __init__(
        self, 
        model_id: str, 
        device: str = 'cuda', 
        quantization_bits: Optional[int] = None, 
        stop_list: Optional[List[str]] = None, 
        model_max_length: int = 4096
    ):
        self.device = device
        self.model_max_length = model_max_length

        self.stop_list = stop_list
        if stop_list is None:
            self.stop_list = ['\nHuman:', '\n```\n', '\nQuestion:', '<|endoftext|>', '\n']
        
        self.bnb_config = self._set_quantization(quantization_bits)
        self.model, self.tokenizer = self._initialize_model_tokenizer(model_id)
        self.stopping_criteria = self._define_stopping_criteria()
        

    def _set_quantization(self, quantization_bits: Optional[int]) -> Optional[BitsAndBytesConfig]:
        """
        Configure quantization settings based on the specified number of bits.
        """
        if quantization_bits in [4, 8]:
            bnb_config = BitsAndBytesConfig()
            if quantization_bits == 4:
                bnb_config.load_in_4bit = True
                bnb_config.bnb_4bit_quant_type = 'nf4'
                bnb_config.bnb_4bit_use_double_quant = True
                bnb_config.bnb_4bit_compute_dtype = torch.bfloat16
            elif quantization_bits == 8:
                bnb_config.load_in_8bit = True
            return bnb_config
        return None
 

    def _initialize_model_tokenizer(self, model_id: str) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
        """
        Initializes the model and tokenizer with the given model ID.
        """
        model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
        model_config.max_seq_len = self.model_max_length

        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            trust_remote_code=True,
            config=model_config,
            quantization_config=self.bnb_config,
            torch_dtype=torch.bfloat16,
            device_map='auto',
        )
        model.eval() # Set the model to evaluation mode

        tokenizer = AutoTokenizer.from_pretrained(
            model_id, padding_side="left", truncation_side="left",
            model_max_length=self.model_max_length
        )
        # Most LLMs don't have a pad token by default
        tokenizer.pad_token = tokenizer.eos_token  

        return model, tokenizer


    def _define_stopping_criteria(self) -> StoppingCriteriaList:
        """
        Defines stopping criteria for text generation based on the provided stop_list.
        """
        stop_token_ids = [self.tokenizer(x)['input_ids'] for x in self.stop_list]
        stop_token_ids = [torch.LongTensor(x).to(self.device) for x in stop_token_ids]

        class StopOnTokens(StoppingCriteria):
            def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
                for stop_ids in stop_token_ids:
                    if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
                        return True
                return False

        return StoppingCriteriaList([StopOnTokens()])
    
    
    def generate(self, prompt: str, max_new_tokens: int = 15) -> List[str]:
        """
        Generates text based on the given prompt.
        
        Args:
            prompt (str): Input text prompt for generation.
        
        Returns:
            List[str]: The generated text responses.
        """
        inputs = self.tokenizer(
            prompt, 
            padding=True, 
            truncation=True, 
            max_length=self.model_max_length, 
            return_tensors="pt"
        ).to(self.device)
        
        generated_ids = self.model.generate(
            **inputs,
            do_sample=False,
            max_new_tokens=max_new_tokens,
            repetition_penalty=1.1,
            stopping_criteria=self.stopping_criteria,
            pad_token_id=self.tokenizer.eos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        )
        return self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

In [8]:
llm = LLM(
        llm_id, device, quantization_bits=4, 
        model_max_length=max_input_length
    )
tokenizer = llm.tokenizer

In [8]:
corpus = read_corpus(corpus_path)
search_results = read_pickle(context_retrieval_dir)
print(f"Loaded {len(corpus)} records and {len(search_results)} search results for Top-{num_docs}.")

In [9]:
"""
adapted from chemdataextractor.text.normalize
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tools for normalizing text.
https://github.com/mcs07/ChemDataExtractor
:copyright: Copyright 2016 by Matt Swain.
:license: MIT

Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
'Software'), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:

The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""

#: Control characters.
CONTROLS = {
    '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011',
    '\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b',
}
# There are further control characters, but they are instead replaced with a space by unicode normalization
# '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c',  '\u001d', '\u001e', '\u001f'


#: Hyphen and dash characters.
HYPHENS = {
    '-',  # \u002d Hyphen-minus
    '‐',  # \u2010 Hyphen
    '‑',  # \u2011 Non-breaking hyphen
    '⁃',  # \u2043 Hyphen bullet
    '‒',  # \u2012 figure dash
    '–',  # \u2013 en dash
    '—',  # \u2014 em dash
    '―',  # \u2015 horizontal bar
}

#: Minus characters.
MINUSES = {
    '-',  # \u002d Hyphen-minus
    '−',  # \u2212 Minus
    '－',  # \uff0d Full-width Hyphen-minus
    '⁻',  # \u207b Superscript minus
}

#: Plus characters.
PLUSES = {
    '+',  # \u002b Plus
    '＋',  # \uff0b Full-width Plus
    '⁺',  # \u207a Superscript plus
}

#: Slash characters.
SLASHES = {
    '/',  # \u002f Solidus
    '⁄',  # \u2044 Fraction slash
    '∕',  # \u2215 Division slash
}

#: Tilde characters.
TILDES = {
    '~',  # \u007e Tilde
    '˜',  # \u02dc Small tilde
    '⁓',  # \u2053 Swung dash
    '∼',  # \u223c Tilde operator #in mbert vocab
    '∽',  # \u223d Reversed tilde
    '∿',  # \u223f Sine wave
    '〜',  # \u301c Wave dash #in mbert vocab
    '～',  # \uff5e Full-width tilde #in mbert vocab
}

#: Apostrophe characters.
APOSTROPHES = {
    "'",  # \u0027
    '’',  # \u2019
    '՚',  # \u055a
    'Ꞌ',  # \ua78b
    'ꞌ',  # \ua78c
    '＇',  # \uff07
}

#: Single quote characters.
SINGLE_QUOTES = {
    "'",  # \u0027
    '‘',  # \u2018
    '’',  # \u2019
    '‚',  # \u201a
    '‛',  # \u201b

}

#: Double quote characters.
DOUBLE_QUOTES = {
    '"',  # \u0022
    '“',  # \u201c
    '”',  # \u201d
    '„',  # \u201e
    '‟',  # \u201f
}

#: Accent characters.
ACCENTS = {
    '`',  # \u0060
    '´',  # \u00b4
}

#: Prime characters.
PRIMES = {
    '′',  # \u2032
    '″',  # \u2033
    '‴',  # \u2034
    '‵',  # \u2035
    '‶',  # \u2036
    '‷',  # \u2037
    '⁗',  # \u2057
}

#: Quote characters, including apostrophes, single quotes, double quotes, accents and primes.
QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES

def normalize(text):
    for control in CONTROLS:
        text = text.replace(control, '')
    text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ')

    for hyphen in HYPHENS | MINUSES:
        text = text.replace(hyphen, '-')
    text = text.replace('\u00ad', '')

    for double_quote in DOUBLE_QUOTES:
        text = text.replace(double_quote, '"')  # \u0022
    for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS):
        text = text.replace(single_quote, "'")  # \u0027
    text = text.replace('′', "'")     # \u2032 prime
    text = text.replace('‵', "'")     # \u2035 reversed prime
    text = text.replace('″', "''")    # \u2033 double prime
    text = text.replace('‶', "''")    # \u2036 reversed double prime
    text = text.replace('‴', "'''")   # \u2034 triple prime
    text = text.replace('‷', "'''")   # \u2037 reversed triple prime
    text = text.replace('⁗', "''''")  # \u2057 quadruple prime

    text = text.replace('…', '...').replace(' . . . ', ' ... ')  # \u2026

    for slash in SLASHES:
        text = text.replace(slash, '/')

    for tilde in TILDES:
       text = text.replace(tilde, '~')

    return text


In [10]:
def hash_document(text: str) -> str:
    """
    Generate a SHA-256 hash for a given text.
    """
    return hashlib.sha256(text.encode()).hexdigest()

In [22]:
class PromptDataset(Dataset):
    def __init__(
        self, 
        corpus: List[Dict],  
        tokenizer: AutoTokenizer,
        search_results: List[Tuple[List[int], List[float]]],
        with_oracle: bool = False,
    ):
        super().__init__()
        self.corpus = corpus
        self.tokenizer = tokenizer
        self.search_results = search_results
        self.data_path = queries_path
        self.max_tokenized_length = max_input_length - 2
        self.do_normalize_query = normalize_queries
        self.num_documents_in_context = num_docs
        self.with_oracle = with_oracle

        self._validate_initialization_parameters()
        self._load_data()


    def _validate_initialization_parameters(self):
        if self.num_documents_in_context <= 0:
            raise ValueError("num_documents_in_context must be positive.")
        
        if self.max_tokenized_length <= 0:
            raise ValueError("max_tokenized_length must be positive.")


    def _load_data(self):
        try:
            with open(self.data_path, "r") as reader:
                data = json.load(reader)
            self.process_file_data(data)
        except IOError as e:
            print(f"Error reading file {self.data_path}: {e}")


    def process_file_data(self, data: List[Dict]):  
        """
        Processes each example in the dataset to prepare prompts for the LLM.

        This involves assembling document contexts, normalizing text as needed,
        and checking against the maximum token length to ensure compatibility with the LLM's input specifications.

        Args:
            data (List[Dict]): The dataset, where each entry contains information about an example,
            including the example's ID, the gold document index, answers, and the query.
        """
        self.ids = []
        self.queries = []
        self.prompts = []
        self.excluded_samples_ids = []
        self.preprocessed_data = []
        self.prompt_tokens_lengths = []

        for i, sample in enumerate(data):
            id = sample['_id']
            query = sample['question']

            formatted_documents, document_indices = [], []

            if self.with_oracle:
                oracle_docs = [{'title': title, 'text': ' '.join(text)} for title, text in sample['context']]
                formatted_documents, document_indices = self._format_documents(oracle_docs, limit_context=False)
            else:
                # Retrieve the top-k documents for the query
                formatted_documents, document_indices = self.prepare_documents_for_prompt(i)

            # Normalize the query & build the prompt
            documents_str = '\n'.join(formatted_documents)
            if self.do_normalize_query:
                query = normalize(query)
            prompt = self.build_qa_prompt(query, documents_str)

            # Check if the prompt exceeds 'max_tokenized_length'
            tokens = self.tokenizer.tokenize(prompt)
            tokens_len = len(tokens)
            if tokens_len >= self.max_tokenized_length:
                self.excluded_samples_ids.append((i, id))
                print("Skipping example {} due to prompt length.".format((i, id)))
                continue

            if len(formatted_documents) < self.num_documents_in_context:
                print(f"Warning: Not enough documents for example {i}.")

            # If the prompt is valid, add it to the dataset
            self.preprocessed_data.append((formatted_documents, list(document_indices)))
            self.ids.append(id)
            self.queries.append(query)
            self.prompts.append(prompt)
            self.prompt_tokens_lengths.append(tokens_len)


    def prepare_documents_for_prompt(self, i: int) -> Tuple[List[str], List[int]]:
        indices = self._get_indices(i) # Get the indices of the top-k documents
        return self._get_documents_from_indices(indices)


    def _get_indices(self, idx: int) -> List[int]:
        indices, _ = self.search_results[idx]
        return indices
    

    def _get_documents_from_indices(self, indices: List[int]) -> Tuple[List[str], List[int]]:
        documents_info = [self.corpus[i] for i in map(int, indices)]
        return self._format_documents(documents_info)
    
    def _format_documents(self, documents_info: List[Dict], limit_context: bool = True, with_corpus_id: bool = True) -> Tuple[List[str], List[int]]:
        seen_hashes = set()
        # List to store the indices of documents actually added
        document_indices = []  
        formatted_documents = []
        for doc_info in documents_info:
            if limit_context and len(formatted_documents) == self.num_documents_in_context:
                break
            
            doc_idx = doc_info['full_corpus_idx'] if 'full_corpus_idx' in doc_info else -1
            title = doc_info['title']
            text = doc_info['text']

            doc_hash = hash_document(text)
            # Skip the document if it is a duplicate
            if doc_hash in seen_hashes:
                continue
            seen_hashes.add(doc_hash)
            
            doc_str = f"Document [{doc_idx}](Title: {title}) {text}"
            formatted_documents.append(doc_str)
            document_indices.append(doc_idx)

        return formatted_documents, document_indices


    def build_qa_prompt(self, query: str, documents_str: str) -> str:
        task_instruction = "You are given a question and you MUST respond by EXTRACTING the answer (max 5 tokens) from one of the provided documents. If none of the documents contain the answer, respond with NO-RES."
        prompt = f"""{task_instruction}\nDocuments:\n{documents_str}\nQuestion: {query}\nAnswer:"""

        # Custom prompt format for mpt models
        if 'mpt' in self.tokenizer.name_or_path:
            INSTRUCTION_KEY = "### Instruction:"
            RESPONSE_KEY = "### Response:"
            INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
            PROMPT_FOR_GENERATION_FORMAT = """{intro}\n{instruction_key}\n{instruction}\n{response_key}""".format(
                intro=INTRO_BLURB,
                instruction_key=INSTRUCTION_KEY,
                instruction="{instruction}",
                response_key=RESPONSE_KEY,
            )
            prompt = PROMPT_FOR_GENERATION_FORMAT.format(
                instruction=prompt[:-8]
            )

        return prompt


    def __getitem__(self, idx: int):
        _, document_indices = self.preprocessed_data[idx]

        return {
            "id": self.ids[idx],
            "query": self.queries[idx],
            "prompt": self.prompts[idx],
            "document_indices": document_indices,
            "prompt_tokens_len": self.prompt_tokens_lengths[idx]
        }
    

    def __len__(self):
        return len(self.ids)

In [23]:
def initialize_dataset_and_loader(
    corpus: List[Dict], 
    search_results: List[Tuple[List[int], List[float]]], 
    tokenizer: PreTrainedTokenizer,
    with_oracle=False
) -> DataLoader:
    
    prompt_ds = PromptDataset(
        corpus=corpus, 
        tokenizer=tokenizer,
        search_results=search_results,
        with_oracle=with_oracle
    )
    
    prompt_dataloader = DataLoader(
        prompt_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )
    return prompt_dataloader

In [25]:
prompt_dataloader = initialize_dataset_and_loader(
    corpus, search_results, tokenizer, True
)
print(f"Initialized prompt dataset with {len(prompt_dataloader)} batches.")

You are given a question and you MUST respond by EXTRACTING the answer (max 5 tokens) from one of the provided documents. If none of the documents contain the answer, respond with NO-RES.
Documents:
Document [427816](Title: Is It Easy to Be Young?) Vai viegli būt jaunam? (Is It Easy to Be Young?) is a Soviet-era Latvian documentary film directed by Juris Podnieks. It was filmed in 1986 with dialog in both Latvian and Russian, and is considered to be among the most controversial movies of its era. It was one of the five winners of the 1987 International Documentary Association awards. The movie speaks about young people who perished as a result of growing up in Soviet society—their conflicts with parents and society, the patronizing attitudes of their teachers and the authorities, the fear that there is no meaning to their lives. Among the young people portrayed are high-schoolers looking for their place in life, a young mother worried about the future of her daughter after the Chernoby

In [None]:
def print_info():
    print("INFO:")
    print(f"DATA: {queries_path}")
    print(f"MODEL: {llm_id}")
    print(f"NUM DOCUMENTS IN CONTEXT: {num_docs}")
    print(f"BATCH SIZE: {batch_size}")
    print(f"SAVE EVERY: {save_every}")

In [None]:
print_info()

In [None]:
def generate_and_save(llm: LLM, prompt_dataloader: DataLoader):
    llm_folder = llm_id.split("/")[1] if '/' in llm_id else llm_id
    saving_dir = f"{llm_response_dir}/{llm_folder}/{num_docs}_doc"
    if not os.path.exists(saving_dir):
        os.makedirs(saving_dir)

    # MPT has a different answer string in the prompt
    answer_string_in_prompt = "### Response:" if 'mpt' in llm_id else "Answer:"

    all_info = []  
    for idx, prompt_batch in enumerate(tqdm(prompt_dataloader)):
        prompts = prompt_batch['prompt']
        with torch.no_grad():
            generated_output = llm.generate(prompts, max_new_tokens=max_output_length)
        
        generated_answers = []
        for output in generated_output:
            start = output.find(answer_string_in_prompt) + len(answer_string_in_prompt)
            response = output[start:].strip()
            generated_answers.append(response)

        prompt_batch['generated_answer'] = generated_answers
        all_info.append(prompt_batch)
        
        if (idx + 1) % save_every == 0 or (idx + 1) == len(prompt_dataloader):
            print(f"Saving at {idx + 1}...")
            file_name = f"{saving_dir}/numdoc{num_docs}_info_{idx+1}.pkl"
            write_pickle(all_info, file_name)
            all_info = []
        
        del prompts, generated_output
        gc.collect()
        torch.cuda.empty_cache()

In [None]:
torch.cuda.empty_cache()
generate_and_save(llm, prompt_dataloader)

In [None]:
import shutil
shutil.make_archive("/kaggle/working/5_doc", 'zip', "/kaggle/working/Llama-2-7b-chat-hf/5_doc")