TRAINING BERT

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from transformers import BertModel, BertTokenizer
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
import gc

# Custom Dataset
class BertDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }



2025-05-24 11:57:58.119659: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748087878.346484      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748087878.406651      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# BERT Model
class BertClassifier(nn.Module):
    def __init__(self, num_classes):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.bert.gradient_checkpointing_enable()  # Enable gradient checkpointing

        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        pooled_output = outputs[1]  # [CLS] token representation
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

In [3]:
# Early Stopping Class
class EarlyStopping:
    def __init__(self, patience=3, delta=0.001):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf

    def __call__(self, val_loss, model, path):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, path):
        if val_loss < self.val_loss_min:
            torch.save(model.state_dict(), path)
            self.val_loss_min = val_loss

In [4]:
def train_model(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader, desc='Training'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()

        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Clear cache after each batch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    return total_loss / len(train_loader)

def evaluate_model(model, data_loader, criterion, device, desc='Evaluating'):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(data_loader, desc=desc):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)

            total_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Clear cache after each batch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    accuracy = correct / total
    return total_loss / len(data_loader), accuracy

In [5]:
## MAIN

# Constants
MAX_LEN = 64
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 1e-5
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
EARLY_STOPPING_PATIENCE = 3  # Number of epochs to wait before early stopping
EARLY_STOPPING_DELTA = 0.001  # Minimum change in validation accuracy to be considered as improvement
LR_PATIENCE = 2  # Number of epochs to wait before reducing learning rate
LR_FACTOR = 0.2  # Factor to reduce learning rate by
MIN_LR = 1e-6  # Minimum learning rate
# Enable memory efficient attention
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Load data
train_df = pd.read_csv("/kaggle/input/news-dataset/final_news_train.csv")
test_df = pd.read_csv("/kaggle/input/news-dataset/final_news_test.csv")

# Split train into train and validation
train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_df['text'].values,
    train_df['label'].values,
    test_size=0.1,
    random_state=42,
    stratify=train_df['label'].values
)

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Create datasets
train_dataset = BertDataset(
    texts=train_texts,
    labels=train_labels,
    tokenizer=tokenizer,
    max_len=MAX_LEN
)

val_dataset = BertDataset(
    texts=val_texts,
    labels=val_labels,
    tokenizer=tokenizer,
    max_len=MAX_LEN
)

test_dataset = BertDataset(
    texts=test_df['text'].values,
    labels=test_df['label'].values,
    tokenizer=tokenizer,
    max_len=MAX_LEN
)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=2)

# Initialize model
model = BertClassifier(num_classes=4).to(DEVICE)

# Initialize optimizer and criterion
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

# Initialize early stopping
early_stopping = EarlyStopping(patience=EARLY_STOPPING_PATIENCE, delta=EARLY_STOPPING_DELTA)
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=LR_FACTOR,
    patience=LR_PATIENCE,
    min_lr=MIN_LR,
    verbose=True
)


cuda


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [6]:
# Clear GPU memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

# Training loop
best_val_accuracy = 0
for epoch in range(EPOCHS):
    print(f'\nEpoch {epoch + 1}/{EPOCHS}')

    # Train
    train_loss = train_model(model, train_loader, optimizer, criterion, DEVICE)
    print(f'Training Loss: {train_loss:.4f}')

    # Validate
    val_loss, val_accuracy = evaluate_model(model, val_loader, criterion, DEVICE, desc='Validating')
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')

    # Test
    test_loss, test_accuracy = evaluate_model(model, test_loader, criterion, DEVICE, desc='Testing')
    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')

    # Update learning rate based on validation loss
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    print(f'Current Learning Rate: {current_lr:.2e}')
    
    # Early stopping check
    early_stopping(val_loss, model, '/kaggle/working/bert_classifier.pth')
    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

    # Save best model based on validation accuracy
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        # torch.save(model.state_dict(), 'bert_classifier.pth')
        print(f'New best model saved with validation accuracy: {val_accuracy:.4f}')

      # Clear memory after each epoch
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()


Epoch 1/10


Training: 100%|██████████| 12823/12823 [55:42<00:00,  3.84it/s]


Training Loss: 0.3054


Validating: 100%|██████████| 1425/1425 [01:24<00:00, 16.80it/s]


Validation Loss: 0.2438, Validation Accuracy: 0.9135


Testing: 100%|██████████| 1225/1225 [01:12<00:00, 16.91it/s]


Test Loss: 0.2577, Test Accuracy: 0.9067
Current Learning Rate: 1.00e-05
New best model saved with validation accuracy: 0.9135

Epoch 2/10


Training: 100%|██████████| 12823/12823 [55:51<00:00,  3.83it/s]


Training Loss: 0.1946


Validating: 100%|██████████| 1425/1425 [01:24<00:00, 16.91it/s]


Validation Loss: 0.2414, Validation Accuracy: 0.9173


Testing: 100%|██████████| 1225/1225 [01:12<00:00, 16.89it/s]


Test Loss: 0.2557, Test Accuracy: 0.9107
Current Learning Rate: 1.00e-05
New best model saved with validation accuracy: 0.9173

Epoch 3/10


Training: 100%|██████████| 12823/12823 [55:47<00:00,  3.83it/s]


Training Loss: 0.1310


Validating: 100%|██████████| 1425/1425 [01:24<00:00, 16.81it/s]


Validation Loss: 0.2715, Validation Accuracy: 0.9143


Testing: 100%|██████████| 1225/1225 [01:12<00:00, 16.84it/s]


Test Loss: 0.2855, Test Accuracy: 0.9096
Current Learning Rate: 1.00e-05

Epoch 4/10


Training: 100%|██████████| 12823/12823 [55:51<00:00,  3.83it/s]


Training Loss: 0.0863


Validating: 100%|██████████| 1425/1425 [01:25<00:00, 16.68it/s]


Validation Loss: 0.2956, Validation Accuracy: 0.9122


Testing: 100%|██████████| 1225/1225 [01:13<00:00, 16.63it/s]


Test Loss: 0.3148, Test Accuracy: 0.9075
Current Learning Rate: 1.00e-05

Epoch 5/10


Training: 100%|██████████| 12823/12823 [56:00<00:00,  3.82it/s]


Training Loss: 0.0584


Validating: 100%|██████████| 1425/1425 [01:25<00:00, 16.75it/s]


Validation Loss: 0.3326, Validation Accuracy: 0.9156


Testing: 100%|██████████| 1225/1225 [01:13<00:00, 16.75it/s]

Test Loss: 0.3682, Test Accuracy: 0.9089
Current Learning Rate: 2.00e-06
Early stopping triggered



