In [None]:
import datasets
# from datasets import DatasetDict
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import SVC

# Load dataset
dataset = datasets.load_dataset("victorambrose11/normalized_scotus")

# Tokenizer setup
tokenizer_legalbert = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
tokenizer_roberta = AutoTokenizer.from_pretrained("roberta-base")

class TextClassifier(nn.Module):
    def __init__(self, model_name, num_labels):
        super(TextClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_labels)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return self.fc(outputs.pooler_output)

model_legalbert = TextClassifier("nlpaueb/legal-bert-base-uncased", num_labels=13)
model_roberta = TextClassifier("roberta-base", num_labels=13)
metrics_data = []

def tokenize_with_chunks(texts, tokenizer, max_length=512, stride=256):
    input_ids_list, attention_mask_list = [], []
    
    for text in texts:
        encoding = tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            stride=stride,
            return_overflowing_tokens=True,
            return_tensors="pt"
        )
        input_ids_list.append(encoding["input_ids"])
        attention_mask_list.append(encoding["attention_mask"])
    
    return input_ids_list, attention_mask_list

def train_model(model, tokenizer, dataset, model_name, epochs=10, batch_size=1, lr=2e-5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    # Prepare encodings with chunking
    train_input_ids, train_attention_masks = tokenize_with_chunks(dataset["train"]["text"], tokenizer)
    val_input_ids, val_attention_masks = tokenize_with_chunks(dataset["validation"]["text"], tokenizer)
    
    train_labels = torch.tensor(dataset["train"]["label"])
    val_labels = torch.tensor(dataset["validation"]["label"])
    
    train_data = list(zip(train_input_ids, train_attention_masks, train_labels))
    val_data = list(zip(val_input_ids, val_attention_masks, val_labels))

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        all_preds, all_labels = [], []

        for chunked_inputs, chunked_masks, label in train_data:
            optimizer.zero_grad()
            logits_chunks = []
            
            for i in range(chunked_inputs.size(0)):  # Loop over chunks
                input_ids = chunked_inputs[i].unsqueeze(0).to(device)
                attention_mask = chunked_masks[i].unsqueeze(0).to(device)
                logits = model(input_ids, attention_mask)
                logits_chunks.append(logits)

            # Average logits over all chunks
            avg_logits = torch.stack(logits_chunks).mean(dim=0)
            loss = criterion(avg_logits, label.unsqueeze(0).to(device))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        model.eval()
        with torch.no_grad():
            for chunked_inputs, chunked_masks, label in val_data:
                logits_chunks = []
                for i in range(chunked_inputs.size(0)):
                    input_ids = chunked_inputs[i].unsqueeze(0).to(device)
                    attention_mask = chunked_masks[i].unsqueeze(0).to(device)
                    logits = model(input_ids, attention_mask)
                    logits_chunks.append(logits)

                avg_logits = torch.stack(logits_chunks).mean(dim=0)
                pred = torch.argmax(avg_logits, dim=1)
                all_preds.append(pred.item())
                all_labels.append(label.item())

        accuracy = accuracy_score(all_labels, all_preds)
        precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
        metrics_data.append({"Model": model_name, "Epoch": epoch+1, "Accuracy": accuracy, "Precision": precision, "Recall": recall, "F1": f1})
        print(f"[{model_name}] Epoch {epoch+1}: Loss = {total_loss:.4f}, Accuracy = {accuracy:.4f}")

train_model(model_legalbert, tokenizer_legalbert, dataset, "legal-bert-base-uncased")
train_model(model_roberta, tokenizer_roberta, dataset, "roberta-base")

def train_tfidf_svm(dataset):
    vectorizer = TfidfVectorizer(max_features=5000)
    X_train_tfidf = vectorizer.fit_transform(dataset["train"]["text"])
    X_val_tfidf = vectorizer.transform(dataset["validation"]["text"])
    svm_model = SVC(kernel="linear", probability=True, random_state=42)
    svm_model.fit(X_train_tfidf, dataset["train"]["label"])
    y_pred = svm_model.predict(X_val_tfidf)
    accuracy = accuracy_score(dataset["validation"]["label"], y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(dataset["validation"]["label"], y_pred, average='weighted')
    metrics_data.append({"Model": "TFIDF+SVM", "Epoch": 10, "Accuracy": accuracy, "Precision": precision, "Recall": recall, "F1": f1})
    print(f"[TFIDF+SVM] Accuracy = {accuracy:.4f}")
    return svm_model, vectorizer

train_tfidf_svm(dataset)
metrics_df = pd.DataFrame(metrics_data)
print(metrics_df)
plt.figure(figsize=(10, 5))
sns.lineplot(data=metrics_df, x="Epoch", y="Accuracy", hue="Model", marker="o")
plt.title("Accuracy per Epoch")
plt.show()