In [1]:
import os
import json
import gc

import torch 
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, AutoConfig , TrainingArguments, AutoModelForCausalLM
import trl
from trl import apply_chat_template

from outlines import models, generate
from typing import List, Dict
from pydantic import BaseModel
from tqdm import tqdm_notebook
from datasets import Dataset


2024-10-19 19:40:16.658629: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-19 19:40:16.658698: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-19 19:40:16.659996: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-19 19:40:16.668137: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [8]:
TRAIN_DATASET_PATH = '../dataset/train.json'
TEST_DATASET_PATH = '../dataset/test.json'
DEV_DATASET_PATH = '../dataset/dev.json'

system_prompt = '''
You are a legal agent. You are provided with statements from a contract and a hyothesis.
Your task is to evaluate wheather the hyothesis statement is:
1. Entailment: The hypothesis is directly supported by one or more of the statements from contract i.e 
if the statement is true then the hypothesis is true. 
The implication should be unabiguious without assumptions.
2. Contradiction: The hypothesis is contradiction of one of more of the statements from contract i.e 
if the statement is true then hypothesis is false.
This implication should be unabiguios without assumptions.
3. Not Mentioned: In case none of the above conculusions can be drawn without assumption.

Contract:
{0}

Evalution Method:
1. First repharse each of the statement in simple language.
2. Repharse the hypothesis also in simple language
3. Select repharsed statements that implies or denies the repharsed hypothesis. Dont make assumptions.
4. Choose Not Mentioned if no such statement are present.

Input:
Hypothesis:
<hypothesis_statement>

Output: (should be strictly as follows and should be properly formatted as json only):
{{
    "choice": "<one of Entailement, Contradiction, NotMentioned>"
    "spans": ["<only mention upto 3 most relevant statement ids, in same format as mentioned above that directly supports the choice in itself without support of any other statement>"]
    "reason": "<step by step reasoning, first mention the statement selected as is by statement id, then elaborate above evalution method for given hypthesis>"

}}
Give response as json

'''

ft_system_prompt = '''
You are a legal agent. You are provided with statements from a contract and a hypothesis. Your task is to evaluate whether the hypothesis statement is:
1. Entailment: The hypothesis is directly supported by one or more of the statements from the contract, meaning if the statement is true, then the hypothesis is true. The implication should be unambiguous and without assumptions.
2. Contradiction: The hypothesis contradicts one or more of the statements from the contract, meaning if the statement is true, then the hypothesis is false. This implication should be unambiguous and without assumptions.
3. Not Mentioned: In cases where none of the above conclusions can be drawn without assumptions.

Contract:
{0}

Input:
Evaluate following hypothesis:
<hypothesis_statement>

Output: (should be strictly formatted as JSON only):
```json
{{
    "choice": "<one of Entailment, Contradiction, NotMentioned>",
    "spans": ["<mention up to 3 most relevant statement ids that directly support the choice without the support of any other statement>"],
}}
```
'''

chat_template = '''
<|system|>
{0}
<|user|>
{1}<|end|>
<|assistant|>
'''

In [13]:
def load_dataset(path):
    with open(path,'r') as f:
        dataset =  json.load(f)
    return dataset

def get_embedding(model, tokenizer, text: str, aggregation_method: str = None):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    last_hidden_state = outputs.hidden_states[-1]

    if aggregation_method == 'mean':
        # Mean pool the token embeddings to get a sentence embedding
        embedding = last_hidden_state.mean(dim=1)
    elif aggregation_method == 'last_token':
        # Get the hidden state of the last token
        embedding = last_hidden_state[:, -1, :]  # Shape: [batch_size, hidden_size]
    else:
        # If no valid aggregation method is specified, return the last hidden state
        embedding = last_hidden_state

    # Normalize the embedding (if applicable)
    if aggregation_method in ['mean', 'last_token']:
        embedding = F.normalize(embedding, p=2, dim=1)

    return embedding

def construct_statement_annotation_pair(dataset_input):
    statement_annotation_pair = {}
    for i,document in enumerate(dataset_input):
        #print(document['id'])
        statement_annotation_pair[i] = {}
        statement_annotation_pair[i]['statement'] = ''
        for idx,span in enumerate(document['spans']):
            #print(span)
            statement_annotation_pair[i]['statement'] += f'''[[statement_id_{idx}]]: {document['text'][span[0]:span[1]]}\n'''
            statement_annotation_pair[i]['target'] =  document['annotation_sets'][0]
    return statement_annotation_pair


def prepare_dataset_for_finetuning_method1(train_contract_statements_and_annotations,hypothesis_set,tokeniser):
    '''
    {
        "prompt": [{"role": "user", "content": "What color is the sky?"}],
        "completion": [{"role": "assistant", "content": "It is blue."}]
    }
    '''
    dataset = []
    for idx,contract in tqdm_notebook(train_contract_statements_and_annotations.items()):
        current_system_prompt = ft_system_prompt.format(contract['statement'])
        for key,value in hypothesis_set.items():
            current_user_prompt = f"Evaluate following hypothesis: \n{value['hypothesis']}"
            data_point = {
                "prompt": [
                    {
                        "role": "system",
                        "content": current_system_prompt
                    },
                    {
                        "role": "user",
                        "content": current_user_prompt
                    }
                ],
                "completion": [
                    {
                        "role": "assistant",
                        "content": f"```json\n{contract['target']['annotations'][key]}\n```"
                    }
                ]
            }
            dataset.append(apply_chat_template(data_point,tokeniser))
    return dataset

'''
With All Statements
'''
def generate_message_target_pair(idx,train_contract_statements_and_annotations,nda,model=None,tokeniser=None):
    sim_scores = {}
    example_data = train_contract_statements_and_annotations[idx]
    current_system_prompt = system_prompt.format(example_data['statement'])
    hypothesis_embed = get_embedding(model,tokeniser,hypothesis_set[nda]['hypothesis'],'mean')
    if model is not None and tokeniser is not None:
        for idx,statement in enumerate(example_data['statement'].split('\n')):
            if len(statement) > 0:
                statement_embed = get_embedding(model,tokeniser,statement,'mean')
                sim_scores[statement.split(']]')[0].split('[[')[-1]] = ((statement_embed@hypothesis_embed.T).item())

    current_user_prompt = hypothesis_set[nda]['hypothesis']
    chat_template_current = chat_template.format(current_system_prompt,current_user_prompt)
    messages = [ 
        {"role": "system", "content": current_system_prompt}, 
        {"role": "user", "content": current_user_prompt},
    ]
    target = example_data['target']['annotations'][nda]

    # Get the top 5 items sorted by value
    top_5_statements = sorted(sim_scores.items(), key=lambda item: item[1], reverse=True)[:5]


    return messages,target,sim_scores,top_5_statements

def generate_message_target_pair_with_top_5_staements(idx, train_contract_statements_and_annotations, nda, model=None, tokeniser=None):
    sim_scores = {}
    example_data = train_contract_statements_and_annotations[idx]
    hypothesis_embed = get_embedding(model, tokeniser, hypothesis_set[nda]['hypothesis'], 'mean')

    if model is not None and tokeniser is not None:
        for idx, statement in enumerate(example_data['statement'].split('\n')):
            if len(statement) > 0:
                statement_embed = get_embedding(model, tokeniser, statement, 'mean')
                # Store similarity scores
                statement_id = statement.split(']]')[0].split('[[')[-1]
                sim_scores[int(statement_id.split('statement_id_')[-1])] = (statement_embed @ hypothesis_embed.T).item()

    current_user_prompt = hypothesis_set[nda]['hypothesis']
    target = example_data['target']['annotations'][nda]

    # Get the top 5 items sorted by value
    top_5_statements = sorted(sim_scores.items(), key=lambda item: item[1], reverse=True)[:5]

    # Initialize system prompt components
    formatted_statements = []
    ids_to_format = set()  # To store unique IDs for formatting

    for statement_id, _ in top_5_statements:
        # Assume that statement_id is an integer or can be converted to an integer
        statement_id = int(statement_id)
        ids_to_format.update({statement_id-1, statement_id, statement_id + 1})
    ids_to_format = list(map(lambda x: f'statement_id_{x}', ids_to_format))
    print(ids_to_format)

    # Format statements only for the collected IDs
    for statement in example_data['statement'].split('\n'):
        # Assuming statement has a format where the ID is part of it
        statement_id =  statement.split(']]')[0].split('[[')[-1]  # Extract the ID
        if statement_id in ids_to_format:
            formatted_statements.append(statement)

    # Construct the system prompt using the filtered statements
    current_system_prompt = system_prompt.format('\n'.join(formatted_statements))

    # Construct the chat messages
    messages = [
        {"role": "system", "content": current_system_prompt},
        {"role": "user", "content": current_user_prompt},
    ]

    return messages, target, sim_scores, top_5_statements


In [10]:
train_dataset = load_dataset(TRAIN_DATASET_PATH) 
hypothesis_set = train_dataset['labels']
dev_dataset = load_dataset(DEV_DATASET_PATH)

In [11]:
train_contract_statements_and_annotations = construct_statement_annotation_pair(train_dataset['documents'])

**Direct Inferencing**

In [6]:
# class Annotation(BaseModel):
#     choice: str
#     spans: List[int]
#     reason: str

# class AnnotationsModel(BaseModel):
#     annotations: Dict[str, Annotation]

# model = models.transformers("microsoft/Phi-3-mini-4k-instruct",device='auto')

In [None]:
len(output)

: 

: 

: 

In [77]:

gc.collect()  # Run garbage collection to release unreferenced memory
torch.cuda.empty_cache()  # Clears cache


In [101]:

model_id = "microsoft/Phi-3-mini-128k-instruct"
model = AutoModelForCausalLM.from_pretrained( 
    model_id,  
    device_map="auto", 
    trust_remote_code=True,  
) 

tokenizer = AutoTokenizer.from_pretrained(model_id) 

pipe = pipeline( 
    "text-generation", 
    model=model, 
    tokenizer=tokenizer, 
) 

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [177]:
nda = 'nda-1'
messages, target,scores,top_5_statements = generate_message_target_pair(34,train_contract_statements_and_annotations,nda,model,tokenizer)

{'short_description': 'Explicit identification', 'hypothesis': 'All Confidential Information shall be expressly identified by the Disclosing Party.'}
{'choice': 'Entailment', 'spans': [4, 5, 6]}


In [179]:
generation_args = { 
    "max_new_tokens": 500, 
    "return_full_text": False, 
    "temperature": 0.0, 
    "do_sample": False, 
} 
output = pipe(messages, **generation_args) 
print(output[0]['generated_text']) 



 {
    "choice": "Contradiction",
    "spans": ["statement_id_5"],
    "reason": "The hypothesis states that 'All Confidential Information shall be expressly identified by the Disclosing Party.' However, according to statement_id_5, 'Confidential Information' is defined as technologies disclosed to COMPANY from time to time after the Effective Date of this Agreement by ROCHESTER and identified with particularity at time of disclosure and marked confidential. This implies that not all Confidential Information needs to be expressly identified by the Disclosing Party, but only those disclosed after the Effective Date and identified at the time of disclosure. Therefore, the hypothesis is contradicted by statement_id_5."
}



**Finetuning Phi Mode**

In [15]:
model_id = "microsoft/Phi-3-mini-128k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
train_dataset_dict = prepare_dataset_for_finetuning_method1(train_contract_statements_and_annotations,hypothesis_set,tokenizer)
train_dataset = Dataset.from_list(train_dataset_dict)
# args = TrainingArguments(
#     output_dir="./test-galore",
#     max_steps=100,
#     per_device_train_batch_size=2,
#     optim="galore_adamw_8bit",
#     optim_target_modules=["attn", "mlp"],
# )

args = TrainingArguments(
    output_dir="./test-galore",
    max_steps=100,
    per_device_train_batch_size=2,  # This is per GPU
    optim="galore_adamw_8bit",
    optim_target_modules=["attn", "mlp"],
    # fp16=True,  # Enable mixed precision to optimize GPU usage
    # dataloader_num_workers=4,  # Parallel data loading
    # gradient_accumulation_steps=4,  # Accumulate gradients if batch size is too large for GPU memory
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=5,
    logging_dir="./logs",
    logging_steps=1,
    report_to="wandb",  # Optionally report to Weights and Biases
    # ddp_find_unused_parameters=False  # Required for multi-GPU
)

config = AutoConfig.from_pretrained(model_id)


model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
    model=model, 
    args=args,
    train_dataset=train_dataset,
    max_seq_length=5000,
)

trainer.train()


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for idx,contract in tqdm_notebook(train_contract_statements_and_annotations.items()):


  0%|          | 0/423 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


OutOfMemoryError: CUDA out of memory. Tried to allocate 36.00 MiB. GPU 

In [232]:
train_dataset['completion'][0]

"```json\n{'choice': 'Entailment', 'spans': [14]}\n```<|end|>\n<|endoftext|>"