In [None]:
!pip install transformers datasets torch scipy

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
from scipy.stats import pearsonr
from torch.cuda.amp import autocast, GradScaler

# =====================
# 1. Định nghĩa mô hình MTL sử dụng TinyBERT
# =====================
class MultiTaskTinyBERT(nn.Module):
    def __init__(self, model_name="huawei-noah/TinyBERT_General_4L_312D", dropout=0.2):
        super(MultiTaskTinyBERT, self).__init__()
        self.tinybert = AutoModel.from_pretrained(model_name)
        hidden_size = self.tinybert.config.hidden_size  # Với TinyBERT_General_4L_312D, hidden_size = 312
        self.dropout = nn.Dropout(dropout)  # Tăng dropout từ 0.1 lên 0.2
        self.sentiment_classifier = nn.Linear(hidden_size, 2)   # 2 lớp (positive/negative)
        self.paraphrase_classifier = nn.Linear(hidden_size, 2)  # 2 lớp (duplicate/not-duplicate)
        self.sts_regressor = nn.Linear(hidden_size, 1)          # Hồi quy cho STS (điểm liên tục)

    def forward(self, input_ids, attention_mask, task, labels=None):
        outputs = self.tinybert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = outputs.last_hidden_state  # (batch_size, seq_len, hidden_size)
        pooled_output = hidden_state[:, 0]        # vector [CLS]
        pooled_output = self.dropout(pooled_output)

        if task == "sentiment":
            logits = self.sentiment_classifier(pooled_output)
            if labels is not None:
                loss_fn = nn.CrossEntropyLoss()
                loss = loss_fn(logits, labels)
                return loss, logits
            return logits

        elif task == "paraphrase":
            logits = self.paraphrase_classifier(pooled_output)
            if labels is not None:
                loss_fn = nn.CrossEntropyLoss()
                loss = loss_fn(logits, labels)
                return loss, logits
            return logits

        elif task == "sts":
            logits = self.sts_regressor(pooled_output).squeeze(-1)
            if labels is not None:
                loss_fn = nn.MSELoss()
                loss = loss_fn(logits, labels.float())
                return loss, logits
            return logits

        else:
            raise ValueError("Task không hợp lệ!")

# =====================
# 2. Hàm tải & tiền xử lý dữ liệu (GLUE)
# =====================
def load_dataset_and_tokenize(task, tokenizer, max_length=128):
    if task == "sentiment":
        dataset = load_dataset("glue", "sst2")
        def preprocess(examples):
            return tokenizer(examples["sentence"], truncation=True, padding="max_length", max_length=max_length)
        dataset = dataset.map(preprocess, batched=True)
        dataset = dataset.rename_column("label", "labels")
        dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        return dataset

    elif task == "paraphrase":
        dataset = load_dataset("glue", "qqp")
        def preprocess(examples):
            return tokenizer(examples["question1"], examples["question2"],
                             truncation=True, padding="max_length", max_length=max_length)
        dataset = dataset.map(preprocess, batched=True)
        dataset = dataset.rename_column("label", "labels")
        dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        return dataset

    elif task == "sts":
        dataset = load_dataset("glue", "stsb")
        def preprocess(examples):
            return tokenizer(examples["sentence1"], examples["sentence2"],
                             truncation=True, padding="max_length", max_length=max_length)
        dataset = dataset.map(preprocess, batched=True)
        dataset = dataset.rename_column("label", "labels")
        dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        return dataset

    else:
        raise ValueError("Task không hợp lệ!")

# =====================
# 3. Hàm huấn luyện đa nhiệm (MTL) với early stopping
# =====================
def train_multitask(model, dataloaders, val_dataloaders, optimizer, scheduler, device, num_epochs=5, task_weights=None, patience=2):
    if task_weights is None:
        task_weights = {"sentiment": 1.0, "paraphrase": 1.0, "sts": 1.0}

    model.to(device)
    scaler = GradScaler()  # Cho mixed precision
    best_val_score = -float('inf')
    counter = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for task, dataloader in dataloaders.items():
            for batch in dataloader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                optimizer.zero_grad()
                with autocast():  # Mixed precision
                    loss, _ = model(input_ids, attention_mask, task=task, labels=labels)
                    loss = loss * task_weights.get(task, 1.0)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                total_loss += loss.item()

        avg_loss = total_loss / sum(len(dl) for dl in dataloaders.values())
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

        val_score = evaluate_validation(model, val_dataloaders, device)
        if val_score > best_val_score:
            best_val_score = val_score
            counter = 0
            torch.save(model.state_dict(), "best_mtl_model.pt")
        else:
            counter += 1
            if counter >= patience:
                print("Early stopping triggered!")
                break

# =====================
# 4. Hàm đánh giá trên tập validation
# =====================
def evaluate_validation(model, val_dataloaders, device):
    sentiment_acc = evaluate_classification(model, val_dataloaders["sentiment"], "sentiment", device)
    paraphrase_acc = evaluate_classification(model, val_dataloaders["paraphrase"], "paraphrase", device)
    sts_pearson = evaluate_sts(model, val_dataloaders["sts"], device)
    avg_score = (sentiment_acc + paraphrase_acc + sts_pearson) / 3
    print(f"Validation - Sentiment ACC: {sentiment_acc:.4f}, Paraphrase ACC: {paraphrase_acc:.4f}, STS Pearson: {sts_pearson:.4f}, Avg Score: {avg_score:.4f}")
    return avg_score

# =====================
# 5. Hàm đánh giá (phân loại & hồi quy)
# =====================
def evaluate_classification(model, dataloader, task, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            logits = model(input_ids, attention_mask, task=task)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total

def evaluate_sts(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            preds = model(input_ids, attention_mask, task="sts")
            all_preds.extend(preds.cpu().numpy().tolist())
            all_labels.extend(labels.cpu().numpy().tolist())
    pearson_corr, _ = pearsonr(all_preds, all_labels)
    return pearson_corr

# =====================
# 6. Hàm fine-tuning riêng cho từng task với mixed precision
# =====================
def fine_tune_task(model, dataloader, task, optimizer, scheduler, device, num_epochs=5):
    model.to(device)
    scaler = GradScaler()  # Cho mixed precision
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            with autocast():
                loss, _ = model(input_ids, attention_mask, task=task, labels=labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Task {task} - Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

    torch.save(model.state_dict(), f"tinybert_finetuned_{task}.pt")
    print(f"Model fine-tuned cho task {task} đã được lưu!")

# =====================
# 7. Quy trình tổng hợp
# =====================
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    model_name = "huawei-noah/TinyBERT_General_4L_312D"
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Tải & tiền xử lý dữ liệu
    sst2_dataset = load_dataset_and_tokenize("sentiment", tokenizer)
    qqp_dataset = load_dataset_and_tokenize("paraphrase", tokenizer)
    stsb_dataset = load_dataset_and_tokenize("sts", tokenizer)

    # DataLoader với batch size tăng từ 8 lên 16
    batch_size = 16
    dataloaders = {
        "sentiment": DataLoader(sst2_dataset["train"], batch_size=batch_size, shuffle=True),
        "paraphrase": DataLoader(qqp_dataset["train"], batch_size=batch_size, shuffle=True),
        "sts": DataLoader(stsb_dataset["train"], batch_size=batch_size, shuffle=True)
    }
    val_dataloaders = {
        "sentiment": DataLoader(sst2_dataset["validation"], batch_size=batch_size),
        "paraphrase": DataLoader(qqp_dataset["validation"], batch_size=batch_size),
        "sts": DataLoader(stsb_dataset["validation"], batch_size=batch_size)
    }

    # Khởi tạo mô hình MTL
    model = MultiTaskTinyBERT(model_name=model_name)

    # Huấn luyện đa nhiệm (MTL) ban đầu
    optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)  # Giảm lr từ 2e-5 xuống 1e-5, thêm weight decay
    total_steps = 5 * sum(len(dl) for dl in dataloaders.values())  # Tăng từ 3 lên 5 epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

    print("\n===== HUẤN LUYỆN MÔ HÌNH MTL BAN ĐẦU =====")
    train_multitask(model, dataloaders, val_dataloaders, optimizer, scheduler, device, num_epochs=5, patience=2)

    # Load best MTL model
    model.load_state_dict(torch.load("best_mtl_model.pt"))

    # Đánh giá MTL trên tập validation
    sentiment_acc = evaluate_classification(model, val_dataloaders["sentiment"], "sentiment", device)
    paraphrase_acc = evaluate_classification(model, val_dataloaders["paraphrase"], "paraphrase", device)
    sts_pearson = evaluate_sts(model, val_dataloaders["sts"], device)

    print("=== KẾT QUẢ VALIDATION SAU MTL BAN ĐẦU ===")
    print(f"Sentiment ACC: {sentiment_acc:.4f}")
    print(f"Paraphrase ACC: {paraphrase_acc:.4f}")
    print(f"STS Pearson: {sts_pearson:.4f}")

    torch.save(model.state_dict(), "multi_task_tinybert.pt")
    print("Model đa nhiệm đã được lưu thành công!\n")

    # Fine-tuning riêng cho từng task
    # 2.1) Fine-Tuning cho Sentiment
    print("===== FINE-TUNING SENTIMENT =====")
    optimizer_sentiment = AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
    total_steps_sentiment = len(dataloaders["sentiment"]) * 5
    scheduler_sentiment = get_linear_schedule_with_warmup(optimizer_sentiment, num_warmup_steps=0, num_training_steps=total_steps_sentiment)
    fine_tune_task(model, dataloaders["sentiment"], "sentiment", optimizer_sentiment, scheduler_sentiment, device, num_epochs=5)

    sentiment_acc_ft = evaluate_classification(model, val_dataloaders["sentiment"], "sentiment", device)
    print(f"Fine-Tuned Sentiment - Validation Accuracy: {sentiment_acc_ft:.4f}\n")

    # 2.2) Fine-Tuning cho Paraphrase
    print("===== FINE-TUNING PARAPHRASE =====")
    model.load_state_dict(torch.load("tinybert_finetuned_sentiment.pt"))
    optimizer_paraphrase = AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
    total_steps_paraphrase = len(dataloaders["paraphrase"]) * 5
    scheduler_paraphrase = get_linear_schedule_with_warmup(optimizer_paraphrase, num_warmup_steps=0, num_training_steps=total_steps_paraphrase)
    fine_tune_task(model, dataloaders["paraphrase"], "paraphrase", optimizer_paraphrase, scheduler_paraphrase, device, num_epochs=5)

    paraphrase_acc_ft = evaluate_classification(model, val_dataloaders["paraphrase"], "paraphrase", device)
    print(f"Fine-Tuned Paraphrase - Validation Accuracy: {paraphrase_acc_ft:.4f}\n")

    # 2.3) Fine-Tuning cho STS
    print("===== FINE-TUNING STS =====")
    model.load_state_dict(torch.load("tinybert_finetuned_paraphrase.pt"))
    optimizer_sts = AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
    total_steps_sts = len(dataloaders["sts"]) * 5
    scheduler_sts = get_linear_schedule_with_warmup(optimizer_sts, num_warmup_steps=0, num_training_steps=total_steps_sts)
    fine_tune_task(model, dataloaders["sts"], "sts", optimizer_sts, scheduler_sts, device, num_epochs=5)

    sts_pearson_ft = evaluate_sts(model, val_dataloaders["sts"], device)
    print(f"Fine-Tuned STS - Validation Pearson Correlation: {sts_pearson_ft:.4f}\n")

    # Pi-Tuning
    print("===== PI-TUNING =====")
    model_sentiment_sd = torch.load("tinybert_finetuned_sentiment.pt")
    model_paraphrase_sd = torch.load("tinybert_finetuned_paraphrase.pt")
    model_sts_sd = torch.load("tinybert_finetuned_sts.pt")

    pi_tuned_weights = {}
    for key in model_sentiment_sd:
        pi_tuned_weights[key] = (model_sentiment_sd[key] + model_paraphrase_sd[key] + model_sts_sd[key]) / 3

    torch.save(pi_tuned_weights, "tinybert_pi_tuned.pt")
    print("Mô hình Pi-Tuned đã được lưu thành công!")

    model.load_state_dict(torch.load("tinybert_pi_tuned.pt"))
    sentiment_acc_pi = evaluate_classification(model, val_dataloaders["sentiment"], "sentiment", device)
    paraphrase_acc_pi = evaluate_classification(model, val_dataloaders["paraphrase"], "paraphrase", device)
    sts_pearson_pi = evaluate_sts(model, val_dataloaders["sts"], device)

    print("=== KẾT QUẢ VALIDATION SAU PI-TUNING ===")
    print(f"Pi-Tuned Sentiment ACC: {sentiment_acc_pi:.4f}")
    print(f"Pi-Tuned Paraphrase ACC: {paraphrase_acc_pi:.4f}")
    print(f"Pi-Tuned STS Pearson: {sts_pearson_pi:.4f}")

if __name__ == "__main__":
    main()

Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/409 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

train-00000-of-00001.parquet:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/3.73M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/36.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/363846 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/40430 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/390965 [00:00<?, ? examples/s]

Map:   0%|          | 0/363846 [00:00<?, ? examples/s]

Map:   0%|          | 0/40430 [00:00<?, ? examples/s]

Map:   0%|          | 0/390965 [00:00<?, ? examples/s]

train-00000-of-00001.parquet:   0%|          | 0.00/502k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/151k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/114k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5749 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1500 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1379 [00:00<?, ? examples/s]

Map:   0%|          | 0/5749 [00:00<?, ? examples/s]

Map:   0%|          | 0/1500 [00:00<?, ? examples/s]

Map:   0%|          | 0/1379 [00:00<?, ? examples/s]

pytorch_model.bin:   0%|          | 0.00/62.7M [00:00<?, ?B/s]




===== HUẤN LUYỆN MÔ HÌNH MTL BAN ĐẦU =====


  scaler = GradScaler()  # Cho mixed precision
  with autocast():  # Mixed precision


model.safetensors:   0%|          | 0.00/62.7M [00:00<?, ?B/s]

Epoch 1/5 - Loss: 0.3805
Validation - Sentiment ACC: 0.7294, Paraphrase ACC: 0.8075, STS Pearson: 0.8609, Avg Score: 0.7993
Epoch 2/5 - Loss: 0.3090
Validation - Sentiment ACC: 0.8280, Paraphrase ACC: 0.8426, STS Pearson: 0.8660, Avg Score: 0.8455
Epoch 3/5 - Loss: 0.2772
Validation - Sentiment ACC: 0.8406, Paraphrase ACC: 0.8683, STS Pearson: 0.8690, Avg Score: 0.8593
Epoch 4/5 - Loss: 0.2578
Validation - Sentiment ACC: 0.8704, Paraphrase ACC: 0.8737, STS Pearson: 0.8624, Avg Score: 0.8688
Epoch 5/5 - Loss: 0.2462
Validation - Sentiment ACC: 0.8922, Paraphrase ACC: 0.8739, STS Pearson: 0.8381, Avg Score: 0.8681
=== KẾT QUẢ VALIDATION SAU MTL BAN ĐẦU ===
Sentiment ACC: 0.8704
Paraphrase ACC: 0.8737
STS Pearson: 0.8624
Model đa nhiệm đã được lưu thành công!

===== FINE-TUNING SENTIMENT =====


  scaler = GradScaler()  # Cho mixed precision
  with autocast():


Task sentiment - Epoch 1/5 - Loss: 0.1617
Task sentiment - Epoch 2/5 - Loss: 0.1358
Task sentiment - Epoch 3/5 - Loss: 0.1178
Task sentiment - Epoch 4/5 - Loss: 0.1075
Task sentiment - Epoch 5/5 - Loss: 0.0994
Model fine-tuned cho task sentiment đã được lưu!
Fine-Tuned Sentiment - Validation Accuracy: 0.8922

===== FINE-TUNING PARAPHRASE =====
Task paraphrase - Epoch 1/5 - Loss: 0.2721
Task paraphrase - Epoch 2/5 - Loss: 0.2466
Task paraphrase - Epoch 3/5 - Loss: 0.2275
Task paraphrase - Epoch 4/5 - Loss: 0.2123
Task paraphrase - Epoch 5/5 - Loss: 0.2023
Model fine-tuned cho task paraphrase đã được lưu!
Fine-Tuned Paraphrase - Validation Accuracy: 0.8822

===== FINE-TUNING STS =====
Task sts - Epoch 1/5 - Loss: 0.7078
Task sts - Epoch 2/5 - Loss: 0.5419
Task sts - Epoch 3/5 - Loss: 0.4896
Task sts - Epoch 4/5 - Loss: 0.4510
Task sts - Epoch 5/5 - Loss: 0.4330
Model fine-tuned cho task sts đã được lưu!
Fine-Tuned STS - Validation Pearson Correlation: 0.8690

===== PI-TUNING =====
Mô hìn