In [None]:
import datasets
# from datasets import DatasetDict
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import SVC

# Load dataset
dataset = datasets.load_dataset("coastalcph/lex_glue", "scotus")

# Tokenizer setup
tokenizer_legalbert = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
tokenizer_roberta = AutoTokenizer.from_pretrained("roberta-base")

class TextClassifier(nn.Module):
    def __init__(self, model_name, num_labels):
        super(TextClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_labels)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return self.fc(outputs.pooler_output)

model_legalbert = TextClassifier("nlpaueb/legal-bert-base-uncased", num_labels=13)
model_roberta = TextClassifier("roberta-base", num_labels=13)
metrics_data = []

def train_model(model, tokenizer, dataset, model_name, epochs=10, batch_size=8, lr=2e-5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    train_encodings = tokenizer(dataset["train"]["text"], truncation=True, padding=True, return_tensors="pt")
    val_encodings = tokenizer(dataset["validation"]["text"], truncation=True, padding=True, return_tensors="pt")
    
    train_labels = torch.tensor(dataset["train"]["label"])
    val_labels = torch.tensor(dataset["validation"]["label"])
    
    train_dataset = torch.utils.data.TensorDataset(train_encodings["input_ids"], train_encodings["attention_mask"], train_labels)
    val_dataset = torch.utils.data.TensorDataset(val_encodings["input_ids"], val_encodings["attention_mask"], val_labels)
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        all_preds, all_labels = [], []
        
        for batch in train_loader:
            input_ids, attention_mask, labels = [b.to(device) for b in batch]
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        model.eval()
        with torch.no_grad():
            for batch in val_loader:
                input_ids, attention_mask, labels = [b.to(device) for b in batch]
                outputs = model(input_ids, attention_mask)
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        accuracy = accuracy_score(all_labels, all_preds)
        precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
        metrics_data.append({"Model": model_name, "Epoch": epoch+1, "Accuracy": accuracy, "Precision": precision, "Recall": recall, "F1": f1})
        print(f"[{model_name}] Epoch {epoch+1}: Loss = {total_loss:.4f}, Accuracy = {accuracy:.4f}")

train_model(model_legalbert, tokenizer_legalbert, dataset, "legal-bert-base-uncased")
train_model(model_roberta, tokenizer_roberta, dataset, "roberta-base")

def train_tfidf_svm(dataset):
    vectorizer = TfidfVectorizer(max_features=5000)
    X_train_tfidf = vectorizer.fit_transform(dataset["train"]["text"])
    X_val_tfidf = vectorizer.transform(dataset["validation"]["text"])
    svm_model = SVC(kernel="linear", probability=True, random_state=42)
    svm_model.fit(X_train_tfidf, dataset["train"]["label"])
    y_pred = svm_model.predict(X_val_tfidf)
    accuracy = accuracy_score(dataset["validation"]["label"], y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(dataset["validation"]["label"], y_pred, average='weighted')
    metrics_data.append({"Model": "TFIDF+SVM", "Epoch": 10, "Accuracy": accuracy, "Precision": precision, "Recall": recall, "F1": f1})
    print(f"[TFIDF+SVM] Accuracy = {accuracy:.4f}")
    return svm_model, vectorizer

train_tfidf_svm(dataset)
metrics_df = pd.DataFrame(metrics_data)
print(metrics_df)
plt.figure(figsize=(10, 5))
sns.lineplot(data=metrics_df, x="Epoch", y="Accuracy", hue="Model", marker="o")
plt.title("Accuracy per Epoch")
plt.show()

2025-04-08 10:30:31.990010: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[legal-bert-base-uncased] Epoch 1: Loss = 770.8020, Accuracy = 0.7086


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[legal-bert-base-uncased] Epoch 2: Loss = 409.8816, Accuracy = 0.7536


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[legal-bert-base-uncased] Epoch 3: Loss = 265.8586, Accuracy = 0.7757


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[legal-bert-base-uncased] Epoch 4: Loss = 164.1363, Accuracy = 0.7593


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[legal-bert-base-uncased] Epoch 5: Loss = 98.3053, Accuracy = 0.7764


KeyboardInterrupt: 