BERT Model Implementation for Classification (using sliding window, strides 256)

In [8]:
import pandas as pd
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset

df = pd.read_csv('Input.csv')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class CaseDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        inputs = self.tokenizer(text, 
                                truncation=True, 
                                padding='max_length', 
                                max_length=self.max_length, 
                                return_tensors='pt')

        input_ids = inputs['input_ids'].squeeze()
        attention_mask = inputs['attention_mask'].squeeze()
        return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': torch.tensor(label)}

label_encoder = LabelEncoder()
df['Label'] = label_encoder.fit_transform(df['Label'])

train_texts = df['Text'].tolist()
train_labels = df['Label'].tolist()
dataset = CaseDataset(train_texts, train_labels, tokenizer)




In [2]:
from transformers import BertForSequenceClassification, Trainer, TrainingArguments
import torch

model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(label_encoder.classes_))

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
)

def custom_collate_fn(batch):
    input_ids = torch.cat([item['input_ids'] for item in batch], dim=0)
    attention_mask = torch.cat([item['attention_mask'] for item in batch], dim=0)
    labels = torch.cat([item['labels'] for item in batch], dim=0)
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=custom_collate_fn,
)

trainer.train()


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  5%|▌         | 10/192 [02:03<35:57, 11.86s/it]

{'loss': 1.4151, 'grad_norm': 9.126773834228516, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.16}


 10%|█         | 20/192 [03:59<32:26, 11.32s/it]

{'loss': 1.4614, 'grad_norm': 8.845014572143555, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.31}


 16%|█▌        | 30/192 [05:55<31:34, 11.70s/it]

{'loss': 1.3808, 'grad_norm': 13.717350006103516, 'learning_rate': 3e-06, 'epoch': 0.47}


 21%|██        | 40/192 [07:52<29:37, 11.69s/it]

{'loss': 1.3785, 'grad_norm': 6.439785003662109, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.62}


 26%|██▌       | 50/192 [09:48<27:04, 11.44s/it]

{'loss': 1.3526, 'grad_norm': 9.21692943572998, 'learning_rate': 5e-06, 'epoch': 0.78}


 31%|███▏      | 60/192 [11:44<25:28, 11.58s/it]

{'loss': 1.3157, 'grad_norm': 8.14303970336914, 'learning_rate': 6e-06, 'epoch': 0.94}


 36%|███▋      | 70/192 [13:33<23:00, 11.31s/it]

{'loss': 1.2841, 'grad_norm': 13.480112075805664, 'learning_rate': 7.000000000000001e-06, 'epoch': 1.09}


 42%|████▏     | 80/192 [15:28<21:25, 11.48s/it]

{'loss': 1.1889, 'grad_norm': 13.900252342224121, 'learning_rate': 8.000000000000001e-06, 'epoch': 1.25}


 47%|████▋     | 90/192 [17:22<19:19, 11.36s/it]

{'loss': 1.1497, 'grad_norm': 15.24687385559082, 'learning_rate': 9e-06, 'epoch': 1.41}


 52%|█████▏    | 100/192 [19:17<17:37, 11.49s/it]

{'loss': 1.174, 'grad_norm': 11.422685623168945, 'learning_rate': 1e-05, 'epoch': 1.56}


 57%|█████▋    | 110/192 [21:12<15:41, 11.49s/it]

{'loss': 1.1488, 'grad_norm': 12.940317153930664, 'learning_rate': 1.1000000000000001e-05, 'epoch': 1.72}


 62%|██████▎   | 120/192 [23:07<13:44, 11.46s/it]

{'loss': 1.0594, 'grad_norm': 17.908597946166992, 'learning_rate': 1.2e-05, 'epoch': 1.88}


 68%|██████▊   | 130/192 [24:52<10:48, 10.45s/it]

{'loss': 1.0077, 'grad_norm': 10.598592758178711, 'learning_rate': 1.3000000000000001e-05, 'epoch': 2.03}


 73%|███████▎  | 140/192 [26:48<10:01, 11.57s/it]

{'loss': 0.9339, 'grad_norm': 14.211957931518555, 'learning_rate': 1.4000000000000001e-05, 'epoch': 2.19}


 78%|███████▊  | 150/192 [28:39<07:56, 11.36s/it]

{'loss': 0.8812, 'grad_norm': 13.261638641357422, 'learning_rate': 1.5e-05, 'epoch': 2.34}


 83%|████████▎ | 160/192 [30:35<06:05, 11.41s/it]

{'loss': 0.8604, 'grad_norm': 13.637377738952637, 'learning_rate': 1.6000000000000003e-05, 'epoch': 2.5}


 89%|████████▊ | 170/192 [32:29<04:09, 11.34s/it]

{'loss': 0.7776, 'grad_norm': 11.430459976196289, 'learning_rate': 1.7000000000000003e-05, 'epoch': 2.66}


 94%|█████████▍| 180/192 [34:22<02:16, 11.34s/it]

{'loss': 0.7316, 'grad_norm': 8.220366477966309, 'learning_rate': 1.8e-05, 'epoch': 2.81}


 99%|█████████▉| 190/192 [36:17<00:22, 11.33s/it]

{'loss': 0.602, 'grad_norm': 4.435051918029785, 'learning_rate': 1.9e-05, 'epoch': 2.97}


100%|██████████| 192/192 [36:35<00:00, 11.44s/it]

{'train_runtime': 2195.5219, 'train_samples_per_second': 0.693, 'train_steps_per_second': 0.087, 'train_loss': 1.1079342539111774, 'epoch': 3.0}





TrainOutput(global_step=192, training_loss=1.1079342539111774, metrics={'train_runtime': 2195.5219, 'train_samples_per_second': 0.693, 'train_steps_per_second': 0.087, 'total_flos': 400199101526016.0, 'train_loss': 1.1079342539111774, 'epoch': 3.0})

In [3]:
model.save_pretrained('./fine-tuned-bert512-sliding256')
tokenizer.save_pretrained('./fine-tuned-bert512-sliding256')
def predict(text):
    inputs = tokenizer(text, 
                       truncation=True, 
                       padding='max_length', 
                       max_length=512, 
                       stride=256, 
                       return_overflowing_tokens=True, 
                       return_tensors='pt')

    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        avg_logits = torch.mean(logits, dim=0)
        predicted_class_id = avg_logits.argmax().item()
        return label_encoder.inverse_transform([predicted_class_id])[0]

example_text = '''A property owner files a lawsuit against their neighbor for encroaching on their land and demands that the boundary be restored to its rightful place.'''
predicted_class = predict(example_text)
print(f"Predicted class: {predicted_class}")

Predicted class: Civil Case


In [12]:
from transformers import BertTokenizer, BertForSequenceClassification
from collections import Counter
import joblib
import torch
tokenizer = BertTokenizer.from_pretrained('./fine-tuned-bert512-sliding256')
model = BertForSequenceClassification.from_pretrained('./fine-tuned-bert512-sliding256')
label_encoder = joblib.load('label_encoder.joblib')

In [13]:
def predict(text):
    inputs = tokenizer(text, 
                       truncation=True, 
                       padding='max_length', 
                       max_length=512, 
                       stride=256, 
                       return_overflowing_tokens=True, 
                       return_tensors='pt')

    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        avg_logits = torch.mean(logits, dim=0)
        predicted_class_id = avg_logits.argmax().item()
        return label_encoder.inverse_transform([predicted_class_id])[0]

example_text = '''A property owner files a lawsuit against their neighbor for encroaching on their land and demands that the boundary be restored to its rightful place.'''
predicted_class = predict(example_text)
print(f"Predicted class: {predicted_class}")

Predicted class: Civil Case
