In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.optim import AdamW
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
import json

class LegalContrastiveDataset(Dataset):
    def __init__(self, data_path: str, tokenizer, max_length: int = 512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_and_prepare_data(data_path)
    
    def load_and_prepare_data(self, data_path: str):
        # Load your legal clause dataset
        df = pd.read_csv(data_path)
        print(f"Dataset columns: {df.columns.tolist()}")
        print(f"Dataset shape: {df.shape}")
        
        # Your dataset has: clause_text, clause_type, totalwords, totalletters
        text_col = 'clause_text'
        category_col = 'clause_type'
        
        print(f"Using text column: '{text_col}'")
        print(f"Using category column: '{category_col}'")
        print(f"Unique clause types: {df[category_col].value_counts()}")
        
        # Create contrastive pairs
        data = []
        df_clean = df.dropna(subset=[text_col])  # Remove rows with missing text
        
        # Filter out very short clauses (less than 30 characters)
        df_clean = df_clean[df_clean['totalletters'] >= 30]
        
        print(f"After filtering: {len(df_clean)} clauses")
        
        for idx, row in df_clean.iterrows():
            clause = str(row[text_col]).strip()
            clause_type = str(row[category_col]).strip()
            
            # Create different question types based on clause type
            if clause_type.lower() == 'investments':
                questions = [
                    f"What are the investment restrictions in: {clause[:60]}...",
                    f"Explain the investment clause: {clause[:60]}...",
                ]
            elif clause_type.lower() == 'interest':
                questions = [
                    f"What does this interest clause specify: {clause[:60]}...",
                    f"Explain the interest terms: {clause[:60]}...",
                ]
            else:
                questions = [
                    f"What does this {clause_type.lower()} clause mean: {clause[:60]}...",
                    f"Explain this legal clause: {clause[:60]}...",
                ]
            
            # Add general questions
            questions.extend([
                f"Interpret this legal text: {clause[:60]}...",
                f"What are the key legal points in: {clause[:60]}..."
            ])
            
            # Use first 2 questions to avoid too much data
            for question in questions[:2]:
                data.append({
                    'anchor': question,
                    'positive': clause,
                    'negative': self.get_negative_sample(df_clean, idx, clause_type),
                    'label': 1
                })
        
        print(f"Created {len(data)} training examples")
        return data
    
    def get_negative_sample(self, df, current_idx, current_category):
        # Get a clause from different category as negative sample
        different_category = df[df['clause_type'] != current_category]
        if len(different_category) > 0:
            return str(different_category.sample(1)['clause_text'].iloc[0]).strip()
        
        # Fallback: random sampling excluding current row
        other_rows = df[df.index != current_idx]
        if len(other_rows) > 0:
            return str(other_rows.sample(1)['clause_text'].iloc[0]).strip()
        else:
            return "This is a sample negative legal clause for comparison."
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Tokenize anchor, positive, and negative
        anchor = self.tokenizer(
            item['anchor'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        positive = self.tokenizer(
            item['positive'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        negative = self.tokenizer(
            item['negative'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'anchor_input_ids': anchor['input_ids'].squeeze(),
            'anchor_attention_mask': anchor['attention_mask'].squeeze(),
            'positive_input_ids': positive['input_ids'].squeeze(),
            'positive_attention_mask': positive['attention_mask'].squeeze(),
            'negative_input_ids': negative['input_ids'].squeeze(),
            'negative_attention_mask': negative['attention_mask'].squeeze(),
        }

class LegalContrastiveModel(nn.Module):
    def __init__(self, model_name: str = "microsoft/Phi-3-mini-4k-instruct"):
        super().__init__()
        self.base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            trust_remote_code=True
        )
        self.hidden_size = self.base_model.config.hidden_size
        self.projection = nn.Linear(self.hidden_size, 256)  # Project to smaller dim
        
    def get_embeddings(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
            # Use mean pooling of last hidden states
            embeddings = outputs.last_hidden_state.mean(dim=1)
        return self.projection(embeddings)
    
    def forward(self, anchor_ids, anchor_mask, pos_ids, pos_mask, neg_ids, neg_mask):
        anchor_emb = self.get_embeddings(anchor_ids, anchor_mask)
        positive_emb = self.get_embeddings(pos_ids, pos_mask)
        negative_emb = self.get_embeddings(neg_ids, neg_mask)
        
        return anchor_emb, positive_emb, negative_emb

def contrastive_loss(anchor, positive, negative, margin=1.0, temperature=0.07):
    # Cosine similarity
    pos_sim = F.cosine_similarity(anchor, positive, dim=1)
    neg_sim = F.cosine_similarity(anchor, negative, dim=1)
    
    # InfoNCE-style loss
    pos_exp = torch.exp(pos_sim / temperature)
    neg_exp = torch.exp(neg_sim / temperature)
    
    loss = -torch.log(pos_exp / (pos_exp + neg_exp))
    return loss.mean()

def train_model(data_path: str, epochs: int = 3, batch_size: int = 4, lr: float = 5e-5):
    # Initialize tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = LegalContrastiveModel()
    
    # Prepare dataset
    dataset = LegalContrastiveDataset(data_path, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Optimizer
    optimizer = AdamW(model.projection.parameters(), lr=lr)  # Only train projection layer
    
    model.train()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    print(f"Training on {device}")
    print(f"Dataset size: {len(dataset)}")
    
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            # Move to device
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Forward pass
            anchor_emb, pos_emb, neg_emb = model(
                batch['anchor_input_ids'], batch['anchor_attention_mask'],
                batch['positive_input_ids'], batch['positive_attention_mask'],
                batch['negative_input_ids'], batch['negative_attention_mask']
            )
            
            # Calculate loss
            loss = contrastive_loss(anchor_emb, pos_emb, neg_emb)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}')
    
    # Save model
    torch.save({
        'model_state_dict': model.state_dict(),
        'tokenizer': tokenizer
    }, 'legal_contrastive_model.pth')
    
    return model, tokenizer

def inference(model, tokenizer, query: str, legal_clauses: List[str]):
    """Find most relevant legal clause for a query"""
    device = next(model.parameters()).device
    model.eval()
    
    # Encode query
    query_tokens = tokenizer(
        query,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    ).to(device)
    
    with torch.no_grad():
        query_emb = model.get_embeddings(
            query_tokens['input_ids'],
            query_tokens['attention_mask']
        )
    
    # Find best matching clause
    best_score = -1
    best_clause = ""
    
    for clause in legal_clauses:
        clause_tokens = tokenizer(
            clause,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        ).to(device)
        
        with torch.no_grad():
            clause_emb = model.get_embeddings(
                clause_tokens['input_ids'],
                clause_tokens['attention_mask']
            )
            
            score = F.cosine_similarity(query_emb, clause_emb, dim=1).item()
            
            if score > best_score:
                best_score = score
                best_clause = clause
    
    return best_clause, best_score

# Usage example
if __name__ == "__main__":
    # Train the model
    model, tokenizer = train_model("/kaggle/input/contracts-clauses-datasets/legal_docs.csv")
    
    # Example inference
    query = "What are the termination clauses in employment contracts?"
    sample_clauses = [
        "Employee may terminate employment with 30 days notice...",
        "Company reserves right to terminate for cause...",
        "Confidentiality obligations survive termination..."
    ]
    
    best_clause, score = inference(model, tokenizer, query, sample_clauses)
    print(f"Best matching clause (score: {score:.3f}): {best_clause}")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Dataset columns: ['Unnamed: 0', 'clause_text', 'clause_type', 'totalwords', 'totalletters']
Dataset shape: (21187, 5)
Using text column: 'clause_text'
Using category column: 'clause_type'
Unique clause types: clause_type
interest                  1010
base-salary               1010
ownership_of_shares       1000
payment                   1000
taxes                     1000
investment-company-act    1000
compensation              1000
investments               1000
capitalization             930
loans                      920
Definitions                890
Headings                   860
WHEREAS                    730
Entire                     670
Assignment                 630
Counterparts               630
Representations            610
Termination                590
Severability               580
NOW                        540
Miscellaneous              530
Insurance                  470
Indemnification            370
dividends                  360
Confidentiality            360
Gove

AttributeError: 'CausalLMOutputWithPast' object has no attribute 'last_hidden_state'