In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np


In [1]:
#constant
TEST_EMBEDDING_FILE = "test_embeddings.pt"
TRAIN_EMBEDDING_FILE = "train_embeddings.pt"
VAL_EMBEDDING_FILE = "val_embeddings.pt"

GNN_TEST_EMBEDDING = f"../GATNN/embeddings/{TEST_EMBEDDING_FILE}"
GNN_TRAIN_EMBEDDING = f"../GATNN/embeddings/{TRAIN_EMBEDDING_FILE}"
GNN_VAL_EMBEDDING = f"../GATNN/embeddings/{VAL_EMBEDDING_FILE}"
BERT_TEST_EMBEDDING = f"../BERT/embedding/{TEST_EMBEDDING_FILE}"
BERT_TRAIN_EMBEDDING = f"../BERT/embedding/{TRAIN_EMBEDDING_FILE}"
BERT_VAL_EMBEDDING = f"../BERT/embedding/{VAL_EMBEDDING_FILE}"



In [2]:
def load_embedding_file(path):
    """
    Hàm load file embedding từ đường dẫn `path`.
    Tự động nhận dạng nhiều kiểu dữ liệu khác nhau:
    - dict có key 'embeddings' hoặc 'emb', 'vectors', 'features'
    - trực tiếp là Tensor
    - list/tuple (chuyển sang Tensor)
    """
    # ---- Bỏ chặn numpy pickle cổ ----
    try:
        torch.serialization.add_safe_globals([np.core.multiarray._reconstruct])
    except Exception:
        pass  # không sao

    # Load thử với weights_only=False ----
    try:
        data = torch.load(path, weights_only=False)
    except Exception as e:
        raise RuntimeError(f"LỖI load file embedding: {e}")
    
    # data = torch.load(path, map_location='cpu')

    #Chuẩn hóa về tensor ----
    # Nếu là tensor trực tiếp
    if isinstance(data, torch.Tensor):
        return {'embeddings': data}
    

    # Nếu là numpy
    if isinstance(data, np.ndarray):
        return {'embeddings': torch.tensor(data)}
    

    # Trường hợp 1: File lưu dạng dict
    if isinstance(data, dict):

        # TH1.1: Dict có key phổ biến 'embeddings'
        if 'embeddings' in data:
            return {
                'embeddings': torch.as_tensor(data['embeddings']),
                'ids': data.get('mol_ids', None)  or data.get('mol_id',None) # có thể không có
            }

        # TH1.2: Một số file embedding dùng key khác
        for key in ['emb', 'vectors', 'features']:
            if key in data:
                return {
                    'embeddings': torch.as_tensor(data[key]),
                    'ids': data.get('mol_ids', None) or data.get('mol_id',None)
                }

        # TH1.3: Dict không rõ cấu trúc → thử convert cả dict sang tensor
        try:
            return {'embeddings': torch.as_tensor(data)}
        except Exception:
            raise ValueError(
                f"Lỗi: Dict trong file {path} có cấu trúc không hỗ trợ để chuyển sang Tensor"
            )

    

    # Trường hợp 3: File là list hoặc tuple → convert sang Tensor
    elif isinstance(data, (list, tuple)):
        try:
            return {'embeddings': torch.as_tensor(data)}
        except Exception:
            raise ValueError(
                f"Lỗi: Không thể chuyển list/tuple trong file {path} sang Tensor"
            )

    # Trường hợp không thuộc loại nào
    else:
        raise ValueError(
            f"Lỗi: Định dạng dữ liệu trong file {path} không được hỗ trợ"
            f"Loại: {type(data)}"
        )


In [12]:
import os
import torch




def check_embedding_file(path):
    """Kiểm tra 1 file embedding và trả về True/False + in thông tin chi tiết."""

    print(f"KIỂM TRA FILE: {path} ======")

    if not os.path.exists(path):
        print("Lỗi: File không tồn tại.")
        return False

    try:
        result = load_embedding_file(path)
        emb = result["embeddings"]

        # --- Kiểm tra embedding có phải Tensor ---
        if not isinstance(emb, torch.Tensor):
            print("Lỗi: embeddings không phải Torch.Tensor.")
            return False

        # --- Kiểm tra số chiều ---
        if emb.ndim < 2:
            print(f"Lỗi: embeddings phải >= 2 chiều, hiện tại: {emb.ndim}")
            return False

        # --- In thông tin hợp lệ ---
        print("✅File hợp lệ!")
        print(f"Kích thước: {tuple(emb.shape)}___Kiểu dữ liệu: {emb.dtype}___Có IDs không? { 'Có' if result.get('ids') is not None else 'Không' }")
       
        return True

    except Exception as e:
        print(f"Lỗi khi load file: {e}")
        return False


def test_all_embeddings():
    """Test toàn bộ file BERT + GNN"""

    bert_files = [
        BERT_TEST_EMBEDDING,
        BERT_TRAIN_EMBEDDING,
        BERT_VAL_EMBEDDING,
    ]

    gnn_files = [
        GNN_TEST_EMBEDDING,
        GNN_TRAIN_EMBEDDING,
        GNN_VAL_EMBEDDING,
    ]

   
    print("KIỂM TRA BERT BRANCH==================================")
  

    for f in bert_files:
        check_embedding_file(f)

    
    print("KIỂM TRA GNN BRANCH====================================")

    for f in gnn_files:
        check_embedding_file(f)


if __name__ == "__main__":
    test_all_embeddings()



✅File hợp lệ!
Kích thước: (802, 768)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (6411, 768)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (801, 768)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (801, 512)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (6404, 512)___Kiểu dữ liệu: torch.float32___Có IDs không? Có
✅File hợp lệ!
Kích thước: (801, 512)___Kiểu dữ liệu: torch.float32___Có IDs không? Có


In [24]:


# Load BERT embeddings
bert_train = torch.load(BERT_TRAIN_EMBEDDING, weights_only=False)
bert_val = torch.load(BERT_VAL_EMBEDDING, weights_only=False)
bert_test = torch.load(BERT_TEST_EMBEDDING, weights_only=False)

# Load GAT embeddings
gat_train = torch.load(GNN_TRAIN_EMBEDDING, weights_only=False)
gat_val = torch.load(GNN_VAL_EMBEDDING, weights_only=False)
gat_test = torch.load(GNN_TEST_EMBEDDING, weights_only=False)

In [13]:
def check_alignment(bert_data, gat_data, split_name):
    """Kiểm tra xem mol_ids có khớp nhau không"""
    bert_ids = bert_data['mol_id']  # hoặc 'mol_ids'
    gat_ids = gat_data['mol_ids']
    
    # So sánh
    if len(bert_ids) != len(gat_ids):
        print(f" {split_name}: Số lượng không khớp! BERT={len(bert_ids)}, GAT={len(gat_ids)}")
    
    # Kiểm tra thứ tự
    if isinstance(bert_ids[0], str) and isinstance(gat_ids[0], str):
        match = all(t == g for t, g in zip(bert_ids, gat_ids))
    else:
        match = torch.equal(bert_ids, gat_ids)
    
    if match:
        print(f" {split_name}: Mol IDs khớp hoàn toàn!")
    else:
        print(f"{split_name}: Mol IDs KHÔNG khớp, cần align lại!")
    
    return match

# Check tất cả splits
train_match = check_alignment(bert_train, gat_train, "Train")
val_match = check_alignment(bert_val, gat_val, "Val")
test_match = check_alignment(bert_test, gat_test, "Test")

 Train: Số lượng không khớp! BERT=6411, GAT=6404
Train: Mol IDs KHÔNG khớp, cần align lại!
Val: Mol IDs KHÔNG khớp, cần align lại!
 Test: Số lượng không khớp! BERT=802, GAT=801
Test: Mol IDs KHÔNG khớp, cần align lại!


In [10]:
from torch.utils.data import Dataset, DataLoader

class FusionDataset(Dataset):
    def __init__(self, bert_data, gat_data):
        # Embeddings
        self.text_embs = bert_data['embeddings']
        self.graph_embs = gat_data['embeddings']
        
        # Labels (lấy từ 1 trong 2, giả sử giống nhau)
        self.labels = bert_data['labels']
        
        # Optional: mol_ids để track
        self.mol_ids = bert_data['mol_id']
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return {
            'text_emb': self.text_embs[idx],
            'graph_emb': self.graph_embs[idx],
            'label': self.labels[idx],
            'mol_id': self.mol_ids[idx]
        }



In [17]:


def align_datasets(bert_data, gat_data):
    """Align 2 datasets theo mol_id"""
    bert_ids = bert_data['mol_id']
    gat_ids = gat_data['mol_ids']
    
    # Tạo dictionaries
    bert_dict = {str(mol_id): idx for idx, mol_id in enumerate(bert_ids)}
    gat_dict = {str(mol_id): idx for idx, mol_id in enumerate(gat_ids)}
    
    # Tìm common IDs
    common_ids = sorted(set(bert_dict.keys()) & set(gat_dict.keys()))
    
    print(f"BERT có {len(bert_ids)} samples")
    print(f"GAT có {len(gat_ids)} samples")
    print(f"Chung: {len(common_ids)} samples")
    
    # Extract aligned data
    bert_indices = [bert_dict[mol_id] for mol_id in common_ids]
    gat_indices = [gat_dict[mol_id] for mol_id in common_ids]
    
    # Convert to tensor nếu là list
    bert_embeddings = bert_data['embeddings']
    bert_labels = bert_data['labels']
    gat_embeddings = gat_data['embeddings']
    gat_labels = gat_data['labels']
    
    # Handle indexing based on type
    if isinstance(bert_embeddings, list):
        aligned_bert_embs = torch.stack([bert_embeddings[i] for i in bert_indices])
    else:
        aligned_bert_embs = bert_embeddings[bert_indices]
    
    if isinstance(bert_labels, list):
        aligned_bert_labels = torch.tensor([bert_labels[i] for i in bert_indices])
    else:
        aligned_bert_labels = bert_labels[bert_indices]
    
    if isinstance(gat_embeddings, list):
        aligned_gat_embs = torch.stack([gat_embeddings[i] for i in gat_indices])
    else:
        aligned_gat_embs = gat_embeddings[gat_indices]
    
    if isinstance(gat_labels, list):
        aligned_gat_labels = torch.tensor([gat_labels[i] for i in gat_indices])
    else:
        aligned_gat_labels = gat_labels[gat_indices]
    
    aligned_bert = {
        'embeddings': aligned_bert_embs,
        'labels': aligned_bert_labels,
        'mol_id': [bert_ids[i] for i in bert_indices]
    }
    
    aligned_gat = {
        'embeddings': aligned_gat_embs,
        'labels': aligned_gat_labels,
        'mol_ids': [gat_ids[i] for i in gat_indices]
    }
    
    return aligned_bert, aligned_gat

# Align từng split
train_bert_aligned, train_gat_aligned = align_datasets(bert_train, gat_train)
val_bert_aligned, val_gat_aligned = align_datasets(bert_val, gat_val)
test_bert_aligned, test_gat_aligned = align_datasets(bert_test, gat_test)

# Tạo datasets
train_dataset = FusionDataset(train_bert_aligned, train_gat_aligned)
val_dataset = FusionDataset(val_bert_aligned, val_gat_aligned)
test_dataset = FusionDataset(test_bert_aligned, test_gat_aligned)

print(f"\n Datasets created:")
print(f"Train: {len(train_dataset)} samples")
print(f"Val: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")

BERT có 6411 samples
GAT có 6404 samples
Chung: 5134 samples
BERT có 801 samples
GAT có 801 samples
Chung: 80 samples
BERT có 802 samples
GAT có 801 samples
Chung: 95 samples

 Datasets created:
Train: 5134 samples
Val: 80 samples
Test: 95 samples


In [18]:
class AttentionFusionMultiLabel(nn.Module):
    """Multi-label classification version"""
    def __init__(self, text_dim=768, graph_dim=512, num_labels=12):
        super().__init__()
        
        # Projection layers
        hidden_dim = 256
        self.text_proj = nn.Linear(text_dim, hidden_dim)
        self.graph_proj = nn.Linear(graph_dim, hidden_dim)
        
        # Learnable query for attention
        self.query = nn.Parameter(torch.randn(1, 1, hidden_dim))
        
        # Multi-head attention
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=4,
            batch_first=True
        )
        
        # Multi-label classifier (sigmoid for each label)
        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_labels)  # Output: [batch, 12]
        )
    
    def forward(self, text_emb, graph_emb):
        batch_size = text_emb.size(0)
        
        # Project to same dimension
        text_proj = self.text_proj(text_emb)
        graph_proj = self.graph_proj(graph_emb)
        
        # Stack: [B, 2, hidden_dim]
        embeddings = torch.stack([text_proj, graph_proj], dim=1)
        
        # Attention fusion
        query = self.query.expand(batch_size, -1, -1)
        fused, attention_weights = self.attention(query, embeddings, embeddings)
        fused = fused.squeeze(1)
        
        # Multi-label logits (no softmax, use sigmoid later)
        logits = self.classifier(fused)
        
        return logits, attention_weights


In [19]:
# Tạo DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print(f"Train: {len(train_dataset)} samples")
print(f"Val: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")

Train: 5134 samples
Val: 80 samples
Test: 95 samples


In [20]:
# Kiểm tra trước khi train
print("=== Dimension Check ===")
sample = next(iter(train_loader))
print(f"Text embedding shape: {sample['text_emb'].shape}")
print(f"Graph embedding shape: {sample['graph_emb'].shape}")
print(f"Label shape: {sample['label'].shape}")
print(f"Label unique values: {torch.unique(sample['label'])}")

=== Dimension Check ===
Text embedding shape: torch.Size([32, 768])
Graph embedding shape: torch.Size([32, 512])
Label shape: torch.Size([32, 12])
Label unique values: tensor([0., 1.])


In [22]:

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Get dimensions
sample_batch = next(iter(train_loader))
text_dim = sample_batch['text_emb'].shape[1]  # 768
graph_dim = sample_batch['graph_emb'].shape[1]  # 512
num_labels = sample_batch['label'].shape[1]  # 12

print(f"\n=== Configuration ===")
print(f"Text dim: {text_dim}")
print(f"Graph dim: {graph_dim}")
print(f"Num labels (organs): {num_labels}")

# Initialize model
model = AttentionFusionMultiLabel(
    text_dim=text_dim, 
    graph_dim=graph_dim, 
    num_labels=num_labels
)
model = model.to(device)

# Chạy này trước khi train để kiểm tra
batch = next(iter(train_loader))
print("=== Batch Debug ===")
print(f"text_emb: {batch['text_emb'].shape}, dtype: {batch['text_emb'].dtype}")
print(f"graph_emb: {batch['graph_emb'].shape}, dtype: {batch['graph_emb'].dtype}")
print(f"label: {batch['label'].shape}, dtype: {batch['label'].dtype}")
print(f"label values: {batch['label'][:5]}")  # First 5 labels

# Test forward pass
model.eval()
with torch.no_grad():
    logits, _ = model(batch['text_emb'].to(device), batch['graph_emb'].to(device))
    print(f"logits shape: {logits.shape}")
    print(f"Expected: [batch_size, num_classes]")

Using device: cpu

=== Configuration ===
Text dim: 768
Graph dim: 512
Num labels (organs): 12
=== Batch Debug ===
text_emb: torch.Size([32, 768]), dtype: torch.float32
graph_emb: torch.Size([32, 512]), dtype: torch.float32
label: torch.Size([32, 12]), dtype: torch.float32
label values: tensor([[0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
logits shape: torch.Size([32, 12])
Expected: [batch_size, num_classes]


In [25]:
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, hamming_loss

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for batch in loader:
        text_emb = batch['text_emb'].to(device)
        graph_emb = batch['graph_emb'].to(device)
        labels = batch['label'].to(device)  # [batch, 12]
        
        # Forward
        logits, _ = model(text_emb, graph_emb)  # [batch, 12]
        
        # Multi-label loss
        loss = criterion(logits, labels)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)


def evaluate(model, loader, device, threshold=0.5):
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for batch in loader:
            text_emb = batch['text_emb'].to(device)
            graph_emb = batch['graph_emb'].to(device)
            labels = batch['label'].to(device)
            
            logits, _ = model(text_emb, graph_emb)
            probs = torch.sigmoid(logits)  # [batch, 12] in range [0, 1]
            preds = (probs > threshold).float()  # Binary predictions
            
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
            all_probs.append(probs.cpu())
    
    # Concatenate all batches
    all_preds = torch.cat(all_preds, dim=0).numpy()  # [N, 12]
    all_labels = torch.cat(all_labels, dim=0).numpy()  # [N, 12]
    all_probs = torch.cat(all_probs, dim=0).numpy()  # [N, 12]
    
    # Multi-label metrics
    # 1. Exact match ratio (all labels correct)
    exact_match = accuracy_score(all_labels, all_preds)
    
    # 2. Hamming loss (average per-label error)
    hamming = hamming_loss(all_labels, all_preds)
    
    # 3. Macro F1 (average F1 across labels)
    f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    
    # 4. Micro F1 (global F1)
    f1_micro = f1_score(all_labels, all_preds, average='micro', zero_division=0)
    
    # 5. AUC-ROC per label (if possible)
    try:
        auc_macro = roc_auc_score(all_labels, all_probs, average='macro')
    except:
        auc_macro = 0.0
    
    return {
        'exact_match': exact_match,
        'hamming_loss': hamming,
        'f1_macro': f1_macro,
        'f1_micro': f1_micro,
        'auc_macro': auc_macro
    }






# Check label distribution
all_labels = torch.cat([batch['label'] for batch in train_loader], dim=0)
label_freq = all_labels.sum(dim=0)
print(f"\nLabel frequencies (train):")
for i, freq in enumerate(label_freq):
    print(f"  Organ {i}: {freq.item():.0f} samples ({freq.item()/len(all_labels)*100:.1f}%)")



total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")

# Multi-label loss: BCEWithLogitsLoss (combines sigmoid + BCE)
# Use pos_weight to handle class imbalance
pos_weight = (len(all_labels) - label_freq) / (label_freq + 1e-6)
pos_weight = pos_weight.to(device)
print(f"\nPos weights (for imbalanced labels): {pos_weight}")

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

# Training loop
print("\n=== Training Started ===")
best_val_f1 = 0
patience = 15
patience_counter = 0

for epoch in range(7):
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    val_metrics = evaluate(model, val_loader, device)
    
    print(f"Epoch {epoch+1:3d} | "
          f"Loss: {train_loss:.4f} | "
          f"EM: {val_metrics['exact_match']:.4f} | "
          f"Hamming: {val_metrics['hamming_loss']:.4f} | "
          f"F1-Macro: {val_metrics['f1_macro']:.4f} | "
          f"F1-Micro: {val_metrics['f1_micro']:.4f} | "
          f"AUC: {val_metrics['auc_macro']:.4f}")
    
    # Early stopping based on macro F1
    if val_metrics['f1_macro'] > best_val_f1:
        best_val_f1 = val_metrics['f1_macro']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_metrics': val_metrics,
        }, 'best_fusion_multilabel_model.pt')
        patience_counter = 0
        print(f"   Saved best model with F1-Macro: {best_val_f1:.4f}")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

# Test evaluation
print("\n=== Final Evaluation ===")
checkpoint = torch.load('best_fusion_multilabel_model.pt', map_location=device, weights_only = False)
model.load_state_dict(checkpoint['model_state_dict'])

test_metrics = evaluate(model, test_loader, device)
print(f" Test Results:")
print(f"   Exact Match: {test_metrics['exact_match']:.4f}")
print(f"   Hamming Loss: {test_metrics['hamming_loss']:.4f}")
print(f"   F1-Macro: {test_metrics['f1_macro']:.4f}")
print(f"   F1-Micro: {test_metrics['f1_micro']:.4f}")
print(f"   AUC-Macro: {test_metrics['auc_macro']:.4f}")


Label frequencies (train):
  Organ 0: 203 samples (4.0%)
  Organ 1: 166 samples (3.2%)
  Organ 2: 498 samples (9.7%)
  Organ 3: 188 samples (3.7%)
  Organ 4: 494 samples (9.6%)
  Organ 5: 225 samples (4.4%)
  Organ 6: 121 samples (2.4%)
  Organ 7: 606 samples (11.8%)
  Organ 8: 171 samples (3.3%)
  Organ 9: 237 samples (4.6%)
  Organ 10: 591 samples (11.5%)
  Organ 11: 279 samples (5.4%)

Total parameters: 626,572

Pos weights (for imbalanced labels): tensor([24.2906, 29.9277,  9.3092, 26.3085,  9.3927, 21.8178, 41.4298,  7.4719,
        29.0234, 20.6624,  7.6870, 17.4014])

=== Training Started ===




Epoch   1 | Loss: 0.0392 | EM: 0.9875 | Hamming: 0.0010 | F1-Macro: 0.9048 | F1-Micro: 0.9867 | AUC: nan
   Saved best model with F1-Macro: 0.9048




Epoch   2 | Loss: 0.0291 | EM: 0.9875 | Hamming: 0.0010 | F1-Macro: 0.9048 | F1-Micro: 0.9867 | AUC: nan




Epoch   3 | Loss: 0.0233 | EM: 0.9875 | Hamming: 0.0010 | F1-Macro: 0.9048 | F1-Micro: 0.9867 | AUC: nan




Epoch   4 | Loss: 0.0225 | EM: 0.9750 | Hamming: 0.0021 | F1-Macro: 0.9004 | F1-Micro: 0.9737 | AUC: nan




Epoch   5 | Loss: 0.0184 | EM: 0.9875 | Hamming: 0.0010 | F1-Macro: 0.9048 | F1-Micro: 0.9867 | AUC: nan




Epoch   6 | Loss: 0.0184 | EM: 0.9750 | Hamming: 0.0021 | F1-Macro: 0.9004 | F1-Micro: 0.9737 | AUC: nan
Epoch   7 | Loss: 0.0149 | EM: 0.9750 | Hamming: 0.0021 | F1-Macro: 0.9004 | F1-Micro: 0.9737 | AUC: nan

=== Final Evaluation ===
 Test Results:
   Exact Match: 0.9684
   Hamming Loss: 0.0035
   F1-Macro: 0.9599
   F1-Micro: 0.9722
   AUC-Macro: 0.9999


