In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from collections import Counter
import os
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt

# Constants
MAX_LEN = 64
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 2e-5
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
EARLY_STOPPING_PATIENCE = 7
EARLY_STOPPING_DELTA = 0.001
LR_PATIENCE = 2  # Number of epochs to wait before reducing learning rate
LR_FACTOR = 0.5  # Factor to reduce learning rate by
MIN_LR = 1e-6  # Minimum learning rate

# Custom Dataset with pre-computed BERT outputs
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len, bert_outputs_dir):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        # Load pre-computed BERT outputs
        self.bert_logits = np.load(os.path.join(bert_outputs_dir, 'bert_logits.npy'))
        self.bert_features = np.load(os.path.join(bert_outputs_dir, 'bert_features.npy'))
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        # Encode text using BERT tokenizer
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=False,
            return_tensors='pt'
        )
        lstm_input = encoding['input_ids'].squeeze(0)  # shape: (max_len,)
        # Get pre-computed BERT outputs
        bert_logits = torch.tensor(self.bert_logits[idx], dtype=torch.float)
        bert_features = torch.tensor(self.bert_features[idx], dtype=torch.float)
        return {
            'lstm_input': lstm_input,
            'bert_logits': bert_logits,
            'bert_features': bert_features,
            'label': torch.tensor(label, dtype=torch.long)
        }

# Student Model (LSTM)
class StudentModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes):
        super(StudentModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
        self.dropout = nn.Dropout(0.25)
        self.classifier = nn.Linear(hidden_dim * 2, num_classes)
        self.match_hidden = nn.Linear(hidden_dim * 2, 768)  # Match với BERT
    def forward(self, x):
        embedded = self.embedding(x)
        lstm_out, _ = self.lstm(embedded)
        # Use mean pooling of all hidden states
        last_hidden = torch.mean(lstm_out, dim=1)  # Take mean across sequence length dimension
        last_hidden = self.dropout(last_hidden)
        matched_hidden = self.match_hidden(last_hidden)  # Đưa về 768 chiều
        logits = self.classifier(last_hidden)
        return logits, matched_hidden

# Distillation Loss
class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=2.0):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.ce_loss = nn.CrossEntropyLoss()
        self.mse_loss = nn.MSELoss()
        
    def forward(self, student_logits, teacher_logits, student_features, teacher_features, labels):
        # Soft targets loss
        soft_targets = F.softmax(teacher_logits / self.temperature, dim=-1)
        soft_prob = F.log_softmax(student_logits / self.temperature, dim=-1)
        soft_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size(0)
        
        # Hard targets loss
        hard_loss = self.ce_loss(student_logits, labels)
        
        # Feature-based loss
        feature_loss = self.mse_loss(student_features, teacher_features)
        
        # Combine losses
        total_loss = (1 - self.alpha) * hard_loss + self.alpha * soft_loss + 0.1 * feature_loss
        return total_loss


In [2]:
def train_model(student_model, train_loader, optimizer, criterion, device):
    student_model.train()
    total_loss = 0
    
    for batch in tqdm(train_loader):
        lstm_input = batch['lstm_input'].to(device)
        teacher_logits = batch['bert_logits'].to(device)
        teacher_features = batch['bert_features'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        
        # Get student outputs
        student_logits, student_features = student_model(lstm_input)
        
        # Calculate loss
        loss = criterion(student_logits, teacher_logits, student_features, teacher_features, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Clear memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return total_loss / len(train_loader)

def evaluate_model(student_model, data_loader, criterion, device):
    student_model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in data_loader:
            lstm_input = batch['lstm_input'].to(device)
            teacher_logits = batch['bert_logits'].to(device)
            teacher_features = batch['bert_features'].to(device)
            labels = batch['label'].to(device)
            
            student_logits, student_features = student_model(lstm_input)
            loss = criterion(student_logits, teacher_logits, student_features, teacher_features, labels)
            total_loss += loss.item()
            
            # Clear memory
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                
    return total_loss / len(data_loader)

def test_model(student_model, test_loader, criterion, device):
    student_model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_labels = []
    all_preds = []
    
    with torch.no_grad():
        for batch in test_loader:
            lstm_input = batch['lstm_input'].to(device)
            teacher_logits = batch['bert_logits'].to(device)
            teacher_features = batch['bert_features'].to(device)
            labels = batch['label'].to(device)
            
            student_logits, student_features = student_model(lstm_input)
            loss = criterion(student_logits, teacher_logits, student_features, teacher_features, labels)
            total_loss += loss.item()
            
            preds = torch.argmax(student_logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            
            # Clear memory
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                
    avg_loss = total_loss / len(test_loader)
    accuracy = correct / total
    cm = confusion_matrix(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, digits=4)
    return avg_loss, accuracy, cm, report

In [3]:
def save_checkpoint(model, optimizer, scheduler, epoch, best_metric, filename):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'epoch': epoch,
        'best_metric': best_metric
    }
    torch.save(checkpoint, filename)

def load_checkpoint(model, optimizer, scheduler, filename):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if checkpoint['scheduler_state_dict'] and scheduler:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    return checkpoint['epoch'], checkpoint['best_metric']

In [None]:

# Load and preprocess data
train_df = pd.read_csv('/kaggle/input/news-dataset/final_news_train.csv')
test_df = pd.read_csv('/kaggle/input/news-dataset/final_news_test.csv')
# Split train into train and validation
train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_df['text'].values,
    train_df['label'].values,
    test_size=0.1,
    random_state=42,
    stratify=train_df['label'].values
)
# Initialize tokenizer
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Create datasets with pre-computed BERT outputs
train_dataset = TextDataset(
    texts=train_texts,
    labels=train_labels,
    tokenizer=bert_tokenizer,
    max_len=MAX_LEN,
    bert_outputs_dir='/kaggle/input/precomputed-bert/precomputed_bert/train'
)
val_dataset = TextDataset(
    texts=val_texts,
    labels=val_labels,
    tokenizer=bert_tokenizer,
    max_len=MAX_LEN,
    bert_outputs_dir='/kaggle/input/precomputed-bert/precomputed_bert/val'  # Use separate validation outputs
)
test_dataset = TextDataset(
    texts=test_df['text'].values,
    labels=test_df['label'].values,
    tokenizer=bert_tokenizer,
    max_len=MAX_LEN,
    bert_outputs_dir='/kaggle/input/precomputed-bert/precomputed_bert/test'
)
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True)
# Initialize student model
student_model = StudentModel(
    vocab_size=bert_tokenizer.vocab_size,
    embedding_dim=256,
    hidden_dim=256,
    num_classes=4
).to(DEVICE)

# Initialize optimizer and criterion
optimizer = torch.optim.Adam(student_model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
criterion = DistillationLoss(alpha=0.5, temperature=2.0)

# Initialize learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=LR_FACTOR,
    patience=LR_PATIENCE,
    min_lr=MIN_LR,
    verbose=True
)

# Always start training from scratch
start_epoch = 0
best_val_loss = float('inf')
train_losses = []
val_losses = []

# Training loop
for epoch in range(start_epoch, EPOCHS):
    print(f'Epoch {epoch + 1}/{EPOCHS}')
    # Clear GPU cache at the start of each epoch
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    train_loss = train_model(student_model, train_loader, optimizer, criterion, DEVICE)
    val_loss = evaluate_model(student_model, val_loader, criterion, DEVICE)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    print(f'Training Loss: {train_loss:.4f} | Validation Loss: {val_loss:.4f}')
    # Update learning rate based on validation loss
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    print(f'Current Learning Rate: {current_lr:.2e}')
    # Early stopping check
    if val_loss < best_val_loss - EARLY_STOPPING_DELTA:
        patience_counter = 0
        best_val_loss = val_loss
        # Save best model
        save_checkpoint(
            student_model,
            optimizer,
            scheduler,
            epoch,
            best_val_loss,
            'best_student_model.pth'
        )
        print('Best model saved!')
    else:
            patience_counter += 1
            print(f'EarlyStopping counter: {patience_counter} out of {EARLY_STOPPING_PATIENCE}')
            if patience_counter >= EARLY_STOPPING_PATIENCE:
                print('Early stopping triggered')
                break

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Epoch 1/50


100%|██████████| 12823/12823 [02:57<00:00, 72.23it/s]


Training Loss: 1.0830 | Validation Loss: 0.8992
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 2/50


100%|██████████| 12823/12823 [02:57<00:00, 72.19it/s]


Training Loss: 0.8102 | Validation Loss: 0.7659
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 3/50


100%|██████████| 12823/12823 [02:57<00:00, 72.18it/s]


Training Loss: 0.7211 | Validation Loss: 0.7321
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 4/50


100%|██████████| 12823/12823 [02:57<00:00, 72.36it/s]


Training Loss: 0.6773 | Validation Loss: 0.7194
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 5/50


100%|██████████| 12823/12823 [02:57<00:00, 72.26it/s]


Training Loss: 0.6499 | Validation Loss: 0.6864
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 6/50


100%|██████████| 12823/12823 [02:57<00:00, 72.30it/s]


Training Loss: 0.6283 | Validation Loss: 0.6507
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 7/50


100%|██████████| 12823/12823 [02:57<00:00, 72.18it/s]


Training Loss: 0.6092 | Validation Loss: 0.6337
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 8/50


100%|██████████| 12823/12823 [02:57<00:00, 72.36it/s]


Training Loss: 0.5920 | Validation Loss: 0.6433
Current Learning Rate: 2.00e-05
EarlyStopping counter: 1 out of 7
Epoch 9/50


100%|██████████| 12823/12823 [02:57<00:00, 72.11it/s]


Training Loss: 0.5749 | Validation Loss: 0.6099
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 10/50


100%|██████████| 12823/12823 [02:57<00:00, 72.06it/s]


Training Loss: 0.5595 | Validation Loss: 0.5915
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 11/50


100%|██████████| 12823/12823 [02:57<00:00, 72.34it/s]


Training Loss: 0.5456 | Validation Loss: 0.5887
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 12/50


100%|██████████| 12823/12823 [02:57<00:00, 72.27it/s]


Training Loss: 0.5341 | Validation Loss: 0.5741
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 13/50


100%|██████████| 12823/12823 [02:57<00:00, 72.25it/s]


Training Loss: 0.5234 | Validation Loss: 0.5730
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 14/50


100%|██████████| 12823/12823 [02:57<00:00, 72.25it/s]


Training Loss: 0.5145 | Validation Loss: 0.5639
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 15/50


100%|██████████| 12823/12823 [02:57<00:00, 72.28it/s]


Training Loss: 0.5063 | Validation Loss: 0.5592
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 16/50


100%|██████████| 12823/12823 [02:57<00:00, 72.25it/s]


Training Loss: 0.4995 | Validation Loss: 0.5617
Current Learning Rate: 2.00e-05
EarlyStopping counter: 1 out of 7
Epoch 17/50


100%|██████████| 12823/12823 [02:57<00:00, 72.17it/s]


Training Loss: 0.4935 | Validation Loss: 0.5538
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 18/50


100%|██████████| 12823/12823 [02:57<00:00, 72.10it/s]


Training Loss: 0.4880 | Validation Loss: 0.5525
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 19/50


100%|██████████| 12823/12823 [02:57<00:00, 72.24it/s]


Training Loss: 0.4831 | Validation Loss: 0.5490
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 20/50


100%|██████████| 12823/12823 [02:57<00:00, 72.25it/s]


Training Loss: 0.4787 | Validation Loss: 0.5552
Current Learning Rate: 2.00e-05
EarlyStopping counter: 1 out of 7
Epoch 21/50


100%|██████████| 12823/12823 [02:57<00:00, 72.31it/s]


Training Loss: 0.4743 | Validation Loss: 0.5489
Current Learning Rate: 2.00e-05
EarlyStopping counter: 2 out of 7
Epoch 22/50


100%|██████████| 12823/12823 [02:57<00:00, 72.12it/s]


Training Loss: 0.4710 | Validation Loss: 0.5526
Current Learning Rate: 2.00e-05
EarlyStopping counter: 3 out of 7
Epoch 23/50


100%|██████████| 12823/12823 [02:57<00:00, 72.05it/s]


Training Loss: 0.4679 | Validation Loss: 0.5423
Current Learning Rate: 2.00e-05
Best model saved!
Epoch 24/50


100%|██████████| 12823/12823 [02:58<00:00, 71.91it/s]


Training Loss: 0.4645 | Validation Loss: 0.5464
Current Learning Rate: 2.00e-05
EarlyStopping counter: 1 out of 7
Epoch 25/50


100%|██████████| 12823/12823 [02:58<00:00, 71.99it/s]


Training Loss: 0.4618 | Validation Loss: 0.5466
Current Learning Rate: 2.00e-05
EarlyStopping counter: 2 out of 7
Epoch 26/50


100%|██████████| 12823/12823 [02:58<00:00, 71.77it/s]


Training Loss: 0.4592 | Validation Loss: 0.5459
Current Learning Rate: 1.00e-05
EarlyStopping counter: 3 out of 7
Epoch 27/50


100%|██████████| 12823/12823 [02:58<00:00, 71.97it/s]


Training Loss: 0.4416 | Validation Loss: 0.5539
Current Learning Rate: 1.00e-05
EarlyStopping counter: 4 out of 7
Epoch 28/50


100%|██████████| 12823/12823 [02:58<00:00, 71.99it/s]


Training Loss: 0.4387 | Validation Loss: 0.5394
Current Learning Rate: 1.00e-05
Best model saved!
Epoch 29/50


100%|██████████| 12823/12823 [02:58<00:00, 71.86it/s]


Training Loss: 0.4365 | Validation Loss: 0.5510
Current Learning Rate: 1.00e-05
EarlyStopping counter: 1 out of 7
Epoch 30/50


100%|██████████| 12823/12823 [02:58<00:00, 72.03it/s]


Training Loss: 0.4347 | Validation Loss: 0.5569
Current Learning Rate: 1.00e-05
EarlyStopping counter: 2 out of 7
Epoch 31/50


100%|██████████| 12823/12823 [02:58<00:00, 71.85it/s]


Training Loss: 0.4334 | Validation Loss: 0.5603
Current Learning Rate: 5.00e-06
EarlyStopping counter: 3 out of 7
Epoch 32/50


100%|██████████| 12823/12823 [02:58<00:00, 72.03it/s]


Training Loss: 0.4226 | Validation Loss: 0.5486
Current Learning Rate: 5.00e-06
EarlyStopping counter: 4 out of 7
Epoch 33/50


 86%|████████▌ | 11025/12823 [02:33<00:24, 72.71it/s]

In [None]:
# Sau khi kết thúc training, plot và lưu hình:
plt.figure(figsize=(8, 5))
plt.plot(range(start_epoch + 1, start_epoch + 1 + len(train_losses)), train_losses, label='Train Loss')
plt.plot(range(start_epoch + 1, start_epoch + 1 + len(val_losses)), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train/Validation Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
plt.close()

# Đánh giá trên test set với best model
print('Evaluating on test set with best model...')
student_model.load_state_dict(torch.load('best_student_model.pth')['model_state_dict'])
test_loss, test_acc, cm, report = test_model(student_model, test_loader, criterion, DEVICE)
print(f'Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.4f}')
print('Confusion Matrix:')
print(cm)
print('Classification Report:')
print(report)