In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

import torch

print(torch.cuda.get_device_name())

import functools
import ipywidgets

import re

import numpy as np
import chardet
import pandas as pd
pd.set_option('display.max_colwidth', 100)
pd.set_option('display.max_rows', 10000)
import random

import pickle
import spacy

from datasets import load_dataset

from sklearn.utils import shuffle

# Create a spacy natural-language processor for English
nlp = spacy.load("en_core_web_sm")

# build the vocabulary for the tf-idf
import string
from spacy.lang.en.stop_words import STOP_WORDS
from spacy.lang.en import English

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

from sentence_transformers import SentenceTransformer

from faiss import IndexFlatL2, read_index, write_index

NVIDIA A100 80GB PCIe


# DATA AND SUPPORT FUNCTIONS

In [2]:
#NEWS_PATH = "/data01/atrema/20newsgroups/"
#bad_lines = ["From:", "writes:", "wrote:",'.edu', '@', '.gov', 'Subject:']
bad_lines = ['"']
TOPICS = ["talk.politics.misc","sci.med","soc.religion.christian"]
DATASET_LINK = "sentence-transformers/wikipedia-en-sentences"#"Fraser/wiki_sentences"#

In [3]:
def prep_wiki(DATASET_LINK, size):
    data = load_dataset(DATASET_LINK, split='train')
    df = pd.DataFrame({'text' : data['sentence']})[:10**6]
    print('total size : ', len(df))    
    df.text = df.text.apply(lambda x : re.sub(r"\(.*?\)", "", x))
    df.text = df.text.apply(lambda x : rm_special_chars(x))
    df = df[(df.text.str.len() >= 50) & (df.text.str.len() <= 500)]
    print('length range size : ', len(df))

    #for training
    tmp = df[:size] #df.sample(n= size, random_state=88)
    tmp = shuffle(tmp)
    train_collection = tmp.sample(frac=0.8, random_state=99)
    train_set = tmp.drop(train_collection.index)

    #for testing
    gap = 10**5
    tmp = df[gap : gap + size]
    tmp = shuffle(tmp)
    test_collection = tmp.sample(frac=0.8, random_state=99)
    test_set = tmp.drop(test_collection.index)

    #for line in bad_lines:
    #    tmp_1 = df[df['text'].str.contains(line)].index.tolist()
    #    df = df.drop(tmp_1) 
    
    print(f'Train Collection size : {len(train_collection)}')
    print(f'Trainset size : {len(train_set)}')
    print(f'Test Collection size : {len(test_collection)}')
    print(f'Testset size : {len(test_set)}')
    
    return train_collection, train_set, test_collection, test_set


def prep_news(news_types):
    for news_type in news_types:
        with open(NEWS_PATH + news_type + ".txt", 'rb') as f:
            result = chardet.detect(f.read())
            #print(result)
        encoding = result['encoding']
        tmp = pd.read_csv(NEWS_PATH + news_type + ".txt", encoding=encoding, on_bad_lines='skip')
        #topic = tmp.columns[0].strip('Newsgroup:')
        tmp['topic'] = news_type
        tmp =tmp.rename(columns={f'{tmp.columns[0]}' : 'text'})
        #df = df.drop_duplicates(subset=[df.columns.tolist()[0]], keep='first')
        for line in bad_lines:
            tmp_1 = tmp[tmp['text'].str.contains(line)].index.tolist()
            tmp = tmp.drop(tmp_1)
        tmp['text'] = tmp['text'].apply(lambda x : re.sub(r'n[\'|’]t', ' not', x))
        tmp['text'] = tmp['text'].apply(lambda x : rm_special_chars(x))
        #empties = df[df.text == '']
        #print(empties)
        #df = df.drop(empties.index)
        tmp = tmp[tmp['text'].str.len() >= 40]
        tmp = tmp.drop_duplicates(subset=["text"], keep='first')
        tmp_1 = tmp.sample(frac=0.5, random_state=99) # half for collection, half for train-test
        if news_type == news_types[0]:
            collection = tmp.drop(tmp_1.index)
            train = tmp_1.sample(frac=0.8, random_state=88)
            test = tmp_1.drop(train.index)
        else:
            collection_tmp = tmp.drop(tmp_1.index)
            train_tmp =  tmp_1.sample(frac=0.8, random_state=88)
            test_tmp = tmp_1.drop(train_tmp.index)
            collection = pd.concat([collection, collection_tmp])
            train = pd.concat([train, train_tmp])
            test = pd.concat([test, test_tmp])

    train = train.reset_index(drop=True)
    collection = collection['text'].tolist()
    
    print(f"Topics : {train.topic.unique()}")
    print(f'Collection size : {len(collection)}')
    print(f'Trainset size : {len(train)}')
    print(f'Testset size : {len(test)}')
    return collection, train, test

def mask_word(s):
    text = s['text']
    #print(text)
    tmp = text.split(' ')
    tmp = pd.Series(tmp) # this way each word is associated to an index
    possibilities = tmp[tmp.str.len() >= 3] # only choose mask if the word is at least 3 chars long
    #print(possibilities)
    if possibilities.empty :
        s['groundTruth'] = 'DELETE'
        s['text'] = 'DELETE'
        return s
    ind = random.choice(possibilities.index.tolist()) # choose the index
    s['groundTruth'] = tmp.iloc[ind].lower() # save the ground truth
    mask_length = len(tmp.iloc[ind]) 
    tmp.iloc[ind] = '?' * mask_length # f'[[{mask_length} missing characters]]'# mask the word
    s['text'] = (' ').join(tmp)
    return s

def masking(df):
    df = df.apply(mask_word, axis=1)
    df = df.drop(df[df.text == 'DELETE'].index)
    df = df.reset_index(drop=True)
    #print(f'Masked Testset size : {len(df)}')
    return df

def ft_context(s):
    text = s['text']
    tmp = text.split(' ')
    missing_word_index = tmp.index([s for s in tmp if '?' in s][0]) # get the index of the missing word
    tmp[missing_word_index] = "[MASK]"
    mask_sent = (' ').join(tmp)
    s['context'] = search_lm(mask_sent, top_k=3)
    return s

def contexting(df):
    df = df.apply(ft_context, axis=1)
    return df

In [4]:
# remove special characters
def rm_special_chars(s):
    #stripped = re.sub(r'[^\w\s]', ' ', s)
    stripped = re.sub(r'[^A-Za-z0-9 ]+', ' ', s)
    stripped = re.sub(' br ', ' ', stripped)
    stripped = re.sub(' +', ' ', stripped)
    stripped = stripped.strip()
    return stripped
# remove first word, split each entry into single sentences, remove special characters
def prep_context(c):
    out=[]
    for i in c:
        #i = rm_first_word(i)
        #tmp = [rm_special_chars(element).lower() for element in i.split('.') if element]
        tmp = [element for element in i.split('.') if element]
        out.extend(tmp)
    return out

# Create our list of punctuation marks
punctuations = string.punctuation
# Create our list of stopwords
nlp = spacy.load('en_core_web_sm')
stop_words = spacy.lang.en.stop_words.STOP_WORDS
# Load English tokenizer, tagger, parser, NER and word vectors
parser = English()

# Creating tokenizer function
def spacy_tokenizer(sentence):
    # Creating our token object, which is used to create documents with linguistic annotations.
    sentence = rm_special_chars(sentence)
    mytokens = nlp(sentence)
    # Lemmatizing each token and converting each token into lowercase
    mytokens = [word.lemma_.lower() for word in mytokens ]
    # Removing stop words
    mytokens = [ word for word in mytokens if word not in stop_words and word not in punctuations ]

    # return preprocessed list of tokens
    return mytokens

In [5]:
def tf_idify(context):
    vectorizer = TfidfVectorizer(tokenizer= spacy_tokenizer, token_pattern=None)
    matrix = vectorizer.fit_transform(context)
    words = vectorizer.get_feature_names_out()
    #print("Matrix completed")
    doc_vectors = matrix.toarray()
    dimension = doc_vectors.shape[1]
    index = IndexFlatL2(dimension)
    index.add(doc_vectors)
    return matrix, vectorizer, words, index

def search_tfidf(query, vectorizer, index, documents, top_k):
    query_vector = vectorizer.transform([query]).toarray()
    distances, indices = index.search(query_vector, top_k)
    #print(indices)
    #results = [(news_list[i], distances[0][i]) for i in indices[0]]
    results = [documents[i] for i in indices[0]]
    return results

retriever = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

def LM_retriever(documents):
    embeddings = retriever.encode(documents)
    embeddings_dimension = embeddings.shape[1]
    index = IndexFlatL2(embeddings_dimension)
    index.add(embeddings)
    return index

def search_lm(query, top_k):
    query_embedding = retriever.encode([query], show_progress_bar = False)
    distances, indices = lm_index.search(query_embedding, top_k)
    results = [collection[idx] for idx in indices[0]]
    return results

In [6]:
#from datasets import load_dataset, Subset

#ds = load_dataset("jordiclive/wikipedia-summary-dataset", split='train', streaming=True)

## CORRECT MISTAKES
Preprocess the data, then correct it with pretrained Llama

In [13]:
#collection, train_set, test_set = prep_news(TOPICS)
#col1, train1, test1 = prep_news(TOPICS)

Topics : ['talk.politics.misc' 'sci.med' 'soc.religion.christian']
Collection size : 16933
Trainset size : 13547
Testset size : 3387


In [7]:
#from spellchecker import SpellChecker
#spell = SpellChecker()
#train2 = train1[500:520].copy()
#train2['fixed'] = train2.text.progress_apply(spell.correction)

In [7]:
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
# Load the LLama 3 8B model
BASE_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" 

model = LLM(BASE_MODEL)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

INFO 12-13 07:57:02 llm_engine.py:226] Initializing an LLM engine (v0.6.1.dev238+ge2c6e0a82) with config: model='meta-llama/Meta-Llama-3-8B-Instruct', speculative_config=None, tokenizer='meta-llama/Meta-Llama-3-8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=meta-llama/Meta-Llama-3-8B-Instruct, use_v2_block_manager=False, num_scheduler_steps

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]


INFO 12-13 07:57:05 model_runner.py:1025] Loading model weights took 14.9595 GB
INFO 12-13 07:57:07 gpu_executor.py:122] # GPU blocks: 11991, # CPU blocks: 2048
INFO 12-13 07:57:08 model_runner.py:1329] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 12-13 07:57:08 model_runner.py:1333] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 12-13 07:57:24 model_runner.py:1456] Graph capturing finished in 16 secs.


In [8]:
correct_system = """
                You are an automated English spelling check machine.
                Correct any spelling mistakes in the sentence that you are given, and just return the corrected sentence, without any introduction.
                If you find no mistakes, simply return the sentence as it is.
                Use {$examples} to better understand the task.
               """
correct_examples =  """
        These are the examples that you can use to understand how to solve the task :
        
        -Example 1: INPUT= I wishe I was en Dixie. OUTPUT=I wish I was in Dixie. ;
        -Example 2: INPUT= Maybe you coud try to act nycely OUTPUT= Maybe you could try to act nicely ;
        -Example 3: INPUT= The ents justyfy the meens. OUTPUT= The ends justify the means.
            """
CORRECT_PARAMS = SamplingParams(n=1, max_tokens=300)
def correct(s):
    sent = s['text']
    messages = [
                {'role' : 'system', 'content' : correct_system},
                {'role' : 'examples', 'content' : correct_examples},
                {'role' : 'problem', 'content' : sent}
            ]
    formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True,)
    output = model.generate(formatted_prompt, sampling_params=CORRECT_PARAMS, use_tqdm=False)
    out = output[0].outputs
    s['text'] = out[0].text
    return s

In [9]:
from tqdm import tqdm
tqdm.pandas(leave=None)

In [10]:
train_set = train_set.progress_apply(correct, axis=1)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13547/13547 [1:17:56<00:00,  2.90it/s]


In [11]:
test_set = test_set.progress_apply(correct, axis=1)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3387/3387 [19:21<00:00,  2.92it/s]


In [12]:
df_collection = pd.DataFrame({'text' : collection})
df_collection = df_collection.progress_apply(correct, axis=1)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16933/16933 [1:37:13<00:00,  2.90it/s]


In [75]:
#clean collection
to_del = df_collection[df_collection.text.str.contains("s the corrected sentence:")]
df_collection = df_collection.drop(to_del.index)
to_del = df_collection[df_collection.text.str.contains("I cannot correct")]
df_collection = df_collection.drop(to_del.index)
to_del = df_collection[df_collection.text.str.contains("I can't correct")]
df_collection = df_collection.drop(to_del.index)

df_collection.text = df_collection.text.apply(lambda x : rm_special_chars(x))

#clean train_set
to_del = train_set[train_set.text.str.contains("s the corrected sentence:")]
train_set = train_set.drop(to_del.index)
to_del = train_set[train_set.text.str.contains("I cannot correct")]
train_set = train_set.drop(to_del.index)
to_del = train_set[train_set.text.str.contains("I can't correct")]
train_set = train_set.drop(to_del.index)

train_set.text = train_set.text.apply(lambda x : rm_special_chars(x))

#clean test_set
to_del = test_set[test_set.text.str.contains("s the corrected sentence:")]
test_set = test_set.drop(to_del.index)
to_del = test_set[test_set.text.str.contains("I cannot correct")]
test_set = test_set.drop(to_del.index)
to_del = test_set[test_set.text.str.contains("I can't correct")]
test_set = test_set.drop(to_del.index)

test_set.text = test_set.text.apply(lambda x : rm_special_chars(x))

## SAVE DATA

In [14]:
train_collection, train_set, test_collection, test_set = prep_wiki(DATASET_LINK, 20000)

Using the latest cached version of the dataset since sentence-transformers/wikipedia-en-sentences couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /home/atrema/.cache/huggingface/datasets/sentence-transformers___wikipedia-en-sentences/default/0.0.0/4a0972dcb781b5b5d27799798f032606421dd422 (last modified on Mon Dec 16 08:51:10 2024).


total size :  1000000
length range size :  916404
Train Collection size : 16000
Trainset size : 4000
Test Collection size : 16000
Testset size : 4000


### Training

In [15]:
train_collection_list = train_collection.text.tolist()

In [16]:
#create the index with ST embeddings
train_index = LM_retriever(train_collection_list)
write_index(train_index, 'saved_files/train_data/train_index')

In [17]:
with open('saved_files/train_data/train_collection.pickle', 'wb') as f:
    pickle.dump(train_collection_list, f)

train_set.to_csv('saved_files/train_data/train_set_unmasked', index=False)
masked_train = masking(train_set)
masked_train.to_csv('saved_files/train_data/train_set', index=False)

### Testing

In [18]:
test_collection_list = test_collection.text.tolist()

In [19]:
# get the tf-idf matrix, the vectorizer, the set of words in it , for TFIDFRAG
matrix_3, vectorizer_3, words_3, index_3 = tf_idify(test_collection_list)
#create the index with ST embeddings
test_index = LM_retriever(test_collection_list)

In [20]:
write_index(test_index, 'saved_files/test_data/test_index')
write_index(index_3, 'saved_files/test_data/tfidf_index')
with open('saved_files/test_data/vectorizer', 'wb') as f:
    pickle.dump(vectorizer_3, f)
with open('saved_files/test_data/matrix', 'wb') as f:
    pickle.dump(matrix_3, f)

In [21]:
with open('saved_files/test_data/test_collection.pickle', 'wb') as f:
    pickle.dump(test_collection_list, f)
 
test_set.to_csv('saved_files/test_data/test_set_unmasked', index=False)
masked_test = masking(test_set)
masked_test.to_csv('saved_files/test_data/test_set', index=False)

# FINE TUNING

In [16]:
#unless you want an OUT OF MEMORY Error
del model
del tokenizer

In [9]:
from unsloth import FastLanguageModel, is_bfloat16_supported

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


## PREPARE DATA
For RAG model

In [6]:
train_set = pd.read_csv('saved_files/train_data/train_set')
lm_index = read_index('saved_files/train_data/train_index')
with open('saved_files/train_data/train_collection.pickle', 'rb') as f:
    collection = pickle.load(f)

In [7]:
#train_set_cont = contexting(train_set)
#train_set_cont.to_csv('saved_files/train_data/train_set_context', index=False)
model_folder = "outputs/ragged_model"
train_set = pd.read_csv('saved_files/train_data/train_set_context')

For no-RAG model

In [17]:
model_folder = "outputs/unragged_model"
train_set = pd.read_csv('saved_files/train_data/train_set')

## FINE-TUNE

In [18]:
# Check CUDA compatibility
major_version, minor_version = torch.cuda.get_device_capability()

# Instantiate FastLanguageModel
max_seq_length = 2048
dtype = None
load_in_4bit = True

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="meta-llama/Meta-Llama-3-8B-Instruct",
    max_seq_length=4096,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)

==((====))==  Unsloth 2024.11.8: Fast Llama patching. Transformers = 4.46.3.
   \\   /|    GPU: NVIDIA A100 80GB PCIe. Max memory: 79.151 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.4.0+cu121. CUDA = 8.0. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.27.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!




In [19]:
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing=True,
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)

In [20]:
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
    if 'context' in examples.data:
        prompt = """
            You are an English-speaking oracle.
            You are given a sentence where a word has been masked with a number of question marks (?) of equal length.
            Return the masked word, do not produce any more words than the one you have to return!

            ### Instruction:
            {}

            ### Context:
            {}

            ### Response:
            {}
            """
        instructions = examples["text"]
        contexts = examples["context"]
        responses = examples["groundTruth"]
        prompts = []
        for i,j,k  in zip(instructions,contexts,responses):
            text = prompt.format(i,j,k) + EOS_TOKEN
            prompts.append(text)
    else:
        prompt = """
            You are an English-speaking oracle.
            You are given a sentence where a word has been masked with a number of question marks (?) of equal length.
            Return the masked word!

            ### Instruction:
            {}

            ### Response:
            {}
            """
        instructions = examples["text"]
        responses = examples["groundTruth"]
        prompts = []
        for i,j  in zip(instructions,responses):
            text = prompt.format(i,j) + EOS_TOKEN
            prompts.append(text)
    return { "prompt" : prompts, }

from datasets import load_dataset, Dataset
dataset = Dataset.from_pandas(train_set)    
dataset = dataset.map(formatting_prompts_func, batched = True)

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

In [21]:
from trl import SFTTrainer
from transformers import TrainingArguments

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "prompt",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False,
    args = TrainingArguments(
        #num_train_epochs = 5,
        max_steps = 100,
        per_device_train_batch_size = 16,
        gradient_accumulation_steps = 4,
        warmup_steps = 2,
        
        learning_rate = 0.0005,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
    ),
)

# Training 
trainer_stats = trainer.train()

Map (num_proc=2):   0%|          | 0/4000 [00:00<?, ? examples/s]

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
max_steps is given, it will override any value given in num_train_epochs
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 4,000 | Num Epochs = 2
O^O/ \_/ \    Batch size per device = 16 | Gradient Accumulation steps = 4
\        /    Total batch size = 64 | Total steps = 100
 "-____-"     Number of trainable parameters = 41,943,040
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss
1,3.8978
2,3.9119
3,3.5668
4,2.5213
5,1.697
6,1.4327
7,1.1649
8,1.185
9,1.1757
10,1.1629


In [22]:
model_folder

'outputs/unragged_model'

In [23]:
model.save_pretrained_merged(model_folder, tokenizer, save_method = "merged_16bit",)

Unsloth: Merging 4bit and LoRA weights to 16bit...
Unsloth: Will use up to 1450.08 out of 2003.89 RAM for saving.


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 78.34it/s]


Unsloth: Saving tokenizer... Done.
Unsloth: Saving model... This might take 5 minutes for Llama-7b...
Done.


# VLLM RESTORATION


In [24]:
#unless you want an OUT OF MEMORY Error
del model
del tokenizer

## Load Test Data

In [7]:
lm_index = read_index('saved_files/test_data/test_index')
index_3 = read_index('saved_files/test_data/tfidf_index')

with open('saved_files/test_data/vectorizer', 'rb') as f:
    vectorizer_3 = pickle.load(f)
words_3 = vectorizer_3.get_feature_names_out()
with open('saved_files/test_data/matrix', 'rb') as f:
    matrix_3 = pickle.load(f)

test_set = pd.read_csv('saved_files/test_data/test_set')
lm_index = read_index('saved_files/test_data/test_index')
with open('saved_files/test_data/test_collection.pickle', 'rb') as f:
    collection = pickle.load(f)

## Generation and Metrics Functions
Generation functions for the different cases.

#### OLD

In [None]:
def gen_b(s):
    sent = s['text']
    length = sent.count('?')
    messages = [
                {'role' : 'system', 'content' : system},
                {'role' : 'examples', 'content' : examples},
                {'role' : 'problem', 'content' : sent}
            ]

    formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True,)
    #print(formatted_prompt)
    output = model.generate(formatted_prompt, sampling_params=BASELINE_PARAMS, use_tqdm=False,)
    temp = []
    out = output[0].outputs
    #print('Restorations: \n \n')
    for o in out:
        pred = o.text.lower()
        #print(pred)
        pred = rm_special_chars(pred)
        if pred not in temp and len(pred) == length :
            temp.append(pred)
        #print("-----------------------------\n")
    s['restorations'] = sorted(temp)
    return s

#### Generations

In [8]:
def to_gen(msg):
    formatted_prompt = tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True,)
    output = model.generate(formatted_prompt, sampling_params=SAMPLING_PARAMS, use_tqdm=False)
    out = output[0].outputs
    preds = []
    for o in out:
        pair = (o.text, np.exp(o.cumulative_logprob))
        preds.append(pair)
    preds = list(set(preds))
    preds.sort(key= lambda x : x[1], reverse=True)
    return preds

def gen_0(s):
    sent = s['text']
    length = sent.count('?')
    messages = [
                {'role' : 'system', 'content' : system},
                {'role' : 'examples', 'content' : examples},
                {'role' : 'problem', 'content' : sent}
            ]
    s['restorations'] = to_gen(messages)
    return s
    
def gen_1(s):
    sent = s['text']
    #print("Sentence : ",sent, f"\n \n")
    length = sent.count('?')
    tmp = sent.split(' ')
    missing_word_index = tmp.index([s for s in tmp if '?' in s][0]) # get the index of the missing word
    tmp[missing_word_index] = "[MASK]"
    mask_sent = (' ').join(tmp)
    context = f'{context_header} {search_lm(mask_sent, CONTEXT_LEN)}'
    #print(context)
    messages = [
                {'role' : 'system', 'content' : system},
                {'role' : 'examples', 'content' : examples},
                {'role' : 'user', 'content' : sent},
                {'role' : 'context', 'content' : context}
            ]
    s['restorations'] = to_gen(messages)
    return s

def gen_2(s):
    sent = s['text']
    length = sent.count('?')
    tmp = sent.split(' ')
    missing_word_index = tmp.index([s for s in tmp if '?' in s][0]) # get the index of the missing word
    tmp[missing_word_index] = "[MASK]"
    mask_sent = (' ').join(tmp)
    
    context = search_lm(mask_sent, CONTEXT_LEN*5)
    #print(f"Initial context : {context}")
    # this is the part where the context is reduced through tf-idf
    matrix, vectorizer, words, index = tf_idify(context) # get the tf-idf matrix, the vectorizer, the set of words in it
    #tmp_sent = re.sub(r'\[\[(.*) missing characters\]\]', '?'*length, sent)
    tmp = sent.split(' ')
    missing_word_index = tmp.index([s for s in tmp if '?' in s][0]) # get the index of the missing word
    #subword = tmp[missing_word_index].replace('?', '')
    matches = []
    #for cont in context: # take the possible words from within the context
     #   cont = cont.split(" ")
     #   matches.extend([word for word in cont if len(word)==length and word not in matches]) # if word.startswith(subword)
    
    if matches == []:# no matches -> use the vanilla query
        #print("no matches")
        final_context = search_tfidf(sent, vectorizer, index, context, top_k=CONTEXT_LEN)
        #print(final_context)
    else: # take the match with highest tf-idf score
        #print(matches)
        final_context = []
        for j in range(len(matches)):
            tmp1 = tmp
            tmp1[missing_word_index] = matches[j]
            matches[j] = (' ').join(tmp1)
            fin_con = search_tfidf(matches[j], vectorizer, index, context, top_k=CONTEXT_LEN)
            #print(fin_con)
            final_context.extend(fin_con)
        final_context = list(set(final_context))[:CONTEXT_LEN]
        #print("final context:\n ", final_context)
    #print(final_context)
    context = f'{context_header} {final_context}'
    # the subset thus determined is used to generate the answer
    messages = [
                {'role' : 'system', 'content' : system},
                {'role' : 'examples', 'content' : examples},
                {'role' : 'user', 'content' : sent},
                {'role' : 'context', 'content' : context}
            ]
    s['restorations'] = to_gen(messages)
    return s

def gen_3(s):
    sent = s['text']
    #print("Sentence : ",sent, f"\n \n")
    length = sent.count('?')
    # this is the part where the context is found through tf-idf
    # {matrix,vectorizer,words}_3 are used
    #tmp_sent = re.sub(r'\[\[(.*) missing characters\]\]', '?'*length, sent)
    tmp = sent.split(' ')
    missing_word_index = tmp.index([s for s in tmp if '?' in s][0]) # get the index of the missing word

    #get the possible matches amongst the words present in the tf-idf matrix
    
    final_context = search_tfidf(sent, vectorizer_3, index_3, collection, top_k=CONTEXT_LEN) # FOR NO SUBWORDS
    #same_len_words = [word for word in words_3 if len(word)==length]
    #if len(same_len_words) <= CONTEXT_LEN:
    #    matches= same_len_words
    #else:
    #    matches= random.choices(same_len_words, k=CONTEXT_LEN) #cannot use all subwords otherwise it's unfeasible
    # take the match with highest tf-idf score
    #print(matches)
    #final_context = []
    #for j in range(len(matches)):
    #    tmp1 = tmp
    #    tmp1[missing_word_index] = matches[j]
    #    matches[j] = (' ').join(tmp1)
    #    fin_con = search_tfidf(matches[j], vectorizer_3, index_3, collection, top_k=CONTEXT_LEN)
        #print(fin_con)
    #    final_context.extend(fin_con)
    #final_context = list(set(final_context))[:CONTEXT_LEN]
    context = f'{context_header} {final_context}'
    #print(context)
    # the subset thus determined is used to generate the answer
    messages = [
                {'role' : 'system', 'content' : system},
                {'role' : 'examples', 'content' : examples},
                {'role' : 'user', 'content' : sent},
                {'role' : 'context', 'content' : context}
            ]
    s['restorations'] = to_gen(messages)
    return s

def gen_random(s):
    sent = s['text']
    #print("Sentence : ",sent, f"\n \n")
    length = sent.count('?')
    tmp = sent.split(' ')
    missing_word_index = tmp.index([s for s in tmp if '?' in s][0]) # get the index of the missing word
    context = f'{context_header} {random.sample(collection, CONTEXT_LEN)}'
    #print(context)
    messages = [
                {'role' : 'system', 'content' : system},
                {'role' : 'examples', 'content' : examples},
                {'role' : 'user', 'content' : sent},
                {'role' : 'context', 'content' : context}
            ]
    s['restorations'] = to_gen(messages)
    return s

In [9]:
def generation(test, mode):
    modes = ['noRAG', 'RAGnormal', 'RAGwithTFIDF', 'TFIDFRAG', 'base', 'random']
    mode = modes[mode]
    if mode == 'noRAG':
        test = test.progress_apply(gen_0, axis=1)
    elif mode == 'RAGnormal':
        test = test.progress_apply(gen_1, axis=1)
    elif mode == 'RAGwithTFIDF':
        test = test.progress_apply(gen_2, axis=1)
    elif mode == 'TFIDFRAG': 
        test = test.progress_apply(gen_3, axis=1)
    elif mode == 'base': 
        test = test.progress_apply(gen_b, axis=1)
    elif mode == 'random': 
        test = test.progress_apply(gen_random, axis=1)
    return test

##### MODES
# [0] : noRAG
# [1] : RAGnormal
# [2] : RAGwithTFIDF, get context from embeddings and reduce with tfidf
# [3] : TFIDFRAG
# [4] : baseline
# [5] : random

#### Metrics

In [18]:
from torchmetrics.text import CharErrorRate
CER = CharErrorRate()

def metrics(result):
    gt = result.groundTruth
    res = result.restorations
    cer = CER(res, [gt]*len(res))
    result['characterErrorRate'] = cer.item()
    if gt in res :
        result['exactMatch'] = 1
    else:
        result['exactMatch'] = 0
    return result

def topic_score(results):
    for t in results.topic.unique():
        tmp = results[results['topic'] == t]
        #print(len(tmp))
        print(f"Exact match score for {t} : {round(tmp.apply(metrics, axis=1)['exactMatch'].mean(axis=0) * 100, 2)}%")
    print(f"\n Average : {round(results.apply(metrics, axis=1)['exactMatch'].mean(axis=0)* 100, 2)}%")

def score(results):
    res = results.apply(metrics, axis=1)
    print(f"Exact match score : {round(res['exactMatch'].mean(axis=0)* 100, 2)}%")
    print(f"Average CER : {round(res['characterErrorRate'].mean(axis=0), 2)}")
    return res
    #print(f"Average CER : {round(results.apply(metrics, axis=1)['characterErrorRate'].mean(axis=0), 2)}%")

In [19]:
SLICES = [1, 5, 50, 200]
def ranks(result):
    gt = result.groundTruth
    res = result.restorations
    for slice in SLICES:
        choices = [x[0] for x in res[:slice] if gt in x]
        if gt in choices:
            result[f'match@{slice}'] = 1
        else:
            result[f'match@{slice}'] = 0
    return result
            
def rank_score(results):
    res = results.apply(ranks, axis=1)
    print(f"Match@1 score : {round(res['match@1'].mean(axis=0)* 100, 2)}%")
    print(f"Match@5 score : {round(res['match@5'].mean(axis=0)* 100, 2)}%")
    print(f"Match@50 score : {round(res['match@50'].mean(axis=0)* 100, 2)}%")
    print(f"Match@200 score : {round(res['match@200'].mean(axis=0)* 100, 2)}%")
    return res

## Load Model

In [None]:
del model
del tokenizer

In [None]:
#model_folder= "outputs/unragged_model"
model_folder= "outputs/ragged_model"
#model_folder= "meta-llama/Meta-Llama-3-8B-Instruct"

In [None]:
from vllm import LLM
from transformers import AutoTokenizer
from vllm import SamplingParams

model = LLM(model_folder)
tokenizer = AutoTokenizer.from_pretrained(model_folder) #"meta-llama/Meta-Llama-3-8B-Instruct")

## Prompt and Sampling Parameters

In [13]:
from transformers import LogitsProcessor

class SuppressTokenProcessor(LogitsProcessor):
    def __init__(self, token_ids):
        self.token_ids = token_ids
        
    def __call__(self, input_ids, scores):
        # Set the logit of the specific token to a large negative value
        for token_id in self.token_ids:
            scores[token_id] = -float("inf")
        return scores
        
token_id_to_block = [198, 271] #tokenizer.convert_tokens_to_ids(r"life")
suppress_processor = SuppressTokenProcessor(token_id_to_block)

In [14]:
#MESSAGES

SAMPLING_PARAMS = SamplingParams(n=200, temperature=1, max_tokens=10, top_k = 50, logprobs=1,logits_processors=[suppress_processor]) #, best_of=9, frequency_penalty=0.0, temperature=0.2 logit_bias = {198 : 0.0}
BASELINE_PARAMS = SamplingParams(n=1, temperature=1, max_tokens=10, top_k = 50)
CONTEXT_LEN = 3

system = """
        You are an English-speaking oracle.
        You are given a sentence where a word has been masked with a number of question marks (?) of equal length.
        Return the masked word, do not produce any more words than the one you have to return!
        Use {$examples} to better understand how to produce good solutions.
        Use {$context} to guide you towards the correct answer. 
         """
examples =  """
        These are the examples that you can use to understand how to solve the task :
        
        -Example 1: INPUT= the ???? is on the table. OUTPUT=book ;
        -Example 2: INPUT= I ???? you. OUTPUT=love ;
        -Example 3: INPUT= There was a ????? cat on my porch. OUTPUT=stray
            """

context_header = "Use these documents to find the correct answer to the task : "

# TESTING

In [15]:
from tqdm import tqdm
tqdm.pandas(leave=None)

## Plain LLAMA

In [49]:
a = res.restorations[0]
b = ['tissue']*len(a)

In [50]:
CER(a,b)

tensor(1.9103)

In [14]:
test_slice = test_set[:200]
if model_folder == 'meta-llama/Meta-Llama-3-8B-Instruct':
    RES_0 = generation(test_slice, 0)
    RES_1 = generation(test_slice, 1)
    RES_2 = generation(test_slice, 2)
    RES_3 = generation(test_slice, 3)
    RES_R = generation(test_slice, 5)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:32<00:00,  2.17it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:43<00:00,  1.93it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:58<00:00,  1.69it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [02:11<00:00,  1.52it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:39<00:00,  2.01it/s]


In [15]:
modes = ["No RAG", "Transformer RAG", "Mixed RAG", "TF-IDF RAG", "Random RAG"]
results = [RES_0, RES_1, RES_2, RES_3, RES_R]
for i in range(len(results)):
    print(f"{modes[i]} results : \n")
    score(results[i])
    print("\n -------------------------------------------")

No RAG results : 

Exact match score : 17.0%

 -------------------------------------------
Transformer RAG results : 

Exact match score : 0.0%

 -------------------------------------------
Mixed RAG results : 

Exact match score : 0.0%

 -------------------------------------------
TF-IDF RAG results : 

Exact match score : 0.0%

 -------------------------------------------
Random RAG results : 

Exact match score : 0.5%

 -------------------------------------------


## UNRAGGED MODEL

In [16]:
test_slice = test_set[:1000]
if model_folder == 'outputs/unragged_model':
    RES_0 = generation(test_slice, 0)
    RES_1 = generation(test_slice, 1)
    RES_2 = generation(test_slice, 2)
    RES_3 = generation(test_slice, 3)
    RES_R = generation(test_slice, 5)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [04:46<00:00,  3.49it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [04:49<00:00,  3.45it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [06:29<00:00,  2.57it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [06:55<00:00,  2.40it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [04:31<00:00,  3.68it/s]


In [20]:
modes = ["No RAG", "Transformer RAG", "Mixed RAG", "TF-IDF RAG", "Random RAG"]
results = [RES_0, RES_1, RES_2, RES_3, RES_R]
for i in range(len(results)):
    print(f"{modes[i]} results : \n")
    rank_score(results[i])
    print("\n -------------------------------------------")

No RAG results : 

Match@1 score : 40.9%
Match@5 score : 62.4%
Match@50 score : 73.4%
Match@200 score : 73.6%

 -------------------------------------------
Transformer RAG results : 

Match@1 score : 26.5%
Match@5 score : 46.6%
Match@50 score : 62.0%
Match@200 score : 62.0%

 -------------------------------------------
Mixed RAG results : 

Match@1 score : 26.8%
Match@5 score : 46.9%
Match@50 score : 63.6%
Match@200 score : 63.9%

 -------------------------------------------
TF-IDF RAG results : 

Match@1 score : 26.4%
Match@5 score : 47.4%
Match@50 score : 62.7%
Match@200 score : 62.8%

 -------------------------------------------
Random RAG results : 

Match@1 score : 28.2%
Match@5 score : 46.9%
Match@50 score : 62.4%
Match@200 score : 62.4%

 -------------------------------------------


## RAGGED MODEL

In [None]:
test_slice = test_set[:1000]
if model_folder == 'outputs/ragged_model':
    RES_0 = generation(test_slice, 0)
    RES_1 = generation(test_slice, 1)
    RES_2 = generation(test_slice, 2)
    RES_3 = generation(test_slice, 3)
    RES_R = generation(test_slice, 5)

In [None]:
modes = ["No RAG", "Transformer RAG", "Mixed RAG", "TF-IDF RAG", "Random RAG"]
results = [RES_0, RES_1, RES_2, RES_3, RES_R]
for i in range(len(results)):
    print(f"{modes[i]} results : \n")
    rank_score(results[i])
    print("\n -------------------------------------------")

## RESULTS

In [17]:
%%html
<style>
table {float:left}
</style>

| TEST TYPE | UNRAGGED@1 | UNRAGGED@5 | UNRAGGED@50 | RAGGED@1 | RAGGED@5 | RAGGED@50 |
|:--|---|---|---|---|---|---|
| NO RAG | 14.7% | 34.05% | 43.28% | **29.3%** | 50.3% | 63.4% |
| TRANSFORMER RAG | 21.5% | 40.0% | 56.3% | 20.9% | 37.15% | 49.62% |
| MIX RAG | 21.42% | 40.08%  | 56.45% | 21.08% | 37.48% | 49.28% |
| TFIDF RAG | **21.95%** | 40.1% | 56.3% | 20.42% | 36.78% | 48.92% |
| RANDOM RAG | 19.8% | 38.52% | 53.52% | 19.8% | 36.25% | 47.15% |

In [57]:
RES_COMP = RES_0.assign(res3 = RES_3.restorations.tolist())
RES_COMP

Unnamed: 0,text,groundTruth,restorations,res3
0,In the European Union a sperm bank must have a license according to the EU ?????? Directive,tissue,"[(tissue\n\n, 0.8015212861668876), (tissue, 0.0849823484126026), (tissues\n\n, 0.044999445451738...","[(tissue, 0.38545657312747666), (tissue\n\n, 0.22612123969402853), (directives, 0.05991818200197..."
1,Carbonel is a children s book series by Barbara Sleigh first published by Puffin ????? from 1955...,books,"[(books, 0.85199211443129), (books\n, 0.11495957935577418), (press, 0.013408353111830263), (pres...","[(books, 0.9305413571777537), (penguin, 0.014429245639035908), (book, 0.012764911731749403), (bo..."
2,Miopelodytes is an extinct genus of ??????????? amphibian,prehistoric,"[(terrestrial, 0.06965743789623417), (extinct\n, 0.039728992539329454), (amphibian, 0.0361697824...","[(reptiliates, 0.178133745279974), (amphibian, 0.10853905819744253), (reptile, 0.070264330985776..."
3,This ???? details Toronto FC II s league results from their inaugural season in 2015 as a member...,page,"[(this page\n, 0.5826369649500325), (this\n, 0.18384528868363628), (this\n\n, 0.0795800184578971...","[(the, 0.5064527395237556), (this, 0.4012886966887108), (their, 0.02255067046017188), (these, 0...."
4,Designed by Alex B Mahood it is the most significant example of ??? style in southern West Virginia,the,"[(this\n, 0.48238580255375724), (the\n, 0.2250722088679745), (this, 0.12290734459912071), (that\...","[(the, 0.6594956340774109), (this, 0.11119873605736523), (the\n, 0.06888191376778897), (the\n\n,..."
5,The department chairs and therefore the curriculum coordinators for their area are Jennifer Dani...,languages,"[(languages\n, 0.5588878227068412), (languages, 0.3404437805276191), (language, 0.03454838296268...","[(language, 0.40445978643099584), (languages, 0.4005024791688289), (langues, 0.02558471411505919..."
6,The ????? African Football Association is the second Football Association in South Africa to be ...,south,"[(third\n, 0.21318552437420488), (other\n, 0.18506445924844817), (first\n, 0.17563974769381283),...","[(the, 0.7227887877786677), (the\n\n, 0.12411718701271347), (south\n\n, 0.017376924752221475), (..."
7,After this ??? League of Nations formalized the UK s control of the area who renamed it Tanganyika,the,"[(the\n, 0.4209599613720795), (the\n\n, 0.30749458860242773), (and, 0.04954773466043427), (in\n,...","[(the, 0.421398291425599), (after\n, 0.10155288555918064), (this, 0.05554142915536445), (after, ..."
8,However despite its large black population it was also the last country in ??? western hemispher...,the,"[(the\n, 0.5026006124067871), (the, 0.3960566330181435), (in\n, 0.0781248916623479), (in, 0.0044...","[(the, 0.765764630035786), (and, 0.03271540957745643), (hemisphere, 0.027481986830546465), (heme..."
9,The CEP is comprehensive covering goods services investment and technical and ??????? quarantine...,hygiene,"[(financial, 0.21462900759461648), (non\n, 0.11117241125515175), (customs, 0.08231154810115898),...","[(financial, 0.2316252215319915), (and\n\n, 0.12993205872531727), (scientific, 0.066283065206732..."


# Data Analysis

In [34]:
masked_train = pd.read_csv('saved_files/train_data/train_set')
masked_test = pd.read_csv('saved_files/test_data/test_set')

In [35]:
train_masks = masked_train.groundTruth.tolist()
test_masks = masked_test.groundTruth.tolist()

In [36]:
intersection = set(train_masks).intersection(test_masks)
print(len(intersection))

670


In [55]:
for r in [RES_0, RES_1, RES_2, RES_3, RES_R]:
    print("mean number of restorations : ", r.restorations.str.len().mean(axis=0))

mean number of restorations :  36.28325
mean number of restorations :  49.69425
mean number of restorations :  49.4495
mean number of restorations :  49.3885
mean number of restorations :  50.17675


# Output Analysis

In [9]:
from torchmetrics.text import CharErrorRate
preds = "there is an other sample"
target = "pingus"
cer = CharErrorRate()
cer(preds, target)


tensor(3.5000)

In [137]:
def produce(s):
    messages = [
                {'role' : 'system', 'content' : system},
                {'role' : 'examples', 'content' : examples},
                {'role' : 'problem', 'content' : s}
            ]
    formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    output = model.generate(formatted_prompt, sampling_params=SAMPLING_PARAMS, use_tqdm=False)
    #out = output[0].outputs
    #s['text'] = out[0].text
    return output

In [138]:
sentence = "Samuel stared at the ???????? body of his comrade, his eyes blank as cement."
#output = produce(sentence)
#outs = output[0].outputs
TEMP = 1
TOP_K = 50
TOP_P = 0.9

In [140]:
output = produce(sentence)
preds = []
for out in output[0].outputs:
    pair = (out.text, np.exp(out.cumulative_logprob))
    preds.append(pair)
    preds = list(set(preds))

In [142]:
preds

[('deadly', 0.0005248176041512274),
 ('grayishfaceless', 1.7048421956564242e-07),
 ('lifeless', 0.4259134668555747),
 ('bleeding', 0.017377057214856363),
 ('bleibedneye', 1.2312695885729156e-08),
 ('coldly', 0.006265890365465455),
 ('black', 0.0036943084194194426),
 ('laborate', 0.00010760852485483042),
 ('lifeless', 0.4257827773842596),
 ('lifeless', 0.41691222045264265),
 ('blasted', 0.00028454739051560034),
 ('bleeding', 0.020839705233505104),
 ('lifeless', 0.34267461537226956),
 ('wound', 0.0007208693811802582),
 ('bloated', 0.019919431754298663),
 ('lifeless', 0.37317105036570625),
 ('deadlycoldly', 0.00021136759999117538),
 ('grayishpurple', 4.535614405373014e-05)]

In [141]:
output[0].outputs

[CompletionOutput(index=16, text='lifeless', token_ids=(14789, 1752, 128009), cumulative_logprob=-0.8535190827933548, logprobs=[{14789: Logprob(logprob=-0.852980375289917, rank=1, decoded_token='life')}, {1752: Logprob(logprob=-5.507317473529838e-05, rank=1, decoded_token='less')}, {128009: Logprob(logprob=-0.0004836343287024647, rank=1, decoded_token='')}], finish_reason=stop, stop_reason=None),
 CompletionOutput(index=7, text='lifeless', token_ids=(14789, 1752, 128009), cumulative_logprob=-0.8538259750057478, logprobs=[{14789: Logprob(logprob=-0.852980375289917, rank=1, decoded_token='life')}, {1752: Logprob(logprob=-0.00031609306461177766, rank=1, decoded_token='less')}, {128009: Logprob(logprob=-0.0005295066512189806, rank=1, decoded_token='')}], finish_reason=stop, stop_reason=None),
 CompletionOutput(index=8, text='lifeless', token_ids=(14789, 1752, 128009), cumulative_logprob=-0.8538259750057478, logprobs=[{14789: Logprob(logprob=-0.852980375289917, rank=1, decoded_token='life')

In [39]:
preds.sort(key= lambda x : x[1], reverse=True)

In [69]:
p = [x[0] for x in preds if 'lifeless' in x]

In [70]:
if 'lifeless' in p:
    print(1)

1


In [40]:
preds

[('lifeless', 0.24838454175026836),
 ('lifeless\n', 0.1694189835487959),
 ('dead\n', 0.08346258968935745),
 ('bloody\n', 0.06630782704125304),
 ('bloated\n', 0.04019607728779919),
 ('gray\n', 0.020963977036142396),
 ('mangled\n', 0.020626719397703386),
 ('still\n', 0.018628639097749224),
 ('bleeding\n', 0.018201381413304716),
 ('stiff\n', 0.018171465029939413),
 ('white\n', 0.012805584770657366),
 ('broken\n', 0.011269482438968253),
 ('the\n', 0.011251886977458766),
 ('frozen\n', 0.009453893780624775),
 ('burned\n', 0.00696852342913257),
 ('decaying\n', 0.005996403124781332),
 ('rotting\n', 0.005741898230900414),
 ('crushed\n', 0.005680046022226747),
 ('death\n', 0.005319049721216923),
 ('motionless', 0.004892055661451553),
 ('lifeless\n\n', 0.004521102830132614),
 ('terrible\n', 0.004280428674724361),
 ('decomposing\n', 0.004014663941239302),
 ('deceased\n', 0.00332542092961323),
 ('corpse', 0.0028902961921068024),
 ('ghastly\n', 0.0028029364559456747),
 ('burnt\n', 0.0025656808854883

In [176]:
SAMPLING_PARAMS = SamplingParams(n=1, temperature=TEMP, max_tokens=10, top_k =TOP_K, top_p=TOP_P, logprobs=20,) #, best_of=9, frequency_penalty=0.0, temperature=0.2
for i in range(10):
    print(f"generation n.{i}")
    output = produce(sentence)
    outs = output[0].outputs
    out = outs[0]
    print(f'{out.text}')
    print(np.exp(out.cumulative_logprob))
    #print(out.logprobs)
    print("-"*10)

generation n.0
lifeless
0.2788471149753743
----------
generation n.1
lifeless
0.2788471149753743
----------
generation n.2
dead

0.09264611688085804
----------
generation n.3
mutilated

0.007316358055034805
----------
generation n.4
lifeless

0.19164863254436829
----------
generation n.5
bloated

0.04845636856549938
----------
generation n.6
frozen

0.011065000737466316
----------
generation n.7
stiff

0.022699197158231117
----------
generation n.8
lifeless
0.2788471149753743
----------
generation n.9
lifeless

0.19164863254436829
----------


In [177]:
SAMPLING_PARAMS = SamplingParams(n=10, temperature=TEMP, max_tokens=10, top_k =TOP_K, top_p=TOP_P, logprobs=20,) #, best_of=9, frequency_penalty=0.0, temperature=0.2
output = produce(sentence)
outs = output[0].outputs
for out in outs:
    print(out.text)
    print(np.exp(out.cumulative_logprob))
    #print(out.logprobs)
    print("-"*10)

lifeless
0.2788471149753743
----------
lifeless

0.19164863254436829
----------
dead

0.09264611688085804
----------
bloody

0.07050360024637689
----------
bloody

0.07050360024637689
----------
gray

0.023424606745035447
----------
gray

0.023424606745035447
----------
battered

0.0032946792106591655
----------
corpses
0.0030395729810376404
----------
tangled

0.001441004922395057
----------


In [84]:
for out in outs:
    print(out.text)
    print(np.exp(out.cumulative_logprob))
    #print(out.logprobs)
    print("-"*10)

land

0.8153627968904047
----------
land

0.8153627968904047
----------
land

0.8153627968904047
----------
land

0.8153627968904047
----------
land

0.8153627968904047
----------
land

0.8153627968904047
----------
land

0.8153627968904047
----------
land

0.8153627968904047
----------
until

0.018738851306622433
----------
blood

0.0012247584127716852
----------


# Other

In [23]:
#RES_COMP = RES_0.assign(res1 = RES_1.restorations.tolist())
#RES_COMP

In [39]:
sentence = "The invasion of Iraq was justified with ???????? which later proved to be false"
d = {'text': sentence, 'groundTruth' : 'evidence'}
qualitative = pd.Series(data = d)

print(qualitative)
print("-"*10)
qual_results = gen_0(qualitative)
print(f"Simple Gen: {qual_results.restorations}")
print("-"*10)
qual_results = gen_1(qualitative)
print(f"Simple RAG: {qual_results.restorations}")
print("-"*10)
qual_results = gen_2(qualitative)
print(f"Mixed RAG: {qual_results.restorations}")
print("-"*10)
qual_results = gen_3(qualitative)
print(f"TFIDF RAG: {qual_results.restorations}")

text           The invasion of Iraq was justified with ???????? which later proved to be false
groundTruth                                                                           evidence
dtype: object
----------
Simple Gen: ['evidence', 'premises']
----------
Use these documents to find the correct answer to the task :  ['Arafat and Saddam Hussein and defended the placing of car bombs', 'First off they could recognise Iraqu s responsibility in initiating the', '3 He sttod firm in refusing the advice of the false prophets and insisted', 'the theological realm The false prophets understood their relatioship', 'loathed and then supported Saddam when he mounted an unprovoked attack', 'in what is now Iraq If this document were to actually turn up', 'christians verbally attack people who might otherwise have been won to', '2 Does that make the speaker a false prophet', 'simulations based around this conflict Great to go and bomb Saddam s', 'truth who would admit the possibility of errors 