In [4]:
!pip install --upgrade datasets




In [5]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from datasets import load_dataset
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from transformers import AutoModel
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW

In [None]:

# Configuration
MODEL_NAME = 'bert-base-uncased'
MAX_LENGTH = 512
DROPOUT_RATE = 0.3
FREEZE_LAYERS = True
FREEZE_EMBEDDINGS = True
FREEZE_EARLY_LAYERS = 6


BATCH_SIZE = 32
LEARNING_RATE = 2e-5
EPOCHS = 3
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.1
MAX_GRAD_NORM = 1.0


VALIDATION_SPLIT = 0.2
DATASET_NAME = "fancyzhx/ag_news"
TEXT_COLUMN = 'text'
LABEL_COLUMN = 'label'

RANDOM_SEED = 42
NUM_WORKERS = 2

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_PATH = '/content/sample_data/checkpoints/BERT_Text_Classifier.pth'



In [26]:
# custome classification dataset
class ClfDataset(Dataset):
  def __init__(self, texts, labels, tokenizer, max_length=512):
    self.texts= texts
    self.labels= labels
    self.tokenizer= tokenizer
    self.max_length= max_length

  def __len__(self):
    return len(self.texts)

  def __getitem__(self, idx):
    text = str(self.texts[idx])
    label = self.labels[idx]

    encoding = self.tokenizer(
        text,
        truncation=True,
        padding='max_length',
        max_length=self.max_length,
        return_tensors='pt'
    )

    return {
        'input_ids': encoding['input_ids'].flatten(),
        'attention_mask': encoding['attention_mask'].flatten(),
        'labels': torch.tensor(label, dtype= torch.long)
    }

In [27]:
# tokenizer setup
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [28]:


def load_ag_news_dataset():
  try:
      # load dataset
      dataset = load_dataset('fancyzhx/ag_news')
      train_ds = dataset['train']
      test_ds = dataset['test']

      # convert to dataframe
      train_df = pd.DataFrame(train_ds)
      test_df = pd.DataFrame(test_ds)

      # Create id2label and label2id mappings
      label_feature = dataset['train'].features['label']
      label_names = label_feature.names
      id2label = {i: name for i, name in enumerate(label_names)}
      label2id = {name: i for i, name in enumerate(label_names)}

      # printing some infos
      print('Dataset loaded successfully')
      print(f'Training samples length: {len(train_df)}')
      print(f'Test samples length: {len(test_df)}')
      print(f'Dataset columns: {train_df.columns.tolist()}')
      print(f'Label mappings: {id2label}')
      print(f'First samples:\n{train_df.head(2)}')

      return train_df, test_df, id2label, label2id

  except Exception as e:
      print(f"Error loading dataset: {e}")
      return None, None, None, None


In [29]:
train_df, test_df, id2label, label2id = load_ag_news_dataset()

Dataset loaded successfully
Training samples length: 120000
Test samples length: 7600
Dataset columns: ['text', 'label']
Label mappings: {0: 'World', 1: 'Sports', 2: 'Business', 3: 'Sci/Tech'}
First samples:
                                                text  label
0  Wall St. Bears Claw Back Into the Black (Reute...      2
1  Carlyle Looks Toward Commercial Aerospace (Reu...      2


In [30]:
# prepare dataloader
def prepare_dataloader():
  texts = train_df[TEXT_COLUMN].tolist()
  labels = train_df[LABEL_COLUMN].tolist()

  train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=VALIDATION_SPLIT, random_state=RANDOM_SEED)

  train_dataset = ClfDataset(train_texts, train_labels, tokenizer)
  val_dataset = ClfDataset(train_texts, train_labels, tokenizer)

  train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
  val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

  return train_dataloader, val_dataloader

In [31]:
train_dataloader, val_dataloader = prepare_dataloader()

In [32]:
# check dataloaders
for batch in train_dataloader:
  print(f"Batch input_ids shape: {batch['input_ids'].shape}")
  print(f"Batch attention_mask shape: {batch['attention_mask'].shape}")
  print(f"Batch labels shape: {batch['labels'].shape}")
  break


Batch input_ids shape: torch.Size([32, 512])
Batch attention_mask shape: torch.Size([32, 512])
Batch labels shape: torch.Size([32])


In [33]:
class BERTClassifier(nn.Module):
  def __init__(self, model_name= None, num_classes=None, dropout_rate=None):
    super(BERTClassifier, self).__init__()
    model_name = model_name or MODEL_NAME
    dropout_rate = dropout_rate or DROPOUT_RATE

    self.bert = AutoModel.from_pretrained(model_name)
    self.dropout = nn.Dropout(dropout_rate)
    self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)

    if FREEZE_LAYERS:
      self.freeze_layers()

  def freeze_layers(self):
    if FREEZE_EMBEDDINGS:
      for param in self.bert.embeddings.parameters():
          param.requires_grad = False

    if FREEZE_EARLY_LAYERS > 0:
      for layer in self.bert.encoder.layer[:FREEZE_EARLY_LAYERS]:
        for param in layer.parameters():
            param.requires_grad = False

  def forward(self, input_ids, attention_mask):
    outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
    pooled_output = outputs.pooler_output
    pooled_output = self.dropout(pooled_output)
    logits = self.classifier(pooled_output)
    return logits


  def get_model_info(self):
    total_params = sum(p.numel() for p in self.parameters())
    trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)

    return {
      'total_params': total_params,
      'trainable_params': trainable_params,
      'frozen_params': total_params - trainable_params
    }


In [34]:
model = BERTClassifier(num_classes=len(label2id)).to(DEVICE)

In [35]:
info = model.get_model_info()
print(f"Model created with {len(label2id)} classes")
print(f"Total parameters: {info['total_params']:,}")
print(f"Trainable parameters: {info['trainable_params']:,}")
print(f"Frozen parameters: {info['frozen_params']:,}")


Model created with 4 classes
Total parameters: 109,485,316
Trainable parameters: 43,120,900
Frozen parameters: 66,364,416


In [36]:

# initialize optimizer, scheduler, loss funcrion
optimizer = AdamW(
    model.parameters(),
    lr = LEARNING_RATE,
    weight_decay = WEIGHT_DECAY
)

total_steps = len(train_dataloader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps = int(WARMUP_RATIO * total_steps),
    num_training_steps = total_steps
)

loss_fn = nn.CrossEntropyLoss()

In [37]:
# train one epoch
def train_epoch(train_dataloader):
  model.train()
  total_loss = 0.0

  train_progress = tqdm(train_dataloader, desc="Training")
  for batch in train_progress:
    input_ids = batch['input_ids'].to(DEVICE)
    attention_mask = batch['attention_mask'].to(DEVICE)
    labels = batch['labels'].to(DEVICE)

    optimizer.zero_grad()

    logits = model(input_ids, attention_mask)
    loss = loss_fn(logits, labels)

    loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)

    optimizer.step()
    scheduler.step()

    total_loss += loss.item()
    train_progress.set_postfix({'loss': loss.item()})

  avg_loss = total_loss / len(train_dataloader)
  print(f"Training Loss: {avg_loss}")
  return avg_loss


In [38]:
# evaluation fun
def evaluate(val_dataloader):
  model.eval()
  total_loss = 0.0
  all_predictions = []
  all_labels = []
  val_progress = tqdm(val_dataloader, desc="Validation")

  with torch.no_grad():
    for batch in val_progress:
      input_ids = batch['input_ids'].to(DEVICE)
      attention_mask = batch['attention_mask'].to(DEVICE)
      labels = batch['labels'].to(DEVICE)


      logits = model(input_ids, attention_mask)
      loss = loss_fn(logits, labels)

      total_loss += loss.item()
      val_progress.set_postfix({'loss': loss.item()})

      prediction = torch.argmax(logits, dim=1)
      all_predictions.extend(prediction.cpu().numpy())
      all_labels.extend(labels.cpu().numpy())

  avg_loss = total_loss / len(val_dataloader)
  accuracy = np.mean(np.array(all_predictions) == np.array(all_labels))
  print(f"Validation Loss: {avg_loss}")
  print(f"Validation Accuracy: {accuracy}")

  return avg_loss, accuracy



In [39]:
def load_checkpoint(model, optimizer, filename):
  if os.path.exists(filename):
    checkpoint = torch.load(filename, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Checkpoint loaded — resuming from epoch {start_epoch}")
    return start_epoch
  else:
    print("No checkpoint found — starting from scratch.")
    return 0
start_epoch = load_checkpoint(model, optimizer, CHECKPOINT_PATH)

No checkpoint found — starting from scratch.


In [40]:

def save_checkpoint(model, optimizer, epoch, loss, filename):
  os.makedirs(os.path.dirname(filename), exist_ok=True)
  checkpoint = {
      'epoch': epoch,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': loss
  }
  torch.save(checkpoint, filename)

In [41]:
# training loop
def train(train_dataloader, val_dataloader, epochs=EPOCHS, start_epoch=0):
  # Training history
  train_losses = []
  val_losses = []
  val_accuracies = []
  best_val_accuracy = 0

  for epoch in range(start_epoch, epochs):
    train_loss = train_epoch(train_dataloader)
    val_loss, val_accuracy = evaluate(val_dataloader)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)

    if val_accuracy > best_val_accuracy:
      best_val_accuracy = val_accuracy
      print(f"Best validation accuracy: {best_val_accuracy}")

    save_checkpoint(model, optimizer, epoch, val_loss, CHECKPOINT_PATH)
  print(f"\nTraining completed!")
  print(f"Best validation accuracy: {best_val_accuracy:.4f}")

  return {
      'train_losses': train_losses,
      'val_losses': val_losses,
      'val_accuracies': val_accuracies,
      'best_val_accuracy': best_val_accuracy
  }



In [None]:
training = train(train_dataloader, val_dataloader, epochs=EPOCHS)


Training:  14%|█▍        | 422/3000 [12:56<1:19:19,  1.85s/it, loss=0.591]