In [3]:
import pandas as pd
import numpy as np
import logging

from groq import Groq
from dotenv import load_dotenv
import os


load_dotenv()
logging.basicConfig(level=logging.INFO)

In [4]:
df = pd.read_csv("./articles.csv", nrows = 15)
logging.info(msg="Dataset loaded successfully")

INFO:root:Dataset loaded successfully


In [5]:
client = Groq(api_key = os.environ.get("GROQ_API_KEY"))

SUMMARIZER_PROMPT = "Summarize the following legal content in a crisp manner with the important details kept intact. Only give the summary, without the starting  line Here is a crisp summary"

SUMMARIZER_MODEL = "meta-llama/llama-4-scout-17b-16e-instruct"

logging.info("Model initialized with the instructed prompt")

INFO:root:Model initialized with the instructed prompt


In [6]:
summaries = []
for idx in range(len(df['article_desc'])):
    completion = client.chat.completions.create(
    model=SUMMARIZER_MODEL,
    messages=[
      {
        "role": "user",
        "content": f"{SUMMARIZER_PROMPT}: {df['article_desc'][idx]}"
      }
    ],
    temperature=1,
    max_completion_tokens=1024,
    top_p=1,
    stream=True,
    stop=None,
    )


    full_summary = ""
    for chunk in completion:
        if chunk.choices[0].delta.content:
            full_summary += chunk.choices[0].delta.content

    summaries.append(full_summary)

logging.info("Summaries generated for the legal content")

INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://ap

In [7]:
df['summary'] = summaries

In [20]:
df['complete_desc'] = df['article_id'] + df['article_desc']
df.head(5)

Unnamed: 0,article_id,article_desc,summary,complete_desc
0,Article 1 of Indian Constitution,"Name and territory of the Union India, that is...","India, also known as Bharat, is a Union of Sta...",Article 1 of Indian ConstitutionName and terri...
1,Article 2 of Indian Constitution,Admission or establishment of new States: Parl...,Parliament may admit or establish new States t...,Article 2 of Indian ConstitutionAdmission or e...
2,Article 2A of Indian Constitution,Sikkim to be associated with the Union Rep by ...,Sikkim became an associated state with the Uni...,Article 2A of Indian ConstitutionSikkim to be ...
3,Article 3 of Indian Constitution,Formation of new States and alteration of area...,The Parliament can form new states or alter ex...,Article 3 of Indian ConstitutionFormation of n...
4,Article 4 of Indian Constitution,Laws made under Articles 2 and 3 to provide fo...,Laws made under Articles 2 and 3 to amend the ...,Article 4 of Indian ConstitutionLaws made unde...


In [21]:
df.drop(labels=['article_id','article_desc'], axis = 1)

Unnamed: 0,summary,complete_desc
0,"India, also known as Bharat, is a Union of Sta...",Article 1 of Indian ConstitutionName and terri...
1,Parliament may admit or establish new States t...,Article 2 of Indian ConstitutionAdmission or e...
2,Sikkim became an associated state with the Uni...,Article 2A of Indian ConstitutionSikkim to be ...
3,The Parliament can form new states or alter ex...,Article 3 of Indian ConstitutionFormation of n...
4,Laws made under Articles 2 and 3 to amend the ...,Article 4 of Indian ConstitutionLaws made unde...
5,"Every person domiciled in India, who was born ...",Article 5 of Indian ConstitutionCitizenship at...
6,A person who migrated to India from Pakistan s...,Article 6 of Indian ConstitutionRights of citi...
7,A person who migrated from India to Pakistan a...,Article 7 of Indian ConstitutionRights of citi...
8,Certain persons of Indian origin residing outs...,Article 8 of Indian ConstitutionRights of citi...
9,No person shall be a citizen of India if they ...,Article 9 of Indian ConstitutionPerson volunta...


##### Finetuning of Flant5-small

In [22]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import Dataset

In [23]:
dataset = Dataset.from_pandas(df[['complete_desc','summary']])

In [24]:
model_id = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

In [26]:
max_input_length = 512
max_target_length = 512

def preprocess(example):
    inputs = "summarize the legal document with the legal terms intact: " + example['complete_desc']
    model_inputs = tokenizer(inputs, max_length = max_input_length, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(example['summary'], max_length=max_target_length,truncation=True)
    
    model_inputs["labels"] = labels['input_ids']
    return model_inputs


tokenized_dataset = dataset.map(preprocess, batched=False)

Map: 100%|██████████| 15/15 [00:00<00:00, 360.63 examples/s]


In [27]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq

training_args = Seq2SeqTrainingArguments(
    output_dir="./flan-t5-small-legal-finetuned",
    eval_strategy="epoch",
    learning_rate=2e-5,
    fp16=True,
    num_train_epochs=20,
    weight_decay=0.01,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2
)

In [28]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [30]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset,
    processing_class=tokenizer,
    data_collator=data_collator,
)

trainer.train()



Epoch,Training Loss,Validation Loss
1,No log,2.099972
2,No log,2.014731
3,No log,1.932259
4,No log,1.871072
5,No log,1.821996
6,No log,1.772391
7,No log,1.730087
8,No log,1.694354
9,No log,1.663444
10,No log,1.633808




TrainOutput(global_step=160, training_loss=1.9871181488037108, metrics={'train_runtime': 654.227, 'train_samples_per_second': 0.459, 'train_steps_per_second': 0.245, 'total_flos': 19705845768192.0, 'train_loss': 1.9871181488037108, 'epoch': 20.0})

##### Use the fine-tuned model

In [35]:
trainer.save_model("./flan-t5-small-legal-finetuned")

input_text = "summarize: " + df['complete_desc'][13]
input_ids = tokenizer(input_text, return_tensors="pt", truncation=True).input_ids
output = model.generate(input_ids, max_length=512)
print("Generated Summary:", tokenizer.decode(output[0], skip_special_tokens=True))

Generated Summary: The State shall not make any law which takes away or abridges the rights conferred by this Part and any law made in contravention of this clause shall, to the extent of the contravention, be void.
