In [None]:
import datasets
import re
import nltk
import torch
import torch.nn as nn
import torch.optim as optim
import requests
import os
import random
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords

# Load dataset
dataset = datasets.load_dataset("coastalcph/lex_glue", "scotus")


tokenizer_legalbert = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
tokenizer_roberta = AutoTokenizer.from_pretrained("roberta-base")

def tokenize_function(examples):
    return tokenizer_legalbert(examples["text"], truncation=True, padding=True, return_tensors="pt")

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Define model class
class TextClassifier(nn.Module):
    def __init__(self, model_name, num_labels):
        super(TextClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_labels)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return self.fc(outputs.pooler_output)

# Initialize models
model_legalbert = TextClassifier("nlpaueb/legal-bert-base-uncased", num_labels=13)
model_roberta = TextClassifier("roberta-base", num_labels=13)

# Initialize metrics storage
metrics_data = []

def train_model(model, tokenizer, dataset, model_name, epochs=3, batch_size=8, lr=2e-5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    train_texts, val_texts, train_labels, val_labels = train_test_split(
        dataset["train"]["text"], dataset["train"]["label"], test_size=0.1, random_state=42)
    
    train_encodings = tokenizer(train_texts, truncation=True, padding=True, return_tensors="pt")
    val_encodings = tokenizer(val_texts, truncation=True, padding=True, return_tensors="pt")
    
    train_dataset = torch.utils.data.TensorDataset(train_encodings["input_ids"], train_encodings["attention_mask"], torch.tensor(train_labels))
    val_dataset = torch.utils.data.TensorDataset(val_encodings["input_ids"], val_encodings["attention_mask"], torch.tensor(val_labels))
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        all_preds, all_labels = [], []

        for batch in train_loader:
            input_ids, attention_mask, labels = [b.to(device) for b in batch]
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Evaluation
        model.eval()
        with torch.no_grad():
            for batch in val_loader:
                input_ids, attention_mask, labels = [b.to(device) for b in batch]
                outputs = model(input_ids, attention_mask)
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        # Calculate metrics
        accuracy = accuracy_score(all_labels, all_preds)
        precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')

        # Store metrics
        metrics_data.append({"Model": model_name, "Epoch": epoch+1, "Accuracy": accuracy, "Precision": precision, "Recall": recall, "F1": f1})

        print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader)}, Accuracy = {accuracy:.4f}, Precision = {precision:.4f}, Recall = {recall:.4f}, F1 = {f1:.4f}")

    return model

# Train models and collect metrics
trained_legalbert = train_model(model_legalbert, tokenizer_legalbert, dataset, "legal-bert-base-uncased")
trained_roberta = train_model(model_roberta, tokenizer_roberta, dataset, "roberta-base")

# Convert metrics to DataFrame
metrics_df = pd.DataFrame(metrics_data)
print(metrics_df)

# Plot Accuracy per epoch
plt.figure(figsize=(10, 5))
sns.lineplot(data=metrics_df, x="Epoch", y="Accuracy", hue="Model", marker="o")
plt.title("Accuracy per Epoch")
plt.show()

# Plot Precision, Recall, and F1-score
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
sns.lineplot(data=metrics_df, x="Epoch", y="Precision", hue="Model", marker="o", ax=axes[0])
axes[0].set_title("Precision per Epoch")
sns.lineplot(data=metrics_df, x="Epoch", y="Recall", hue="Model", marker="o", ax=axes[1])
axes[1].set_title("Recall per Epoch")
sns.lineplot(data=metrics_df, x="Epoch", y="F1", hue="Model", marker="o", ax=axes[2])
axes[2].set_title("F1 Score per Epoch")

plt.show()
