In [20]:
import torch
print(torch.cuda.is_available()) # Ph·∫£i tr·∫£ v·ªÅ True

True


In [21]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [22]:
!git clone https://github.com/NguyenPhanNhatLan/medical_re.git

Cloning into 'medical_re'...
remote: Enumerating objects: 305, done.[K
remote: Counting objects: 100% (97/97), done.[K
remote: Compressing objects: 100% (63/63), done.[K
remote: Total 305 (delta 66), reused 44 (delta 34), pack-reused 208 (from 1)[K
Receiving objects: 100% (305/305), 6.78 MiB | 3.98 MiB/s, done.
Resolving deltas: 100% (150/150), done.


In [23]:
%cd medical_re

/content/medical_re/medical_re


In [24]:
import os
import gc
import yaml
import torch
import torch.nn as nn
import numpy as np
import random
import itertools
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup, AutoTokenizer
from torch.optim import AdamW

from utils.collator import RBERTCollator, BERTESCollator
from datasets.rbert_dataset import RBERTDataset
from datasets.bert_es_dataset import BERTESDataset
from encoder.vihealth_encoder import ViHealthBERTEncoder
from models.r_bert import RBERT
from models.bert_es import BERTES

def load_config(path="config.yaml"):
    with open(path, 'r') as f:
        return yaml.safe_load(f)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

def clean_memory():
    """D·ªçn s·∫°ch r√°c trong RAM ƒë·ªÉ tr√°nh crash notebook"""
    gc.collect()
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()
    elif torch.cuda.is_available():
        torch.cuda.empty_cache()
    print("üßπ Memory cleaned!")

In [25]:
%%writefile config.yaml
# --- C·∫•u h√¨nh chung ---
project_name: "medical_re_optimization_notebook"
seed: 42
output_dir: "./outputs_tuning"

# --- Model & Data ---
model_type: "rbert"         # "rbert" ho·∫∑c "bertes"
encoder_type: "vihealth"    # "vihealth"
num_labels: 5

# --- T√†i nguy√™n & Training (C·ªë ƒë·ªãnh) ---
fixed_params:
  max_epochs: 5             # Search nhanh
  patience: 2               # Early stopping s·ªõm
  grad_clip_norm: 1.0
  accumulation_steps: 4     # Quan tr·ªçng cho m√°y RAM y·∫øu (Batch th·ª±c = batch_size * 4)

# --- Kh√¥ng gian t√¨m ki·∫øm (Grid Search) ---
search_space:
  learning_rate: [1.0e-5, 2.0e-5, 3.0e-5]
  batch_size: [8]           # Gi·ªØ nguy√™n 8 ƒë·ªÉ an to√†n RAM
  dropout_rate: [0.1, 0.15]
  warmup_ratio: [0.1]
  max_length: [256]         # Gi·∫£m xu·ªëng 128 n·∫øu v·∫´n b·ªã OOM
  weight_decay: [0.01]

Overwriting config.yaml


In [31]:
def build_components(config, hparams, tokenizer=None):
    model_type = config['model_type']
    encoder_type = config['encoder_type']

    # 1. Ch·ªçn Encoder Class
    if encoder_type == "vihealth":
        EncoderClass = ViHealthBERTEncoder
        model_name = "demdecuong/vihealthbert-base-word"
    else:
        raise NotImplementedError(f"Encoder type {encoder_type} ch∆∞a ƒë∆∞·ª£c h·ªó tr·ª£ trong notebook n√†y.")

    if model_type == "rbert":
        ModelClass = RBERT
        DatasetClass = RBERTDataset
        CollatorClass = RBERTCollator
    elif model_type == "bertes":
        ModelClass = BERTES
        DatasetClass = BERTESDataset
        CollatorClass = BERTESCollator
    else:
        raise ValueError(f"Unknown model type: {model_type}")

    # 3. Load Tokenizer (N·∫øu ch∆∞a c√≥)
    # L∆∞u √Ω: Tokenizer load 1 l·∫ßn d√πng chung ƒë·ªÉ ti·∫øt ki·ªám RAM & th·ªùi gian
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        tokenizer.add_special_tokens({'additional_special_tokens': ["<e1>", "</e1>", "<e2>", "</e2>"]})

    train_dataset = DatasetClass(
        json_path="data/1_train_model.json",
        tokenizer=tokenizer,
        max_length=hparams['max_length']
    )
    val_dataset = DatasetClass(
        json_path="data/2_dev_model.json",
        tokenizer=tokenizer,
        max_length=hparams['max_length']
    )
    collator = CollatorClass(tokenizer)

    encoder = EncoderClass(model_name=model_name)

    if len(encoder.tokenizer) != len(tokenizer):
        encoder.tokenizer = tokenizer
        encoder.model.resize_token_embeddings(len(tokenizer))

    model = ModelClass(
        encoder=encoder,
        hidden_size=encoder.hidden_size,
        num_labels=config['num_labels'],
        dropout_rate=hparams['dropout_rate']
    )

    return model, train_dataset, val_dataset, collator, tokenizer

In [32]:
from torch.cuda.amp import autocast
from sklearn.metrics import precision_recall_fscore_support

def evaluate_refined(model, dataloader, device, amp_enabled=True):
    model.eval()
    preds = []
    labels_list = []
    total_loss = 0
    loss_fct = nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # T·ª± ƒë·ªông d√πng FP16 khi eval
            with autocast(enabled=amp_enabled):
                if isinstance(model, RBERT):
                    outputs = model(
                        input_ids, attention_mask,
                        batch['e1_mask'].to(device),
                        batch['e2_mask'].to(device),
                        labels
                    )
                    logits = outputs['logits']
                    loss = outputs['loss']
                else:
                    logits = model(
                        input_ids, attention_mask,
                        batch['e1_pos'].to(device),
                        batch['e2_pos'].to(device)
                    )
                    loss = loss_fct(logits, labels)

            total_loss += loss.item()
            preds.extend(torch.argmax(logits, dim=1).cpu().numpy())
            labels_list.extend(labels.cpu().numpy())

    # --- S·ª¨A METRIC ---
    # T√≠nh Micro-F1 (Quan tr·ªçng cho RE).
    # N·∫øu class 0 l√† "Other/No Relation", c√≥ th·ªÉ c√¢n nh·∫Øc labels=[1,2,3,4] ƒë·ªÉ lo·∫°i b·ªè n√≥ kh·ªèi metric
    # ·ªû ƒë√¢y t√¥i ƒë·ªÉ average='micro' chung cho an to√†n tr∆∞·ªõc.
    precision, recall, f1, _ = precision_recall_fscore_support(labels_list, preds, average='micro')

    avg_loss = total_loss / len(dataloader)
    return f1, avg_loss

In [28]:
!pip install optuna



In [33]:
from torch.cuda.amp import GradScaler, autocast

def objective_optimized(trial):
    hparams = {
        'learning_rate': trial.suggest_float('learning_rate', 1e-5, 5e-5, log=True),
        'batch_size': trial.suggest_categorical('batch_size', [16, 32]), # TƒÉng BS l√™n v√¨ d√πng FP16
        'dropout_rate': trial.suggest_float('dropout_rate', 0.1, 0.3),
        'warmup_ratio': trial.suggest_float('warmup_ratio', 0.05, 0.2),
        'weight_decay': trial.suggest_float('weight_decay', 0.01, 0.1),
        'max_length': 256
    }

    # Config
    accum_steps = cfg['fixed_params']['accumulation_steps']
    max_epochs = 3

    # 2. Build Components (Load Tokenizer ngo√†i v√≤ng l·∫∑p n·∫øu ƒë∆∞·ª£c ƒë·ªÉ nhanh h∆°n)
    # L∆∞u √Ω: D√πng ƒë∆∞·ªùng d·∫´n train_path v√† dev_path m·ªõi t·∫°o ·ªü B∆∞·ªõc 1
    # S·ª≠a l·∫°i h√†m build_components ƒë·ªÉ nh·∫≠n path file data ƒë·ªông thay v√¨ hardcode
    model, train_ds, val_ds, collator, tokenizer = build_components(cfg, hparams)
    model.to(device)

    # T·ªëi ∆∞u DataLoader
    train_loader = DataLoader(train_ds, batch_size=hparams['batch_size'], shuffle=True,
                              collate_fn=collator, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=hparams['batch_size'], shuffle=False,
                            collate_fn=collator, num_workers=2, pin_memory=True)

    optimizer = AdamW(model.parameters(), lr=hparams['learning_rate'], weight_decay=hparams['weight_decay'])
    scaler = GradScaler() # INIT SCALER CHO FP16

    total_steps = len(train_loader) * max_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(total_steps * hparams['warmup_ratio']), num_training_steps=total_steps)

    best_f1_trial = 0.0

    for epoch in range(max_epochs):
        model.train()
        for step, batch in enumerate(train_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # --- TRAINING V·ªöI MIXED PRECISION ---
            with autocast():
                if isinstance(model, RBERT):
                    outputs = model(input_ids, attention_mask, batch['e1_mask'].to(device), batch['e2_mask'].to(device), labels)
                    loss = outputs['loss']
                else:
                    logits = model(input_ids, attention_mask, batch['e1_pos'].to(device), batch['e2_pos'].to(device))
                    loss = nn.CrossEntropyLoss()(logits, labels)

                loss = loss / accum_steps

            # Scale loss v√† backward
            scaler.scale(loss).backward()

            if (step + 1) % accum_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()

        # Evaluate tr√™n t·∫≠p DEV (kh√¥ng ph·∫£i Test)
        val_f1, val_loss = evaluate_refined(model, val_loader, device, amp_enabled=True)

        trial.report(val_f1, epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()

        best_f1_trial = max(best_f1_trial, val_f1)

    del model, optimizer, scaler
    clean_memory()
    return best_f1_trial

In [None]:
# Setup
clean_memory()
cfg = load_config()
device = get_device()

# T·∫°o Study
# direction="maximize" v√¨ ta mu·ªën F1 c√†ng cao c√†ng t·ªët
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())

print("üöÄ STARTING OPTUNA STUDY...")
# n_trials=10: Ch·∫°y th·ª≠ 10 c·∫•u h√¨nh kh√°c nhau
study.optimize(objective, n_trials=10)

print("\n" + "="*40)
print("‚úÖ OPTUNA FINISHED.")
print(f"Best value (F1): {study.best_value}")
print(f"Best params: {study.best_params}")
print("="*40)

# L∆∞u best params ƒë·ªÉ d√πng cho retrain
best_hparams = study.best_params
# Th√™m c√°c params c·ªë ƒë·ªãnh v√†o best_hparams ƒë·ªÉ code retrain ph√≠a d∆∞·ªõi ch·∫°y ƒë∆∞·ª£c
best_hparams['max_length'] = 256
best_hparams['batch_size'] = 8

[I 2026-01-22 02:29:39,947] A new study created in memory with name: no-name-b18ba3c8-78df-45b7-ae71-be696e246976


üßπ Memory cleaned!
üöÄ STARTING OPTUNA STUDY...


In [None]:
if best_hparams is None:
    print("Kh√¥ng t√¨m th·∫•y c·∫•u h√¨nh n√†o ch·∫°y th√†nh c√¥ng. Vui l√≤ng ki·ªÉm tra l·∫°i.")
else:
    print(f"üöÄ Retraining FINAL MODEL with best params: {best_hparams}")
    clean_memory()

    # TƒÉng max_epochs cho l·∫ßn train cu·ªëi (train k·ªπ h∆°n l√∫c search)
    cfg['fixed_params']['max_epochs'] = 10

    # 1. Build l·∫°i model t·ªët nh·∫•t
    model, train_ds, val_ds, collator, tokenizer = build_components(cfg, best_hparams)
    model.to(device)

    # 2. Setup training components
    train_loader = DataLoader(train_ds, batch_size=best_hparams['batch_size'], shuffle=True, collate_fn=collator, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=best_hparams['batch_size'], shuffle=False, collate_fn=collator, num_workers=0)

    optimizer = AdamW(model.parameters(), lr=best_hparams['learning_rate'], weight_decay=best_hparams['weight_decay'])

    total_steps = len(train_loader) * cfg['fixed_params']['max_epochs']
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(total_steps * best_hparams['warmup_ratio']),
        num_training_steps=total_steps
    )

    accum_steps = cfg['fixed_params']['accumulation_steps']
    best_final_f1 = 0.0

    # 3. Final Training Loop
    for epoch in range(cfg['fixed_params']['max_epochs']):
        model.train()
        train_loss = 0
        optimizer.zero_grad()

        progress_bar = tqdm(train_loader, desc=f"Final Epoch {epoch+1}", leave=False)

        for step, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            if isinstance(model, RBERT):
                outputs = model(input_ids, attention_mask, batch['e1_mask'].to(device), batch['e2_mask'].to(device), labels)
                loss = outputs['loss']
            else:
                logits = model(input_ids, attention_mask, batch['e1_pos'].to(device), batch['e2_pos'].to(device))
                loss = nn.CrossEntropyLoss()(logits, labels)

            loss = loss / accum_steps
            loss.backward()

            if (step + 1) % accum_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg['fixed_params']['grad_clip_norm'])
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                train_loss += loss.item() * accum_steps
                progress_bar.set_postfix({'loss': train_loss / (step + 1)})

        # Evaluate & Save Best Checkpoint
        val_f1, val_loss = evaluate(model, val_loader, device)
        print(f"   Epoch {epoch+1}: F1={val_f1:.4f} | Loss={val_loss:.4f}")

        if val_f1 > best_final_f1:
            best_final_f1 = val_f1
            print(f"   üíæ Saving new best model to {cfg['output_dir']}...")

            # Save tokenizer v√† encoder config
            model.encoder.save_pretrained(cfg['output_dir'])

            # Save to√†n b·ªô weights c·ªßa model (bao g·ªìm c·∫£ classifier head)
            torch.save(model.state_dict(), os.path.join(cfg['output_dir'], "best_model.pth"))

            # Save config t·ªët nh·∫•t ƒë·ªÉ d√πng l·∫°i
            import json
            with open(os.path.join(cfg['output_dir'], "best_config.json"), 'w') as f:
                json.dump(best_hparams, f, indent=4)

    print(f"\nDONE. Final Best F1: {best_final_f1:.4f}")