In [1]:
import os
import torch
import json
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
import torch.nn.functional as F

from tqdm.auto import tqdm
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader

from sklearn.metrics import confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt

from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
#!git clone https://huggingface.co/madatnlp/km-bert

## Tokenizer

In [20]:
km_bert_dir = "km-bert"
kmbert_tokenizer = AutoTokenizer.from_pretrained(km_bert_dir, do_lower_case=False)

max_seq_len = 256
stride = 128
input_dir = "Data/4Class_Dataset"
output_dir = "Data/4Class_Dataset_Tokenized"

os.makedirs(output_dir, exist_ok=True)

def tokenize_with_sliding_window(text: str):
    tokens = kmbert_tokenizer(text, return_tensors="pt", padding='longest', truncation=True, max_length=512)
    input_ids = tokens['input_ids'].squeeze(0)
    attention_mask = tokens['attention_mask'].squeeze(0)

    tokenized_windows = []
    num_windows = (len(input_ids) - max_seq_len + stride) // stride + 1
    for i in range(num_windows):
        start = i * stride
        end = start + max_seq_len
        input_ids_window = input_ids[start:end]
        attention_mask_window = attention_mask[start:end]

        if len(input_ids_window) < max_seq_len:
            padding_length = max_seq_len - len(input_ids_window)
            input_ids_window = torch.cat([input_ids_window, torch.zeros(padding_length, dtype=torch.long)])
            attention_mask_window = torch.cat([attention_mask_window, torch.zeros(padding_length, dtype=torch.long)])

        tokenized_windows.append({'input_ids': input_ids_window, 'attention_mask': attention_mask_window})

    return tokenized_windows

for root, _, files in os.walk(input_dir):
    for filename in files:
        if filename.endswith(".json"):
            filepath = os.path.join(root, filename)

            with open(filepath, 'r', encoding='utf-8') as file:
                data = json.load(file)

            if "modifiedquery" in data:
                text = data["modifiedquery"]

                tokenized_windows = tokenize_with_sliding_window(text)

                relative_path = os.path.relpath(root, input_dir)
                output_subdir = os.path.join(output_dir, relative_path)
                os.makedirs(output_subdir, exist_ok=True)

                for i, window in enumerate(tokenized_windows):
                    output_path = os.path.join(output_subdir, f"{os.path.splitext(filename)[0]}_window_{i}.pt")
                    torch.save(window, output_path)

print("All files have been tokenized and saved with sliding windows.")


All files have been tokenized and saved with sliding windows.


In [21]:
class PreprocessedDataset(Dataset):
    def __init__(self, root_dir):
        self.data = []
        self.class_to_idx = self._get_class_to_idx(root_dir)
        for root, _, files in os.walk(root_dir):
            for file in files:
                if file.endswith('.pt'):
                    class_name = os.path.basename(root) 
                    label = self.class_to_idx[class_name]

                    self.data.append({
                        'file_path': os.path.join(root, file),
                        'label': label
                    })

    def _get_class_to_idx(self, root_dir):
        class_names = sorted(os.listdir(root_dir)) 
        return {class_name: idx for idx, class_name in enumerate(class_names)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data_item = self.data[idx]
        data = torch.load(data_item['file_path'], weights_only=False)

        input_ids = data['input_ids']
        attention_mask = data['attention_mask']
        label = torch.tensor(data_item['label'], dtype=torch.long) 

        if torch.isnan(input_ids).any() or torch.isinf(input_ids).any():
            print(f"NaN or Inf detected in input_ids at index {idx}")
        if torch.isnan(attention_mask).any() or torch.isinf(attention_mask).any():
            print(f"NaN or Inf detected in attention_mask at index {idx}")


        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'label': label
        }


In [22]:
def train_step(model, dataloader, loss_fn, optimizer, device):
    model.train()
    train_running_loss = 0.0
    train_correct = 0
    total_train_samples = 0

    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

        loss = outputs.loss if hasattr(outputs, 'loss') else loss_fn(outputs.logits, labels)
        logits = outputs.logits if hasattr(outputs, 'logits') else outputs

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        train_running_loss += loss.item()

        # Calculate training accuracy
        predictions = torch.argmax(logits, dim=1)
        train_correct += (predictions == labels).sum().item()
        total_train_samples += labels.size(0)

    avg_train_loss = train_running_loss / len(dataloader)
    train_accuracy = train_correct / total_train_samples
    return avg_train_loss, train_accuracy


In [28]:
def test_step(model, dataloader, loss_fn, device, num_classes, epoch, class_names_dir):
    model.eval()
    val_running_loss = 0.0
    val_correct = 0
    total_val_samples = 0
    all_predictions = []
    all_labels = []
    all_probs = []

    class_names = sorted(os.listdir(class_names_dir))

    result_dir = f"Models/4_class_classification/Result_{epoch}"
    os.makedirs(result_dir, exist_ok=True)

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

            loss = outputs.loss if hasattr(outputs, 'loss') else loss_fn(outputs.logits, labels)
            logits = outputs.logits if hasattr(outputs, 'logits') else outputs

            probs = F.softmax(logits, dim=1)

            val_running_loss += loss.item()

            predictions = torch.argmax(logits, dim=1)
            val_correct += (predictions == labels).sum().item()
            total_val_samples += labels.size(0)

            all_predictions.extend(predictions.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())
            all_probs.extend(probs.cpu().tolist())

    avg_val_loss = val_running_loss / len(dataloader)
    val_accuracy = val_correct / total_val_samples

    conf_matrix = confusion_matrix(all_labels, all_predictions)
    plt.figure(figsize=(8, 6))
    plt.title(f"Confusion Matrix (Epoch {epoch})")
    plt.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues)
    plt.colorbar()
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.xticks(ticks=range(num_classes), labels=class_names, rotation=45)
    plt.yticks(ticks=range(num_classes), labels=class_names)
    plt.savefig(os.path.join(result_dir, f"confusion_matrix_epoch_{epoch}.png"))
    plt.close()

    plt.figure(figsize=(10, 8))
    all_labels_onehot = torch.nn.functional.one_hot(torch.tensor(all_labels), num_classes=num_classes).numpy()
    all_probs_array = torch.tensor(all_probs).numpy()

    for i in range(num_classes):
        fpr, tpr, _ = roc_curve(all_labels_onehot[:, i], all_probs_array[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f"{class_names[i]} (AUC = {roc_auc:.2f})")

    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve (Epoch {epoch})")
    plt.legend(loc="lower right")
    plt.savefig(os.path.join(result_dir, f"roc_curve_epoch_{epoch}.png"))
    plt.close()

    return avg_val_loss, val_accuracy, all_predictions, all_labels


### Hyper-parameters

In [29]:
batch_size = 16
num_epochs = 60
learning_rate = 1e-5
num_classes = 4
root_dir = "Data/4Class_Dataset_Tokenized"
save_dir = "Models/4_class_classification"

In [32]:
kmbert_tokenizer = AutoTokenizer.from_pretrained(km_bert_dir, do_lower_case=False)
vocab_size = kmbert_tokenizer.vocab_size

class_names_dir = "Data/4Class_Dataset_Tokenized/Train"

# Dataset and DataLoader
train_dataset = PreprocessedDataset(os.path.join(root_dir, "Train"))
test_dataset = PreprocessedDataset(os.path.join(root_dir, "Test"))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print("DataLoader Process Completed")

model = AutoModelForSequenceClassification.from_pretrained(km_bert_dir, num_labels=num_classes).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

for epoch in tqdm(range(num_epochs)):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")

    # Train step
    train_loss, train_accuracy = train_step(model, train_dataloader, loss_fn, optimizer, device)
    print(f"\nTrain Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
    torch.cuda.empty_cache()

    # Validation step
    test_loss, test_accuracy, test_predictions, test_labels = test_step(model, test_dataloader, loss_fn, device, num_classes, epoch, class_names_dir)
    f1 = f1_score(test_labels, test_predictions, average='weighted')
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}, F1-Score: {f1:.4f}")

    torch.cuda.empty_cache()

    epoch_model_path = os.path.join(save_dir, f"4_Class_Classification_epoch_{epoch + 1}.pth")
    torch.save(model.state_dict(), epoch_model_path)
    print(f"Model for epoch {epoch + 1} saved to {epoch_model_path}")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at km-bert and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DataLoader Process Completed


  0%|          | 0/60 [00:00<?, ?it/s]


Epoch 1/60

Train Loss: 0.6372, Train Accuracy: 0.7821
Test Loss: 0.2207, Test Accuracy: 0.9420, F1-Score: 0.9392
Model for epoch 1 saved to Models/4_class_classification\4_Class_Classification_epoch_1.pth

Epoch 2/60


KeyboardInterrupt: 