# ChemBERTA model 

## imports

In [1]:
import os, json, math, time, random
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score
from sklearn.utils import resample

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup

try:
    from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
except Exception:
    MultilabelStratifiedShuffleSplit = None
    print("iterstrat not available; will fall back to random split.")

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

## dataset paths

In [2]:
# Prefer LightGBM CSV so we share labels/splits
LGB_CSV = Path("tox21_lightgb_pipeline/Data_v6/processed/tox21.csv")
CHEMBERTA_CSV = Path("tox21_chembera_pipeline/Data_v6/processed/tox21.csv")
DATA_CSV = LGB_CSV if LGB_CSV.exists() else CHEMBERTA_CSV

ROOT = Path("tox21_chembera_pipeline")
MODELS_DIR = ROOT / "models" / "chemberta_v1"
SPLITS_DIR = ROOT / "Data_v6" / "splits"
OUT_DIR = ROOT / "outputs"

for p in [MODELS_DIR, SPLITS_DIR, OUT_DIR]:
    p.mkdir(parents=True, exist_ok=True)

print("Using data:", DATA_CSV.resolve())
print("Saving models to:", MODELS_DIR.resolve())
print("Saving outputs to:", OUT_DIR.resolve())


Using data: D:\Coding Projects\Predicting-Drug-Response-Using-Multi-Omics-Data-with-XAI\tox21_lightgb_pipeline\Data_v6\processed\tox21.csv
Saving models to: D:\Coding Projects\Predicting-Drug-Response-Using-Multi-Omics-Data-with-XAI\tox21_chembera_pipeline\models\chemberta_v1
Saving outputs to: D:\Coding Projects\Predicting-Drug-Response-Using-Multi-Omics-Data-with-XAI\tox21_chembera_pipeline\outputs


## loading CSV, Labels/meta columns

In [3]:
df = pd.read_csv(DATA_CSV)
display(df.head())

# Assumption: first 12 columns are labels, then 'mol_id','smiles'
all_cols = df.columns.tolist()
assert 'smiles' in all_cols, "CSV must contain a 'smiles' column."

label_cols = all_cols[:12]
meta_cols = all_cols[12:]
print("Label columns:", label_cols)
print("Meta columns:", meta_cols)

# --- Clean labels: coerce to numeric, replace inf with NaN, then fill NaN with 0 and cast to int ---
for c in label_cols:
    df[c] = pd.to_numeric(df[c], errors='coerce')  # convert strings like 'nan' to NaN
# replace +/-inf -> NaN
df[label_cols] = df[label_cols].replace([np.inf, -np.inf], np.nan)

# report NaNs before filling
nan_counts = df[label_cols].isna().sum()
if nan_counts.sum() > 0:
    print("NaNs detected in label columns (filling with 0):")
    print(nan_counts[nan_counts > 0])

# fill NaNs with 0 (common practice for Tox21 preprocessing); clip to [0,1] and cast to int
df[label_cols] = df[label_cols].fillna(0.0).clip(lower=0.0, upper=1.0).astype(int)

# sanity check
unique_vals = {c: sorted(df[c].unique().tolist()) for c in label_cols}
print("Unique values per label (should be [0,1]):")
print(unique_vals)

Y = df[label_cols].values.astype(np.float32)
len(df), Y.shape


Unnamed: 0,NR-AR,NR-AR-LBD,NR-AhR,NR-Aromatase,NR-ER,NR-ER-LBD,NR-PPAR-gamma,SR-ARE,SR-ATAD5,SR-HSE,SR-MMP,SR-p53,mol_id,smiles
0,0.0,0.0,1.0,,,0.0,0.0,1.0,0.0,0.0,0.0,0.0,TOX3021,CCOc1ccc2nc(S(N)(=O)=O)sc2c1
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,,0.0,0.0,TOX3020,CCN1C(=O)NC(c2ccccc2)C1=O
2,,,,,,,,0.0,,0.0,,,TOX3024,CC[C@]1(O)CC[C@H]2[C@@H]3CCC4=CCCC[C@@H]4[C@H]...
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,,0.0,0.0,TOX3027,CCCN(CC)C(CC)C(=O)Nc1c(C)cccc1C
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,TOX20800,CC(O)(P(=O)(O)O)P(=O)(O)O


Label columns: ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']
Meta columns: ['mol_id', 'smiles']
NaNs detected in label columns (filling with 0):
NR-AR             566
NR-AR-LBD        1073
NR-AhR           1282
NR-Aromatase     2010
NR-ER            1638
NR-ER-LBD         876
NR-PPAR-gamma    1381
SR-ARE           1999
SR-ATAD5          759
SR-HSE           1364
SR-MMP           2021
SR-p53           1057
dtype: int64
Unique values per label (should be [0,1]):
{'NR-AR': [0, 1], 'NR-AR-LBD': [0, 1], 'NR-AhR': [0, 1], 'NR-Aromatase': [0, 1], 'NR-ER': [0, 1], 'NR-ER-LBD': [0, 1], 'NR-PPAR-gamma': [0, 1], 'SR-ARE': [0, 1], 'SR-ATAD5': [0, 1], 'SR-HSE': [0, 1], 'SR-MMP': [0, 1], 'SR-p53': [0, 1]}


(7831, (7831, 12))

## creating our data folds, similar to the LightGBM model

In [4]:
train_idx_path = SPLITS_DIR / "train_idx.npy"
val_idx_path   = SPLITS_DIR / "val_idx.npy"
test_idx_path  = SPLITS_DIR / "test_idx.npy"

if train_idx_path.exists():
    train_idx = np.load(train_idx_path)
    val_idx = np.load(val_idx_path)
    test_idx = np.load(test_idx_path)
    print(f"Loaded splits: train={len(train_idx)}, val={len(val_idx)}, test={len(test_idx)}")
else:
    n = len(df)
    test_size = 0.15
    val_size = 0.15

    indices = np.arange(n)
    if MultilabelStratifiedShuffleSplit is not None:
        # Split off test
        msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=SEED)
        train_val_idx, test_idx = next(msss.split(indices, Y))
        # Split train/val
        rel_val = val_size / (1.0 - test_size)
        msss2 = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=rel_val, random_state=SEED)
        train_idx, val_idx = next(msss2.split(train_val_idx, Y[train_val_idx]))
    else:
        rng = np.random.default_rng(SEED)
        rng.shuffle(indices)
        tcount = int(test_size * n)
        vcount = int(val_size * n)
        test_idx = indices[:tcount]
        val_idx  = indices[tcount:tcount+vcount]
        train_idx = indices[tcount+vcount:]

    np.save(train_idx_path, train_idx); np.save(val_idx_path, val_idx); np.save(test_idx_path, test_idx)
    print(f"Saved new splits: train={len(train_idx)}, val={len(val_idx)}, test={len(test_idx)}")


Loaded splits: train=5481, val=1175, test=1175


## tokenizer and dataset class

In [5]:
# === Config (you can tweak) ===
MODEL_NAME   = "seyonec/ChemBERTa-zinc-base-v1"
MAX_LEN      = 256          # 128 is faster; bump to 256 if you see truncation issues
AUG_PROB     = 0.25         # train-time SMILES randomization probability
CANONICALIZE = True         # canonicalize base SMILES once on dataset build

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    padding_side="right",
    truncation_side="right",
    use_fast=True
)

import torch
from torch.utils.data import Dataset
import numpy as np

# --- RDKit helpers (optional but recommended) ---
try:
    from rdkit import Chem
except Exception as e:
    Chem = None
    print("RDKit not available; SMILES validation/augmentation will be skipped.")

def canonicalize_smiles(s: str) -> str:
    if Chem is None:
        return s
    mol = Chem.MolFromSmiles(s)
    if mol is None:
        return None
    # Canonical SMILES with stereochemistry kept
    return Chem.MolToSmiles(mol, canonical=True)

def randomize_smiles(s: str) -> str:
    """Return a randomized (but equivalent) SMILES; fallback to original if RDKit missing/fails."""
    if Chem is None:
        return s
    mol = Chem.MolFromSmiles(s)
    if mol is None:
        return s
    # doRandom=True creates a randomized traversal (augmentation)
    return Chem.MolToSmiles(mol, canonical=False, doRandom=True)

class SmilesDataset(Dataset):
    """
    Dynamic-padding ready: we DO NOT pad in __getitem__; the DataCollator will pad per batch.
    Augmentation (randomized SMILES) occurs only in train mode with probability AUG_PROB.
    """
    def __init__(self, df, indices, label_cols, tokenizer, max_len=128,
                 mode="train", augment_prob=0.25, do_canonicalize=True):
        self.df = df.iloc[indices].reset_index(drop=True).copy()
        smiles = self.df["smiles"].astype(str).tolist()

        # Validate/canonicalize once up-front (drop invalids)
        if do_canonicalize and Chem is not None:
            keep = []
            canon = []
            for s in smiles:
                cs = canonicalize_smiles(s)
                if cs is None:
                    keep.append(False)
                    canon.append(s)
                else:
                    keep.append(True)
                    canon.append(cs)
            if not all(keep):
                dropped = int(np.sum(np.logical_not(keep)))
                print(f"[SmilesDataset] Dropping {dropped} invalid SMILES.")
                self.df = self.df.loc[np.where(keep)[0]].reset_index(drop=True)
                smiles = [canon[i] for i in np.where(keep)[0]]
            else:
                smiles = canon

        self.smiles = smiles
        self.labels = self.df[label_cols].values.astype(np.float32)
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.mode = mode
        self.augment_prob = augment_prob if mode == "train" else 0.0

    def __len__(self):
        return len(self.smiles)

    def _maybe_augment(self, s: str) -> str:
        if self.augment_prob > 0 and np.random.rand() < self.augment_prob:
            rs = randomize_smiles(s)
            return rs if isinstance(rs, str) and len(rs) > 0 else s
        return s

    def __getitem__(self, idx):
        s = self.smiles[idx]
        s = self._maybe_augment(s)

        # NOTE: no padding here; let the collator handle it per-batch.
        enc = self.tokenizer(
            s,
            padding=False,
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )
        item = {k: v.squeeze(0) for k, v in enc.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float32)
        return item


## dataloaders with class imbalance weights

In [6]:
from torch.utils.data import DataLoader, WeightedRandomSampler
from transformers import DataCollatorWithPadding
import numpy as np
import torch

# --- Config ---
BATCH_SIZE   = 64          # increased from 32
NUM_WORKERS  = 0           # safe on Windows
PIN_MEMORY   = False
USE_WEIGHTED_SAMPLER = True   # set False if you see overfitting/instability

# --- Data collator: dynamic padding to the longest in the batch (amp-friendly) ---
collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)

# --- Build datasets (augmentation only on train) ---
train_ds = SmilesDataset(
    df, train_idx, label_cols, tokenizer, max_len=MAX_LEN,
    mode="train", augment_prob=0.25, do_canonicalize=True
)
val_ds = SmilesDataset(
    df, val_idx, label_cols, tokenizer, max_len=MAX_LEN,
    mode="val", augment_prob=0.0, do_canonicalize=True
)
test_ds = SmilesDataset(
    df, test_idx, label_cols, tokenizer, max_len=MAX_LEN,
    mode="test", augment_prob=0.0, do_canonicalize=True
)

# --- Class imbalance handling in the LOSS: pos_weight = (#neg / #pos) on TRAIN ONLY ---
y_train = train_ds.labels  # (N_train, C)
pos_counts = y_train.sum(axis=0)
neg_counts = y_train.shape[0] - pos_counts
pos_counts_safe = np.clip(pos_counts, 1, None)
pos_weight = torch.tensor(neg_counts / pos_counts_safe, dtype=torch.float32, device=device)

# --- OPTIONAL: imbalance-aware sampler (multi-label) ---
# Per-sample weight = sum_c inv_freq[c] * y_ic; if no positives, give a small baseline weight
if USE_WEIGHTED_SAMPLER:
    N = y_train.shape[0]
    pos_freq = pos_counts / N
    inv_freq = 1.0 / np.clip(pos_freq, 1e-8, None)      # inverse prevalence
    inv_freq = inv_freq / inv_freq.mean()                # normalize (mean ~1)

    sample_weights = (y_train * inv_freq).sum(axis=1)    # higher if sample has rare positives
    # Give some weight to all-negative rows so we still sample negatives:
    sample_weights = np.where(sample_weights > 0, sample_weights, 0.2)
    # Normalize for stability (not strictly necessary):
    sample_weights = sample_weights / sample_weights.mean()

    sampler = WeightedRandomSampler(
        weights=torch.tensor(sample_weights, dtype=torch.double),
        num_samples=len(train_ds),
        replacement=True
    )

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, sampler=sampler,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collator
    )
else:
    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collator
    )

val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collator
)
test_loader = DataLoader(
    test_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collator
)

print("Loaders OK — Train/Val/Test sizes:", len(train_ds), len(val_ds), len(test_ds))
print("pos_weight (rounded):", pos_weight.detach().cpu().numpy().round(2))
if USE_WEIGHTED_SAMPLER:
    print("Weighted sampler enabled. Example weights stats:",
          f"min={sample_weights.min():.2f}  max={sample_weights.max():.2f}  mean={sample_weights.mean():.2f}")




Loaders OK — Train/Val/Test sizes: 5481 1175 1175
pos_weight (rounded): [24.26 31.43  9.36 26.82  9.04 23.47 45.85  7.33 28.47 20.75  7.88 17.71]
Weighted sampler enabled. Example weights stats: min=0.30  max=13.71  mean=1.00


## ChemBERTa Model

In [7]:
# MODEL BUILD — eager mode (no torch.compile), safer on Windows

from transformers import AutoModelForSequenceClassification, AutoConfig
import torch

num_labels = len(label_cols)

# Config with slightly higher dropout (regularisation)
config = AutoConfig.from_pretrained(
    MODEL_NAME,
    num_labels=num_labels,
    problem_type="multi_label_classification",
)
if hasattr(config, "hidden_dropout_prob"):
    config.hidden_dropout_prob = 0.15
if hasattr(config, "attention_probs_dropout_prob"):
    config.attention_probs_dropout_prob = 0.15

# Build model (no torch.compile)
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    config=config,
).to(device)

# Training-time flags
if hasattr(model.config, "use_cache"):
    model.config.use_cache = False
if hasattr(model, "gradient_checkpointing_enable"):
    model.gradient_checkpointing_enable()  # reduces VRAM, OK to keep

# Re-init classification head (better than default)
if hasattr(model, "classifier"):
    if hasattr(model.classifier, "dense"):
        torch.nn.init.xavier_uniform_(model.classifier.dense.weight)
        if model.classifier.dense.bias is not None:
            torch.nn.init.zeros_(model.classifier.dense.bias)
    if hasattr(model.classifier, "out_proj"):
        torch.nn.init.xavier_uniform_(model.classifier.out_proj.weight)
        if model.classifier.out_proj.bias is not None:
            torch.nn.init.zeros_(model.classifier.out_proj.bias)

# (Optional) freeze first N encoder layers for warm-up
FREEZE_N_LAYERS = 0  # set to 2–4 if you want
if FREEZE_N_LAYERS > 0 and hasattr(model, "roberta"):
    for i, layer in enumerate(model.roberta.encoder.layer):
        if i < FREEZE_N_LAYERS:
            for p in layer.parameters():
                p.requires_grad = False

# Report
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
drop_hidden = getattr(model.config, "hidden_dropout_prob", None)
drop_attn = getattr(model.config, "attention_probs_dropout_prob", None)
print(f"Total params: {total_params:,} | Trainable: {trainable_params:,}")
print(f"Dropout — hidden: {drop_hidden}  attn: {drop_attn}")
print("Gradient checkpointing:", "on" if hasattr(model, "gradient_checkpointing") else "enabled")
print("First", FREEZE_N_LAYERS, "encoder layer(s) frozen.")


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at seyonec/ChemBERTa-zinc-base-v1 and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total params: 44,113,164 | Trainable: 44,113,164
Dropout — hidden: 0.15  attn: 0.15
Gradient checkpointing: enabled
First 0 encoder layer(s) frozen.


## Training Setup — ASL + LLRD + EMA

In [12]:
# =========================
# Training setup: ASL + LLRD + EMA
# (Place this cell RIGHT AFTER your model is built and moved to `device`)
# =========================

import math, numpy as np, torch
from copy import deepcopy
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup

# ---------- Tuned hyperparameters (safe defaults) ----------
EPOCHS                 = 40
BASE_LR                = 1.5e-5      # backbone base lr (decayed by depth)
HEAD_LR                = 3e-4        # classifier head lr
WEIGHT_DECAY           = 0.01
WARMUP_RATIO           = 0.10        # 10% warmup
PATIENCE               = 8
GRAD_CLIP_NORM         = 1.0
USE_COSINE_SCHED       = True
GRAD_ACCUM_STEPS       = 2
USE_AMP                = True
BACKBONE_WARMUP_EPOCHS = 1
LLRD_DECAY             = 0.85        # layer-wise lr decay per depth step
EMA_DECAY              = 0.999

# ---------- Loss: Asymmetric Loss (ASL) with clamped pos_weight ----------
class AsymmetricLossMultiLabel(torch.nn.Module):
    """
    ASL for multilabel classification with optional pos_weight.
    """
    def __init__(self, gamma_pos=0.5, gamma_neg=3.0, clip=0.05, eps=1e-8, reduction='mean', pos_weight=None):
        super().__init__()
        self.gamma_pos  = gamma_pos
        self.gamma_neg  = gamma_neg
        self.clip       = clip
        self.eps        = eps
        self.reduction  = reduction
        self.pos_weight = pos_weight

    def forward(self, logits, targets):
        # logits, targets: (B, C)
        xs_pos = torch.sigmoid(logits)
        xs_neg = 1.0 - xs_pos

        # optional clipping for negatives
        if self.clip is not None and self.clip > 0:
            xs_neg = torch.clamp(xs_neg + self.clip, max=1)

        log_pos = torch.log(xs_pos.clamp(min=self.eps))
        log_neg = torch.log(xs_neg.clamp(min=self.eps))

        loss_pos = targets * log_pos
        loss_neg = (1 - targets) * log_neg

        # focal modulation
        if self.gamma_pos > 0 or self.gamma_neg > 0:
            pt_pos = xs_pos * targets + (1 - targets)
            pt_neg = xs_neg * (1 - targets) + targets
            loss_pos *= (1 - pt_pos) ** self.gamma_pos
            loss_neg *= (1 - pt_neg) ** self.gamma_neg

        loss = -(loss_pos + loss_neg)

        # optional per-class weighting
        if self.pos_weight is not None:
            loss = loss * (targets * (self.pos_weight - 1) + 1)

        if self.reduction == 'mean': return loss.mean()
        if self.reduction == 'sum':  return loss.sum()
        return loss

# Compute pos_weight from TRAIN ONLY and clamp extremes for stability
y_train = train_ds.labels  # (N_train, C) — from your earlier dataset construction
pos_counts = y_train.sum(axis=0)
neg_counts = y_train.shape[0] - pos_counts
pos_counts_safe = np.clip(pos_counts, 1, None)

pos_weight = torch.tensor(neg_counts / pos_counts_safe, dtype=torch.float32, device=device)
pos_weight = torch.clamp(pos_weight, 1.0, 50.0)  # clamp to avoid huge gradients

asl_loss = AsymmetricLossMultiLabel(
    gamma_pos=0.5, gamma_neg=3.0, clip=0.05, reduction='mean',
    pos_weight=pos_weight.to(device)
)
print(f"[ASL] Using Asymmetric Loss (gamma_pos=0.5, gamma_neg=3.0, clip=0.05) + clamped pos_weight.")

# ---------- LLRD: Layer-wise LR Decay parameter groups ----------
decay_exclusions = ("bias", "LayerNorm.weight", "LayerNorm.bias")
def is_decay_param(n): return not any(nd in n for nd in decay_exclusions)

# Map parameter name -> layer id (embeddings=0, encoder.layers=1..L, classifier=1000)
def get_layer_id(name: str):
    if name.startswith("classifier") or name.startswith("lm_head"):
        return 1000
    if name.startswith("roberta.embeddings"):
        return 0
    if name.startswith("roberta.encoder.layer."):
        try:
            idx = int(name.split("roberta.encoder.layer.")[1].split(".")[0])
            return idx + 1  # 1..L
        except Exception:
            return 0
    # fallback for other backbones (bert, etc.)
    if name.startswith("bert.embeddings"): return 0
    if name.startswith("bert.encoder.layer."):
        try:
            idx = int(name.split("bert.encoder.layer.")[1].split(".")[0])
            return idx + 1
        except Exception:
            return 0
    return 0

# Determine max encoder layer for decay computation
max_layer = 0
for n, p in model.named_parameters():
    if not p.requires_grad: continue
    lid = get_layer_id(n)
    if lid != 1000:
        max_layer = max(max_layer, lid)

# Build param groups
optimizer_grouped_parameters = []
for n, p in model.named_parameters():
    if not p.requires_grad:
        continue
    if n.startswith("classifier") or n.startswith("lm_head"):
        lr = HEAD_LR
        wd = WEIGHT_DECAY if is_decay_param(n) else 0.0
        optimizer_grouped_parameters.append({"params": [p], "lr": lr, "weight_decay": wd})
    else:
        lid = get_layer_id(n)  # 0..max_layer
        lr = BASE_LR * (LLRD_DECAY ** (max_layer - lid))
        wd = WEIGHT_DECAY if is_decay_param(n) else 0.0
        optimizer_grouped_parameters.append({"params": [p], "lr": lr, "weight_decay": wd})

optimizer = AdamW(optimizer_grouped_parameters)

# Optional 1-epoch head-only warmup
def set_backbone_trainable(trainable: bool):
    for n, p in model.named_parameters():
        if n.startswith("classifier") or n.startswith("lm_head"):
            continue  # head always trainable
        p.requires_grad = trainable

set_backbone_trainable(BACKBONE_WARMUP_EPOCHS <= 0)

# ---------- Scheduler, AMP scaler ----------
num_update_steps_per_epoch = math.ceil(len(train_loader) / max(1, GRAD_ACCUM_STEPS))
num_training_steps = EPOCHS * num_update_steps_per_epoch
num_warmup_steps   = int(WARMUP_RATIO * num_training_steps)

scheduler = (
    get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
    if USE_COSINE_SCHED else
    get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
)

scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)

# ---------- EMA (Exponential Moving Average) ----------
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {}
        for name, param in model.state_dict().items():
            if param.dtype.is_floating_point:
                self.shadow[name] = param.detach().clone()

    @torch.no_grad()
    def update(self, model):
        for name, param in model.state_dict().items():
            if name in self.shadow and param.dtype.is_floating_point:
                self.shadow[name].mul_(self.decay).add_(param.detach(), alpha=(1.0 - self.decay))

    @torch.no_grad()
    def apply_shadow(self, model):
        self.backup = {}
        for name, param in model.state_dict().items():
            if name in self.shadow and param.dtype.is_floating_point:
                self.backup[name] = param.detach().clone()
                param.data.copy_(self.shadow[name].data)

    @torch.no_grad()
    def restore(self, model):
        for name, param in model.state_dict().items():
            if hasattr(self, "backup") and name in self.backup:
                param.data.copy_(self.backup[name].data)
        self.backup = {}

ema = EMA(model, decay=EMA_DECAY)

# ---------- Report ----------
total_params    = sum(p.numel() for p in model.parameters())
trainable_params= sum(p.numel() for p in model.parameters() if p.requires_grad)
drop_hidden = getattr(model.config, "hidden_dropout_prob", None)
drop_attn   = getattr(model.config, "attention_probs_dropout_prob", None)

print(f"Total params: {total_params:,} | Trainable: {trainable_params:,}")
print(f"Dropout — hidden: {drop_hidden}  attn: {drop_attn}")
print(f"[LLRD] max_layer={max_layer} | decay={LLRD_DECAY:.2f} | HEAD_LR={HEAD_LR:.1e} | BASE_LR={BASE_LR:.1e}")
print(f"[EMA]  decay={EMA_DECAY}")
print(f"Training steps: {num_training_steps}, warmup steps: {num_warmup_steps}")
print(f"AMP: {'on' if USE_AMP else 'off'} | GradAccum: {GRAD_ACCUM_STEPS} | GradClip: {GRAD_CLIP_NORM}")
print(f"Backbone warmup epochs: {BACKBONE_WARMUP_EPOCHS} (classifier-only training if >0)")


[ASL] Using Asymmetric Loss (gamma_pos=0.5, gamma_neg=3.0, clip=0.05) + clamped pos_weight.
Total params: 44,113,164 | Trainable: 599,820
Dropout — hidden: 0.15  attn: 0.15
[LLRD] max_layer=6 | decay=0.85 | HEAD_LR=3.0e-04 | BASE_LR=1.5e-05
[EMA]  decay=0.999
Training steps: 1720, warmup steps: 172
AMP: on | GradAccum: 2 | GradClip: 1.0
Backbone warmup epochs: 1 (classifier-only training if >0)


## metrics (macro AUROC/AP, macro F1 at thresholds)

In [13]:
# =========================
# Metrics + Train/Validate loop (macro-AP early stopping)
# =========================
import time, numpy as np, torch
from copy import deepcopy
from sklearn.metrics import f1_score

# --- Metrics (your existing function; unchanged) ---
def multilabel_metrics(y_true, y_prob, thresholds=None):
    y_true = np.asarray(y_true); y_prob = np.asarray(y_prob)
    N, C = y_true.shape
    per_auc, per_ap = [], []
    for c in range(C):
        yt = y_true[:, c]; yp = y_prob[:, c]
        if len(np.unique(yt)) < 2:
            per_auc.append(np.nan); per_ap.append(np.nan); continue
        try:
            from sklearn.metrics import roc_auc_score, average_precision_score
            per_auc.append(roc_auc_score(yt, yp))
        except ValueError: per_auc.append(np.nan)
        try:
            per_ap.append(average_precision_score(yt, yp))
        except ValueError: per_ap.append(np.nan)
    macro_auc = np.nanmean(per_auc); macro_ap = np.nanmean(per_ap)

    micro_auc = np.nan; micro_ap = np.nan
    if np.unique(y_true).size > 1:
        try:
            from sklearn.metrics import roc_auc_score, average_precision_score
            micro_auc = roc_auc_score(y_true.ravel(), y_prob.ravel())
            micro_ap  = average_precision_score(y_true.ravel(), y_prob.ravel())
        except Exception: pass

    macro_f1 = np.nan
    if thresholds is not None:
        thr = np.asarray(thresholds)
        y_pred = (y_prob >= thr).astype(int)
        per_f1 = []
        for c in range(y_true.shape[1]):
            yt = y_true[:, c]; yp = y_pred[:, c]
            if len(np.unique(yt)) < 2:
                per_f1.append(np.nan); continue
            per_f1.append(f1_score(yt, yp, zero_division=0))
        macro_f1 = np.nanmean(per_f1)

    pos_counts = y_true.sum(axis=0)
    prevalence = (pos_counts / N).astype(float)

    return {
        "per_auc": per_auc, "per_ap": per_ap,
        "macro_auc": macro_auc, "macro_ap": macro_ap,
        "micro_auc": micro_auc, "micro_ap": micro_ap,
        "macro_f1": macro_f1,
        "label_stats": {"pos_counts": pos_counts, "prevalence": prevalence}
    }

# --- Utilities ---
def sigmoid_np(x): return 1.0 / (1.0 + np.exp(-x))

@torch.no_grad()
def predict_with_tta(model, base_ds, n_tta=8, p_aug=0.8, batch_size=64):
    """Average probs across randomized SMILES. Flip to 'train' mode for augmentation path in dataset."""
    from copy import deepcopy
    ds = deepcopy(base_ds)
    ds.mode = "train"; ds.augment_prob = p_aug
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collator)
    all_probs = []
    model.eval()
    for _ in range(n_tta):
        _, logits, _, _ = run_epoch(loader, train=False, log_every=10_000)
        all_probs.append(sigmoid_np(logits))
    return np.mean(all_probs, axis=0)

def calibrate_thresholds(y_true, y_prob, grid=np.linspace(0.05, 0.95, 19)):
    """Choose per-label threshold by max F1 on VAL."""
    y_true = np.asarray(y_true); y_prob = np.asarray(y_prob)
    C = y_true.shape[1]; thr = np.zeros(C, dtype=float)
    for c in range(C):
        yt = y_true[:, c]; yp = y_prob[:, c]
        if len(np.unique(yt)) < 2: thr[c] = 0.5; continue
        best, best_f1 = 0.5, -1
        for t in grid:
            f1 = f1_score(yt, (yp >= t).astype(int), zero_division=0)
            if f1 > best_f1: best, best_f1 = t, f1
        thr[c] = best
    return thr

# --- Core epoch runner (uses your globals: model, asl_loss, optimizer, scheduler, scaler, ema, GRAD_* etc.) ---
def run_epoch(dataloader, train=False, log_every=50):
    model.train() if train else model.eval()
    total_loss = 0.0
    logits_list, labels_list = [], []
    accum = max(1, GRAD_ACCUM_STEPS)
    optimizer.zero_grad(set_to_none=True)
    step_in_accum = 0
    t0 = time.time()

    for step, batch in enumerate(dataloader, 1):
        input_ids = batch["input_ids"].to(device, non_blocking=True)
        attention_mask = batch["attention_mask"].to(device, non_blocking=True)
        labels = batch["labels"].to(device, non_blocking=True)

        if train and USE_AMP:
            with torch.amp.autocast('cuda'):
                logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
                loss   = asl_loss(logits, labels) / accum
        else:
            logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
            loss   = asl_loss(logits, labels) / accum

        if train:
            if USE_AMP: scaler.scale(loss).backward()
            else:       loss.backward()
            step_in_accum += 1

            if step_in_accum >= accum:
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
                if USE_AMP:
                    scaler.step(optimizer); scaler.update()
                else:
                    optimizer.step()
                ema.update(model)          # EMA after each step
                scheduler.step()
                optimizer.zero_grad(set_to_none=True)
                step_in_accum = 0

        total_loss += (loss.item() * accum) * input_ids.size(0)
        logits_list.append(logits.detach().cpu().numpy())
        labels_list.append(labels.detach().cpu().numpy())

        if train and (step % log_every == 0):
            elapsed = time.time() - t0
            print(f" step {step:5d}/{len(dataloader):5d} | loss {loss.item()*accum:.4f} | {elapsed:.1f}s")
            t0 = time.time()

    avg_loss   = total_loss / len(dataloader.dataset)
    logits_all = np.concatenate(logits_list, axis=0)
    labels_all = np.concatenate(labels_list, axis=0)
    probs_all  = sigmoid_np(logits_all)
    return avg_loss, logits_all, probs_all, labels_all

# --- Train/validate with early stopping on macro-AP (EMA weights on VAL) ---
best_state_ema = None
best_val_ap    = -np.inf
no_improve     = 0

for epoch in range(1, EPOCHS + 1):
    t_epoch = time.time()

    # optional head-only warmup
    if epoch == 1 and BACKBONE_WARMUP_EPOCHS > 0:
        set_backbone_trainable(False)

    tr_loss, _, _, _ = run_epoch(train_loader, train=True, log_every=50)

    # unfreeze after warmup
    if BACKBONE_WARMUP_EPOCHS > 0 and epoch == BACKBONE_WARMUP_EPOCHS + 1:
        set_backbone_trainable(True)

    # Validate with EMA shadow
    ema.apply_shadow(model)
    va_loss, va_logits, va_prob, va_y = run_epoch(val_loader, train=False, log_every=10_000)
    ema.restore(model)

    va_metrics = multilabel_metrics(va_y, va_prob)
    print(f"Epoch {epoch:02d} | train_loss {tr_loss:.4f} | val_loss {va_loss:.4f} | "
          f"VAL(EMA) macro_AUROC {va_metrics['macro_auc']:.4f} | macro_AP {va_metrics['macro_ap']:.4f} | "
          f"time {time.time()-t_epoch:.1f}s", flush=True)

    # Early stopping by macro-AP
    if va_metrics['macro_ap'] > best_val_ap + 1e-4:
        best_val_ap = va_metrics['macro_ap']
        ema.apply_shadow(model)
        best_state_ema = deepcopy(model.state_dict())
        ema.restore(model)
        no_improve = 0
    else:
        no_improve += 1

    if no_improve >= PATIENCE:
        print(f"Early stopping (macro-AP) after {epoch} epochs.")
        break

# Restore best EMA checkpoint
if best_state_ema is not None:
    model.load_state_dict(best_state_ema)

# --- Threshold calibration on VAL (optional, for reporting macro-F1) ---
val_thr = calibrate_thresholds(va_y, va_prob, grid=np.linspace(0.05, 0.95, 19))
val_metrics_thr = multilabel_metrics(va_y, va_prob, thresholds=val_thr)

# --- FINAL TEST EVAL (with TTA for better ranking metrics) ---
# Note: requires test_ds/test_loader from your earlier setup
te_prob = predict_with_tta(model, test_ds, n_tta=8, p_aug=0.8, batch_size=BATCH_SIZE)
te_y    = test_ds.labels
test_metrics      = multilabel_metrics(te_y, te_prob)
test_metrics_thr  = multilabel_metrics(te_y, te_prob, thresholds=val_thr)

print("\n==== SUMMARY ====")
print(f"VAL  macro_AUROC: {va_metrics['macro_auc']:.4f} | macro_AP: {va_metrics['macro_ap']:.4f} | macro_F1@thr: {val_metrics_thr['macro_f1']:.4f}")
print(f"TEST macro_AUROC: {test_metrics['macro_auc']:.4f} | macro_AP: {test_metrics['macro_ap']:.4f} | macro_F1@thr: {test_metrics_thr['macro_f1']:.4f}")

# --- Optional: build a per-label table like yours and save JSON summary ---
per_label = []
for i, lab in enumerate(label_cols):
    va_auc = va_metrics["per_auc"][i]; va_ap = va_metrics["per_ap"][i]
    te_auc = test_metrics["per_auc"][i]; te_ap = test_metrics["per_ap"][i]
    per_label.append({
        "label": lab,
        "val_AUROC": None if np.isnan(va_auc) else round(va_auc, 4),
        "val_AP":    None if np.isnan(va_ap)  else round(va_ap,  4),
        "test_AUROC":None if np.isnan(te_auc) else round(te_auc, 4),
        "test_AP":   None if np.isnan(te_ap)  else round(te_ap,  4),
    })

summary = {
    "VAL":  {"macro_AUROC": float(va_metrics['macro_auc']),  "macro_AP": float(va_metrics['macro_ap']),  "macro_F1_thr": float(val_metrics_thr['macro_f1'])},
    "TEST": {"macro_AUROC": float(test_metrics['macro_auc']),"macro_AP": float(test_metrics['macro_ap']),"macro_F1_thr": float(test_metrics_thr['macro_f1'])},
    "thresholds": {label_cols[i]: float(val_thr[i]) for i in range(len(label_cols))},
    "per_label": per_label,
}
(OUT_DIR / "metrics").mkdir(parents=True, exist_ok=True)
np.save(OUT_DIR / "metrics" / "val_prob.npy",  va_prob)
np.save(OUT_DIR / "metrics" / "test_prob.npy", te_prob)
with open(OUT_DIR / "summary.json","w") as f:
    import json; json.dump(summary, f, indent=2)

print("✅ Saved predictions and summary to", OUT_DIR)




 step    50/   86 | loss 0.3773 | 2.4s
Epoch 01 | train_loss 0.3287 | val_loss 0.7409 | VAL(EMA) macro_AUROC 0.7742 | macro_AP 0.3450 | time 4.9s




 step    50/   86 | loss 0.2616 | 2.1s
Epoch 02 | train_loss 0.3323 | val_loss 0.7409 | VAL(EMA) macro_AUROC 0.7742 | macro_AP 0.3450 | time 9.7s
 step    50/   86 | loss 0.3610 | 9.1s
Epoch 03 | train_loss 0.3305 | val_loss 0.7409 | VAL(EMA) macro_AUROC 0.7742 | macro_AP 0.3450 | time 27.0s
 step    50/   86 | loss 0.3154 | 7.0s
Epoch 04 | train_loss 0.3189 | val_loss 0.7409 | VAL(EMA) macro_AUROC 0.7742 | macro_AP 0.3450 | time 14.6s
 step    50/   86 | loss 0.3705 | 8.4s
Epoch 05 | train_loss 0.3252 | val_loss 0.7409 | VAL(EMA) macro_AUROC 0.7742 | macro_AP 0.3450 | time 16.6s
 step    50/   86 | loss 0.2567 | 9.8s
Epoch 06 | train_loss 0.3231 | val_loss 0.7409 | VAL(EMA) macro_AUROC 0.7742 | macro_AP 0.3450 | time 751.5s
 step    50/   86 | loss 0.3116 | 9.3s
Epoch 07 | train_loss 0.3257 | val_loss 0.7409 | VAL(EMA) macro_AUROC 0.7742 | macro_AP 0.3450 | time 18.2s
 step    50/   86 | loss 0.2643 | 10.2s
Epoch 08 | train_loss 0.3190 | val_loss 0.7409 | VAL(EMA) macro_AUROC 0.7742 |

## validation & test metrics (pre-threshold) + per-label table

In [14]:
import numpy as np
import pandas as pd

# Reuse run_epoch() and multilabel_metrics() from earlier cells

# --- Validation ---
val_loss, val_logits, val_prob, val_y = run_epoch(val_loader, train=False)
val_metrics = multilabel_metrics(val_y, val_prob)

print("VAL — macro AUROC:", round(val_metrics['macro_auc'], 4),
      "| macro AP:", round(val_metrics['macro_ap'], 4))

# --- Test ---
test_loss, test_logits, test_prob, test_y = run_epoch(test_loader, train=False)
test_metrics = multilabel_metrics(test_y, test_prob)

print("TEST — macro AUROC:", round(test_metrics['macro_auc'], 4),
      "| macro AP:", round(test_metrics['macro_ap'], 4))

# --- Per-label table (AUROC/AP) for both VAL and TEST ---
per_label = []
for i, lab in enumerate(label_cols):
    va_auc = val_metrics["per_auc"][i]
    va_ap  = val_metrics["per_ap"][i]
    te_auc = test_metrics["per_auc"][i]
    te_ap  = test_metrics["per_ap"][i]
    per_label.append({
        "label": lab,
        "val_AUROC": None if np.isnan(va_auc) else round(va_auc, 4),
        "val_AP":    None if np.isnan(va_ap)  else round(va_ap, 4),
        "test_AUROC":None if np.isnan(te_auc) else round(te_auc, 4),
        "test_AP":   None if np.isnan(te_ap)  else round(te_ap, 4),
    })

per_label_df = pd.DataFrame(per_label)
display(per_label_df)


VAL — macro AUROC: 0.7742 | macro AP: 0.345
TEST — macro AUROC: 0.8651 | macro AP: 0.5397


Unnamed: 0,label,val_AUROC,val_AP,test_AUROC,test_AP
0,NR-AR,0.7049,0.3264,0.814,0.5261
1,NR-AR-LBD,0.8516,0.5073,0.9547,0.6608
2,NR-AhR,0.8287,0.4252,0.8763,0.5431
3,NR-Aromatase,0.7625,0.1989,0.8561,0.3669
4,NR-ER,0.705,0.4047,0.771,0.4654
5,NR-ER-LBD,0.7401,0.2957,0.8778,0.6023
6,NR-PPAR-gamma,0.7334,0.2272,0.8702,0.585
7,SR-ARE,0.7361,0.3458,0.8266,0.4572
8,SR-ATAD5,0.816,0.2775,0.8888,0.6902
9,SR-HSE,0.7815,0.3393,0.8787,0.5797


## Model save

In [17]:
from pathlib import Path

# === Save ChemBERTa model & tokenizer ===
save_dir = Path("tox21_chembera_pipeline/models/chemberta_v1")
save_dir.mkdir(parents=True, exist_ok=True)

# save best EMA weights into model before saving
if best_state_ema is not None:
    model.load_state_dict(best_state_ema)

# Hugging Face format
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)

print(f"✅ ChemBERTa model + tokenizer saved to {save_dir}")


✅ ChemBERTa model + tokenizer saved to tox21_chembera_pipeline\models\chemberta_v1


## threshold calibration (per label) + F1 on VAL/TEST

In [15]:
import numpy as np
from sklearn.metrics import f1_score

def calibrate_thresholds(y_true, y_prob, grid=np.linspace(0.05, 0.95, 19)):
    """
    Returns:
      thresholds: (C,) best threshold per label (by F1 on VAL)
      f1s:        (C,) best F1 per label (on VAL)
    """
    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)
    C = y_true.shape[1]
    thresholds = np.zeros(C, dtype=float)
    f1s = np.zeros(C, dtype=float)

    for c in range(C):
        yt = y_true[:, c]
        yp = y_prob[:, c]
        if len(np.unique(yt)) < 2:
            thresholds[c] = 0.5
            f1s[c] = np.nan
            continue
        best_f1 = -1.0
        best_thr = 0.5
        for thr in grid:
            ypred = (yp >= thr).astype(int)
            f1 = f1_score(yt, ypred, zero_division=0)
            if f1 > best_f1:
                best_f1 = f1
                best_thr = thr
        thresholds[c] = best_thr
        f1s[c] = best_f1
    return thresholds, f1s

# Calibrate on VAL
thr_vec, f1s_val = calibrate_thresholds(val_y, val_prob)
print("Per-label VAL F1 and chosen thresholds:")
for lab, f1v, thr in zip(label_cols, f1s_val, thr_vec):
    f1_show = "nan" if np.isnan(f1v) else f"{f1v:.3f}"
    print(f"{lab:12s}  F1={f1_show:>6}  thr={thr:.2f}")

# Evaluate macro-F1 on VAL/TEST with calibrated thresholds
val_metrics_thr  = multilabel_metrics(val_y,  val_prob,  thresholds=thr_vec)
test_metrics_thr = multilabel_metrics(test_y, test_prob, thresholds=thr_vec)

print("\nVAL — macro F1 (thr-calibrated): ", round(val_metrics_thr['macro_f1'], 4))
print("TEST — macro F1 (thr-calibrated):", round(test_metrics_thr['macro_f1'], 4))


Per-label VAL F1 and chosen thresholds:
NR-AR         F1= 0.378  thr=0.95
NR-AR-LBD     F1= 0.436  thr=0.95
NR-AhR        F1= 0.436  thr=0.90
NR-Aromatase  F1= 0.228  thr=0.95
NR-ER         F1= 0.423  thr=0.95
NR-ER-LBD     F1= 0.346  thr=0.95
NR-PPAR-gamma  F1= 0.235  thr=0.95
SR-ARE        F1= 0.374  thr=0.90
SR-ATAD5      F1= 0.323  thr=0.95
SR-HSE        F1= 0.460  thr=0.95
SR-MMP        F1= 0.511  thr=0.90
SR-p53        F1= 0.366  thr=0.95

VAL — macro F1 (thr-calibrated):  0.3764
TEST — macro F1 (thr-calibrated): 0.5231


## saving the output

In [13]:
import json

# Save HuggingFace model + tokenizer
model.save_pretrained(MODELS_DIR)
tokenizer.save_pretrained(MODELS_DIR)

# Save thresholds & label names
with open(MODELS_DIR / "thresholds.json", "w") as f:
    json.dump({lab: float(thr) for lab, thr in zip(label_cols, thr_vec)}, f, indent=2)

with open(MODELS_DIR / "labels.json", "w") as f:
    json.dump(list(label_cols), f, indent=2)

# Metadata about training
metadata = {
    "model_name": MODEL_NAME,
    "max_len": MAX_LEN,
    "batch_size": BATCH_SIZE,
    "epochs": EPOCHS,
    "base_lr": BASE_LR,
    "head_lr": HEAD_LR,
    "weight_decay": WEIGHT_DECAY,
    "warmup_ratio": WARMUP_RATIO,
    "best_val_macro_auc": float(val_metrics["macro_auc"]),
    "best_val_macro_ap": float(val_metrics["macro_ap"]),
    "val_macro_f1_thr": float(val_metrics_thr["macro_f1"]),
    "test_macro_auc": float(test_metrics["macro_auc"]),
    "test_macro_ap": float(test_metrics["macro_ap"]),
    "test_macro_f1_thr": float(test_metrics_thr["macro_f1"]),
}
with open(MODELS_DIR / "metadata.json", "w") as f:
    json.dump(metadata, f, indent=2)

print("✅ Saved model, tokenizer, thresholds, labels, and metadata to", MODELS_DIR)


✅ Saved model, tokenizer, thresholds, labels, and metadata to tox21_chembera_pipeline\models\chemberta_v1


## prediction & summary

In [16]:
# Save arrays for reproducibility
np.save(OUT_DIR / "val_logits.npy", val_logits)
np.save(OUT_DIR / "val_prob.npy", val_prob)
np.save(OUT_DIR / "val_y.npy", val_y)

np.save(OUT_DIR / "test_logits.npy", test_logits)
np.save(OUT_DIR / "test_prob.npy", test_prob)
np.save(OUT_DIR / "test_y.npy", test_y)

# Summarize key results
summary = {
    "VAL": {
        "macro_AUROC": float(val_metrics['macro_auc']),
        "macro_AP": float(val_metrics['macro_ap']),
        "macro_F1_thr": float(val_metrics_thr['macro_f1'])
    },
    "TEST": {
        "macro_AUROC": float(test_metrics['macro_auc']),
        "macro_AP": float(test_metrics['macro_ap']),
        "macro_F1_thr": float(test_metrics_thr['macro_f1'])
    }
}

with open(OUT_DIR / "summary.json", "w") as f:
    json.dump(summary, f, indent=2)

print("✅ Saved predictions and summary metrics to", OUT_DIR)
summary


✅ Saved predictions and summary metrics to tox21_chembera_pipeline\outputs


{'VAL': {'macro_AUROC': 0.774190344876498,
  'macro_AP': 0.3450011528294685,
  'macro_F1_thr': 0.3763729999290725},
 'TEST': {'macro_AUROC': 0.8650696516173912,
  'macro_AP': 0.539660679964537,
  'macro_F1_thr': 0.5230949784582123}}