In [1]:
%%capture
!pip install transformers datasets sentencepiece sacrebleu

In [2]:
from datasets import load_dataset

dataset = load_dataset("Wasserstoff-AI/legalTransEn_Indic")


README.md:   0%|          | 0.00/2.31k [00:00<?, ?B/s]

data.jsonl:   0%|          | 0.00/27.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/89951 [00:00<?, ? examples/s]

In [3]:
from transformers import MarianTokenizer, MarianMTModel

model_name = "Helsinki-NLP/opus-mt-en-hi"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)


tokenizer_config.json:   0%|          | 0.00/44.0 [00:00<?, ?B/s]

source.spm:   0%|          | 0.00/812k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/1.07M [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.10M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.39k [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/306M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/293 [00:00<?, ?B/s]

In [4]:
def preprocess(example):
    input_text = example["english"]
    target_text = example["hindi"]
    model_inputs = tokenizer(input_text, truncation=True, padding="max_length", max_length=128)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(target_text, truncation=True, padding="max_length", max_length=128)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = dataset["train"].map(preprocess, batched=True)


model.safetensors:   0%|          | 0.00/306M [00:00<?, ?B/s]

Map:   0%|          | 0/89951 [00:00<?, ? examples/s]



In [5]:
import torch

def get_token_bias_mask(input_ids, tokenizer, boost_tokens, boost_value=2.0):
    mask = torch.zeros_like(torch.tensor(input_ids))
    for token in boost_tokens:
        token_id = tokenizer.convert_tokens_to_ids(token)
        mask[input_ids == token_id] = boost_value
    return mask


In [6]:
from transformers.models.marian.modeling_marian import MarianAttention

from transformers.models.marian.modeling_marian import MarianAttention
import torch.nn.functional as F

class LegalBiasMarianAttention(MarianAttention):
    def forward(
        self,
        hidden_states,
        key_value_states=None,
        past_key_value=None,
        attention_mask=None,
        layer_head_mask=None,
        output_attentions=False,
    ):
        # Call super() to get original attention outputs
        outputs = super().forward(
            hidden_states,
            key_value_states=key_value_states,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )

        # If bias mask exists and we're doing encoder attention (not decoder)
        if hasattr(self, "token_bias_mask") and self.token_bias_mask is not None:
            # Get attention scores from output
            attn_weights = outputs[1] if output_attentions else None
            if attn_weights is not None:
                attn_weights += self.token_bias_mask.unsqueeze(1).unsqueeze(2).to(attn_weights.device)

        return outputs



In [7]:
from transformers.models.marian.modeling_marian import MarianEncoder

boost_tokens = [
    "▁agreement", "▁party", "▁shall", "▁hereby", "▁warrant", "▁represent", "▁assign", "▁claim",
    "▁consideration", "▁contract", "▁obligation", "▁liability", "▁damages", "▁breach", "▁dispute",
    "▁indemnify", "▁settlement", "▁covenant", "▁deed", "▁title", "▁trust", "▁estate", "▁successor",
    "▁interest", "▁license", "▁guarantee", "▁default", "▁notice", "▁consent", "▁waiver", "▁remedy",
    "▁jurisdiction", "▁venue", "▁hearing", "▁petition", "▁motion", "▁affidavit", "▁pleading",
    "▁injunction", "▁equity", "▁tort", "▁liens", "▁statute", "▁clause", "▁provision", "▁amendment",
    "▁termination", "▁assignment", "▁disclosure", "▁compliance", "▁confidentiality", "▁fiduciary",
    "▁bonds", "▁securities", "▁creditor", "▁debtor", "▁trustee", "▁beneficiary", "▁arbitration",
    "▁mediation", "▁enforcement", "▁force", "▁majeure", "▁prejudice", "▁release", "▁forfeit",
    "▁damages", "▁fine", "▁penalty", "▁appeal", "▁counterclaim", "▁testimony", "▁evidence",
    "▁discovery", "▁negligence", "▁representation", "▁warranty", "▁equitable", "▁specific",
    "▁performance", "▁material", "▁breach", "▁rescission", "▁novation", "▁fraud", "▁duress",
    "▁estoppel", "▁lien", "▁mortgage", "▁lease", "▁tenant", "▁landlord", "▁possession", "▁ownership",
    "▁severability", "▁discretion", "▁injury", "▁accounting", "▁entitlement", "▁grievance",
    "▁procurement", "▁tender", "▁writ"
]

for layer in model.model.encoder.layers:
    layer.self_attn = LegalBiasMarianAttention(
        embed_dim=layer.self_attn.embed_dim,
        num_heads=layer.self_attn.num_heads,
        dropout=layer.self_attn.dropout,
        is_decoder=False,
    )


In [8]:
from transformers import Trainer, TrainingArguments
from torch.utils.data import DataLoader

class LegalBiasTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # Inject attention bias
        input_ids = inputs["input_ids"]
        attention_bias = get_token_bias_mask(
            input_ids=input_ids,
            tokenizer=tokenizer,
            boost_tokens=boost_tokens,
            boost_value=2.0
        )
        for layer in model.model.encoder.layers:
            layer.self_attn.token_bias_mask = attention_bias

        return super().compute_loss(model, inputs, return_outputs)



In [9]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./marian-legal-hi",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    save_steps=1000,
    save_total_limit=3,
    logging_dir="./logs",
    logging_steps=50,
    report_to="none",
)



In [10]:
trainer = LegalBiasTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
)

trainer.train()


  trainer = LegalBiasTrainer(
  mask = torch.zeros_like(torch.tensor(input_ids))


Step,Training Loss
50,1.4363
100,0.7906
150,0.715
200,0.695
250,0.7058
300,0.6709
350,0.6916
400,0.6664
450,0.6388
500,0.661


  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))
  mask = torch.zeros_like(torch.tensor(input_ids))


KeyboardInterrupt: 

In [None]:
model.save_pretrained("./marian-legal-hi-final")
tokenizer.save_pretrained("./marian-legal-hi-final")


The parties hereby agree that any breach of this agreement shall result in immediate termination of the contract.

# incorrect
पक्ष सहमत हैं कि किसी भी उल्लंघन पर अनुबंध तुरंत समाप्त हो जाएगा।
पक्ष सहमत हैं कि कोई भी गलती अनुबंध समाप्त कर देगी।

# correct
पक्ष इस समझौते के किसी भी उल्लंघन पर अनुबंध के तात्कालिक समाप्ति के लिए सहमत हैं।