In [None]:
# prompt: mount

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from collections import Counter
import os

# Constants
MAX_LEN = 64
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 2e-5
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
EARLY_STOPPING_PATIENCE = 7
EARLY_STOPPING_DELTA = 0.001
LR_PATIENCE = 2  # Number of epochs to wait before reducing learning rate
LR_FACTOR = 0.5  # Factor to reduce learning rate by
MIN_LR = 1e-6  # Minimum learning rate

# Custom Dataset with pre-computed BERT outputs
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len, bert_outputs_dir):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        # Load pre-computed BERT outputs
        self.bert_logits = np.load(os.path.join(bert_outputs_dir, 'bert_logits.npy'))
        self.bert_features = np.load(os.path.join(bert_outputs_dir, 'bert_features.npy'))
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        # Encode text using BERT tokenizer
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=False,
            return_tensors='pt'
        )
        lstm_input = encoding['input_ids'].squeeze(0)  # shape: (max_len,)
        # Get pre-computed BERT outputs
        bert_logits = torch.tensor(self.bert_logits[idx], dtype=torch.float)
        bert_features = torch.tensor(self.bert_features[idx], dtype=torch.float)
        return {
            'lstm_input': lstm_input,
            'bert_logits': bert_logits,
            'bert_features': bert_features,
            'label': torch.tensor(label, dtype=torch.long)
        }

# Student Model (LSTM)
class StudentModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes, tokenizer):
        super(StudentModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=tokenizer.pad_token_id)
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
        self.dropout = nn.Dropout(0.25)
        self.classifier = nn.Linear(hidden_dim * 2, num_classes)
        self.match_hidden = nn.Linear(hidden_dim * 2, 768)  # Match với BERT
    def forward(self, x):
        embedded = self.embedding(x)
        lstm_out, _ = self.lstm(embedded)
        # Use mean pooling of all hidden states
        last_hidden = torch.mean(lstm_out, dim=1)  # Take mean across sequence length dimension
        last_hidden = self.dropout(last_hidden)
        matched_hidden = self.match_hidden(last_hidden)  # Đưa về 768 chiều
        logits = self.classifier(last_hidden)
        return logits, matched_hidden

# Distillation Loss
class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=2.0):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.ce_loss = nn.CrossEntropyLoss()
        self.mse_loss = nn.MSELoss()

    def forward(self, student_logits, teacher_logits, student_features, teacher_features, labels):
        # Soft targets loss
        soft_targets = F.softmax(teacher_logits / self.temperature, dim=-1)
        soft_prob = F.log_softmax(student_logits / self.temperature, dim=-1)
        soft_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size(0)

        # Hard targets loss
        hard_loss = self.ce_loss(student_logits, labels)

        # Feature-based loss
        feature_loss = self.mse_loss(student_features, teacher_features)

        # Combine losses
        total_loss = (1 - self.alpha) * hard_loss + self.alpha * soft_loss + 0.1 * feature_loss
        return total_loss


In [None]:
from sklearn.metrics import accuracy_score

# Load test data
test_df = pd.read_csv('/content/drive/MyDrive/ML_DM/final_news_test.csv')

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Create test dataset (update bert_outputs_dir if needed)
test_dataset = TextDataset(
    texts=test_df['text'].values,
    labels=test_df['label'].values,
    tokenizer=tokenizer,
    max_len=MAX_LEN,
    bert_outputs_dir='/content/drive/MyDrive/ML_DM/precomputed_bert/test'
)

# Create test dataloader
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True)

# Initialize model
student_model = StudentModel(
    vocab_size=tokenizer.vocab_size,
    embedding_dim=256,
    hidden_dim=256,
    num_classes=4,
    tokenizer=tokenizer
).to(DEVICE)

# Load best model weights
checkpoint = torch.load('/content/drive/MyDrive/ML_DM/Student model/acc 87/best_student_model.pth', map_location=DEVICE)
student_model.load_state_dict(checkpoint['model_state_dict'])

# Initialize loss (for feature distillation, not used for accuracy)
criterion = DistillationLoss(alpha=0.5, temperature=2.0)

# Evaluate
student_model.eval()
all_labels = []
all_preds = []
with torch.no_grad():
    for batch in test_loader:
        lstm_input = batch['lstm_input'].to(DEVICE)
        labels = batch['label'].to(DEVICE)
        student_logits, _ = student_model(lstm_input)
        preds = torch.argmax(student_logits, dim=1)
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())

accuracy = accuracy_score(all_labels, all_preds)
print(f"Test Accuracy: {accuracy:.4f} ({sum([p==t for p, t in zip(all_preds, all_labels)])}/{len(all_labels)})")
print(f"Test Accuracy: {accuracy*100:.2f}%")

# Print results
print("Classification Report:")
print(classification_report(all_labels, all_preds, digits=4))
print("Confusion Matrix:")
cm = confusion_matrix(all_labels, all_preds)
print(cm)

# Plot confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.savefig('student_confusion_matrix.png')
plt.close()





Test Accuracy: 0.8700 (17048/19596)
Test Accuracy: 87.00%
Classification Report:
              precision    recall  f1-score   support

           0     0.8790    0.8977    0.8882      7026
           1     0.8935    0.8848    0.8891      3905
           2     0.8730    0.8752    0.8741      5952
           3     0.8029    0.7656    0.7838      2713

    accuracy                         0.8700     19596
   macro avg     0.8621    0.8558    0.8588     19596
weighted avg     0.8695    0.8700    0.8696     19596

Confusion Matrix:
[[6307  185  342  192]
 [ 256 3455  148   46]
 [ 324  147 5209  272]
 [ 288   80  268 2077]]


In [None]:
print(f"\nTest Accuracy: {accuracy:.4f}")

NameError: name 'accuracy' is not defined