# Section 3: Inference using Together AI
This section focuses on loading a pre-trained model and using it for inference on the provided data.
It includes loading the model and tokenizer, generating predictions, and processing the results.

In [None]:
#!pip install together

import together as tg
import json
import numpy as np
from tqdm import tqdm
import os
from datetime import datetime
import evaluate
import nltk
nltk.download('punkt_tab')

# Function to load data from a JSON file with each line containing a JSON object.
def load_json_lines(filename):
    data = []
    with open(filename, 'r') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

# Load the case data from a JSON file.
f_cases_data = load_json_lines('./data/nia_cases_to_process.json')

# Sort the cases by their judgment date in ascending order.
sorted_cases = sorted(f_cases_data, key=lambda x: datetime.strptime(x['JudgmentDate'], '%d/%m/%Y'))

# Load different prompt templates from files.
with open("./prompts/prompt_with_nocontext.txt", "r") as f:
    prompt_with_nocontext = f.read().strip()

with open("./prompts/prompt_with_context.txt", "r") as f:
    prompt_template = f.read().strip()

with open("./prompts/dspy_template.txt", "r") as f:
    dspy_template = f.read().strip()

# Function to format the prompt using a template that includes context.
def format_prompt(doc):
    return prompt_template.format(
        Case_Inputs=json.dumps(doc.get('Case_Inputs', ''), indent=1)[2:-2],
        Citation_context=json.dumps(doc.get('Citation_context', ''), indent=1)[2:-2],
        Similar_Cases_Analysis=json.dumps(doc.get('Similar_Cases_Analysis', ''), indent=1)[2:-2]
    )

# Function to format the prompt using a template from DSpy.
def dspy_format_prompt(doc):
    return dspy_template.format(
        Case_Inputs=json.dumps(doc.get('Case_Inputs', ''), indent=1)[2:-2],
        Citation_context=json.dumps(doc.get('Citation_context', ''), indent=1)[2:-2],
        Similar_Cases_Analysis=json.dumps(doc.get('Similar_Cases_Analysis', ''), indent=1)[2:-2]
    )

# Function to format the prompt without any context.
def format_prompt_nocontext(doc):
    return prompt_with_nocontext.format(
        Case_Inputs=json.dumps(doc.get('Case_Inputs', ''), indent=1)[2:-2]
    )

# Initialize the Together API client with an API key.
tg.api_key = "Enter your Together API key"
tg.api_base = "https://api.together.xyz/v1"

# Function to generate text using the Together API.
def generate_text(prompt, max_tokens=800):
    response = tg.Completion.create(
        model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",  # Baseline Model -> Will only run it with the First prompt (no added context)
        #model="mistralai/Mixtral-8x7B-Instruct-v0.1",  # Uncomment to use Mistral-8x7B
        prompt=prompt,  # The formatted prompt to generate text from
        max_tokens=max_tokens,  # Maximum number of tokens to generate
        n=1,  # Number of completions to generate
        stop=None,  # Stopping criteria for the generation
        temperature=1.0  # Controls the randomness of the generation
    )
    return response.choices[0].text.strip()  # Return the generated text

# Load existing generated analyses if the JSON file exists.
if os.path.exists('./results/generated_analyses.json'):
    with open('./results/generated_analyses.json', 'r') as f:
        existing_data = json.load(f)
else:
    existing_data = [] 

# Initialize lists to store results and references for evaluation.
results = []
references = []
generated_texts = []

# Create a lookup dictionary for existing data by case ID.
existing_data_lookup = {case['_id']: case for case in existing_data}

# Loop through the cases to generate analyses using the Together API.
for case in tqdm(f_cases_data, desc='Generating Analysis'):
    formatted_prompt = format_prompt_nocontext(case)  # First: Using the no-context prompt format initially
    #formatted_prompt = format_prompt(case)           # Second: Uncomment to use with context prompt format
    #formatted_prompt = dspy_format_prompt(case)      # Third: Created using dspy framework
    
    # Generate analysis using the Together API
    analysis = generate_text(formatted_prompt, max_tokens=800)
    
    case_id = case['_id']
    
    # Updating existing data or adding new analysis.
    if case_id in existing_data_lookup:
        # Update the existing case with the new analysis
        existing_data_lookup[case_id]['Gen_Analysis_TG_llama_NC'] = analysis  # Store in column named Gen_Analysis_TG_MST_NC, 'Gen_Analysis_TG_MST_WC', 'Gen_Analysis_TG_MST_dspy' while using Mistral with First, Second and Third prompt respectively 
    else:
        # Add the generated analysis to the new case
        case['Gen_Analysis_TG_llama_NC'] = analysis                           # Store in column named Gen_Analysis_TG_MST_NC, 'Gen_Analysis_TG_MST_WC', 'Gen_Analysis_TG_MST_dspy' while using Mistral with First, Second and Third prompt respectively 
        existing_data.append(case)  # Append new case to the list
    
    # Store reference and generated texts for evaluation.
    references.append(case['Case_Result'])
    generated_texts.append(analysis)

# Save the generated analyses to a JSON file.
with open('./results/generated_analyses.json', 'w') as f:
    json.dump(existing_data, f, indent=2)

print("Analysis Generated")

In [None]:
# Evaluation using various metrics.

# Evaluate using METEOR
meteor = evaluate.load('meteor')
meteor_results = meteor.compute(predictions=generated_texts, references=references)
print("METEOR Score:", meteor_results)

# Evaluate using BERTScore
bertscore = evaluate.load("bertscore")
bertscore_results = bertscore.compute(predictions=generated_texts, references=references, lang="en")
#print("BERTScore Results:", bertscore_results)

# Evaluate using ROUGE
rouge = evaluate.load('rouge')
rouge_results = rouge.compute(predictions=generated_texts, references=references, use_stemmer=True)
print("ROUGE Results:", rouge_results)

# Convert the BERTScore F1 scores to a numpy array and calculate the average.
scores_array = np.array(bertscore_results['f1'])
average_score = np.nanmean(scores_array)
print(f"Avg F1 Score: {average_score}")

# Calculate the average precision score from BERTScore results.
p_scores_array = np.array(bertscore_results['precision'])
p_average_score = np.nanmean(p_scores_array)
print(f"Avg Precision Score: {p_average_score}")

# Calculate the average recall score from BERTScore results.
r_scores_array = np.array(bertscore_results['recall'])
r_average_score = np.nanmean(r_scores_array)
print(f"Avg Recall Score: {r_average_score}")