In [1]:
import os
import gc
import yaml
import torch
import torch.nn as nn
import numpy as np
import random
import itertools
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup, AutoTokenizer
from torch.optim import AdamW

from utils.collator import RBERTCollator, BERTESCollator
from datasets.rbert_dataset import RBERTDataset
from datasets.bert_es_dataset import BERTESDataset
from encoder.vihealth_encoder import ViHealthBERTEncoder
from models.r_bert import RBERT
from models.bert_es import BERTES

def load_config(path="config.yaml"):
    with open(path, 'r') as f:
        return yaml.safe_load(f)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

def clean_memory():
    """D·ªçn s·∫°ch r√°c trong RAM ƒë·ªÉ tr√°nh crash notebook"""
    gc.collect()
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()
    elif torch.cuda.is_available():
        torch.cuda.empty_cache()
    print("üßπ Memory cleaned!")

In [17]:
%%writefile config.yaml
# --- C·∫•u h√¨nh chung ---
project_name: "medical_re_optimization_notebook"
seed: 42
output_dir: "./outputs_tuning"

# --- Model & Data ---
model_type: "rbert"         # "rbert" ho·∫∑c "bertes"
encoder_type: "vihealth"    # "vihealth"
num_labels: 5

# --- T√†i nguy√™n & Training (C·ªë ƒë·ªãnh) ---
fixed_params:
  max_epochs: 5             # Search nhanh
  patience: 2               # Early stopping s·ªõm
  grad_clip_norm: 1.0
  accumulation_steps: 4     # Quan tr·ªçng cho m√°y RAM y·∫øu (Batch th·ª±c = batch_size * 4)

# --- Kh√¥ng gian t√¨m ki·∫øm (Grid Search) ---
search_space:
  learning_rate: [1.0e-5, 2.0e-5, 3.0e-5]
  batch_size: [8]           # Gi·ªØ nguy√™n 8 ƒë·ªÉ an to√†n RAM
  dropout_rate: [0.1, 0.2]
  warmup_ratio: [0.1]
  max_length: [256]         # Gi·∫£m xu·ªëng 128 n·∫øu v·∫´n b·ªã OOM
  weight_decay: [0.01]

Writing config.yaml


In [None]:
def build_components(config, hparams, tokenizer=None):
    model_type = config['model_type']
    encoder_type = config['encoder_type']
    
    # 1. Ch·ªçn Encoder Class
    if encoder_type == "vihealth":
        EncoderClass = ViHealthBERTEncoder
        model_name = "demdecuong/vihealthbert-base-word"
    else:
        raise NotImplementedError(f"Encoder type {encoder_type} ch∆∞a ƒë∆∞·ª£c h·ªó tr·ª£ trong notebook n√†y.")

    if model_type == "rbert":
        ModelClass = RBERT
        DatasetClass = RBERTDataset
        CollatorClass = RBERTCollator
    elif model_type == "bertes":
        ModelClass = BERTES
        DatasetClass = BERTESDataset
        CollatorClass = BERTESCollator
    else:
        raise ValueError(f"Unknown model type: {model_type}")

    # 3. Load Tokenizer (N·∫øu ch∆∞a c√≥)
    # L∆∞u √Ω: Tokenizer load 1 l·∫ßn d√πng chung ƒë·ªÉ ti·∫øt ki·ªám RAM & th·ªùi gian
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        tokenizer.add_special_tokens({'additional_special_tokens': ["<e1>", "</e1>", "<e2>", "</e2>"]})

    train_dataset = DatasetClass(
        json_path="data/train.json", 
        tokenizer=tokenizer,
        max_length=hparams['max_length']
    )
    val_dataset = DatasetClass(
        json_path="data/dev.json",   
        tokenizer=tokenizer,
        max_length=hparams['max_length']
    )
    collator = CollatorClass(tokenizer)

    encoder = EncoderClass(model_name=model_name)
    
    if len(encoder.tokenizer) != len(tokenizer):
        encoder.tokenizer = tokenizer 
        encoder.model.resize_token_embeddings(len(tokenizer))

    model = ModelClass(
        encoder=encoder,
        hidden_size=encoder.hidden_size,
        num_labels=config['num_labels'],
        dropout_rate=hparams['dropout_rate']
    )

    return model, train_dataset, val_dataset, collator, tokenizer

In [None]:
def evaluate(model, dataloader, device):
    model.eval()
    preds = []
    labels_list = []
    total_loss = 0
    loss_fct = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            if isinstance(model, RBERT):
                # RBERT nh·∫≠n masks cho entity
                outputs = model(
                    input_ids, attention_mask, 
                    batch['e1_mask'].to(device), 
                    batch['e2_mask'].to(device), 
                    labels
                )
                logits = outputs['logits']
                loss = outputs['loss']
            else:
                # BERT-ES nh·∫≠n positions
                logits = model(
                    input_ids, attention_mask, 
                    batch['e1_pos'].to(device), 
                    batch['e2_pos'].to(device)
                )
                loss = loss_fct(logits, labels)
            
            total_loss += loss.item()
            preds.extend(torch.argmax(logits, dim=1).cpu().numpy())
            labels_list.extend(labels.cpu().numpy())

    macro_f1 = f1_score(labels_list, preds, average='macro')
    avg_loss = total_loss / len(dataloader)
    return macro_f1, avg_loss

def run_session(config, hparams, device, run_id):
    print(f"\n>>> [Run {run_id}] Config: {hparams}")
    
    # 1. Build Components
    model, train_ds, val_ds, collator, tokenizer = build_components(config, hparams)
    model.to(device)
    
    # 2. Dataloader (num_workers=0 ƒë·ªÉ an to√†n nh·∫•t tr√™n Notebook/MPS)
    train_loader = DataLoader(train_ds, batch_size=hparams['batch_size'], shuffle=True, collate_fn=collator, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=hparams['batch_size'], shuffle=False, collate_fn=collator, num_workers=0)
    
    # 3. Optimizer & Scheduler
    optimizer = AdamW(model.parameters(), lr=hparams['learning_rate'], weight_decay=hparams['weight_decay'])

    total_steps = len(train_loader) * config['fixed_params']['max_epochs']
    
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=int(total_steps * hparams['warmup_ratio']),
        num_training_steps=total_steps
    )
    
    # 4. Training Loop
    best_f1 = 0.0
    accum_steps = config['fixed_params']['accumulation_steps']
    max_epochs = config['fixed_params']['max_epochs']
    patience = config['fixed_params']['patience']
    patience_counter = 0
    
    for epoch in range(max_epochs):
        model.train()
        train_loss = 0
        optimizer.zero_grad()
        
        # D√πng tqdm nh∆∞ng t·∫Øt b·ªõt log chi ti·∫øt ƒë·ªÉ ƒë·ª° r·ªëi output grid search
        for step, batch in enumerate(train_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # Forward pass
            if isinstance(model, RBERT):
                outputs = model(input_ids, attention_mask, batch['e1_mask'].to(device), batch['e2_mask'].to(device), labels)
                loss = outputs['loss']
            else:
                logits = model(input_ids, attention_mask, batch['e1_pos'].to(device), batch['e2_pos'].to(device))
                loss = nn.CrossEntropyLoss()(logits, labels)

            # Gradient Accumulation Logic
            loss = loss / accum_steps
            loss.backward()
            
            if (step + 1) % accum_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['fixed_params']['grad_clip_norm'])
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                train_loss += loss.item() * accum_steps
        
        # Evaluate cu·ªëi epoch
        val_f1, val_loss = evaluate(model, val_loader, device)
        print(f"   Epoch {epoch+1}: F1={val_f1:.4f} | Loss={val_loss:.4f}")
        
        # Early Stopping Check (ƒë∆°n gi·∫£n)
        if val_f1 > best_f1:
            best_f1 = val_f1
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"   Early stopping t·∫°i epoch {epoch+1}")
                break

    # --- QUAN TR·ªåNG: DELETE BI·∫æN ƒê·ªÇ GI·∫¢I PH√ìNG RAM CHO NOTEBOOK ---
    # X√≥a tham chi·∫øu t·ªõi model v√† optimizer ƒë·ªÉ Garbage Collector ho·∫°t ƒë·ªông
    del model, optimizer, scheduler, train_loader, val_loader, train_ds, val_ds
    clean_memory()
    
    return best_f1

In [None]:
# 1. Setup ban ƒë·∫ßu
clean_memory() # D·ªçn d·∫πp tr∆∞·ªõc khi b·∫Øt ƒë·∫ßu
cfg = load_config()
set_seed(cfg['seed'])
device = get_device()
print(f"Using Device: {device}")

if not os.path.exists(cfg['output_dir']):
    os.makedirs(cfg['output_dir'])

# 2. T·∫°o kh√¥ng gian t√¨m ki·∫øm (Grid Search)
keys, values = zip(*cfg['search_space'].items())
search_space = [dict(zip(keys, v)) for v in itertools.product(*values)]
print(f"Total configurations to test: {len(search_space)}")

# 3. Ch·∫°y v√≤ng l·∫∑p Search
results = []
best_search_score = -1
best_hparams = None

# D√πng tqdm ƒë·ªÉ hi·ªÉn th·ªã ti·∫øn ƒë·ªô t·ªïng
for i, hparams in enumerate(tqdm(search_space, desc="Grid Search Progress")):
    try:
        # G·ªçi session train cho 1 c·∫•u h√¨nh
        score = run_session(cfg, hparams, device, run_id=i+1)
        
        # L∆∞u k·∫øt qu·∫£ t·ªët nh·∫•t
        if score > best_search_score:
            best_search_score = score
            best_hparams = hparams
            print(f"NEW BEST FOUND: {score:.4f} with {hparams}")
            
        results.append({
            "run_id": i+1,
            "hparams": hparams,
            "best_f1": score
        })
        
    except RuntimeError as e:
        if "out of memory" in str(e):
            print(f"OOM Error at Run {i+1}. Skipping config...")
            clean_memory() #
        else:
            raise e

print("\n" + "="*40)
print(f"‚úÖ SEARCH FINISHED.")
print(f"Best Macro F1: {best_search_score:.4f}")
print(f"Best Hparams: {best_hparams}")
print("="*40)

üßπ Memory cleaned!
Using Device: mps
Total configurations to test: 6


Grid Search Progress:   0%|          | 0/6 [00:00<?, ?it/s]


>>> [Run 1] Config: {'learning_rate': 1e-05, 'batch_size': 8, 'dropout_rate': 0.1, 'warmup_ratio': 0.1, 'max_length': 256, 'weight_decay': 0.01}
‚ùå OOM Error at Run 1. Skipping config...
üßπ Memory cleaned!

>>> [Run 2] Config: {'learning_rate': 1e-05, 'batch_size': 8, 'dropout_rate': 0.2, 'warmup_ratio': 0.1, 'max_length': 256, 'weight_decay': 0.01}
‚ùå OOM Error at Run 2. Skipping config...
üßπ Memory cleaned!

>>> [Run 3] Config: {'learning_rate': 2e-05, 'batch_size': 8, 'dropout_rate': 0.1, 'warmup_ratio': 0.1, 'max_length': 256, 'weight_decay': 0.01}
‚ùå OOM Error at Run 3. Skipping config...
üßπ Memory cleaned!

>>> [Run 4] Config: {'learning_rate': 2e-05, 'batch_size': 8, 'dropout_rate': 0.2, 'warmup_ratio': 0.1, 'max_length': 256, 'weight_decay': 0.01}
‚ùå OOM Error at Run 4. Skipping config...
üßπ Memory cleaned!

>>> [Run 5] Config: {'learning_rate': 3e-05, 'batch_size': 8, 'dropout_rate': 0.1, 'warmup_ratio': 0.1, 'max_length': 256, 'weight_decay': 0.01}
‚ùå OOM Error 

In [None]:
if best_hparams is None:
    print("Kh√¥ng t√¨m th·∫•y c·∫•u h√¨nh n√†o ch·∫°y th√†nh c√¥ng. Vui l√≤ng ki·ªÉm tra l·∫°i.")
else:
    print(f"üöÄ Retraining FINAL MODEL with best params: {best_hparams}")
    clean_memory()
    
    # TƒÉng max_epochs cho l·∫ßn train cu·ªëi (train k·ªπ h∆°n l√∫c search)
    cfg['fixed_params']['max_epochs'] = 10 
    
    # 1. Build l·∫°i model t·ªët nh·∫•t
    model, train_ds, val_ds, collator, tokenizer = build_components(cfg, best_hparams)
    model.to(device)
    
    # 2. Setup training components
    train_loader = DataLoader(train_ds, batch_size=best_hparams['batch_size'], shuffle=True, collate_fn=collator, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=best_hparams['batch_size'], shuffle=False, collate_fn=collator, num_workers=0)
    
    optimizer = AdamW(model.parameters(), lr=best_hparams['learning_rate'], weight_decay=best_hparams['weight_decay'])
    
    total_steps = len(train_loader) * cfg['fixed_params']['max_epochs']
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=int(total_steps * best_hparams['warmup_ratio']),
        num_training_steps=total_steps
    )
    
    accum_steps = cfg['fixed_params']['accumulation_steps']
    best_final_f1 = 0.0
    
    # 3. Final Training Loop
    for epoch in range(cfg['fixed_params']['max_epochs']):
        model.train()
        train_loss = 0
        optimizer.zero_grad()
        
        progress_bar = tqdm(train_loader, desc=f"Final Epoch {epoch+1}", leave=False)
        
        for step, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            if isinstance(model, RBERT):
                outputs = model(input_ids, attention_mask, batch['e1_mask'].to(device), batch['e2_mask'].to(device), labels)
                loss = outputs['loss']
            else:
                logits = model(input_ids, attention_mask, batch['e1_pos'].to(device), batch['e2_pos'].to(device))
                loss = nn.CrossEntropyLoss()(logits, labels)
            
            loss = loss / accum_steps
            loss.backward()
            
            if (step + 1) % accum_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg['fixed_params']['grad_clip_norm'])
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                train_loss += loss.item() * accum_steps
                progress_bar.set_postfix({'loss': train_loss / (step + 1)})
        
        # Evaluate & Save Best Checkpoint
        val_f1, val_loss = evaluate(model, val_loader, device)
        print(f"   Epoch {epoch+1}: F1={val_f1:.4f} | Loss={val_loss:.4f}")
        
        if val_f1 > best_final_f1:
            best_final_f1 = val_f1
            print(f"   üíæ Saving new best model to {cfg['output_dir']}...")
            
            # Save tokenizer v√† encoder config
            model.encoder.save_pretrained(cfg['output_dir'])
            
            # Save to√†n b·ªô weights c·ªßa model (bao g·ªìm c·∫£ classifier head)
            torch.save(model.state_dict(), os.path.join(cfg['output_dir'], "best_model.pth"))
            
            # Save config t·ªët nh·∫•t ƒë·ªÉ d√πng l·∫°i
            import json
            with open(os.path.join(cfg['output_dir'], "best_config.json"), 'w') as f:
                json.dump(best_hparams, f, indent=4)
                
    print(f"\n‚ú® DONE. Final Best F1: {best_final_f1:.4f}")