In [8]:
import torch
import os
import logging
logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)

os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3'
os.environ['HF_TOKEN'] = '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'
# Check if GPU is available
if torch.cuda.is_available():
    # Get the number of available GPUs
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")
    for gpu_id in range(num_gpus):
        print(f"GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}")
else:
    print("No GPU available.")

#logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
logger.setLevel( logging.DEBUG )

Number of available GPUs: 3
GPU 0: NVIDIA A40
GPU 1: NVIDIA A40
GPU 2: NVIDIA A40


In [None]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    pipeline
)

hg_legal_model = "Dhananjayg22/legal-triplet-extractor"
base_model = "google/gemma-7b-it"

tokenizer = AutoTokenizer.from_pretrained(base_model)
tokenizer.padding_side = "right"
model = AutoModelForCausalLM.from_pretrained(
                                                hg_legal_model,
                                                #quantization_config=bnb_config,
                                                device_map="auto",
                                                attn_implementation="flash_attention_2",
                                                torch_dtype=torch.bfloat16,
                                                )

pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=3000,eos_token_id=107)

In [1]:
from utils.genral import get_precedients,get_precedients_triplets,get_precedients_anon,get_metadata
from utils.langchain import parser
import tqdm
import pickle

id2precedents = get_precedients()
id2precedents_triplets = get_precedients_triplets()
# id2precedents_anon = get_precedients_anon()
# metadata = get_metadata()

with open('small_precedents_ids.pkl', 'rb') as f:
    small_precedents_ids = pickle.load(f)
print("Number of small precedents: ", len(small_precedents_ids))


Number of small precedents:  300


In [3]:
output_folder = "dpo-triplets-annon/all"
import json
for id in tqdm.tqdm( small_precedents_ids ):
    #BEAUTIFY OUTPUT
    formatted_json_string = json.dumps(json.loads(id2precedents_triplets[id]), default=lambda o: o.__dict__ , indent=2)  
    
    with open(f"{ output_folder }/{ id }.json", "w") as file:
        file.write(formatted_json_string)

100%|██████████| 300/300 [00:00<00:00, 2560.09it/s]


In [None]:
# non_annon_doc = id2precedents[small_precedents_ids[100]]
# annon_doc = id2precedents_anon[small_precedents_ids[100]]
# print(non_annon_doc)
# print(annon_doc)


GET TRIPLETS FROM NON ANNON PRECEDENTS

In [None]:
from utils.langchain import langchain_chat_template,parser
from utils.genral import clean_string
import json
import traceback
import tqdm

output_folder = "dpo-triplets/all"
parse_error_folder = "dpo-triplets/parse_failed"

for id in tqdm.tqdm( small_precedents_ids ):

    try:
        #IF PRESENT SKIP
        if os.path.exists(f"{ output_folder }/{ id }.json"):
            logger.info(f"Skipping doc_id: {id}")
            continue

        if os.path.exists(f"{ parse_error_folder }/{ id }.json"):
            logger.info(f"Skipping doc_id: {id} Present in parse folder")
            continue

        logger.info(f"Processing doc_id: {id}")
        #GENRATE INPUT PROMPT   
        document = id2precedents[id]
        #document = tokenizer.decode( tokenizer.encode( document )[1:7000] )

        chat_template = langchain_chat_template.format_messages(document=document)
        input_prompt = clean_string( chat_template[0].content + chat_template[1].content ) 
        chat=[
            {"role": "user", "content": input_prompt }
        ]
        prompt = tokenizer.apply_chat_template(chat, tokenize=False,add_generation_prompt=True)
        logger.info(f"Promt token length: { len( tokenizer.encode( prompt )) }")

        # if len( tokenizer.encode( prompt )) > 8000:
        #     logger.error(f"Skipping doc_id: {id} due to length")
        #     continue

        out = pipe(prompt, return_full_text=False)

        #PARSE OUTPUT
        try:
            response = parser.parse( out[0]["generated_text"] )
        except Exception as e:
            logger.error(f"Error parsing doc_id: {id}")

            with open(f"{ parse_error_folder }/{ id }.json", "w") as file:
                file.write(out[0]["generated_text"])
            continue
        
        #BEAUTIFY OUTPUT
        formatted_json_string = json.dumps(response, default=lambda o: o.__dict__ , indent=2)  
        
        with open(f"{ output_folder }/{ id }.json", "w") as file:
            file.write(formatted_json_string)


    except Exception as e:
            logger.error(f"Error processing doc_id: {id}")
            logger.error( clean_string(traceback.format_exc())[:1000] )

FIX PARSE FAILED

In [7]:
import os
from utils.langchain import langchain_chat_template,parser

folder_path = "dpo-triplets/parse_failed"

# Function to read text from a file
def read_file(file_path):
    with open(file_path, 'r') as file:
        return file.read()

# List all files in the folder
files = os.listdir(folder_path)

# Iterate through each file and read its content
for file_name in files:
    file_path = os.path.join(folder_path, file_name)
    if os.path.isfile(file_path):  # Check if it's a file
        text = read_file(file_path)
        print(f"File: {file_name}")
        parser.parse(text)


File: 168907050.json
File: 1900540.json
File: 4148349.json
File: 37254300.json
File: 97604604.json
File: 1351401.json
