In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import math


In [2]:
#constant
TEST_EMBEDDING_FILE = "test_embeddings.pt"
TRAIN_EMBEDDING_FILE = "train_embeddings.pt"
VAL_EMBEDDING_FILE = "val_embeddings.pt"

GNN_TEST_EMBEDDING = f"../GATNN/embeddings/{TEST_EMBEDDING_FILE}"
GNN_TRAIN_EMBEDDING = f"../GATNN/embeddings/{TRAIN_EMBEDDING_FILE}"
GNN_VAL_EMBEDDING = f"../GATNN/embeddings/{VAL_EMBEDDING_FILE}"
BERT_TEST_EMBEDDING = f"../BERT/embedding/{TEST_EMBEDDING_FILE}"
BERT_TRAIN_EMBEDDING = f"../BERT/embedding/{TRAIN_EMBEDDING_FILE}"
BERT_VAL_EMBEDDING = f"../BERT/embedding/{VAL_EMBEDDING_FILE}"



In [3]:
GNN_FILE = "gat_embeddings.pt"
BERT_FILE = "all_embeddings.pt"
GNN_EMBEDDING = f"../GATNN/embeddings/{GNN_FILE}"
BERT_EMBEDDING = f"../BERT/embedding/{BERT_FILE}"

In [4]:
def load_embedding_file(path):
    """
    Hàm load file embedding từ đường dẫn `path`.
    Tự động nhận dạng nhiều kiểu dữ liệu khác nhau:
    - dict có key 'embeddings' hoặc 'emb', 'vectors', 'features'
    - trực tiếp là Tensor
    - list/tuple (chuyển sang Tensor)
    """
    # ---- Bỏ chặn numpy pickle cổ ----
    try:
        torch.serialization.add_safe_globals([np.core.multiarray._reconstruct])
    except Exception:
        pass  # không sao

    # Load thử với weights_only=False ----
    try:
        data = torch.load(path, weights_only=False)
    except Exception as e:
        raise RuntimeError(f"LỖI load file embedding: {e}")
    
    # data = torch.load(path, map_location='cpu')

    #Chuẩn hóa về tensor ----
    # Nếu là tensor trực tiếp
    if isinstance(data, torch.Tensor):
        return {'embeddings': data}
    

    # Nếu là numpy
    if isinstance(data, np.ndarray):
        return {'embeddings': torch.tensor(data)}
    

    # Trường hợp 1: File lưu dạng dict
    if isinstance(data, dict):

        # TH1.1: Dict có key phổ biến 'embeddings'
        if 'embeddings' in data:
            return {
                'embeddings': torch.as_tensor(data['embeddings']),
                'ids': data.get('mol_ids', None)  or data.get('mol_id',None) # có thể không có
            }

        # TH1.2: Một số file embedding dùng key khác
        for key in ['emb', 'vectors', 'features']:
            if key in data:
                return {
                    'embeddings': torch.as_tensor(data[key]),
                    'ids': data.get('mol_ids', None) or data.get('mol_id',None)
                }

        # TH1.3: Dict không rõ cấu trúc → thử convert cả dict sang tensor
        try:
            return {'embeddings': torch.as_tensor(data)}
        except Exception:
            raise ValueError(
                f"Lỗi: Dict trong file {path} có cấu trúc không hỗ trợ để chuyển sang Tensor"
            )

    

    # Trường hợp 3: File là list hoặc tuple → convert sang Tensor
    elif isinstance(data, (list, tuple)):
        try:
            return {'embeddings': torch.as_tensor(data)}
        except Exception:
            raise ValueError(
                f"Lỗi: Không thể chuyển list/tuple trong file {path} sang Tensor"
            )

    # Trường hợp không thuộc loại nào
    else:
        raise ValueError(
            f"Lỗi: Định dạng dữ liệu trong file {path} không được hỗ trợ"
            f"Loại: {type(data)}"
        )


In [5]:
import os
import torch




def check_embedding_file(path):
    """Kiểm tra 1 file embedding và trả về True/False + in thông tin chi tiết."""

    print(f"KIỂM TRA FILE: {path} ======")

    if not os.path.exists(path):
        print("Lỗi: File không tồn tại.")
        return False

    try:
        result = load_embedding_file(path)
        emb = result["embeddings"]

        # --- Kiểm tra embedding có phải Tensor ---
        if not isinstance(emb, torch.Tensor):
            print("Lỗi: embeddings không phải Torch.Tensor.")
            return False

        # --- Kiểm tra số chiều ---
        if emb.ndim < 2:
            print(f"Lỗi: embeddings phải >= 2 chiều, hiện tại: {emb.ndim}")
            return False

        # --- In thông tin hợp lệ ---
        print("✅File hợp lệ!")
        print(f"Kích thước: {tuple(emb.shape)}___Kiểu dữ liệu: {emb.dtype}___Có IDs không? { 'Có' if result.get('ids') is not None else 'Không' }")
       
        return True

    except Exception as e:
        print(f"Lỗi khi load file: {e}")
        return False


def test_all_embeddings():
    """Test toàn bộ file BERT + GNN"""

    bert_files = [
        BERT_TEST_EMBEDDING,
        BERT_TRAIN_EMBEDDING,
        BERT_VAL_EMBEDDING,
        BERT_EMBEDDING
    ]

    gnn_files = [
        GNN_TEST_EMBEDDING,
        GNN_TRAIN_EMBEDDING,
        GNN_VAL_EMBEDDING,
        GNN_EMBEDDING
    ]

   
    print("KIỂM TRA BERT BRANCH==================================")
  

    for f in bert_files:
        check_embedding_file(f)

    
    print("KIỂM TRA GNN BRANCH====================================")

    for f in gnn_files:
        check_embedding_file(f)


if __name__ == "__main__":
    test_all_embeddings()



✅File hợp lệ!
Kích thước: (802, 768)___Kiểu dữ liệu: torch.float32___Có IDs không? Có


  torch.serialization.add_safe_globals([np.core.multiarray._reconstruct])


✅File hợp lệ!
Kích thước: (6411, 768)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (801, 768)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (8014, 768)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (801, 512)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (6404, 512)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (801, 512)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (8006, 512)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (8014, 768)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (801, 512)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (6404, 512)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (801, 512)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (8006, 512)___Kiểu dữ 

In [6]:


# Load BERT embeddings
bert_train = torch.load(BERT_TRAIN_EMBEDDING, weights_only=False)
bert_val = torch.load(BERT_VAL_EMBEDDING, weights_only=False)
bert_test = torch.load(BERT_TEST_EMBEDDING, weights_only=False)

# Load GAT embeddings
gat_train = torch.load(GNN_TRAIN_EMBEDDING, weights_only=False)
gat_val = torch.load(GNN_VAL_EMBEDDING, weights_only=False)
gat_test = torch.load(GNN_TEST_EMBEDDING, weights_only=False)

In [7]:
bert_all = torch.load(BERT_EMBEDDING, weights_only=False)
gat_all = torch.load(GNN_EMBEDDING, weights_only=False)

In [8]:
def check_alignment(bert_data, gat_data, split_name):
    """Kiểm tra xem mol_ids có khớp nhau không"""
    bert_ids = bert_data['mol_id']  # hoặc 'mol_ids'
    gat_ids = gat_data['mol_ids']
    
    # So sánh
    if len(bert_ids) != len(gat_ids):
        print(f" {split_name}: Số lượng không khớp! BERT={len(bert_ids)}, GAT={len(gat_ids)}")
    
    # Kiểm tra thứ tự
    if isinstance(bert_ids[0], str) and isinstance(gat_ids[0], str):
        match = all(t == g for t, g in zip(bert_ids, gat_ids))
    else:
        match = torch.equal(bert_ids, gat_ids)
    
    if match:
        print(f" {split_name}: Mol IDs khớp hoàn toàn!")
    else:
        print(f"{split_name}: Mol IDs KHÔNG khớp, cần align lại!")
    
    return match

# Check tất cả splits
train_match = check_alignment(bert_train, gat_train, "Train")
val_match = check_alignment(bert_val, gat_val, "Val")
test_match = check_alignment(bert_test, gat_test, "Test")
all_match = check_alignment(bert_all, gat_all, "All Data")

 Train: Số lượng không khớp! BERT=6411, GAT=6404
Train: Mol IDs KHÔNG khớp, cần align lại!
Val: Mol IDs KHÔNG khớp, cần align lại!
 Test: Số lượng không khớp! BERT=802, GAT=801
Test: Mol IDs KHÔNG khớp, cần align lại!
 All Data: Số lượng không khớp! BERT=8014, GAT=8006
All Data: Mol IDs KHÔNG khớp, cần align lại!


In [9]:
from torch.utils.data import Dataset, DataLoader

class FusionDataset(Dataset):
    def __init__(self, bert_data, gat_data):
        # Embeddings
        self.text_embs = bert_data['embeddings']
        self.graph_embs = gat_data['embeddings']
        
        # Labels (lấy từ 1 trong 2, giả sử giống nhau)
        self.labels = bert_data['labels']
        
        # Optional: mol_ids để track
        self.mol_ids = bert_data['mol_id']
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return {
            'text_emb': self.text_embs[idx],
            'graph_emb': self.graph_embs[idx],
            'label': self.labels[idx],
            'mol_id': self.mol_ids[idx]
        }



In [10]:
def align_datasets(bert_data, gat_data):
    """Align BERT và GAT theo mol_id (không kiểm tra labels)."""

    bert_ids = bert_data['mol_id']
    gat_ids  = gat_data['mol_ids']

    # Tạo dictionary để tìm index nhanh
    bert_dict = {str(m): i for i, m in enumerate(bert_ids)}
    gat_dict  = {str(m): i for i, m in enumerate(gat_ids)}

    # Lấy phần giao của mol_id
    common_ids = sorted(set(bert_dict.keys()) & set(gat_dict.keys()))

    print(f"BERT có {len(bert_ids)} mẫu")
    print(f"GAT  có {len(gat_ids)} mẫu")
    print(f"Trùng mol_id: {len(common_ids)} mẫu")

    bert_indices = [bert_dict[m] for m in common_ids]
    gat_indices  = [gat_dict[m] for m in common_ids]

    # Embeddings
    bert_embs = bert_data["embeddings"][bert_indices]
    gat_embs  = gat_data["embeddings"][gat_indices]

    # Labels
    bert_labels = bert_data["labels"][bert_indices]
    gat_labels  = gat_data["labels"][gat_indices]

    aligned_bert = {
        "embeddings": bert_embs,
        "labels": bert_labels,
        "mol_id": common_ids
    }

    aligned_gat = {
        "embeddings": gat_embs,
        "labels": gat_labels,
        "mol_ids": common_ids
    }

    return aligned_bert, aligned_gat


# Align từng split

all_bert_aligned, all_gat_aligned = align_datasets(bert_all, gat_all)
# Tạo datasets


# print(f"\n Datasets created:")
# print(f"Train: {len(train_dataset)} samples")
# print(f"Val: {len(val_dataset)} samples")
# print(f"Test: {len(test_dataset)} samples")

BERT có 8014 mẫu
GAT  có 8006 mẫu
Trùng mol_id: 8006 mẫu


In [11]:
import numpy as np
from sklearn.model_selection import train_test_split

N = len(all_bert_aligned["mol_id"])
print("Total aligned samples:", N)
indices = np.arange(N)

print(type(all_bert_aligned["labels"]))
train_idx, test_idx = train_test_split(indices, test_size=0.2, random_state=42)
val_idx, test_idx = train_test_split(test_idx, test_size=0.5, random_state=42)


Total aligned samples: 8006
<class 'torch.Tensor'>


In [12]:
def split_data(aligned_data, idx, mol_id):
    return {
        "embeddings": aligned_data["embeddings"][idx],
        "labels": aligned_data["labels"][idx],
        "mol_id": [aligned_data[mol_id][i] for i in idx]
    }
train_bert_aligned = split_data(all_bert_aligned, train_idx, "mol_id")
val_bert_aligned = split_data(all_bert_aligned, val_idx, "mol_id")
test_bert_aligned = split_data(all_bert_aligned, test_idx, "mol_id")

train_dataset = FusionDataset(train_bert_aligned, split_data(all_gat_aligned, train_idx, "mol_ids"))
val_dataset = FusionDataset(val_bert_aligned, split_data(all_gat_aligned, val_idx, "mol_ids"))
test_dataset = FusionDataset(test_bert_aligned, split_data(all_gat_aligned, test_idx, "mol_ids"))
print(f"\n Datasets created:")
print(f"Train: {len(train_dataset)} samples")
print(f"Val: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")


 Datasets created:
Train: 6404 samples
Val: 801 samples
Test: 801 samples


In [13]:
class AttentionFusionMultiLabel(nn.Module):
    """Balanced fusion: force GAT to ~50% contribution"""
    def __init__(self, text_dim=768, graph_dim=512, num_labels=12,
                 gat_target_ratio=0.5, strategy='balanced'):
        """
        strategy options:
        - 'balanced': Equal capacity for both, independent classifiers
        - 'gat_priority': GAT gets stronger projections & separate head
        - 'force_ratio': Strict 50-50 split
        - 'competitive': Both models compete equally
        """
        super().__init__()
        
        self.gat_target_ratio = gat_target_ratio
        self.strategy = strategy
        
        hidden_dim = 256
        
        # ============ Strategy 1: Separate pathways ============
        if strategy in ['balanced', 'gat_priority', 'competitive']:
            # BERT pathway
            self.text_proj = nn.Linear(text_dim, hidden_dim)
            self.text_classifier = nn.Sequential(
                nn.LayerNorm(hidden_dim),
                nn.Linear(hidden_dim, 128),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(128, num_labels)
            )
            
            # GAT pathway (matched capacity)
            self.graph_proj = nn.Linear(graph_dim, hidden_dim)
            self.graph_classifier = nn.Sequential(
                nn.LayerNorm(hidden_dim),
                nn.Linear(hidden_dim, 128),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(128, num_labels)
            )
            
            # Mixture of Experts: learned weighting
            self.router = nn.Sequential(
                nn.Linear(hidden_dim * 2, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, 2)
            )
            
            # NEW: Temperature for sharpening weights
            self.temperature = nn.Parameter(torch.tensor(1.0))
            
            # NEW: Bias to prefer GAT during init
            self.preference_bias = nn.Parameter(torch.tensor([0.0, 0.5]))

        # ============ Strategy 2: Force ratio ============
        elif strategy == 'force_ratio':
            self.text_proj = nn.Linear(text_dim, hidden_dim)
            self.graph_proj = nn.Linear(graph_dim, hidden_dim)
            self.classifier = nn.Sequential(
                nn.LayerNorm(hidden_dim),
                nn.Linear(hidden_dim, 128),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(128, num_labels)
            )

    def forward(self, text_emb, graph_emb):
        B = text_emb.size(0)
        
        # ============ Separate pathways strategy ============
        if self.strategy in ['balanced', 'gat_priority', 'competitive']:
            t = self.text_proj(text_emb)   # [B, H]
            g = self.graph_proj(graph_emb) # [B, H]
            
            # Get logits from both pathways
            t_logits = self.text_classifier(t)  # [B, num_labels]
            g_logits = self.graph_classifier(g) # [B, num_labels]
            
            # ---- Router: Learn mixture weights ----
            router_input = torch.cat([t, g], dim=1)  # [B, H*2]
            router_logits = self.router(router_input)  # [B, 2]
            
            # Add preference bias (favor GAT)
            router_logits = router_logits + self.preference_bias
            
            # Temperature-scaled softmax for sharper decisions
            temp = torch.sigmoid(self.temperature)  # [0.5, 1.0]
            weights = torch.softmax(router_logits / (temp + 0.1), dim=1)  # [B, 2]
            
            # Combine logits using weights
            final_logits = weights[:, 0].unsqueeze(-1) * t_logits + \
                          weights[:, 1].unsqueeze(-1) * g_logits
            
            return final_logits, weights.unsqueeze(1)
        
        # ============ Force ratio strategy ============
        elif self.strategy == 'force_ratio':
            t = self.text_proj(text_emb)
            g = self.graph_proj(graph_emb)
            
            # Strict 50-50 or custom ratio
            w_text = (1 - self.gat_target_ratio)
            w_gat = self.gat_target_ratio
            
            fused = w_text * t + w_gat * g
            logits = self.classifier(fused)
            
            weights = torch.full((B, 2), 0.5, device=text_emb.device)
            weights[:, 0] = w_text
            weights[:, 1] = w_gat
            
            return logits, weights.unsqueeze(1)



class BalancedTrainer:
    """Helper to ensure balanced training"""
    
    @staticmethod
    def balanced_loss(logits, weights, targets, alpha=0.5):
        """
        Penalize models if one dominates too much.
        Encourage balanced contributions.
        """
        # Standard classification loss
        ce_loss = nn.functional.binary_cross_entropy_with_logits(logits, targets)
        
        # Regularization: penalize weight imbalance
        # weights shape: [B, 1, 2]
        w = weights.squeeze(1)  # [B, 2]
        
        # KL divergence from uniform (0.5, 0.5)
        uniform = torch.ones_like(w) * 0.5
        balance_loss = nn.functional.kl_div(
            torch.log(w + 1e-8), 
            uniform, 
            reduction='mean'
        )
        
        return ce_loss + alpha * balance_loss
    
    @staticmethod
    def entropy_regularization(weights, target_entropy=0.693):
        """
        target_entropy ≈ 0.693 = entropy of (0.5, 0.5)
        Encourage weights to stay balanced
        """
        w = weights.squeeze(1)  # [B, 2]
        entropy = -(w * torch.log(w + 1e-8)).sum(dim=1).mean()
        
        # Loss: penalize deviation from target entropy
        entropy_loss = (entropy - target_entropy).abs()
        return entropy_loss




In [14]:
# Tạo DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print(f"Train: {len(train_dataset)} samples")
print(f"Val: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")

Train: 6404 samples
Val: 801 samples
Test: 801 samples


In [None]:

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Get dimensions
sample_batch = next(iter(train_loader))
text_dim = sample_batch['text_emb'].shape[1]  # 768
graph_dim = sample_batch['graph_emb'].shape[1]  # 512
num_labels = sample_batch['label'].shape[1]  # 12

print(f"\n=== Configuration ===")
print(f"Text dim: {text_dim}")
print(f"Graph dim: {graph_dim}")
print(f"Num labels (organs): {num_labels}")


# Initialize model (you can pass init_gate_pref to bias initial gate, and force_ratio to force global ratio)
model = AttentionFusionMultiLabel(
    gat_target_ratio=0.5,
    strategy= 'gat_priority',
    
    )
model = model.to(device)

# Chạy này trước khi train để kiểm tra
batch = next(iter(train_loader))
print("=== Batch Debug ===")
print(f"text_emb: {batch['text_emb'].shape}, dtype: {batch['text_emb'].dtype}")
print(f"graph_emb: {batch['graph_emb'].shape}, dtype: {batch['graph_emb'].dtype}")
print(f"label: {batch['label'].shape}, dtype: {batch['label'].dtype}")
print(f"label values: {batch['label'][:5]}")  # First 5 labels



Using device: cpu

=== Configuration ===
Text dim: 768
Graph dim: 512
Num labels (organs): 12
=== Batch Debug ===
text_emb: torch.Size([32, 768]), dtype: torch.float32
graph_emb: torch.Size([32, 512]), dtype: torch.float32
label: torch.Size([32, 12]), dtype: torch.float32
label values: tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0.]])


In [16]:
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, hamming_loss
import torch
import torch.nn as nn

def train_epoch(model, loader, optimizer, criterion, device, balance_alpha=0.0):
    """
    Train WITHOUT over-regularization - let router learn naturally
    """
    model.train()
    total_loss = 0
    total_balance_loss = 0
    
    for batch in loader:
        text_emb = batch['text_emb'].to(device)
        graph_emb = batch['graph_emb'].to(device)
        labels = batch['label'].to(device)
        
        # Forward
        logits, weights = model(text_emb, graph_emb)  # [batch, 12], [batch, 1, 2]
        
        # Classification loss (MAIN)
        ce_loss = criterion(logits, labels)
        
        # ===== LIGHT balance regularization =====
        # Only gently encourage toward 50-50, don't force it
        balance_loss = torch.tensor(0.0, device=device)
        if balance_alpha > 0:
            w = weights.squeeze(1)  # [batch, 2]
            # Target: weights should have variance > 0
            # But also: don't stray too far from 50-50
            target = torch.ones_like(w) * 0.5
            balance_loss = torch.mean((w - target) ** 2)  # MSE, not KL
        
        total_weighted_loss = ce_loss + balance_alpha * balance_loss
        
        # Backward
        optimizer.zero_grad()
        total_weighted_loss.backward()
        optimizer.step()
        
        total_loss += ce_loss.item()
        total_balance_loss += balance_loss.item()
    
    return total_loss / len(loader), total_balance_loss / len(loader)


def evaluate_with_variance(model, loader, device, threshold=0.5):
    """
    Evaluate + compute weight statistics including variance
    """
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    all_weights = []
    
    with torch.no_grad():
        for batch in loader:
            text_emb = batch['text_emb'].to(device)
            graph_emb = batch['graph_emb'].to(device)
            labels = batch['label'].to(device)
            
            logits, weights = model(text_emb, graph_emb)
            probs = torch.sigmoid(logits)
            preds = (probs > threshold).float()
            
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
            all_probs.append(probs.cpu())
            all_weights.append(weights.cpu())
    
    all_preds = torch.cat(all_preds, dim=0).numpy()
    all_labels = torch.cat(all_labels, dim=0).numpy()
    all_probs = torch.cat(all_probs, dim=0).numpy()
    all_weights = torch.cat(all_weights, dim=0)  # [N, 1, 2]
    
    # Classification metrics
    exact_match = accuracy_score(all_labels, all_preds)
    hamming = hamming_loss(all_labels, all_preds)
    f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    f1_micro = f1_score(all_labels, all_preds, average='micro', zero_division=0)
    
    try:
        auc_macro = roc_auc_score(all_labels, all_probs, average='macro')
    except:
        auc_macro = 0.0
    
    # ===== Weight statistics =====
    w = all_weights.squeeze(1)  # [N, 2]
    
    # Per-sample variance (should be > 0 if router is learning)
    sample_variance = ((w - 0.5) ** 2).mean(dim=1)  # [N]
    mean_variance = sample_variance.mean().item()
    
    # Weight statistics
    bert_weights = w[:, 0]
    gat_weights = w[:, 1]
    
    return {
        'exact_match': exact_match,
        'hamming_loss': hamming,
        'f1_macro': f1_macro,
        'f1_micro': f1_micro,
        'auc_macro': auc_macro,
        'avg_bert_weight': bert_weights.mean().item(),
        'avg_gat_weight': gat_weights.mean().item(),
        'bert_std': bert_weights.std().item(),
        'gat_std': gat_weights.std().item(),
        'mean_variance': mean_variance,  # ← KEY METRIC
        'weights_tensor': w
    }


# ============ Setup ============

all_labels = torch.cat([batch['label'] for batch in train_loader], dim=0)
label_freq = all_labels.sum(dim=0)
print(f"\nLabel frequencies (train):")
for i, freq in enumerate(label_freq):
    print(f"  Organ {i}: {freq.item():.0f} samples ({freq.item()/len(all_labels)*100:.1f}%)")

total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")

pos_weight = (len(all_labels) - label_freq) / (label_freq + 1e-6)
pos_weight = pos_weight.to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

# ============ Training (FIXED) ============
print("\n" + "="*80)
print("=== Training Started (Variance-Aware) ===")
print("="*80)

best_val_f1 = 0
patience = 20
patience_counter = 0

# Start with NO balance loss - let router learn first
# Gradually introduce balance loss if weights collapse
balance_alpha = 0.0
min_variance_threshold = 0.001  # If variance drops below this, increase alpha

for epoch in range(100):
    train_loss, train_balance = train_epoch(
        model, train_loader, optimizer, criterion, device,
        balance_alpha=balance_alpha
    )
    val_metrics = evaluate_with_variance(model, val_loader, device)
    
    print(f"Epoch {epoch+1:3d} | "
          f"Loss: {train_loss:.4f} | "
          f"F1-M: {val_metrics['f1_macro']:.4f} | "
          f"BERT: {val_metrics['avg_bert_weight']:.3f}±{val_metrics['bert_std']:.3f} | "
          f"GAT: {val_metrics['avg_gat_weight']:.3f}±{val_metrics['gat_std']:.3f} | "
          f"Var: {val_metrics['mean_variance']:.4f} | "
          f"α={balance_alpha:.3f}")
    
    # ===== Dynamic balance alpha adjustment =====
    # If variance is TOO LOW, increase regularization
    if val_metrics['mean_variance'] < min_variance_threshold and epoch > 5:
        balance_alpha = min(balance_alpha + 0.01, 0.05)
        print(f"    Variance too low! Increasing α to {balance_alpha:.3f}")
    
    # If variance is HIGH and GAT < 45%, increase balance gently
    elif val_metrics['mean_variance'] > 0.05 and val_metrics['avg_gat_weight'] < 0.45:
        balance_alpha = min(balance_alpha + 0.002, 0.03)
        print(f"   ℹ GAT low but variance good. Slight boost: α={balance_alpha:.3f}")
    
    # Early stopping
    if val_metrics['f1_macro'] > best_val_f1:
        best_val_f1 = val_metrics['f1_macro']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_metrics': val_metrics,
        }, 'best_fusion_multilabel_model.pt')
        patience_counter = 0
        print(f"    Best F1: {best_val_f1:.4f} | GAT: {val_metrics['avg_gat_weight']:.1%}")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break


# ============ Test ============
print("\n" + "="*80)
print("=== Final Test Results ===")
print("="*80)

checkpoint = torch.load('best_fusion_multilabel_model.pt', map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])

test_metrics = evaluate_with_variance(model, test_loader, device)

print(f"\nClassification Metrics:")
print(f"   Exact Match:      {test_metrics['exact_match']:.4f}")
print(f"   F1-Macro:         {test_metrics['f1_macro']:.4f}")
print(f"   F1-Micro:         {test_metrics['f1_micro']:.4f}")
print(f"   AUC-Macro:        {test_metrics['auc_macro']:.4f}")

print(f"\nWeight Distribution (= healthy):")
print(f"   BERT: {test_metrics['avg_bert_weight']:.3f} ± {test_metrics['bert_std']:.4f}")
print(f"   GAT:  {test_metrics['avg_gat_weight']:.3f} ± {test_metrics['gat_std']:.4f}")
print(f"   Mean Variance: {test_metrics['mean_variance']:.4f} {'✓' if test_metrics['mean_variance'] > 0.001 else '✗'}")

print(f"\n{'='*80}")
if test_metrics['avg_gat_weight'] >= 0.45 and test_metrics['mean_variance'] > 0.001:
    print(f" SUCCESS: GAT {test_metrics['avg_gat_weight']:.1%} with healthy variance!")
elif test_metrics['mean_variance'] < 0.001:
    print(f" ISSUE: Variance too low - router not learning to differentiate")
    print(f"  Solutions:")
    print(f"    1. Reduce balance_alpha initial value")
    print(f"    2. Use force_ratio=0.5 strategy for hard constraint")
    print(f"    3. Check if router gradients are flowing properly")
else:
    print(f" GAT at {test_metrics['avg_gat_weight']:.1%} - increase balance_alpha slightly")


Label frequencies (train):
  Organ 0: 236 samples (3.7%)
  Organ 1: 178 samples (2.8%)
  Organ 2: 638 samples (10.0%)
  Organ 3: 250 samples (3.9%)
  Organ 4: 626 samples (9.8%)
  Organ 5: 279 samples (4.4%)
  Organ 6: 141 samples (2.2%)
  Organ 7: 765 samples (11.9%)
  Organ 8: 208 samples (3.2%)
  Organ 9: 295 samples (4.6%)
  Organ 10: 763 samples (11.9%)
  Organ 11: 340 samples (5.3%)

Total parameters: 529,949

=== Training Started (Variance-Aware) ===

=== Training Started (Variance-Aware) ===
Epoch   1 | Loss: 0.5391 | F1-M: 0.8382 | BERT: 0.999±0.002 | GAT: 0.001±0.002 | Var: 0.2488 | α=0.000
   ℹ GAT low but variance good. Slight boost: α=0.002
    Best F1: 0.8382 | GAT: 0.1%
Epoch   1 | Loss: 0.5391 | F1-M: 0.8382 | BERT: 0.999±0.002 | GAT: 0.001±0.002 | Var: 0.2488 | α=0.000
   ℹ GAT low but variance good. Slight boost: α=0.002
    Best F1: 0.8382 | GAT: 0.1%
Epoch   2 | Loss: 0.1501 | F1-M: 0.8977 | BERT: 1.000±0.000 | GAT: 0.000±0.000 | Var: 0.2498 | α=0.002
   ℹ GAT low 

In [17]:
model.eval()
all_weights = []

for batch in val_loader:
    text_emb = batch['text_emb'].to(device)
    graph_emb = batch['graph_emb'].to(device)

    with torch.no_grad():
        _, attn = model(text_emb, graph_emb)
    
    # attn shape = [batch, 1, 2]  => squeeze thành [batch, 2]
    all_weights.append(attn.squeeze(1).cpu())

all_weights = torch.cat(all_weights, dim=0)

bert_contrib = all_weights[:, 0].mean().item()
gat_contrib  = all_weights[:, 1].mean().item()

print(" BERT contribution:", bert_contrib)
print(" GAT contribution :", gat_contrib)


 BERT contribution: 0.5885030627250671
 GAT contribution : 0.41149693727493286


In [18]:
print("Variance:", all_weights.var(dim=0))


Variance: tensor([0.0281, 0.0281])
