In [2]:
import os
from tqdm import tqdm
import random
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
from difflib import SequenceMatcher
from huggingface_hub import notebook_login

from pymilvus import connections, Collection, DataType, FieldSchema, CollectionSchema, utility
from milvus import default_server
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer


isLinux = True
default_linux_path = os.getcwd().replace("/Data", "/Documents/Downloaded") if "/Data" in os.getcwd() else os.getcwd() + "/Documents/Downloaded"
default_windows_path = os.getcwd().replace("\\Data", "\\Documents\\Downloaded") if "\\Data" in os.getcwd() else os.getcwd() + "\\Documents\\Downloaded"
default_path = default_linux_path if isLinux else default_windows_path

DEFAULT_SAVE_DIR = default_path.replace("/Downloaded", "/Generated") if isLinux else default_path.replace("\\Downloaded", "\\Generated")
LAWS_CSV = DEFAULT_SAVE_DIR + ('/laws.csv' if isLinux else '\\laws.csv')

# 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
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

4


### Check the best LLMs for answer legislative quizzes

In [None]:
# List of models to test
models = {
    "Saul": {'model_name': 'Equall/Saul-7B-Instruct-v1', 'context_window': 1024, 'prompt_function': lambda system_prompt, user_prompt: f"<|system|>\n{system_prompt}\n|<user>|\n{user_prompt}\n|<assistant>|\n\n"}, #Modello addestrato su testi legali
    #"Llamantino": {'model_name': 'swap-uniba/LLaMAntino-2-7b-hf-dolly-ITA', 'context_window': 8000, 'prompt_function': lambda system_prompt, user_prompt: f"Di seguito è riportata un'istruzione che descrive un'attività, abbinata ad un input che fornisce ulteriore informazione.\nScrivi una risposta che soddisfi adeguatamente la richiesta.\n\n### Istruzione:\n{system_prompt}\n\n### Input:\n{user_prompt}\n\n### Risposta:\n"}, # Doesn't work with transformers
    "Meta-Llama 8B": {'model_name': 'meta-llama/Meta-Llama-3.1-8B-Instruct', 'context_window': 8000, 'prompt_function': lambda system_prompt, user_prompt: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"},
    #"Meta-Llama 70B": {'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct', 'context_window': 8000, 'prompt_function': lambda system_prompt, user_prompt: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"},
    "Falcon-7B": {'model_name': 'tiiuae/falcon-7b-instruc t', 'context_window': 512, 'prompt_function': lambda system_prompt, user_prompt: f"User: {user_prompt}\nAssistant:{system_prompt}"},
    #"Mixtral-8x22B": {'model_name': 'mistralai/Mixtral-8x22B-Instruct-v0.1', 'context_window': 1024, 'prompt_function': lambda system_prompt, user_prompt: f"[INST] {system_prompt} {user_prompt}\n[/INST]"},
    #"Minerva-3B": {'model_name': 'sapienzanlp/Minerva-3B-base-v1.0', 'context_window': 512, 'prompt_function': lambda system_prompt, user_prompt: f"{system_prompt} {user_prompt}"}, # Modello italiano della Sapienza
    #"deepset/roberta-base-squad2" : {'model_name': 'deepset/roberta-base-squad2', 'context_window': 512}, # Modello per il question answering
    #"Phi-small" : {'model_name': 'microsoft/Phi-3-small-4k-instruct', 'context_window': 8000, 'prompt_function': lambda system_prompt, user_prompt: f"<|system|>\n{system_prompt}\n|<user>|\n{user_prompt}\n|<assistant>|\n\n"},
    #"Phi-medium" : {'model_name': 'microsoft/Phi-3-medium-4k-instruct', 'context_window': 8000, 'prompt_function': lambda system_prompt, user_prompt: f"<|system|>\n{system_prompt}\n|<user>|\n{user_prompt}\n|<assistant>|\n\n"},
    #"Phi-medium-quantized" : {'model_name': 'kaitchup/Phi-3-medium-128k-instruct-awq-4bit', 'context_window': 8000, 'prompt_function': lambda system_prompt, user_prompt: f"<|system|>\n{system_prompt}\n|<user>|\n{user_prompt}\n|<assistant>|\n\n"},
}

df_quiz = pd.read_csv(DEFAULT_SAVE_DIR + '/quiz_merged.csv')

# Initialize the models and generate answers
for model_name, model_data in models.items():
    model_id, context_window, prompt_function = model_data['model_name'], model_data['context_window'], model_data['prompt_function']
    print(f'Running model: {model_name}')
    
    pipeline = transformers.pipeline(
        "text-generation",
        model=model_id,
        model_kwargs={"torch_dtype": torch.bfloat16},
        device_map="cuda",
    )

    correct_count = 0
    for index, row in df_quiz.iterrows():
        messages = [
            {"role": "system", "content": "You are an expert in the field of law, and you are gonna replay to the following quiz. You have to choose the correct answer among the three options. Just use the question and the answers as context."},
            {"role": "user", "content": row['question']+row['answer_1']+row['answer_2']+row['answer_3']},
        ]

        # Generate answer
        outputs = pipeline(
            messages,
            max_new_tokens=1000,
        )
        ans = outputs[0]["generated_text"][-1]
        
        # Check which answer is more similar to the generated one
        answers = [row['answer_1'], row['answer_2'], row['answer_3']]
        similarities = [SequenceMatcher(None, ans, a).ratio() for a in answers]
        most_similar_answer = answers[similarities.index(max(similarities))]
        
        # Check if the correct answer is within the generated answer
        if most_similar_answer == row['answer_1']:
            correct_count += 1

    # Calculate and print accuracy
    accuracy = correct_count / len(df_quiz)
    print(f'Accuracy of {model_name}: {accuracy}')


### Test quizzes on Milvus RAG

In [6]:
def connect_to_milvus():
    try:
        connections.connect("default", host="0.0.0.0")
    except:
        default_server.start()
        
def drop_everything():
    collections = utility.list_collections()

    for collection in collections:
        utility.drop_collection(collection)

def create_collection():
    laws_fields = [
        FieldSchema(name="law_id", dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=4096),
        FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=50),
        FieldSchema(name="article", dtype=DataType.VARCHAR, max_length=50),
        FieldSchema(name="comma", dtype=DataType.VARCHAR, max_length=50),
        FieldSchema(name="comma_content", dtype=DataType.VARCHAR, max_length=5000)
    ] 

    schema = CollectionSchema(laws_fields, "laws collection")

    laws_collection = Collection(name="laws_collection", schema=schema)
    laws_collection.create_index("embedding", {"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 128}})    
    
    return laws_collection

def load_model(model_name):
    model = AutoModel.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
        
    return model, tokenizer

def generate_embedding(text, tokenizer, model):
    inputs = tokenizer(text, return_tensors="pt", max_length=128, truncation=True)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    embedding_tensor = outputs.last_hidden_state.mean(dim=1).squeeze()
    embedding_list = embedding_tensor.tolist()
    
    return embedding_list

def load_data_and_generate_embeddings(data, model, tokenizer):
    data = data[:3]

    embeddings = []
    for cc in tqdm(data["Comma content"], total=data.shape[0]):
        embeddings.append(generate_embedding(cc, tokenizer, model))
    data["Embedding"] = embeddings

    return data

def insert_data_into_milvus(collection, dataWithEmbeddings):
    source_list = dataWithEmbeddings["Source"].tolist()
    article_list = dataWithEmbeddings["Article"].tolist()
    comma_list = dataWithEmbeddings["Comma number"].tolist()
    comma_content_list = dataWithEmbeddings["Comma content"].tolist()
    embedding_list = dataWithEmbeddings["Embedding"].tolist()
        
    data = []
    for i in range(len(embedding_list)):
        data.append({
            "embedding": embedding_list[i],     # Embedding (FLOAT_VECTOR)
            "source": source_list[i],           # Source (VARCHAR)
            "article": article_list[i],         # Article (VARCHAR)
            "comma": comma_list[i],             # Comma (VARCHAR)
            "comma_content": comma_content_list[i]  # Comma content (VARCHAR)
        })
    
    collection.insert(data)
    
    collection.flush()
    collection.load()
    
def search_similar_text(collection, text, tokenizer, model, top_k=5):    
    # Generate the embedding for the input text
    embedding = generate_embedding(text, tokenizer, model)
    
    # Perform a search on the collection
    search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
    results = collection.search([embedding], "embedding", search_params, limit=top_k, output_fields=["source", "article", "comma", "comma_content"])
    
    # Format the results
    formatted_results = []
    for result in results[0]:
        formatted_results.append({
            "score": result.score,
            "source": result.entity.get("source"),
            "article": result.entity.get("article"),
            "comma": result.entity.get("comma"),
            "comma_content": result.entity.get("comma_content")
        })
    
    return formatted_results

def generate_response(prompt, model, tokenizer, device):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    # Generate the response
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_length=150,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
            temperature=0.7,
            top_p=0.9
        )
    
    # Decode the generated tokens
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

connect_to_milvus()
drop_everything() # !!! WARNING !!!
laws_collection = create_collection()

model, tokenizer = load_model("meta-llama/Meta-Llama-3.1-8B-Instruct")#("BAAI/bge-m3")
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

dataWithEmbeddings = load_data_and_generate_embeddings(pd.read_csv(LAWS_CSV), model, tokenizer)
insert_data_into_milvus(laws_collection, dataWithEmbeddings)

while True:
    user_prompt = "Citami un articolo"#input("Insert a prompt: ")
    
    search_results = search_similar_text(laws_collection, user_prompt, tokenizer, model)
    print(search_results)

    # Combine retrieved documents into a single context
    context = ";".join([result["comma_content"] for result in search_results])
    system_prompt = "You are an expert in the field of law, and you are gonna replay to the following quiz. You have to choose the correct answer among the three options. These are some articles that could help you: " + context

    # Generate a final response using LLaMA 3 (optional, based on your needs)
    response_input = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
    response = generate_response(response_input, model, tokenizer, device)

    print(response)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.96 GiB. GPU 3 has a total capacity of 23.68 GiB of which 1.09 GiB is free. Process 2438087 has 22.32 GiB memory in use. Including non-PyTorch memory, this process has 254.00 MiB memory in use. Of the allocated memory 0 bytes is allocated by PyTorch, and 0 bytes is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
print(search_results)