In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# 1.Dataset Loading  

In [1]:
import json
import os
import glob

DATA_DIR = "/kaggle/input/legal-dataset/all-data"

files = glob.glob(os.path.join(DATA_DIR, "*.json"))
print("Total files:", len(files))
print("Example file:", files[0])

with open(files[0], "r", encoding="utf-8") as f:
    sample = json.load(f)

print("\nKeys in JSON file:")
print(sample.keys())

for k, v in sample.items():
    if isinstance(v, str):
        print(f"\n--- {k} ---")
        print(v[:800])  
    else:
        print(f"\n--- {k} (non-text) ---")
        print(v)


Total files: 12947
Example file: /kaggle/input/legal-dataset/all-data/001-100500_6-1.json

Keys in JSON file:
dict_keys(['case_id', 'article', 'judgment', 'parties', 'all_arguments', 'input_arguments', 'facts_section', 'law_section'])

--- case_id ---
001-100500

--- article ---
6-1

--- judgment ---
violation

--- parties (non-text) ---
['MONTENEGRO', 'GARZICIC']

--- all_arguments (non-text) ---
[{'paragraph_id': '42071bd7-658a-404b-94cc-2b609dbb2adc', 'start': 0, 'end': 204, 'text': "The applicant complained under Article 6 § 1 of the Convention that her right of access to court had been violated by the Supreme Court's refusal to consider her appeal on points of law on its merits.", 'agent': 'Applicant', 'paragraph_num': 22}, {'paragraph_id': '3ba197ad-0614-4087-919d-a637bd805735', 'start': 0, 'end': 31, 'text': 'Article 6 reads as follows:', 'agent': 'Non-Argument', 'paragraph_num': 23}, {'paragraph_id': '780360c0-a360-430d-a153-5c17af58d43b', 'start': 0, 'end': 137, 'text': '“In t

# 2.Extract Labels + Count Class Distribution 

In [2]:
import json
import os
import glob
from collections import Counter

DATA_DIR = "/kaggle/input/legal-dataset/all-data"

files = glob.glob(os.path.join(DATA_DIR, "*.json"))
print("Total files:", len(files))

labels = []

for fpath in files:
    with open(fpath, "r", encoding="utf-8") as f:
        data = json.load(f)
        label = data.get("judgment", "").strip().lower()
        labels.append(label)

print("Unique labels:", set(labels))
print("Counts:", Counter(labels))


Total files: 12947
Unique labels: {'violation', 'no-violation'}
Counts: Counter({'violation': 10646, 'no-violation': 2301})


In [3]:
import json
import os
import glob
from collections import Counter

DATA_DIR = "/kaggle/input/legal-dataset/all-data"

files = glob.glob(os.path.join(DATA_DIR, "*.json"))
print("Total files:", len(files))

labels = []

for fpath in files:
    with open(fpath, "r", encoding="utf-8") as f:
        data = json.load(f)
        label = data.get("judgment", "").strip().lower()
        labels.append(label)

print("Unique labels:", set(labels))
print("Counts:", Counter(labels))


Total files: 12947
Unique labels: {'violation', 'no-violation'}
Counts: Counter({'violation': 10646, 'no-violation': 2301})


# 3.Extract TEXT FIELDS (facts + arguments + law) 

In [6]:
def extract_text(data):
    parts = []

    if "facts_section" in data:
        try:
            for item in data["facts_section"]["elements"]:
                if isinstance(item, dict) and "content" in item:
                    parts.append(item["content"])
        except:
            pass

    if "all_arguments" in data:
        for arg in data["all_arguments"]:
            t = arg.get("text", "")
            if t:
                parts.append(t)

    if "law_section" in data:
        try:
            for item in data["law_section"]["elements"]:
                if isinstance(item, dict) and "content" in item:
                    parts.append(item["content"])
        except:
            pass

    text = "\n".join(parts)
    return text


records = []

for fpath in files:
    with open(fpath, "r", encoding="utf-8") as f:
        data = json.load(f)

    label = data.get("judgment", "").strip().lower()
    text = extract_text(data)

    records.append((text, label))

df = pd.DataFrame(records, columns=["text", "label"])

df.head()

Unnamed: 0,text,label
0,I. THE CIRCUMSTANCES OF THE CASE\nII. RELEVA...,violation
1,I. THE CIRCUMSTANCES OF THE CASE\nII. RELEVA...,violation
2,THE CIRCUMSTANCES OF THE CASE\nThe domestic pr...,violation
3,10. The applicants are lawfully and permanent...,violation
4,I. THE CIRCUMSTANCES OF THE CASE\nII. RELEVA...,violation


In [7]:
df.info()


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 12947 entries, 0 to 12946
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   text    12947 non-null  object
 1   label   12947 non-null  object
dtypes: object(2)
memory usage: 202.4+ KB


# 4. Clean: Remove short texts + duplicates

In [8]:
initial_count = len(df)

df = df[df["text"].str.len() > 300]

df = df.drop_duplicates(subset=["text"])

print("Initial size:", initial_count)
print("After cleaning:", len(df))
print("Removed:", initial_count - len(df))

df.head()

Initial size: 12947
After cleaning: 11866
Removed: 1081


Unnamed: 0,text,label
0,I. THE CIRCUMSTANCES OF THE CASE\nII. RELEVA...,violation
1,I. THE CIRCUMSTANCES OF THE CASE\nII. RELEVA...,violation
2,THE CIRCUMSTANCES OF THE CASE\nThe domestic pr...,violation
3,10. The applicants are lawfully and permanent...,violation
4,I. THE CIRCUMSTANCES OF THE CASE\nII. RELEVA...,violation


# 5.Remove Leakage Phrases (Outcome-Revealing) 

In [9]:
import re

def remove_leakage(text):
    leakage_patterns = [
        r"there has accordingly been a violation",
        r"there has been a violation",
        r"no violation of article",
        r"the court finds that.*violation",
        r"the court finds.*no violation",
        r"the court finds that there.*violation",
        r"the court considers.*violation",
        r"the court awards",
        r"just satisfaction",
        r"admissible",
        r"inadmissible",
        r"manifestly ill-founded",
        r"compatible ratione",
        r"incompatible ratione",
        r"article 41",
        r"grants.*compensation",
        r"grants.*costs",
    ]

    cleaned = text.lower()
    for pattern in leakage_patterns:
        cleaned = re.sub(pattern, " ", cleaned)
    return cleaned

df["clean_text"] = df["text"].apply(remove_leakage)

print("Original sample:\n", df['text'].iloc[0][:400])
print("\nCleaned sample:\n", df['clean_text'].iloc[0][:400])


Original sample:
 I.  THE CIRCUMSTANCES OF THE CASE
II.  RELEVANT DOMESTIC LAW
The applicant complained under Article 6 § 1 of the Convention that her right of access to court had been violated by the Supreme Court's refusal to consider her appeal on points of law on its merits.
Article 6 reads as follows:
“In the determination of his/her civil rights and obligations ... everyone is entitled to a fair ... hearing .

Cleaned sample:
 i.  the circumstances of the case
ii.  relevant domestic law
the applicant complained under article 6 § 1 of the convention that her right of access to court had been violated by the supreme court's refusal to consider her appeal on points of law on its merits.
article 6 reads as follows:
“in the determination of his/her civil rights and obligations ... everyone is entitled to a fair ... hearing .


# 6. Balance Dataset

In [10]:
from sklearn.utils import resample

df_violation = df[df["label"] == "violation"]
df_no_violation = df[df["label"] == "no-violation"]

print("Before balancing:")
print("Violation:", len(df_violation))
print("No-violation:", len(df_no_violation))

TARGET_SIZE = 5000

df_violation_bal = df_violation.sample(n=TARGET_SIZE, random_state=42)

df_no_violation_bal = resample(df_no_violation,
                               replace=True,
                               n_samples=TARGET_SIZE,
                               random_state=42)

df_balanced = pd.concat([df_violation_bal, df_no_violation_bal], axis=0)
df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)

print("\nAfter balancing:")
print(df_balanced['label'].value_counts())
print("\nBalanced dataset size:", len(df_balanced))


Before balancing:
Violation: 9772
No-violation: 2094

After balancing:
label
no-violation    5000
violation       5000
Name: count, dtype: int64

Balanced dataset size: 10000


# 7.Clean and Normalize Text 

In [11]:
import re

def final_clean(text):
    text = text.lower()
    
    text = re.sub(r"\n+", " ", text)
    
    text = re.sub(r"\s+", " ", text)
    
    text = text.replace("“", "\"").replace("”", "\"").replace("’", "'")
    
    text = text.strip()
    
    return text

df_balanced["clean_text"] = df_balanced["clean_text"].apply(final_clean)

df_balanced.sample(3)[["clean_text", "label"]]


Unnamed: 0,clean_text,label
3757,i. the circumstances of the case ii. relevant ...,no-violation
6942,2. the applicants' personal details are set ou...,no-violation
6312,i. the circumstances of the case the applicant...,violation


# 8.Train / Validation / Test Split (80 / 10 / 10) 

In [12]:
from sklearn.model_selection import train_test_split

X = df_balanced["clean_text"].values
y = df_balanced["label"].values

X_train, X_temp, y_train, y_temp = train_test_split(
    X, y, test_size=0.20, random_state=42, stratify=y
)

X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.50, random_state=42, stratify=y_temp
)

print("Train size:", len(X_train))
print("Validation size:", len(X_val))
print("Test size:", len(X_test))

print("\nLabel distribution (Train):")
print(pd.Series(y_train).value_counts())

print("\nLabel distribution (Val):")
print(pd.Series(y_val).value_counts())

print("\nLabel distribution (Test):")
print(pd.Series(y_test).value_counts())


Train size: 8000
Validation size: 1000
Test size: 1000

Label distribution (Train):
no-violation    4000
violation       4000
Name: count, dtype: int64

Label distribution (Val):
violation       500
no-violation    500
Name: count, dtype: int64

Label distribution (Test):
violation       500
no-violation    500
Name: count, dtype: int64


# 9. Load LegalBERT Tokenizer

In [13]:
from transformers import AutoTokenizer

MODEL_NAME = "nlpaueb/legal-bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

print("Tokenizer loaded.")


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Tokenizer loaded.


# 10.Tokenize the Dataset 

In [14]:
import torch

max_len = 512

def tokenize_texts(text_list):
    return tokenizer(
        list(text_list),
        padding='max_length',
        truncation=True,
        max_length=max_len,
        return_tensors='pt'
    )

train_encodings = tokenize_texts(X_train)
val_encodings = tokenize_texts(X_val)
test_encodings = tokenize_texts(X_test)

label_map = {"violation": 1, "no-violation": 0}

y_train_ids = torch.tensor([label_map[y] for y in y_train])
y_val_ids = torch.tensor([label_map[y] for y in y_val])
y_test_ids = torch.tensor([label_map[y] for y in y_test])

print("Tokenization complete.")
print("Train input_ids shape:", train_encodings["input_ids"].shape)
print("Validation input_ids shape:", val_encodings["input_ids"].shape)
print("Test input_ids shape:", test_encodings["input_ids"].shape)


Tokenization complete.
Train input_ids shape: torch.Size([8000, 512])
Validation input_ids shape: torch.Size([1000, 512])
Test input_ids shape: torch.Size([1000, 512])


# 11.Create PyTorch Dataset & DataLoader 

In [15]:
from torch.utils.data import Dataset, DataLoader

class LegalDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = LegalDataset(train_encodings, y_train_ids)
val_dataset = LegalDataset(val_encodings, y_val_ids)
test_dataset = LegalDataset(test_encodings, y_test_ids)

batch_size = 4  

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

print("Datasets and DataLoaders created.")
print("Train batches:", len(train_loader))
print("Val batches:", len(val_loader))
print("Test batches:", len(test_loader))


Datasets and DataLoaders created.
Train batches: 2000
Val batches: 250
Test batches: 250


# 12. Load LegalBERT + Prepare for Training

In [17]:
import torch
from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=2
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print("Model loaded on:", device)

class_counts = torch.tensor([5000, 5000], dtype=torch.float)
class_weights = 1.0 / class_counts
class_weights = class_weights.to(device)

loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

epochs = 2
total_steps = len(train_loader) * epochs

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

print("Optimizer and scheduler ready.")


2025-11-19 06:04:35.848172: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763532276.008377      48 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763532276.053213      48 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded on: cuda
Optimizer and scheduler ready.


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

# 13.Training Loop

In [18]:
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
from tqdm.auto import tqdm

def evaluate(model, data_loader):
    model.eval()
    preds, true_labels = [], []

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids,
                            attention_mask=attention_mask)

            logits = outputs.logits
            preds.extend(torch.argmax(logits, dim=1).cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(true_labels, preds)
    f1 = f1_score(true_labels, preds, average='weighted')
    return acc, f1


for epoch in range(epochs):
    print(f"\n===== Epoch {epoch+1}/{epochs} =====")
    model.train()
    total_train_loss = 0

    for batch in tqdm(train_loader, desc="Training"):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        loss = outputs.loss
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)

    val_acc, val_f1 = evaluate(model, val_loader)

    print(f"Train Loss: {avg_train_loss:.4f}")
    print(f"Validation Accuracy: {val_acc:.4f}")
    print(f"Validation F1: {val_f1:.4f}")



===== Epoch 1/2 =====


Training:   0%|          | 0/2000 [00:00<?, ?it/s]

Train Loss: 0.5675
Validation Accuracy: 0.7370
Validation F1: 0.7328

===== Epoch 2/2 =====


Training:   0%|          | 0/2000 [00:00<?, ?it/s]

Train Loss: 0.3639
Validation Accuracy: 0.8140
Validation F1: 0.8140


# 14. Evaluate on Test Set

In [19]:
# Final test evaluation
test_acc, test_f1 = evaluate(model, test_loader)

print("===== FINAL TEST RESULTS =====")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test F1 Score: {test_f1:.4f}")



===== FINAL TEST RESULTS =====
Test Accuracy: 0.8080
Test F1 Score: 0.8080


# 15. Save Model + Tokenizer

In [20]:
save_path = "/kaggle/working/legalbert_echr_model"

model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

print("Model and tokenizer saved to:", save_path)


Model and tokenizer saved to: /kaggle/working/legalbert_echr_model


# 16. Inference Function (Predict Outcome)

In [40]:
import torch
import re

def preprocess_input_text(text):

    text = text.lower()

    leakage_patterns = [
        r"there has accordingly been a violation",
        r"there has been a violation",
        r"no violation of article",
        r"the court finds that.*violation",
        r"the court finds.*no violation",
        r"the court awards",
        r"just satisfaction",
        r"admissible",
        r"inadmissible",
        r"manifestly ill-founded",
        r"article 41",
        r"compatible ratione",
        r"incompatible ratione",
    ]
    for pattern in leakage_patterns:
        text = re.sub(pattern, " ", text)

    text = re.sub(r"\n+", " ", text)
    text = re.sub(r"\s+", " ", text)

    text = text.replace("“", '"').replace("”", '"').replace("’", "'")

    return text.strip()


def predict_case(text, true_label=None):

    clean_text = preprocess_input_text(text)

    enc = tokenizer(
        clean_text,
        return_tensors='pt',
        padding='max_length',
        truncation=True,
        max_length=512
    )

    enc = {k: v.to(device) for k, v in enc.items()}

    model.eval()
    with torch.no_grad():
        outputs = model(**enc)
        logits = outputs.logits
        probs = torch.softmax(logits, dim=1).cpu().numpy()[0]

    pred_label = "violation" if probs[1] > probs[0] else "no-violation"
    pred_prob = probs[1] if pred_label == "violation" else probs[0]

    return {
        "actual_label": true_label,
        "predicted_label": pred_label,
        "confidence": float(pred_prob)
    }


In [39]:
results = []

for i in range(20):   
    text = X_test[i]
    true_label = y_test[i]
    
    pred = predict_case(text, true_label)
    results.append(pred)

for idx, r in enumerate(results):
    print(f"\n===== EXAMPLE {idx+1} =====")
    print(f"Actual Label:    {r['actual_label']}")
    print(f"Predicted Label: {r['predicted_label']}")
    print(f"Confidence:      {r['confidence']:.4f}")



===== EXAMPLE 1 =====
Actual Label:    violation
Predicted Label: violation
Confidence:      0.9400

===== EXAMPLE 2 =====
Actual Label:    no-violation
Predicted Label: no-violation
Confidence:      0.9583

===== EXAMPLE 3 =====
Actual Label:    violation
Predicted Label: no-violation
Confidence:      0.6041

===== EXAMPLE 4 =====
Actual Label:    no-violation
Predicted Label: no-violation
Confidence:      0.8115

===== EXAMPLE 5 =====
Actual Label:    violation
Predicted Label: violation
Confidence:      0.9762

===== EXAMPLE 6 =====
Actual Label:    no-violation
Predicted Label: no-violation
Confidence:      0.9336

===== EXAMPLE 7 =====
Actual Label:    violation
Predicted Label: violation
Confidence:      0.9802

===== EXAMPLE 8 =====
Actual Label:    no-violation
Predicted Label: no-violation
Confidence:      0.9533

===== EXAMPLE 9 =====
Actual Label:    no-violation
Predicted Label: no-violation
Confidence:      0.7694

===== EXAMPLE 10 =====
Actual Label:    violation
Predict

In [41]:
import shutil

shutil.make_archive("/kaggle/working/legalbert_echr_model", 'zip', "/kaggle/working/legalbert_echr_model")


'/kaggle/working/legalbert_echr_model.zip'

# 16. Reload Saved Model

In [43]:
from transformers import AutoModelForSequenceClassification

save_path = "/kaggle/working/legalbert_echr_model"

model = AutoModelForSequenceClassification.from_pretrained(save_path)
model.to(device)

print("Model reloaded.")


Model reloaded.


# 17. Setup Optimizer with Lower Learning Rate

In [44]:
from transformers import get_linear_schedule_with_warmup
import torch

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

epochs = 2
total_steps = len(train_loader) * epochs

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)


# 18. Continue Training

In [45]:
from tqdm import tqdm
import torch.nn.functional as F

for epoch in range(1, epochs + 1):
    print(f"\n===== CONTINUING TRAINING EPOCH {epoch}/{epochs} =====")
    
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()
        scheduler.step()

    avg_loss = total_loss / len(train_loader)
    print(f"Train Loss: {avg_loss:.4f}")

    model.eval()
    correct = 0
    total = 0
    true_labels = []
    pred_labels = []

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )

            logits = outputs.logits
            preds = torch.argmax(logits, dim=1)

            total += labels.size(0)
            correct += (preds == labels).sum().item()

            true_labels.extend(labels.cpu().numpy())
            pred_labels.extend(preds.cpu().numpy())

    val_acc = correct / total
    print(f"Validation Accuracy: {val_acc:.4f}")



===== CONTINUING TRAINING EPOCH 1/2 =====


100%|██████████| 2000/2000 [13:06<00:00,  2.54it/s]


Train Loss: 0.2553
Validation Accuracy: 0.8400

===== CONTINUING TRAINING EPOCH 2/2 =====


100%|██████████| 2000/2000 [13:07<00:00,  2.54it/s]


Train Loss: 0.1406
Validation Accuracy: 0.8640


In [46]:
import os

save_path = "/kaggle/working/legalbert_echr_model_best"

model.save_pretrained(save_path)

tokenizer.save_pretrained(save_path)

print("Model saved to:", save_path)


Model saved to: /kaggle/working/legalbert_echr_model_best


In [47]:
import shutil

zip_path = "/kaggle/working/legalbert_echr_model_best"

shutil.make_archive(zip_path, 'zip', zip_path)

print("Zipped model created at:", zip_path + ".zip")


Zipped model created at: /kaggle/working/legalbert_echr_model_best.zip
