In [1]:
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Initialize the model and tokenizer
model_name = "meta-llama/Llama-3.1-8B-Instruct"

# Load the model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="cuda",
    use_cache=None,
    attn_implementation=None,
)
model.device

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:
# Function to generate model responses
def generate_output(prompt, max_tokens=1000):
    batch = tokenizer(prompt, return_tensors="pt")
    batch = {k: v.to("cuda") for k, v in batch.items()}
    with torch.no_grad():
        outputs = model.generate(
            **batch,
            max_new_tokens=max_tokens,
            do_sample=False,
            top_p=1.0,
            temperature=0,
            use_cache=True,
            top_k=50,
            repetition_penalty=1.2,
            length_penalty=1,
            output_hidden_states=True,
            return_dict_in_generate=True,
        )
    output_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
    return output_text[len(prompt):].strip()

# Load input CSV
input_file = "legal_reasoning_30.csv"  # Path to input file
output_file = "/home/rmadhav6/legal_reasoning_with_answers.csv"  # Path to save the output
try:
    df = pd.read_csv(input_file, encoding='utf-8')
except UnicodeDecodeError:
    df = pd.read_csv(input_file, encoding='ISO-8859-1')

# Initialize empty lists to store outputs
output_text1_list = []
output_text2_list = []
output_text3_list = []

# Process each row
for index, row in df.iterrows():
    try:
        legal_context = row['Context']  # Assuming these are the column names
        question = row['Question']
        options = row['Options']
        analysis = row['Analysis']
        
        # Generate first output
        user_prompt1 = legal_context + question + options + analysis + '''Task:
        You are a helpful legal assistant. Choose the correct option by performing legal reasoning while strictly adhering to 
        the legal context and analysis provided.
        While answering make sure to use the following format:
        [explanation of your legal reasoning step by step as numbered points]'''
        output_text1 = generate_output(user_prompt1)
        output_text1_list.append(output_text1)

        # Generate second output
        user_prompt2 = legal_context + question + options + analysis + output_text1 + '''Task:
        You are a helpful legal assistant. 
        You need to identify critical terms from the legal reasoning steps based on the legal context, question, options and analysis.
        You need to generate verification questions for each of the legal reasoning steps and these questions must focus on the critical terms identified.'''
        output_text2 = generate_output(user_prompt2)
        output_text2_list.append(output_text2)

        # Generate third output
        #user_prompt3 = legal_context + question + options + output_text1 + output_text2 + ''' Task:
        #You are a helpful legal assistant. You need to identify the correct answer to the question.
        #Analyze the verification questions and their answers for each of the legal reasoning steps and based on that modify 
        #the legal reasoning and give the correct option to the question.'''
        user_prompt3 = legal_context + question + options + analysis + output_text1 + output_text2 + '''Task:
        You are a helpful legal assistant. You need to identify the correct answer to the question.
        Based on the answers to the verification questions, label each of the legal reasoning steps as correct or wrong.
        Then based on the evaluation of the legal reasoning steps, perform legal reasoning again adhering to the legal context 
        and analysis and find the correct option to the question.'''
        output_text3 = generate_output(user_prompt3, max_tokens=1500)
        output_text3_list.append(output_text3)
    
    except Exception as e:
        # Handle errors gracefully and log them
        output_text1_list.append(f"Error: {str(e)}")
        output_text2_list.append(f"Error: {str(e)}")
        output_text3_list.append(f"Error: {str(e)}")

# Add the outputs to the DataFrame
df['output_text1'] = output_text1_list
df['output_text2'] = output_text2_list
df['output_text3'] = output_text3_list

# Save the updated DataFrame to a CSV file
df.to_csv(output_file, index=False, encoding='utf-8')

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_