In [16]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/tox21data/tox21.csv


In [17]:
!pip install torch-geometric
!pip install rdkit-pypi




In [18]:
import torch
import torch.nn as nn
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv, global_mean_pool, global_max_pool, global_add_pool
from torch_geometric.explain import Explainer, GNNExplainer
from rdkit import Chem
from rdkit.Chem import Descriptors
import numpy as np


## Preprocessing

In [19]:
import pandas as pd
import numpy as np
from rdkit import Chem
import torch
from torch_geometric.data import Data
import random

# ============================================================================
# LOAD D·ªÆ LI·ªÜU
# ============================================================================

TOX21 = "/kaggle/input/tox21data/tox21.csv"
df = pd.read_csv(TOX21)

print('='*70)
print('üìä LOAD DATA')
print('='*70)
print(df.head())

label_cols = [
    "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase",
    "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma",
    "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
]

smiles_df = df["smiles"].tolist()
mol_ids_df = df["mol_id"].tolist()
labels_df = df[label_cols].values  # (n_samples, 12)

print(f"\n‚úÖ SMILES: {len(smiles_df)} molecules")
print(f"‚úÖ Labels shape: {labels_df.shape}")

# ============================================================================
# T·∫†O LABELS_CLEAN + MASK
# ============================================================================

labels_clean = np.nan_to_num(labels_df, nan=0.0)
mask = (~np.isnan(labels_df)).astype(np.float32)

print(f"‚úÖ Labels_clean shape: {labels_clean.shape}")
print(f"‚úÖ Mask shape: {mask.shape}")

# ============================================================================
# BUILD DATASET
# ============================================================================

dataset = []
failed = 0

print(f"\n{'='*70}")
print("üèóÔ∏è  BUILD DATASET")
print(f"{'='*70}")

for i, smi in enumerate(smiles_df):
    # Convert SMILES ‚Üí Mol
    try:
        mol = Chem.MolFromSmiles(smi, sanitize=False)
        Chem.SanitizeMol(mol)
    except:
        failed += 1
        continue
    
    if mol is None:
        failed += 1
        continue
    
    # ===== NODE FEATURES =====
    atom_features = []
    periodic_table = Chem.GetPeriodicTable()
    
    for atom in mol.GetAtoms():
        atomic_num = atom.GetAtomicNum()
        valence_electrons = periodic_table.GetNOuterElecs(atomic_num)
        features = [
            atomic_num,
            atom.GetTotalValence(),
            atom.GetTotalDegree(),
            int(atom.GetIsAromatic()),
            atom.GetFormalCharge(),
            valence_electrons,           
            atom.GetNumImplicitHs()      
        ]
        atom_features.append(features)
    
    x = torch.tensor(atom_features, dtype=torch.float)
    
    # ===== EDGE INDEX =====
    edge_index = []
    for bond in mol.GetBonds():
        a = bond.GetBeginAtomIdx()
        b = bond.GetEndAtomIdx()
        edge_index.append([a, b])
        edge_index.append([b, a])
    
    if len(edge_index) == 0:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
    else:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    
    # ===== LABELS (GRAPH-LEVEL!) =====
    y = torch.tensor(labels_clean[i], dtype=torch.float)        # (12,) ‚úÖ
    m = torch.tensor(mask[i], dtype=torch.float)                # (12,) ‚úÖ
    
    # ===== T·∫†O DATA OBJECT =====
    data = Data(
        x=x,
        edge_index=edge_index,
        y=y,
        mask=m,
        mol_id=mol_ids_df[i],
        num_nodes=x.shape[0]
    )
    
    dataset.append(data)

print(f"‚úÖ Dataset created: {len(dataset)} molecules")
print(f"‚ùå Failed to parse: {failed} SMILES")

# Ki·ªÉm tra sample
sample = dataset[0]
print(f"\nüìå Sample Data Object:")
print(f"   x.shape: {sample.x.shape}")
print(f"   edge_index.shape: {sample.edge_index.shape}")
print(f"   y.shape: {sample.y.shape} ‚úÖ (ph·∫£i l√† (12,))")
print(f"   mask.shape: {sample.mask.shape} ‚úÖ (ph·∫£i l√† (12,))")
print(f"   num_nodes: {sample.num_nodes}")
print(f"   mol_id: {sample.mol_id}")

# ============================================================================
# SPLIT DATASET
# ============================================================================

print(f"\n{'='*70}")
print("üìä SPLIT DATASET")
print(f"{'='*70}")

random.seed(42)
random.shuffle(dataset)

total = len(dataset)
len_train = int(0.8 * total)
len_val = int(0.1 * total)
len_test = total - len_train - len_val

print(f"T·ªïng s·ªë m·∫´u: {total}")
print(f"Train: {len_train} (80%)")
print(f"Val:   {len_val} (10%)")
print(f"Test:  {len_test} (10%)")

train_dataset = dataset[:len_train]
val_dataset = dataset[len_train : len_train + len_val]
test_dataset = dataset[len_train + len_val :]

# ============================================================================
# VERIFY
# ============================================================================

print(f"\n{'='*70}")
print("‚úÖ VERIFY DATASET")
print(f"{'='*70}")

print(f"\nüìå Train Sample:")
print(f"   x.shape: {train_dataset[0].x.shape}")
print(f"   y.shape: {train_dataset[0].y.shape} ‚úÖ")

print(f"\nüìå Val Sample:")
print(f"   x.shape: {val_dataset[0].x.shape}")
print(f"   y.shape: {val_dataset[0].y.shape} ‚úÖ")

print(f"\nüìå Test Sample:")
print(f"   x.shape: {test_dataset[0].x.shape}")
print(f"   y.shape: {test_dataset[0].y.shape} ‚úÖ")

# ============================================================================
# SAVE
# ============================================================================

print(f"\n{'='*70}")
print("üíæ SAVE DATASET")
print(f"{'='*70}")

torch.save(train_dataset, "train.pt")
torch.save(val_dataset, "val.pt")
torch.save(test_dataset, "test.pt")

print("‚úÖ train.pt saved")
print("‚úÖ val.pt saved")
print("‚úÖ test.pt saved")

print(f"\n{'='*70}")
print("‚ú® DONE!")
print(f"{'='*70}")

üìä LOAD DATA
   NR-AR  NR-AR-LBD  NR-AhR  NR-Aromatase  NR-ER  NR-ER-LBD  NR-PPAR-gamma  \
0    0.0        0.0     1.0           NaN    NaN        0.0            0.0   
1    0.0        0.0     0.0           0.0    0.0        0.0            0.0   
2    NaN        NaN     NaN           NaN    NaN        NaN            NaN   
3    0.0        0.0     0.0           0.0    0.0        0.0            0.0   
4    0.0        0.0     NaN           0.0    0.0        0.0            0.0   

   SR-ARE  SR-ATAD5  SR-HSE  SR-MMP  SR-p53   mol_id  \
0     1.0       0.0     0.0     0.0     0.0  TOX3021   
1     NaN       0.0     NaN     0.0     0.0  TOX3020   
2     0.0       NaN     0.0     NaN     NaN  TOX3024   
3     NaN       0.0     NaN     0.0     0.0  TOX3027   
4     0.0       0.0     0.0     NaN     0.0  TOX3028   

                                              smiles  
0                       CCOc1ccc2nc(S(N)(=O)=O)sc2c1  
1                          CCN1C(=O)NC(c2ccccc2)C1=O  
2  CC[C@]1(O)C

  has_large_values = (abs_vals > 1e6).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()


‚úÖ Dataset created: 8014 molecules
‚ùå Failed to parse: 0 SMILES

üìå Sample Data Object:
   x.shape: torch.Size([16, 7])
   edge_index.shape: torch.Size([2, 34])
   y.shape: torch.Size([12]) ‚úÖ (ph·∫£i l√† (12,))
   mask.shape: torch.Size([12]) ‚úÖ (ph·∫£i l√† (12,))
   num_nodes: 16
   mol_id: TOX3021

üìä SPLIT DATASET
T·ªïng s·ªë m·∫´u: 8014
Train: 6411 (80%)
Val:   801 (10%)
Test:  802 (10%)

‚úÖ VERIFY DATASET

üìå Train Sample:
   x.shape: torch.Size([43, 7])
   y.shape: torch.Size([12]) ‚úÖ

üìå Val Sample:
   x.shape: torch.Size([31, 7])
   y.shape: torch.Size([12]) ‚úÖ

üìå Test Sample:
   x.shape: torch.Size([11, 7])
   y.shape: torch.Size([12]) ‚úÖ

üíæ SAVE DATASET
‚úÖ train.pt saved
‚úÖ val.pt saved
‚úÖ test.pt saved

‚ú® DONE!


In [20]:
train_dataset = torch.load("train.pt", weights_only=False)
val_dataset   = torch.load("val.pt", weights_only=False)
test_dataset  = torch.load("test.pt", weights_only=False)

In [21]:
data = train_dataset[0]
print(data)
print("Node feature matrix (x):")
print(data.x)
print("x shape:", data.x.shape)


Data(x=[43, 7], edge_index=[2, 98], y=[12], mask=[12], mol_id='TOX197', num_nodes=43)
Node feature matrix (x):
tensor([[ 6.,  4.,  4.,  0.,  0.,  4.,  3.],
        [ 6.,  4.,  4.,  0.,  0.,  4.,  1.],
        [ 6.,  4.,  4.,  0.,  0.,  4.,  3.],
        [ 6.,  4.,  4.,  0.,  0.,  4.,  2.],
        [ 6.,  4.,  4.,  0.,  0.,  4.,  0.],
        [ 6.,  4.,  3.,  0.,  0.,  4.,  0.],
        [ 8.,  2.,  1.,  0.,  0.,  6.,  0.],
        [ 7.,  3.,  3.,  0.,  0.,  5.,  0.],
        [ 6.,  4.,  4.,  0.,  0.,  4.,  2.],
        [ 6.,  4.,  4.,  0.,  0.,  4.,  2.],
        [ 6.,  4.,  4.,  0.,  0.,  4.,  2.],
        [ 6.,  4.,  4.,  0.,  0.,  4.,  0.],
        [ 6.,  4.,  4.,  0.,  0.,  4.,  0.],
        [ 8.,  2.,  2.,  0.,  0.,  6.,  1.],
        [ 8.,  2.,  2.,  0.,  0.,  6.,  0.],
        [ 6.,  4.,  4.,  0.,  0.,  4.,  0.],
        [ 7.,  3.,  3.,  0.,  0.,  5.,  1.],
        [ 6.,  4.,  3.,  0.,  0.,  4.,  0.],
        [ 8.,  2.,  1.,  0.,  0.,  6.,  0.],
        [ 6.,  4.,  4.,  0.,  0., 

In [22]:
def custom_collate_fn(data_list):
    """
    Custom collate function ƒë·ªÉ fix batch.y shape
    
    V·∫§Nƒê·ªÄ: 
    - PyG Batch n·ªëi batch.y th√†nh 1D: (7, 12) ‚Üí (84,)
    
    GI·∫¢I PH√ÅP:
    - Reshape l·∫°i th√†nh (num_graphs, 12)
    """
    
    # T·∫°o batch b√¨nh th∆∞·ªùng
    batch = Batch.from_data_list(data_list)
    
    # ===== FIX: RESHAPE batch.y =====
    num_graphs = batch.batch.max().item() + 1
    num_classes = 12
    
    # N·∫øu batch.y b·ªã flatten, reshape l·∫°i
    if batch.y.dim() == 1 and batch.y.shape[0] == num_graphs * num_classes:
        batch.y = batch.y.view(num_graphs, num_classes)
    
    # T∆∞∆°ng t·ª± cho mask
    if hasattr(batch, 'mask') and batch.mask.dim() == 1 and batch.mask.shape[0] == num_graphs * num_classes:
        batch.mask = batch.mask.view(num_graphs, num_classes)
    
    return batch

# ============================================================================
# PH·∫¶N 2: EXTRACT LABELS - ƒê∆†NGI·∫¢N (v√¨ batch.y ƒë√£ (num_graphs, 12))
# ============================================================================

def extract_graph_labels(batch, num_classes, device):
    """
    L·∫•y graph-level labels t·ª´ batch
    
    ƒê∆°n gi·∫£n v√¨ batch.y ƒë√£ ƒë∆∞·ª£c reshape th√†nh (num_graphs, 12)
    """
    
    num_graphs = batch.batch.max().item() + 1
    
    # ===== batch.y ƒë√£ l√† (num_graphs, 12) =====
    if batch.y.shape[0] == num_graphs and batch.y.shape[1] == num_classes:
        return batch.y
    
    # ===== Ho·∫∑c batch.y v·∫´n l√† 1D (fallback) =====
    elif batch.y.dim() == 1 and batch.y.shape[0] == num_graphs * num_classes:
        return batch.y.view(num_graphs, num_classes)
    
    else:
        print(f"‚ö†Ô∏è Unexpected batch.y shape: {batch.y.shape}")
        print(f"   Expected: ({num_graphs}, {num_classes}) or ({num_graphs * num_classes},)")
        return torch.zeros(num_graphs, num_classes, device=device)


## GAT MODEL

In [23]:
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv, GATv2Conv, global_mean_pool, global_max_pool, global_add_pool
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score, hamming_loss, accuracy_score

# ============================================================================
# üöÄ IMPROVED GAT ARCHITECTURE - OPTIMIZED FOR YOUR DATA
# ============================================================================

class GAT(nn.Module):
    """
    State-of-the-art GAT cho multi-label classification v·ªõi imbalanced data:
    
    Features:
    - GATv2Conv (better than GAT)
    - Multiple pooling strategies (mean + max + sum)
    - Batch Normalization
    - Residual connections at graph level
    - Dropout regularization
    - Better weight initialization
    """
    def __init__(self, input_dim=7, hidden_dim=128, num_heads=8,
                 num_layers=3, embedding_dim=512, dropout=0.2, 
                 edge_dim=None, add_self_loops=True):
        super().__init__()
        
        self.num_layers = num_layers
        self.dropout = dropout
        
        # GAT Layers - using GATv2 (more expressive)
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        # Input layer
        self.convs.append(
            GATv2Conv(
                input_dim, 
                hidden_dim, 
                heads=num_heads,
                dropout=dropout,
                concat=True,
                add_self_loops=add_self_loops,
                edge_dim=edge_dim
            )
        )
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim * num_heads))
        
        # Hidden layers
        for i in range(num_layers - 2):
            self.convs.append(
                GATv2Conv(
                    hidden_dim * num_heads,
                    hidden_dim,
                    heads=num_heads,
                    dropout=dropout,
                    concat=True,
                    add_self_loops=add_self_loops,
                    edge_dim=edge_dim
                )
            )
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim * num_heads))
        
        # Final layer - average attention heads
        self.convs.append(
            GATv2Conv(
                hidden_dim * num_heads,
                hidden_dim,
                heads=num_heads,
                dropout=dropout,
                concat=False,  # Average instead of concat
                add_self_loops=add_self_loops,
                edge_dim=edge_dim
            )
        )
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        # Graph-level pooling projection
        # Concat mean + max + sum = 3 * hidden_dim
        pooling_dim = hidden_dim * 3
        
        # MLP for graph embedding
        self.graph_mlp = nn.Sequential(
            nn.Linear(pooling_dim, embedding_dim),
            nn.BatchNorm1d(embedding_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            
            nn.Linear(embedding_dim, embedding_dim),
            nn.BatchNorm1d(embedding_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout)
        )
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        """Xavier initialization for better gradient flow"""
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.BatchNorm1d):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)
    
    def forward(self, x, edge_index, batch, edge_attr=None):
        """
        Args:
            x: Node features [num_nodes, input_dim]
            edge_index: Edge indices [2, num_edges]
            batch: Batch assignment [num_nodes]
            edge_attr: Edge features [num_edges, edge_dim] (optional)
        """
        
        # Message passing
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index, edge_attr=edge_attr)
            x = self.batch_norms[i](x)
            x = torch.relu(x)
            x = nn.functional.dropout(x, p=self.dropout, training=self.training)
        
        # Multiple pooling strategies
        x_mean = global_mean_pool(x, batch)
        x_max = global_max_pool(x, batch)
        x_sum = global_add_pool(x, batch)
        
        # Concatenate poolings
        graph_repr = torch.cat([x_mean, x_max, x_sum], dim=-1)
        
        # Graph-level MLP
        graph_embedding = self.graph_mlp(graph_repr)
        
        return graph_embedding


In [24]:
class MultiLabelClassifier(nn.Module):
    """
    Advanced classifier v·ªõi temperature scaling v√† multiple hidden layers
    """
    def __init__(self, embedding_dim=512, num_classes=12, 
                 hidden_dim=256, dropout=0.3, temperature=1.0):
        super().__init__()
        
        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            
            nn.Linear(hidden_dim // 2, num_classes)
        )
        
        # Temperature parameter for calibration
        self.temperature = nn.Parameter(torch.ones(1) * temperature)
        
        # Initialize
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
    
    def forward(self, x):
        logits = self.classifier(x)
        # Apply temperature scaling
        scaled_logits = logits / torch.clamp(self.temperature, min=0.5, max=2.0)
        return scaled_logits



In [25]:
# ============================================================================
# üî• ASYMMETRIC LOSS - BEST FOR IMBALANCED MULTI-LABEL

class AsymmetricLossOptimized(nn.Module):
    """
    Asymmetric Loss specifically designed for imbalanced multi-label
   
    """
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, 
                 disable_torch_grad_focal_loss=True):
        super().__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
    
    def forward(self, x, y):
        """
        Args:
            x: logits [batch_size, num_classes]
            y: targets [batch_size, num_classes] (0 or 1)
        """
        # Probabilities
        x_sigmoid = torch.sigmoid(x)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            xs_neg = (xs_neg + self.clip).clamp(max=1)

        # Basic CE
        los_pos = y * torch.log(xs_pos.clamp(min=1e-8))
        los_neg = (1 - y) * torch.log(xs_neg.clamp(min=1e-8))
        loss = los_pos + los_neg

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            
            pt0 = xs_pos * y
            pt1 = xs_neg * (1 - y)
            pt = pt0 + pt1
            one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
            one_sided_w = torch.pow(1 - pt, one_sided_gamma)
            
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            
            loss *= one_sided_w

        return -loss.mean()



# Train

In [26]:
def _find_optimal_thresholds(all_probs, all_labels, num_classes):
    """Improved threshold finding using multiple strategies"""
    from sklearn.metrics import precision_recall_curve
    
    optimal_thresholds = []
    per_class_f1 = []
    
    for i in range(num_classes):
        probs = all_probs[:, i]
        labels = all_labels[:, i]
        
        # Skip if no positive samples
        if labels.sum() == 0:
            optimal_thresholds.append(0.5)
            per_class_f1.append(0.0)
            continue
            
        # Strategy 1: Precision-Recall curve (more robust)
        try:
            precision, recall, thresholds = precision_recall_curve(labels, probs)
            f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
            best_idx = np.argmax(f1_scores[:-1])
            thresh_pr = thresholds[best_idx]
            f1_pr = f1_scores[best_idx]
        except:
            thresh_pr = 0.5
            f1_pr = 0.0
        
        # Strategy 2: Adaptive grid search around PR threshold
        mean_prob = probs.mean()
        pos_ratio = labels.mean()
        
        # Smart search range based on class characteristics
        if pos_ratio < 0.1:  # Rare class
            search_min = max(0.01, thresh_pr - 0.1, mean_prob * 0.3)
            search_max = min(0.3, thresh_pr + 0.1, mean_prob * 2.0)
            n_points = 25
        elif pos_ratio > 0.9:  # Common class
            search_min = max(0.7, thresh_pr - 0.1, mean_prob * 0.8)
            search_max = min(0.99, thresh_pr + 0.1, mean_prob * 1.2)
            n_points = 25
        else:
            search_min = max(0.05, thresh_pr - 0.15, mean_prob * 0.5)
            search_max = min(0.95, thresh_pr + 0.15, mean_prob * 1.5)
            n_points = 20
        
        # Fine-tune with grid search
        best_f1 = f1_pr
        best_thresh = thresh_pr
        
        for thresh in np.linspace(search_min, search_max, n_points):
            preds = (probs > thresh).astype(int)
            f1 = f1_score(labels, preds, zero_division=0)
            if f1 > best_f1:
                best_f1 = f1
                best_thresh = thresh
        
        # Ensure threshold is reasonable
        best_thresh = np.clip(best_thresh, 0.05, 0.95)
        
        optimal_thresholds.append(best_thresh)
        per_class_f1.append(best_f1)
    
    return optimal_thresholds, per_class_f1

In [32]:
def _ensure_shape_compatibility(y_true, logits, num_classes, device):
    """Ensure y_true has same shape as logits"""
    if y_true.shape != logits.shape:
        y_true = y_true.view(logits.shape[0], -1)
        if y_true.shape[1] < num_classes:
            pad = num_classes - y_true.shape[1]
            y_true = torch.cat([
                y_true,
                torch.zeros(y_true.shape[0], pad, device=device)
            ], dim=1)
        elif y_true.shape[1] > num_classes:
            y_true = y_true[:, :num_classes]
    return y_true

In [27]:
def train_gat(model, classifier, train_loader, val_loader, device,
                             num_classes=12, epochs=100, patience=20):
    """
    Improved training function with better threshold optimization and error handling
    """
    
    print("\n" + "="*80)
    print("üìä ANALYZING TRAINING DATA")
    print("="*80)
    
    # Calculate class statistics
    all_labels = []
    for batch in train_loader:
        y_true = extract_graph_labels(batch, num_classes, device)
        all_labels.append(y_true.cpu().numpy())
    
    all_labels = np.vstack(all_labels)
    pos_counts = all_labels.sum(axis=0)
    total_samples = len(all_labels)
    
    print("\nüìà Class Distribution:")
    print(f"{'Class':<8} {'Samples':<10} {'Ratio':<10} {'Imbalance'}")
    print("-" * 80)
    
    for i in range(num_classes):
        ratio = pos_counts[i] / total_samples
        imbalance = (total_samples - pos_counts[i]) / (pos_counts[i] + 1e-5)
        print(f"Class {i:<2d} {int(pos_counts[i]):<10} {ratio*100:<9.2f}% {imbalance:<.2f}:1")
    
    # Setup
    model.to(device)
    classifier.to(device)
    
    # Optimizer with different learning rates
    optimizer = torch.optim.AdamW([
        {'params': model.parameters(), 'lr': 0.001, 'weight_decay': 0.01},
        {'params': classifier.parameters(), 'lr': 0.002, 'weight_decay': 0.005}
    ])
    
    # Loss function
    criterion = AsymmetricLossOptimized(
        gamma_neg=4,
        gamma_pos=1,
        clip=0.05
    )
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=8, min_lr=1e-6
    )
    
    # Warmup scheduler
    warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=0.1, total_iters=5
    )
    
    best_f1 = 0.0
    best_epoch = 0
    patience_counter = 0
    
    print("\n" + "="*80)
    print("üöÄ STARTING TRAINING")
    print("="*80)
    
    for epoch in range(epochs):
        # ===== TRAINING =====
        model.train()
        classifier.train()
        
        total_loss = 0.0
        train_probs_list = []
        
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            
            # Forward
            graph_embedding = model(batch.x, batch.edge_index, batch.batch)
            logits = classifier(graph_embedding)
            
            # Get labels
            y_true = extract_graph_labels(batch, num_classes, device)
            
            # Ensure shape compatibility
            y_true = _ensure_shape_compatibility(y_true, logits, num_classes, device)
            
            # Loss & backward
            loss = criterion(logits, y_true.float())
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(
                list(model.parameters()) + list(classifier.parameters()),
                max_norm=1.0
            )
            
            optimizer.step()
            total_loss += loss.item()
            
            # Monitor probabilities
            with torch.no_grad():
                train_probs_list.append(torch.sigmoid(logits).mean().item())
        
        # Warmup for first 5 epochs
        if epoch < 5:
            warmup_scheduler.step()
        
        avg_train_loss = total_loss / len(train_loader)
        avg_train_prob = np.mean(train_probs_list)
        
        # ===== VALIDATION =====
        model.eval()
        classifier.eval()
        
        val_loss = 0.0
        all_probs = []
        all_labels = []
        
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                
                graph_embedding = model(batch.x, batch.edge_index, batch.batch)
                logits = classifier(graph_embedding)
                probs = torch.sigmoid(logits)
                
                y_true = extract_graph_labels(batch, num_classes, device)
                y_true = _ensure_shape_compatibility(y_true, logits, num_classes, device)
                
                loss = criterion(logits, y_true.float())
                val_loss += loss.item()
                
                all_probs.append(probs.detach().cpu().numpy())
                all_labels.append(y_true.detach().cpu().numpy())
        
        all_probs = np.vstack(all_probs)
        all_labels = np.vstack(all_labels)
        
        # IMPROVED: Better threshold optimization
        optimal_thresholds, per_class_f1 = _find_optimal_thresholds(
            all_probs, all_labels, num_classes
        )
        
        # Apply thresholds
        all_preds = (all_probs > optimal_thresholds).astype(int)
        
        # Calculate metrics
        micro_f1 = f1_score(all_labels, all_preds, average='micro', zero_division=0)
        macro_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
        micro_precision = precision_score(all_labels, all_preds, average='micro', zero_division=0)
        micro_recall = recall_score(all_labels, all_preds, average='micro', zero_division=0)
        hamming = hamming_loss(all_labels, all_preds)
        
        avg_val_loss = val_loss / len(val_loader)
        mean_val_prob = all_probs.mean()
        
        # Print progress
        if (epoch + 1) % 5 == 0 or epoch == 0:
            print(f"\n{'='*80}")
            print(f"Epoch {epoch+1}/{epochs}")
            print(f"{'='*80}")
            print(f"üìâ Loss:        Train={avg_train_loss:.4f} | Val={avg_val_loss:.4f}")
            print(f"üìä Mean Prob:   Train={avg_train_prob:.4f} | Val={mean_val_prob:.4f}")
            print(f"üéØ Micro F1:    {micro_f1:.4f} (P={micro_precision:.4f}, R={micro_recall:.4f})")
            print(f"üéØ Macro F1:    {macro_f1:.4f}")
            print(f"üéØ Hamming:     {hamming:.4f}")
            
            # Show class statistics
            low_f1_classes = np.argsort(per_class_f1)[:3]
            print(f"\n‚ö†Ô∏è  Lowest F1 Classes:")
            for i in low_f1_classes:
                pos_ratio = all_labels[:, i].mean()
                print(f"   Class {i}: F1={per_class_f1[i]:.3f}, "
                      f"Thresh={optimal_thresholds[i]:.3f}, "
                      f"PosRatio={pos_ratio:.3f}")
        
        # LR scheduling
        if epoch >= 5:
            scheduler.step(micro_f1)
            
            # Print current learning rate
            if (epoch + 1) % 10 == 0:
                current_lr = optimizer.param_groups[0]['lr']
                print(f"   Current LR: {current_lr:.6f}")
        
        # Save best model
        if micro_f1 > best_f1:
            best_f1 = micro_f1
            best_epoch = epoch
            patience_counter = 0
            
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'classifier_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_f1': best_f1,
                'optimal_thresholds': optimal_thresholds,
                'mean_probability': mean_val_prob,
                'per_class_f1': per_class_f1,
                'all_probs_val': all_probs,  # For analysis
                'all_labels_val': all_labels
            }, 'best_model_improved.pt')
            
            print(f"\n‚úÖ Saved best model!")
            print(f"   Micro F1: {best_f1:.4f}")
            print(f"   Mean Prob: {mean_val_prob:.4f}")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"\n‚ö†Ô∏è  Early stopping triggered at epoch {epoch+1}")
                print(f"   Best F1 was {best_f1:.4f} at epoch {best_epoch+1}")
                break
    
    print("\n" + "="*80)
    print("‚ú® TRAINING COMPLETED")
    print("="*80)
    print(f"üèÜ Best Micro F1:     {best_f1:.4f}")
    print(f"üìç Best Epoch:        {best_epoch+1}")
    print(f"üìä Final Mean Prob:   {mean_val_prob:.4f}")
    print("="*80 + "\n")
    
    # Load best model
    checkpoint = torch.load('best_model_improved.pt', weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    classifier.load_state_dict(checkpoint['classifier_state_dict'])
    
    return model, classifier, checkpoint['optimal_thresholds']


In [28]:
from sklearn.metrics import (
    accuracy_score, f1_score, hamming_loss, jaccard_score,
    precision_score, recall_score, classification_report
)
def evaluate_multilabel(model, classifier, data_loader, device, 
                         num_classes=12, optimal_thresholds=None):
    """
    Fixed evaluation function - no more RuntimeError!
    """
    
    model.eval()
    classifier.eval()
    
    all_probs = []
    all_labels = []
    
    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)
            
            graph_embedding = model(batch.x, batch.edge_index, batch.batch)
            logits = classifier(graph_embedding)
            probs = torch.sigmoid(logits)
            
            y_true = extract_graph_labels(batch, num_classes, device)
            
            if y_true.shape != logits.shape:
                y_true = y_true.view(logits.shape[0], -1)
                if y_true.shape[1] < num_classes:
                    pad = num_classes - y_true.shape[1]
                    y_true = torch.cat([
                        y_true,
                        torch.zeros(y_true.shape[0], pad, device=device)
                    ], dim=1)
                elif y_true.shape[1] > num_classes:
                    y_true = y_true[:, :num_classes]
            
            # FIX: Add .detach()
            all_probs.append(probs.detach().cpu().numpy())
            all_labels.append(y_true.detach().cpu().numpy())
    
    all_probs = np.vstack(all_probs)
    all_labels = np.vstack(all_labels)
    
    # Apply thresholds
    if optimal_thresholds is None:
        optimal_thresholds = [0.05] * num_classes
    
    all_preds = np.zeros_like(all_probs)
    for i in range(num_classes):
        all_preds[:, i] = (all_probs[:, i] > optimal_thresholds[i]).astype(int)
    
    # Metrics
    from sklearn.metrics import accuracy_score
    micro_f1 = f1_score(all_labels, all_preds, average='micro', zero_division=0)
    macro_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    micro_p = precision_score(all_labels, all_preds, average='micro', zero_division=0)
    micro_r = recall_score(all_labels, all_preds, average='micro', zero_division=0)
    hamming = hamming_loss(all_labels, all_preds)
    subset_acc = accuracy_score(all_labels, all_preds)
    
    per_class_f1 = f1_score(all_labels, all_preds, average=None, zero_division=0)
    per_class_p = precision_score(all_labels, all_preds, average=None, zero_division=0)
    per_class_r = recall_score(all_labels, all_preds, average=None, zero_division=0)
    
    print("\n" + "="*80)
    print("üìä FINAL EVALUATION RESULTS")
    print("="*80)
    print(f"\nüéØ Overall Metrics:")
    print(f"   Micro F1:        {micro_f1:.4f}")
    print(f"   Macro F1:        {macro_f1:.4f}")
    print(f"   Micro Precision: {micro_p:.4f}")
    print(f"   Micro Recall:    {micro_r:.4f}")
    print(f"   Hamming Loss:    {hamming:.4f}")
    print(f"   Subset Accuracy: {subset_acc:.4f}")
    
    print(f"\nüìà Per-Class Results:")
    print(f"{'Class':<8} {'Threshold':<12} {'F1':<10} {'Precision':<12} {'Recall':<10} {'MeanProb':<10}")
    print("-" * 80)
    
    class_names = [
        "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase",
        "NR-ER", "NR-ER-LBD", "NR-PPAR-Œ≥",
        "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
    ]
    
    for i in range(num_classes):
        name = class_names[i] if i < len(class_names) else f"Class{i}"
        print(f"{name:<8} {optimal_thresholds[i]:<12.3f} "
              f"{per_class_f1[i]:<10.3f} {per_class_p[i]:<12.3f} "
              f"{per_class_r[i]:<10.3f} {all_probs[:, i].mean():<10.3f}")
    
    print("="*80 + "\n")
    
    return {
        'micro_f1': micro_f1,
        'macro_f1': macro_f1,
        'predictions': all_preds,
        'probabilities': all_probs,
        'labels': all_labels
    }



# Embedding

In [29]:
from tqdm import tqdm

def extract_embeddings(model, classifier, data_loader, device, 
                       num_classes=12, save_path='embeddings.pt'):
    """
    Extract graph embeddings t·ª´ GAT model v√† l∆∞u ra file .pt
    
    Args:
        model: Trained GAT model
        classifier: Trained classifier
        data_loader: DataLoader (train/val/test)
        device: torch device
        num_classes: S·ªë l∆∞·ª£ng classes
        save_path: ƒê∆∞·ªùng d·∫´n file output (.pt)
    
    Returns:
        embeddings_dict: Dictionary ch·ª©a embeddings v√† metadata
    """
    
    print("\n" + "="*80)
    print("üîç EXTRACTING GRAPH EMBEDDINGS")
    print("="*80)
    
    model.to(device)
    classifier.to(device)
    model.eval()
    classifier.eval()
    
    all_embeddings = []
    all_logits = []
    all_probs = []
    all_labels = []
    all_graph_ids = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(data_loader, desc="Processing batches")):
            batch = batch.to(device)
            
            # Extract embeddings t·ª´ GAT
            graph_embedding = model(batch.x, batch.edge_index, batch.batch)
            
            # Get logits v√† probabilities t·ª´ classifier
            logits = classifier(graph_embedding)
            probs = torch.sigmoid(logits)
            
            # Get labels
            y_true = extract_graph_labels(batch, num_classes, device)
            
            # Ensure shape compatibility
            if y_true.shape != logits.shape:
                y_true = y_true.view(logits.shape[0], -1)
                if y_true.shape[1] < num_classes:
                    pad = num_classes - y_true.shape[1]
                    y_true = torch.cat([
                        y_true,
                        torch.zeros(y_true.shape[0], pad, device=device)
                    ], dim=1)
                elif y_true.shape[1] > num_classes:
                    y_true = y_true[:, :num_classes]
            
            # Store embeddings
            all_embeddings.append(graph_embedding.cpu())
            all_logits.append(logits.cpu())
            all_probs.append(probs.cpu())
            all_labels.append(y_true.cpu())
            
            # Store graph IDs n·∫øu c√≥
            if hasattr(batch, 'graph_id'):
                all_graph_ids.extend(batch.graph_id.cpu().tolist())
            else:
                # Generate sequential IDs
                batch_size = graph_embedding.shape[0]
                start_id = batch_idx * data_loader.batch_size
                all_graph_ids.extend(range(start_id, start_id + batch_size))
    
    # Concatenate all batches
    embeddings = torch.cat(all_embeddings, dim=0)
    logits = torch.cat(all_logits, dim=0)
    probs = torch.cat(all_probs, dim=0)
    labels = torch.cat(all_labels, dim=0)
    
    print(f"\n‚úÖ Extracted {embeddings.shape[0]} graph embeddings")
    print(f"   Embedding dimension: {embeddings.shape[1]}")
    print(f"   Number of classes: {num_classes}")
    
    # Create embeddings dictionary
    embeddings_dict = {
        'embeddings': embeddings,              # [N, embedding_dim]
        'logits': logits,                      # [N, num_classes]
        'probabilities': probs,                # [N, num_classes]
        'labels': labels,                      # [N, num_classes]
        'graph_ids': all_graph_ids,           # List of graph IDs
        'embedding_dim': embeddings.shape[1],
        'num_graphs': embeddings.shape[0],
        'num_classes': num_classes,
        'model_info': {
            'model_type': model.__class__.__name__,
            'classifier_type': classifier.__class__.__name__,
        }
    }
    
    # Save to file
    torch.save(embeddings_dict, save_path)
    print(f"\nüíæ Saved embeddings to: {save_path}")
    print(f"   File size: {os.path.getsize(save_path) / (1024*1024):.2f} MB")
    
    # Print statistics
    print(f"\nüìä EMBEDDING STATISTICS:")
    print(f"   Mean: {embeddings.mean().item():.4f}")
    print(f"   Std:  {embeddings.std().item():.4f}")
    print(f"   Min:  {embeddings.min().item():.4f}")
    print(f"   Max:  {embeddings.max().item():.4f}")
    
    print(f"\nüìä PROBABILITY STATISTICS:")
    print(f"   Mean: {probs.mean().item():.4f}")
    print(f"   Std:  {probs.std().item():.4f}")
    
    print("="*80 + "\n")
    
    return embeddings_dict



In [30]:
def extract_embeddings_by_split(model, classifier, train_loader, val_loader, 
                                test_loader, device, num_classes=12, 
                                output_dir='embeddings'):
    """
    Extract embeddings cho t·∫•t c·∫£ c√°c splits (train/val/test) v√† l∆∞u ri√™ng
    
    Args:
        model: Trained GAT model
        classifier: Trained classifier
        train_loader, val_loader, test_loader: DataLoaders
        device: torch device
        num_classes: S·ªë l∆∞·ª£ng classes
        output_dir: Th∆∞ m·ª•c output
    
    Returns:
        Dictionary ch·ª©a paths ƒë·∫øn c√°c embedding files
    """
    import os
    os.makedirs(output_dir, exist_ok=True)
    
    paths = {}
    
    # Extract train embeddings
    if train_loader is not None:
        print("\nüîπ Extracting TRAIN embeddings...")
        train_path = os.path.join(output_dir, 'train_embeddings.pt')
        extract_embeddings(model, classifier, train_loader, device, 
                          num_classes, train_path)
        paths['train'] = train_path
    
    # Extract validation embeddings
    if val_loader is not None:
        print("\nüîπ Extracting VALIDATION embeddings...")
        val_path = os.path.join(output_dir, 'val_embeddings.pt')
        extract_embeddings(model, classifier, val_loader, device, 
                          num_classes, val_path)
        paths['val'] = val_path
    
    # Extract test embeddings
    if test_loader is not None:
        print("\nüîπ Extracting TEST embeddings...")
        test_path = os.path.join(output_dir, 'test_embeddings.pt')
        extract_embeddings(model, classifier, test_loader, device, 
                          num_classes, test_path)
        paths['test'] = test_path
    
    print("\n" + "="*80)
    print("‚ú® ALL EMBEDDINGS EXTRACTED SUCCESSFULLY")
    print("="*80)
    for split, path in paths.items():
        print(f"   {split.upper()}: {path}")
    print("="*80 + "\n")
    
    return paths


In [33]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"‚úÖ Dataset: {len(dataset)} samples")
print(f"   Train: {len(train_dataset)} samples")
print(f"   Val:   {len(val_dataset)} samples")
print(f"   Test:  {len(test_dataset)} samples\n")

# ============================================================================
# SETUP DEVICE
# ============================================================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üì± Device: {device}\n")

# ============================================================================
# INITIALIZE MODELS (FIX: X√≥a d√≤ng classifier = nn.Linear th·ª´a)
# ============================================================================
print("üîß Initializing models...")

# 1. GAT Model
model = GAT(
    input_dim=7,          # s·ªë ƒë·∫∑c tr∆∞ng input c·ªßa node
    hidden_dim=128,       # hidden dimension
    num_heads=8,          # s·ªë attention heads
    num_layers=3,         # s·ªë layer GAT
    embedding_dim=512,    # k√≠ch th∆∞·ªõc embedding cu·ªëi
    dropout=0.2
)

# 2. Classifier (FIX: Ch·ªâ kh·ªüi t·∫°o 1 l·∫ßn)
classifier = MultiLabelClassifier(
    embedding_dim=512,
    num_classes=12,
    hidden_dim=256,
    dropout=0.3,
    temperature=1.0
)

print(f"‚úÖ GAT parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"‚úÖ Classifier parameters: {sum(p.numel() for p in classifier.parameters()):,}\n")

# ============================================================================
# TRAINING
# ============================================================================
print("üöÄ Starting training...\n")

model, classifier, optimal_thresholds = train_gat(
    model=model,
    classifier=classifier,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    num_classes=12,
    epochs=200,
    patience=20
)

# ============================================================================
# EVALUATION ON TEST SET
# ============================================================================
print("\nüìä Evaluating on test set...")

test_results = evaluate_multilabel(
    model=model,
    classifier=classifier,
    data_loader=test_loader,
    device=device,
    num_classes=12,
    optimal_thresholds=optimal_thresholds
)

# Print final results
print("\n" + "="*80)
print("üèÜ FINAL TEST RESULTS")
print("="*80)
print(f"Micro F1:       {test_results['micro_f1']:.4f}")
print(f"Macro F1:       {test_results['macro_f1']:.4f}")




print("="*80 + "\n")

# ============================================================================
# EXTRACT EMBEDDINGS (OPTIONAL)
# ============================================================================
print("üîç Extracting embeddings from trained model...")

embeddings_paths = extract_embeddings_by_split(
    model=model,
    classifier=classifier,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    device=device,
    num_classes=12,
    output_dir='embeddings'
)

  train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
  val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
  test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


‚úÖ Dataset: 8014 samples
   Train: 6411 samples
   Val:   801 samples
   Test:  802 samples

üì± Device: cuda

üîß Initializing models...
‚úÖ GAT parameters: 4,686,208
‚úÖ Classifier parameters: 166,541

üöÄ Starting training...


üìä ANALYZING TRAINING DATA

üìà Class Distribution:
Class    Samples    Ratio      Imbalance
--------------------------------------------------------------------------------
Class 0  255        3.98     % 24.14:1
Class 1  188        2.93     % 33.10:1
Class 2  643        10.03    % 8.97:1
Class 3  223        3.48     % 27.75:1
Class 4  644        10.05    % 8.95:1
Class 5  281        4.38     % 21.81:1
Class 6  154        2.40     % 40.63:1
Class 7  774        12.07    % 7.28:1
Class 8  228        3.56     % 27.12:1
Class 9  298        4.65     % 20.51:1
Class 10 755        11.78    % 7.49:1
Class 11 345        5.38     % 17.58:1

üöÄ STARTING TRAINING

Epoch 1/200
üìâ Loss:        Train=0.0886 | Val=0.0421
üìä Mean Prob:   Train=0.4334 | Val=0.4450




‚úÖ Saved best model!
   Micro F1: 0.2292
   Mean Prob: 0.4592





‚úÖ Saved best model!
   Micro F1: 0.2560
   Mean Prob: 0.4631





‚úÖ Saved best model!
   Micro F1: 0.2880
   Mean Prob: 0.4739





Epoch 5/200
üìâ Loss:        Train=0.0418 | Val=0.0376
üìä Mean Prob:   Train=0.4535 | Val=0.4643
üéØ Micro F1:    0.3009 (P=0.2393, R=0.4050)
üéØ Macro F1:    0.3027
üéØ Hamming:     0.1165

‚ö†Ô∏è  Lowest F1 Classes:
   Class 6: F1=0.167, Thresh=0.511, PosRatio=0.022
   Class 9: F1=0.168, Thresh=0.503, PosRatio=0.051
   Class 3: F1=0.193, Thresh=0.481, PosRatio=0.051

‚úÖ Saved best model!
   Micro F1: 0.3009
   Mean Prob: 0.4643





‚úÖ Saved best model!
   Micro F1: 0.3119
   Mean Prob: 0.4622





‚úÖ Saved best model!
   Micro F1: 0.3287
   Mean Prob: 0.4556





Epoch 10/200
üìâ Loss:        Train=0.0375 | Val=0.0340
üìä Mean Prob:   Train=0.4506 | Val=0.4304
üéØ Micro F1:    0.3185 (P=0.2778, R=0.3731)
üéØ Macro F1:    0.3201
üéØ Hamming:     0.0988

‚ö†Ô∏è  Lowest F1 Classes:
   Class 6: F1=0.143, Thresh=0.486, PosRatio=0.022
   Class 3: F1=0.222, Thresh=0.505, PosRatio=0.051
   Class 11: F1=0.223, Thresh=0.506, PosRatio=0.060
   Current LR: 0.001000





‚úÖ Saved best model!
   Micro F1: 0.3511
   Mean Prob: 0.4501





Epoch 15/200
üìâ Loss:        Train=0.0367 | Val=0.0338
üìä Mean Prob:   Train=0.4510 | Val=0.4418
üéØ Micro F1:    0.3417 (P=0.2694, R=0.4672)
üéØ Macro F1:    0.3544
üéØ Hamming:     0.1114

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.189, Thresh=0.445, PosRatio=0.051
   Class 6: F1=0.244, Thresh=0.498, PosRatio=0.022
   Class 11: F1=0.277, Thresh=0.529, PosRatio=0.060





‚úÖ Saved best model!
   Micro F1: 0.3779
   Mean Prob: 0.4584





‚úÖ Saved best model!
   Micro F1: 0.3827
   Mean Prob: 0.4322





Epoch 20/200
üìâ Loss:        Train=0.0356 | Val=0.0338
üìä Mean Prob:   Train=0.4438 | Val=0.4471
üéØ Micro F1:    0.3884 (P=0.3350, R=0.4622)
üéØ Macro F1:    0.3961
üéØ Hamming:     0.0901

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.221, Thresh=0.511, PosRatio=0.051
   Class 6: F1=0.312, Thresh=0.547, PosRatio=0.022
   Class 11: F1=0.312, Thresh=0.524, PosRatio=0.060
   Current LR: 0.001000

‚úÖ Saved best model!
   Micro F1: 0.3884
   Mean Prob: 0.4471





‚úÖ Saved best model!
   Micro F1: 0.3906
   Mean Prob: 0.4534





Epoch 25/200
üìâ Loss:        Train=0.0352 | Val=0.0376
üìä Mean Prob:   Train=0.4432 | Val=0.4652
üéØ Micro F1:    0.3930 (P=0.3485, R=0.4504)
üéØ Macro F1:    0.3975
üéØ Hamming:     0.0861

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.235, Thresh=0.537, PosRatio=0.051
   Class 9: F1=0.261, Thresh=0.547, PosRatio=0.051
   Class 6: F1=0.291, Thresh=0.534, PosRatio=0.022

‚úÖ Saved best model!
   Micro F1: 0.3930
   Mean Prob: 0.4652





‚úÖ Saved best model!
   Micro F1: 0.4149
   Mean Prob: 0.4483





Epoch 30/200
üìâ Loss:        Train=0.0346 | Val=0.0307
üìä Mean Prob:   Train=0.4415 | Val=0.4226
üéØ Micro F1:    0.4030 (P=0.3669, R=0.4471)
üéØ Macro F1:    0.3993
üéØ Hamming:     0.0820

‚ö†Ô∏è  Lowest F1 Classes:
   Class 6: F1=0.219, Thresh=0.476, PosRatio=0.022
   Class 3: F1=0.250, Thresh=0.517, PosRatio=0.051
   Class 11: F1=0.307, Thresh=0.520, PosRatio=0.060
   Current LR: 0.001000





‚úÖ Saved best model!
   Micro F1: 0.4238
   Mean Prob: 0.4428





Epoch 35/200
üìâ Loss:        Train=0.0341 | Val=0.0367
üìä Mean Prob:   Train=0.4367 | Val=0.4629
üéØ Micro F1:    0.4009 (P=0.3704, R=0.4370)
üéØ Macro F1:    0.3955
üéØ Hamming:     0.0808

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.250, Thresh=0.525, PosRatio=0.051
   Class 6: F1=0.261, Thresh=0.604, PosRatio=0.022
   Class 8: F1=0.279, Thresh=0.589, PosRatio=0.022





Epoch 40/200
üìâ Loss:        Train=0.0335 | Val=0.0362
üìä Mean Prob:   Train=0.4342 | Val=0.4479
üéØ Micro F1:    0.3944 (P=0.3369, R=0.4756)
üéØ Macro F1:    0.4063
üéØ Hamming:     0.0904

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.241, Thresh=0.494, PosRatio=0.051
   Class 9: F1=0.289, Thresh=0.559, PosRatio=0.051
   Class 6: F1=0.292, Thresh=0.529, PosRatio=0.022
   Current LR: 0.001000





‚úÖ Saved best model!
   Micro F1: 0.4287
   Mean Prob: 0.4575





Epoch 45/200
üìâ Loss:        Train=0.0321 | Val=0.0353
üìä Mean Prob:   Train=0.4268 | Val=0.4477
üéØ Micro F1:    0.4276 (P=0.3930, R=0.4689)
üéØ Macro F1:    0.4463
üéØ Hamming:     0.0777

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.263, Thresh=0.495, PosRatio=0.051
   Class 9: F1=0.328, Thresh=0.530, PosRatio=0.051
   Class 11: F1=0.350, Thresh=0.635, PosRatio=0.060





‚úÖ Saved best model!
   Micro F1: 0.4365
   Mean Prob: 0.4402





‚úÖ Saved best model!
   Micro F1: 0.4563
   Mean Prob: 0.4404





Epoch 50/200
üìâ Loss:        Train=0.0316 | Val=0.0351
üìä Mean Prob:   Train=0.4205 | Val=0.4460
üéØ Micro F1:    0.4419 (P=0.4277, R=0.4571)
üéØ Macro F1:    0.4500
üéØ Hamming:     0.0715

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.262, Thresh=0.543, PosRatio=0.051
   Class 9: F1=0.330, Thresh=0.552, PosRatio=0.051
   Class 11: F1=0.341, Thresh=0.632, PosRatio=0.060
   Current LR: 0.000500





‚úÖ Saved best model!
   Micro F1: 0.4789
   Mean Prob: 0.4339





Epoch 55/200
üìâ Loss:        Train=0.0312 | Val=0.0318
üìä Mean Prob:   Train=0.4177 | Val=0.4303
üéØ Micro F1:    0.4583 (P=0.4112, R=0.5176)
üéØ Macro F1:    0.4722
üéØ Hamming:     0.0757

‚ö†Ô∏è  Lowest F1 Classes:
   Class 11: F1=0.327, Thresh=0.542, PosRatio=0.060
   Class 9: F1=0.331, Thresh=0.509, PosRatio=0.051
   Class 3: F1=0.349, Thresh=0.543, PosRatio=0.051





Epoch 60/200
üìâ Loss:        Train=0.0306 | Val=0.0310
üìä Mean Prob:   Train=0.4110 | Val=0.4168
üéØ Micro F1:    0.4705 (P=0.4423, R=0.5025)
üéØ Macro F1:    0.4743
üéØ Hamming:     0.0700

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.310, Thresh=0.576, PosRatio=0.051
   Class 11: F1=0.343, Thresh=0.550, PosRatio=0.060
   Class 9: F1=0.366, Thresh=0.539, PosRatio=0.051
   Current LR: 0.000500





‚úÖ Saved best model!
   Micro F1: 0.4870
   Mean Prob: 0.4197





Epoch 65/200
üìâ Loss:        Train=0.0294 | Val=0.0321
üìä Mean Prob:   Train=0.4068 | Val=0.4198
üéØ Micro F1:    0.4801 (P=0.5056, R=0.4571)
üéØ Macro F1:    0.4893
üéØ Hamming:     0.0613

‚ö†Ô∏è  Lowest F1 Classes:
   Class 11: F1=0.343, Thresh=0.585, PosRatio=0.060
   Class 3: F1=0.361, Thresh=0.595, PosRatio=0.051
   Class 9: F1=0.391, Thresh=0.560, PosRatio=0.051





Epoch 70/200
üìâ Loss:        Train=0.0292 | Val=0.0320
üìä Mean Prob:   Train=0.4025 | Val=0.4208
üéØ Micro F1:    0.4744 (P=0.4297, R=0.5294)
üéØ Macro F1:    0.4908
üéØ Hamming:     0.0726

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.337, Thresh=0.545, PosRatio=0.051
   Class 11: F1=0.354, Thresh=0.566, PosRatio=0.060
   Class 9: F1=0.380, Thresh=0.536, PosRatio=0.051
   Current LR: 0.000250





‚úÖ Saved best model!
   Micro F1: 0.4911
   Mean Prob: 0.4103





Epoch 75/200
üìâ Loss:        Train=0.0283 | Val=0.0320
üìä Mean Prob:   Train=0.3944 | Val=0.4151
üéØ Micro F1:    0.4742 (P=0.4558, R=0.4941)
üéØ Macro F1:    0.4823
üéØ Hamming:     0.0678

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.340, Thresh=0.542, PosRatio=0.051
   Class 9: F1=0.348, Thresh=0.592, PosRatio=0.051
   Class 11: F1=0.390, Thresh=0.597, PosRatio=0.060





Epoch 80/200
üìâ Loss:        Train=0.0280 | Val=0.0301
üìä Mean Prob:   Train=0.3917 | Val=0.3934
üéØ Micro F1:    0.4869 (P=0.4622, R=0.5143)
üéØ Macro F1:    0.4979
üéØ Hamming:     0.0671

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.351, Thresh=0.536, PosRatio=0.051
   Class 9: F1=0.387, Thresh=0.536, PosRatio=0.051
   Class 11: F1=0.393, Thresh=0.582, PosRatio=0.060
   Current LR: 0.000125





‚úÖ Saved best model!
   Micro F1: 0.4926
   Mean Prob: 0.3954





Epoch 85/200
üìâ Loss:        Train=0.0276 | Val=0.0298
üìä Mean Prob:   Train=0.3884 | Val=0.3953
üéØ Micro F1:    0.4882 (P=0.4580, R=0.5227)
üéØ Macro F1:    0.4971
üéØ Hamming:     0.0678

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.374, Thresh=0.517, PosRatio=0.051
   Class 9: F1=0.400, Thresh=0.549, PosRatio=0.051
   Class 11: F1=0.403, Thresh=0.566, PosRatio=0.060





Epoch 90/200
üìâ Loss:        Train=0.0275 | Val=0.0304
üìä Mean Prob:   Train=0.3860 | Val=0.3949
üéØ Micro F1:    0.4933 (P=0.4619, R=0.5294)
üéØ Macro F1:    0.4986
üéØ Hamming:     0.0673

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.365, Thresh=0.530, PosRatio=0.051
   Class 11: F1=0.377, Thresh=0.565, PosRatio=0.060
   Class 9: F1=0.395, Thresh=0.550, PosRatio=0.051
   Current LR: 0.000125

‚úÖ Saved best model!
   Micro F1: 0.4933
   Mean Prob: 0.3949





Epoch 95/200
üìâ Loss:        Train=0.0272 | Val=0.0308
üìä Mean Prob:   Train=0.3831 | Val=0.3936
üéØ Micro F1:    0.4898 (P=0.4776, R=0.5025)
üéØ Macro F1:    0.4931
üéØ Hamming:     0.0648

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.330, Thresh=0.532, PosRatio=0.051
   Class 9: F1=0.391, Thresh=0.552, PosRatio=0.051
   Class 11: F1=0.406, Thresh=0.564, PosRatio=0.060





Epoch 100/200
üìâ Loss:        Train=0.0268 | Val=0.0304
üìä Mean Prob:   Train=0.3817 | Val=0.3906
üéØ Micro F1:    0.4808 (P=0.4594, R=0.5042)
üéØ Macro F1:    0.4931
üéØ Hamming:     0.0674

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.337, Thresh=0.533, PosRatio=0.051
   Class 9: F1=0.368, Thresh=0.551, PosRatio=0.051
   Class 11: F1=0.373, Thresh=0.559, PosRatio=0.060
   Current LR: 0.000063





Epoch 105/200
üìâ Loss:        Train=0.0265 | Val=0.0305
üìä Mean Prob:   Train=0.3776 | Val=0.3831
üéØ Micro F1:    0.4804 (P=0.4574, R=0.5059)
üéØ Macro F1:    0.4804
üéØ Hamming:     0.0677

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.330, Thresh=0.533, PosRatio=0.051
   Class 9: F1=0.348, Thresh=0.588, PosRatio=0.051
   Class 11: F1=0.377, Thresh=0.557, PosRatio=0.060





‚úÖ Saved best model!
   Micro F1: 0.4938
   Mean Prob: 0.3833





Epoch 110/200
üìâ Loss:        Train=0.0266 | Val=0.0307
üìä Mean Prob:   Train=0.3757 | Val=0.3848
üéØ Micro F1:    0.4848 (P=0.4532, R=0.5210)
üéØ Macro F1:    0.4919
üéØ Hamming:     0.0686

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.353, Thresh=0.529, PosRatio=0.051
   Class 9: F1=0.391, Thresh=0.556, PosRatio=0.051
   Class 11: F1=0.406, Thresh=0.560, PosRatio=0.060
   Current LR: 0.000063





Epoch 115/200
üìâ Loss:        Train=0.0266 | Val=0.0300
üìä Mean Prob:   Train=0.3746 | Val=0.3841
üéØ Micro F1:    0.4987 (P=0.4967, R=0.5008)
üéØ Macro F1:    0.4956
üéØ Hamming:     0.0623

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.328, Thresh=0.586, PosRatio=0.051
   Class 9: F1=0.368, Thresh=0.538, PosRatio=0.051
   Class 11: F1=0.417, Thresh=0.577, PosRatio=0.060

‚úÖ Saved best model!
   Micro F1: 0.4987
   Mean Prob: 0.3841





Epoch 120/200
üìâ Loss:        Train=0.0261 | Val=0.0299
üìä Mean Prob:   Train=0.3715 | Val=0.3779
üéØ Micro F1:    0.5000 (P=0.4835, R=0.5176)
üéØ Macro F1:    0.5041
üéØ Hamming:     0.0641

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.328, Thresh=0.586, PosRatio=0.051
   Class 9: F1=0.384, Thresh=0.531, PosRatio=0.051
   Class 11: F1=0.441, Thresh=0.554, PosRatio=0.060
   Current LR: 0.000063

‚úÖ Saved best model!
   Micro F1: 0.5000
   Mean Prob: 0.3779





Epoch 125/200
üìâ Loss:        Train=0.0261 | Val=0.0304
üìä Mean Prob:   Train=0.3698 | Val=0.3734
üéØ Micro F1:    0.4919 (P=0.4743, R=0.5109)
üéØ Macro F1:    0.4948
üéØ Hamming:     0.0653

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.324, Thresh=0.575, PosRatio=0.051
   Class 9: F1=0.382, Thresh=0.555, PosRatio=0.051
   Class 11: F1=0.415, Thresh=0.562, PosRatio=0.060





Epoch 130/200
üìâ Loss:        Train=0.0258 | Val=0.0306
üìä Mean Prob:   Train=0.3690 | Val=0.3791
üéØ Micro F1:    0.4874 (P=0.4457, R=0.5378)
üéØ Macro F1:    0.4930
üéØ Hamming:     0.0700

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.326, Thresh=0.506, PosRatio=0.051
   Class 9: F1=0.358, Thresh=0.548, PosRatio=0.051
   Class 11: F1=0.406, Thresh=0.567, PosRatio=0.060
   Current LR: 0.000031





‚úÖ Saved best model!
   Micro F1: 0.5056
   Mean Prob: 0.3706





Epoch 135/200
üìâ Loss:        Train=0.0258 | Val=0.0307
üìä Mean Prob:   Train=0.3679 | Val=0.3761
üéØ Micro F1:    0.4863 (P=0.4559, R=0.5210)
üéØ Macro F1:    0.4936
üéØ Hamming:     0.0681

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.317, Thresh=0.513, PosRatio=0.051
   Class 9: F1=0.400, Thresh=0.561, PosRatio=0.051
   Class 11: F1=0.435, Thresh=0.577, PosRatio=0.060





Epoch 140/200
üìâ Loss:        Train=0.0258 | Val=0.0303
üìä Mean Prob:   Train=0.3672 | Val=0.3750
üéØ Micro F1:    0.4947 (P=0.4781, R=0.5126)
üéØ Macro F1:    0.4913
üéØ Hamming:     0.0648

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.324, Thresh=0.590, PosRatio=0.051
   Class 9: F1=0.383, Thresh=0.534, PosRatio=0.051
   Class 11: F1=0.412, Thresh=0.557, PosRatio=0.060
   Current LR: 0.000031





Epoch 145/200
üìâ Loss:        Train=0.0256 | Val=0.0298
üìä Mean Prob:   Train=0.3661 | Val=0.3721
üéØ Micro F1:    0.4980 (P=0.4841, R=0.5126)
üéØ Macro F1:    0.4944
üéØ Hamming:     0.0640

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.333, Thresh=0.585, PosRatio=0.051
   Class 9: F1=0.380, Thresh=0.518, PosRatio=0.051
   Class 11: F1=0.443, Thresh=0.570, PosRatio=0.060





Epoch 150/200
üìâ Loss:        Train=0.0254 | Val=0.0309
üìä Mean Prob:   Train=0.3656 | Val=0.3769
üéØ Micro F1:    0.4963 (P=0.4840, R=0.5092)
üéØ Macro F1:    0.5024
üéØ Hamming:     0.0640

‚ö†Ô∏è  Lowest F1 Classes:
   Class 3: F1=0.328, Thresh=0.522, PosRatio=0.051
   Class 9: F1=0.400, Thresh=0.541, PosRatio=0.051
   Class 11: F1=0.422, Thresh=0.577, PosRatio=0.060
   Current LR: 0.000008





‚ö†Ô∏è  Early stopping triggered at epoch 152
   Best F1 was 0.5056 at epoch 132

‚ú® TRAINING COMPLETED
üèÜ Best Micro F1:     0.5056
üìç Best Epoch:        132
üìä Final Mean Prob:   0.3840


üìä Evaluating on test set...

üìä FINAL EVALUATION RESULTS

üéØ Overall Metrics:
   Micro F1:        0.4223
   Macro F1:        0.3892
   Micro Precision: 0.4038
   Micro Recall:    0.4425
   Hamming Loss:    0.0722
   Subset Accuracy: 0.5835

üìà Per-Class Results:
Class    Threshold    F1         Precision    Recall     MeanProb  
--------------------------------------------------------------------------------
NR-AR    0.732        0.533      0.800        0.400      0.380     
NR-AR-LBD 0.592        0.488      0.476        0.500      0.287     
NR-AhR   0.606        0.518      0.440        0.629      0.408     
NR-Aromatase 0.596        0.317      0.500        0.233      0.341     
NR-ER    0.656        0.250      0.438        0.175      0.481     
NR-ER-LBD 0.661        0.320      0.

Processing batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 201/201 [00:01<00:00, 147.55it/s]



‚úÖ Extracted 6411 graph embeddings
   Embedding dimension: 512
   Number of classes: 12

üíæ Saved embeddings to: embeddings/train_embeddings.pt
   File size: 13.42 MB

üìä EMBEDDING STATISTICS:
   Mean: 0.2929
   Std:  0.4950
   Min:  0.0000
   Max:  21.8663

üìä PROBABILITY STATISTICS:
   Mean: 0.3699
   Std:  0.1538


üîπ Extracting VALIDATION embeddings...

üîç EXTRACTING GRAPH EMBEDDINGS


Processing batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 26/26 [00:00<00:00, 149.86it/s]



‚úÖ Extracted 801 graph embeddings
   Embedding dimension: 512
   Number of classes: 12

üíæ Saved embeddings to: embeddings/val_embeddings.pt
   File size: 1.68 MB

üìä EMBEDDING STATISTICS:
   Mean: 0.2857
   Std:  0.4648
   Min:  0.0000
   Max:  9.7805

üìä PROBABILITY STATISTICS:
   Mean: 0.3706
   Std:  0.1492


üîπ Extracting TEST embeddings...

üîç EXTRACTING GRAPH EMBEDDINGS


Processing batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 26/26 [00:00<00:00, 153.27it/s]


‚úÖ Extracted 802 graph embeddings
   Embedding dimension: 512
   Number of classes: 12

üíæ Saved embeddings to: embeddings/test_embeddings.pt
   File size: 1.68 MB

üìä EMBEDDING STATISTICS:
   Mean: 0.2921
   Std:  0.4867
   Min:  0.0000
   Max:  11.2754

üìä PROBABILITY STATISTICS:
   Mean: 0.3727
   Std:  0.1492


‚ú® ALL EMBEDDINGS EXTRACTED SUCCESSFULLY
   TRAIN: embeddings/train_embeddings.pt
   VAL: embeddings/val_embeddings.pt
   TEST: embeddings/test_embeddings.pt




