In [1]:
import torch
torch.cuda.empty_cache()

In [2]:
import os
import nltk
import torch.nn.functional as F
import numpy as np
import transformers
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
from sklearn.metrics.pairwise import cosine_similarity
import networkx as nx
from transformers import TrainingArguments,AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, Trainer, TrainingArguments, DataCollatorForLanguageModeling, GenerationConfig
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from datetime import datetime
from trl import GRPOTrainer, GRPOConfig
from sentence_transformers import SentenceTransformer




In [3]:
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

In [4]:
def get_text_files(directory):
    return [f for f in os.listdir(directory) if f.endswith('.txt')]

In [5]:
def read_files(file_list, directory):
    content = {}
    for filename in file_list:
        with open(os.path.join(directory, filename), 'r', encoding='utf-8') as file:
            content[filename] = file.read()
    return content

In [6]:
jdirectory = 'D://Legal//IND//IN-Abs//test-data//judgement_pre'
text_files = get_text_files(jdirectory)
judgement_content = read_files(text_files, jdirectory)
sdirectory = 'D://Legal//IND//IN-Abs//test-data//summary'
text_files = get_text_files(sdirectory)
summary_content = read_files(text_files, sdirectory)
X = [text for fn, text in judgement_content.items()]
y = [text for fn, text in summary_content.items()]

In [7]:
from transformers import AutoTokenizer

model_name = "./qwen"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Tokenize inputs and summaries
X_tokenized = [tokenizer(text, truncation=True, max_length=2048, return_tensors="pt") for text in X]
y_tokenized = [tokenizer(text, truncation=True, max_length=512, return_tensors="pt") for text in y]

In [8]:
def calculate_textrank(sentences):
    embeddings = embedding_model.encode(sentences)
    similarity_matrix = cosine_similarity(embeddings)
    graph = nx.from_numpy_array(similarity_matrix)
    scores = nx.pagerank(graph)
    return scores


In [9]:
def compute_reward(prompts, completions, **kwargs):
    rewards = []

    for generated_summary, input_text in zip(completions, prompts):
        gen_sentences = sent_tokenize(generated_summary)
        input_sentences = sent_tokenize(input_text)

        if not gen_sentences or not input_sentences:
            rewards.append(-1)  # Penalize empty outputs
            continue

        gen_scores = calculate_textrank(gen_sentences)
        input_scores = calculate_textrank(input_sentences)

        gen_embeddings = embedding_model.encode(gen_sentences)
        input_embeddings = embedding_model.encode(input_sentences)

        # Debug prints
        print(f"Gen Sentences: {len(gen_sentences)}, Input Sentences: {len(input_sentences)}")
        print(f"Gen Embeddings Shape: {gen_embeddings.shape}, Input Embeddings Shape: {input_embeddings.shape}")

        total_reward = 0
        for i, gen_embedding in enumerate(gen_embeddings):
            sentence_reward = 0
            for j, input_embedding in enumerate(input_embeddings):
                similarity = cosine_similarity(
                    gen_embedding.reshape(1, -1), input_embedding.reshape(1, -1)
                )[0, 0]
                sentence_reward += input_scores[j] * similarity
            total_reward += gen_scores[i] * sentence_reward

        # Length penalty
        length_penalty = -abs(len(gen_sentences) - 5) * 0.1
        total_reward += length_penalty

        rewards.append(total_reward)

    return rewards

In [10]:
X = [text for fn, text in judgement_content.items()]
y = [text for fn, text in summary_content.items()]

In [11]:
tokenizer = AutoTokenizer.from_pretrained("./qwen")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.generation_config = GenerationConfig()
data_dict = {"prompt": X, "summary": y}
dataset = Dataset.from_dict(data_dict)

In [14]:
def tokenize_function(examples):
    return tokenizer(examples["prompt"], examples["summary"], truncation=True, padding="longest", max_length=128)

In [15]:
tokenized_dataset = dataset.map(tokenize_function, batched=True)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [16]:
def format_dataset(example):
    return {"prompt": example["prompt"],
            "response": example["summary"], 
            "input_ids": example["input_ids"], 
            "attention_mask": example["attention_mask"]}

tokenized_dataset = tokenized_dataset.map(format_dataset)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [17]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [18]:
model = AutoModelForCausalLM.from_pretrained("./qwen", quantization_config=bnb_config, device_map="auto")
model.config.use_cache = False

In [19]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(

    task_type="CAUSAL_LM",

    r=8,

    lora_alpha=32,

    lora_dropout=0.1,

    target_modules=["q_proj", "v_proj"],

)

model = get_peft_model(model, lora_config)

model.print_trainable_parameters()

trainable params: 1,089,536 || all params: 1,544,803,840 || trainable%: 0.0705


In [22]:
training_args = GRPOConfig(
    output_dir="Qwen2-1.5B-GRPO",
    learning_rate=1e-5,
    remove_unused_columns=False,  # to access the solution column in accuracy_reward
    gradient_accumulation_steps=16,
    num_train_epochs=1,
    bf16=True,
    # Parameters that control de data preprocessing
    max_completion_length=64,  # default: 256
    num_generations=4,  # default: 8
    max_prompt_length=128,  # default: 512
    # Parameters related to reporting and saving
    report_to=["tensorboard"],
    logging_steps=10,
    push_to_hub=False,
    save_strategy="steps",
    save_steps=10,
)

In [23]:
from trl import GRPOTrainer

trainer = GRPOTrainer(
    model=model, reward_funcs=compute_reward, args=training_args, train_dataset=tokenized_dataset
)

In [24]:
trainer.train()

Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 5, Input Sentences: 1
Gen Embeddings Shape: (5, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddi

Step,Training Loss


Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 2, Input Sentences: 1
Gen Embeddings Shape: (2, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddings Shape: (1, 384)
Gen Sentences: 1, Input Sentences: 1
Gen Embeddings Shape: (1, 384), Input Embeddi

TrainOutput(global_step=3, training_loss=2.4723121896386147e-05, metrics={'train_runtime': 1307.5186, 'train_samples_per_second': 0.076, 'train_steps_per_second': 0.002, 'total_flos': 0.0, 'train_loss': 2.4723121896386147e-05})

In [26]:
trainer.save_model(training_args.output_dir)