In [93]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from datasets import Dataset

In [94]:
df = pd.read_csv("hf://datasets/darrow-ai/LegalLensNLI-SharedTask/NLI.csv")

In [95]:
df

Unnamed: 0,premise,hypothesis,legal_act,label,Unnamed: 4
0,DEFENDANT has reached a settlement in a class ...,Had to visit DEFENDANT a while back for some r...,privacy,Neutral,
1,A class action lawsuit has been certified agai...,"So, at 22, I was into this whole ""collect-and-...",consumer_protection,Entailed,
2,"DEFENDANT, an auto parts supplier, has agreed ...",As an employee of the aforementioned auto part...,consumer_protection,Contradict,
3,"DEFENDANT has agreed to pay $400,000 to settle...","Hey, got a call from DEFENDANT a while back, s...",privacy,Contradict,
4,DEFENDANT and other health benefit companies h...,"Just checked my mail, got a letter from DEFEND...",privacy,Neutral,
...,...,...,...,...,...
307,DEFENDANT has reached a settlement in a breach...,Feeling a bit perplexed today. I've been a loy...,consumer_protection,Entailed,
308,"DEFENDANT, a seafood restaurant operator in Ca...",Had a fantastic seafood dinner at this place l...,privacy,Neutral,
309,Consumers who received promotional text messag...,Hardly ever use my phone for anything other th...,tcpa,Neutral,
310,"DEFENDANT, a restaurant point-of-sale provider...","Upon my daily visits to the local diner, I fre...",privacy,Entailed,


In [96]:
# Remove unnecessary columns
df = df[["premise", "hypothesis", "label"]]

In [97]:
df

Unnamed: 0,premise,hypothesis,label
0,DEFENDANT has reached a settlement in a class ...,Had to visit DEFENDANT a while back for some r...,Neutral
1,A class action lawsuit has been certified agai...,"So, at 22, I was into this whole ""collect-and-...",Entailed
2,"DEFENDANT, an auto parts supplier, has agreed ...",As an employee of the aforementioned auto part...,Contradict
3,"DEFENDANT has agreed to pay $400,000 to settle...","Hey, got a call from DEFENDANT a while back, s...",Contradict
4,DEFENDANT and other health benefit companies h...,"Just checked my mail, got a letter from DEFEND...",Neutral
...,...,...,...
307,DEFENDANT has reached a settlement in a breach...,Feeling a bit perplexed today. I've been a loy...,Entailed
308,"DEFENDANT, a seafood restaurant operator in Ca...",Had a fantastic seafood dinner at this place l...,Neutral
309,Consumers who received promotional text messag...,Hardly ever use my phone for anything other th...,Neutral
310,"DEFENDANT, a restaurant point-of-sale provider...","Upon my daily visits to the local diner, I fre...",Entailed


In [98]:
from sklearn.model_selection import train_test_split

train_df, eval_df = train_test_split(df, test_size=0.3, random_state=42)

In [99]:
train_df

Unnamed: 0,premise,hypothesis,label
101,A settlement has been reached in a class actio...,"Having a blast with my computer, Wi-Fi's been ...",Contradict
193,DEFENDANT Aviation Services has agreed to pay ...,Been working at DEFENDANT Aviation Services fo...,Neutral
72,DEFENDANT-A and DEFENDANT-B Inc. have agreed t...,So I've been using this DEFENDANT-A software f...,Neutral
298,"DEFENDANT, an HR company that provides timekee...",Anyone else used those UKG time clocks at work...,Entailed
15,"DEFENDANT, a manufacturing company, has agreed...","Alright guys, remember that job I had at that ...",Entailed
...,...,...,...
188,"DEFENDANT, has agreed to pay $16 million to se...","Been using DEFENDANT for quite a while now, an...",Contradict
71,DEFENDANT has settled a class action lawsuit o...,Been working at this company for a while now a...,Contradict
106,DEFENDANT and its franchisee have agreed to es...,Despite the numerous phone calls I've received...,Contradict
270,DEFENDANT has agreed to pay $7.2 million to se...,"So, I've been getting a couple of calls from D...",Neutral


In [100]:
eval_df

Unnamed: 0,premise,hypothesis,label
228,DEFENDANT has agreed to a $5.25 million settle...,As a regular visitor to a certain company's fa...,Entailed
9,The DEFENDANT Text Message Class Action Settle...,Been receiving way too many texts from DEFENDA...,Entailed
57,DEFENDANT has agreed to pay $7.5 million to se...,Stumbled upon my former employer in the news t...,Contradict
60,"DEFENDANT, a hospital in Dixon, Illinois, has ...","So, there's this hospital in Dixon I went to a...",Neutral
25,"DEFENDANT, a company that provides ambulance a...",Recently started using the handprint clock-in ...,Neutral
...,...,...,...
304,A verdict has been reached against DEFENDANT f...,"It's rather interesting, I've been using DEFEN...",Neutral
19,"DEFENDANT, a home healthcare services company,...","Hey, folks! So, I've been using this home heal...",Neutral
147,DEFENDANT has agreed to a $12.75 million settl...,Feeling quite content with my employment situa...,Contradict
92,"DEFENDANT has agreed to pay $975,000 to settle...","Hey folks, I've been getting these calls from ...",Contradict


In [101]:
label_to_id = {'Entailed': 0, 'Contradict': 1, 'Neutral': 2}
id_to_label = {0: 'Entailed', 1: 'Contradict', 2: 'Neutral'}

In [114]:
def encode_labels(example):
    example['label'] = label_to_id[example['label']]
    return example

In [103]:
train_dataset = Dataset.from_pandas(train_df)
eval_dataset = Dataset.from_pandas(eval_df)

In [104]:
model_checkpoint = 'FacebookAI/roberta-base'

In [105]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)



In [106]:
def preprocess(data):
    return tokenizer(data['premise'], data['hypothesis'], padding=True, truncation=True)

In [107]:
encoded_train_dataset = train_dataset.map(preprocess, batched=True)
encoded_eval_dataset = eval_dataset.map(preprocess, batched=True)


Map: 100%|██████████| 218/218 [00:00<00:00, 4547.35 examples/s]

Map: 100%|██████████| 94/94 [00:00<00:00, 4582.66 examples/s]


In [117]:
encoded_train_dataset = encoded_train_dataset.map(encode_labels)
encoded_eval_dataset = encoded_eval_dataset.map(encode_labels)


Map: 100%|██████████| 218/218 [00:00<00:00, 14318.62 examples/s]

Map: 100%|██████████| 94/94 [00:00<00:00, 11736.86 examples/s]


In [109]:
# Prepare the RoBERTa model
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=3, id2label = id_to_label, label2id = label_to_id)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [110]:
# Define the compute_metrics function
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='macro')
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }

In [119]:
# Define training arguments
training_args = TrainingArguments(
    output_dir=f"{model_checkpoint}_legal_nli_finetuned",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=10,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_train_dataset,
    eval_dataset=encoded_eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)



In [120]:
trainer.train()

  0%|          | 0/280 [05:38<?, ?it/s]
 10%|█         | 28/280 [00:26<04:51,  1.16s/it]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                                                
[A                                    
 10%|█         | 28/280 [00:29<04:51,  1.16s/it]
[A

{'eval_loss': 1.097902536392212, 'eval_accuracy': 0.30851063829787234, 'eval_precision': 0.10283687943262411, 'eval_recall': 0.3333333333333333, 'eval_f1': 0.15718157181571815, 'eval_runtime': 3.3455, 'eval_samples_per_second': 28.097, 'eval_steps_per_second': 3.587, 'epoch': 1.0}


 20%|██        | 56/280 [00:50<02:17,  1.63it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                                                
[A                                    
 20%|██        | 56/280 [00:53<02:17,  1.63it/s]
[A

{'eval_loss': 1.0880720615386963, 'eval_accuracy': 0.30851063829787234, 'eval_precision': 0.10283687943262411, 'eval_recall': 0.3333333333333333, 'eval_f1': 0.15718157181571815, 'eval_runtime': 2.6837, 'eval_samples_per_second': 35.026, 'eval_steps_per_second': 4.471, 'epoch': 2.0}


 30%|███       | 84/280 [01:14<02:00,  1.62it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                
[A                                    
 30%|███       | 84/280 [01:17<02:00,  1.62it/s]
[A

{'eval_loss': 0.4565306305885315, 'eval_accuracy': 0.8404255319148937, 'eval_precision': 0.8411764705882353, 'eval_recall': 0.8436607123383281, 'eval_f1': 0.8422601049311083, 'eval_runtime': 2.6678, 'eval_samples_per_second': 35.235, 'eval_steps_per_second': 4.498, 'epoch': 3.0}


 40%|████      | 112/280 [01:38<01:43,  1.62it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                 
[A                                    
 40%|████      | 112/280 [01:41<01:43,  1.62it/s]
[A

{'eval_loss': 0.41552871465682983, 'eval_accuracy': 0.8617021276595744, 'eval_precision': 0.8650840299547197, 'eval_recall': 0.8651660886824141, 'eval_f1': 0.8649963101841639, 'eval_runtime': 2.6719, 'eval_samples_per_second': 35.181, 'eval_steps_per_second': 4.491, 'epoch': 4.0}


 50%|█████     | 140/280 [02:02<01:26,  1.63it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                 
[A                                    
 50%|█████     | 140/280 [02:05<01:26,  1.63it/s]
[A

{'eval_loss': 0.551285982131958, 'eval_accuracy': 0.8085106382978723, 'eval_precision': 0.8385598141695704, 'eval_recall': 0.8129730201313006, 'eval_f1': 0.8167064914376742, 'eval_runtime': 2.6531, 'eval_samples_per_second': 35.43, 'eval_steps_per_second': 4.523, 'epoch': 5.0}


 60%|██████    | 168/280 [02:26<01:09,  1.62it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                 
[A                                    
 60%|██████    | 168/280 [02:29<01:09,  1.62it/s]
[A

{'eval_loss': 0.6107558608055115, 'eval_accuracy': 0.8723404255319149, 'eval_precision': 0.8783410138248847, 'eval_recall': 0.8749700102510415, 'eval_f1': 0.8764958211330102, 'eval_runtime': 2.6772, 'eval_samples_per_second': 35.112, 'eval_steps_per_second': 4.482, 'epoch': 6.0}


 70%|███████   | 196/280 [02:50<00:51,  1.62it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                 
[A                                    
 70%|███████   | 196/280 [02:53<00:51,  1.62it/s]
[A

{'eval_loss': 0.6971310377120972, 'eval_accuracy': 0.8297872340425532, 'eval_precision': 0.8441558441558442, 'eval_recall': 0.83237366136666, 'eval_f1': 0.8363334840946782, 'eval_runtime': 2.7379, 'eval_samples_per_second': 34.333, 'eval_steps_per_second': 4.383, 'epoch': 7.0}


 80%|████████  | 224/280 [03:14<00:34,  1.62it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                 
[A                                    
 80%|████████  | 224/280 [03:16<00:34,  1.62it/s]
[A

{'eval_loss': 0.6893373131752014, 'eval_accuracy': 0.8297872340425532, 'eval_precision': 0.8441558441558442, 'eval_recall': 0.83237366136666, 'eval_f1': 0.8363334840946782, 'eval_runtime': 2.6694, 'eval_samples_per_second': 35.214, 'eval_steps_per_second': 4.495, 'epoch': 8.0}


 90%|█████████ | 252/280 [03:38<00:17,  1.63it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                 
[A                                    
 90%|█████████ | 252/280 [03:40<00:17,  1.63it/s]
[A

{'eval_loss': 0.785468339920044, 'eval_accuracy': 0.8297872340425532, 'eval_precision': 0.8441558441558442, 'eval_recall': 0.83237366136666, 'eval_f1': 0.8363334840946782, 'eval_runtime': 2.7216, 'eval_samples_per_second': 34.539, 'eval_steps_per_second': 4.409, 'epoch': 9.0}


100%|██████████| 280/280 [04:02<00:00,  1.64it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                 
[A                                    
100%|██████████| 280/280 [04:06<00:00,  1.64it/s]
                                                 
100%|██████████| 280/280 [04:06<00:00,  1.13it/s]

{'eval_loss': 0.7709857821464539, 'eval_accuracy': 0.8297872340425532, 'eval_precision': 0.8441558441558442, 'eval_recall': 0.83237366136666, 'eval_f1': 0.8363334840946782, 'eval_runtime': 2.7644, 'eval_samples_per_second': 34.003, 'eval_steps_per_second': 4.341, 'epoch': 10.0}
{'train_runtime': 246.7565, 'train_samples_per_second': 8.835, 'train_steps_per_second': 1.135, 'train_loss': 0.35777220044817243, 'epoch': 10.0}





TrainOutput(global_step=280, training_loss=0.35777220044817243, metrics={'train_runtime': 246.7565, 'train_samples_per_second': 8.835, 'train_steps_per_second': 1.135, 'total_flos': 517572870701040.0, 'train_loss': 0.35777220044817243, 'epoch': 10.0})

In [121]:
evaluation_results = trainer.evaluate()

100%|██████████| 12/12 [00:02<00:00,  4.81it/s]


In [122]:
evaluation_results

{'eval_loss': 0.7709857821464539,
 'eval_accuracy': 0.8297872340425532,
 'eval_precision': 0.8441558441558442,
 'eval_recall': 0.83237366136666,
 'eval_f1': 0.8363334840946782,
 'eval_runtime': 2.8378,
 'eval_samples_per_second': 33.124,
 'eval_steps_per_second': 4.229,
 'epoch': 10.0}

In [123]:
from dotenv import load_dotenv
load_dotenv()

True

In [124]:
from huggingface_hub import login
import os

hf_token = os.environ["HF_TOKEN"]
login(token=hf_token)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /Users/khalidrajan/.cache/huggingface/token
Login successful


In [125]:
trainer.push_to_hub()

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]
[A

[A[A
training_args.bin: 100%|██████████| 5.24k/5.24k [00:00<00:00, 34.9kB/s]

[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
model.safetensors: 100%|██████████| 499M/499M [00:15<00:00, 32.7MB/s]
Upload 2 LFS files: 100%|██████████| 2/2 [00:15<00:00,  7.71s/it]


CommitInfo(commit_url='https://huggingface.co/khalidrajan/roberta-base_legal_nli_finetuned/commit/9b84002518dcc97e113c8373eb3808374277d7f4', commit_message='End of training', commit_description='', oid='9b84002518dcc97e113c8373eb3808374277d7f4', pr_url=None, pr_revision=None, pr_num=None)