In [21]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, hamming_loss, precision_recall_curve
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
import xml.etree.ElementTree as ET
import os
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

In [22]:
# DEVICE
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## 1. LOAD 3 DATASETS

In [23]:
print("\nLOADING DATASETS...")

# --- Tox21 ---
TOX21_PATH = "/kaggle/input/dataset/tox21.csv"  # Attachment
df_tox = pd.read_csv(TOX21_PATH)
tox_labels = ["NR-AR","NR-AR-LBD","NR-AhR","NR-Aromatase","NR-ER","NR-ER-LBD",
              "NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"]
# ensure these columns exist
for c in tox_labels:
    if c not in df_tox.columns:
        df_tox[c] = 0
df_tox["labels"] = df_tox[tox_labels].fillna(0).astype(float).values.tolist()
df_tox["dataset"] = "tox21"
print(f"Tox21: {len(df_tox)} samples, 12 labels")

# --- SIDER ---
SIDER_PATH = "/kaggle/input/dataset/sider.csv"  # Attachment
df_sider = pd.read_csv(SIDER_PATH)
# assume first column is molecule identifier, rest are labels
sider_labels = df_sider.columns[1:].tolist()
# convert NaN -> 0 and ensure numeric
df_sider[sider_labels] = df_sider[sider_labels].fillna(0).astype(float)
df_sider["labels"] = df_sider[sider_labels].values.tolist()
df_sider["dataset"] = "sider"
print(f"SIDER: {len(df_sider)} samples, {len(sider_labels)} labels")

# --- DrugBank XML Parse ---
DRUGBANK_PATH = "/kaggle/input/dataset/drugbank_full_database.xml"  # Attachment (if not present, you can use drugbank.csv instead)
common_adrs = ['nausea', 'vomiting', 'headache', 'dizziness', 'fatigue',
               'diarrhea', 'rash', 'liver injury', 'cardiotoxicity', 'hair loss']

df_drug = pd.DataFrame()
if os.path.exists(DRUGBANK_PATH):
    try:
        tree = ET.parse(DRUGBANK_PATH)
        root = tree.getroot()
        ns = {'db': 'http://www.drugbank.ca'}  # typical namespace; if your xml has none you can remove ns usage

        data = []
        # Try with namespace; if no results, try without ns
        drugs = root.findall('db:drug', ns)
        if len(drugs) == 0:
            drugs = root.findall('drug')  # fallback to tag without ns

        for drug in drugs:
            # name (safe)
            name_tag = drug.find('db:name', ns) if ns and drug.find('db:name', ns) is not None else drug.find('name')
            name = name_tag.text.strip() if name_tag is not None and name_tag.text else None

            # find SMILES robustly
            smiles = None
            # try paths with namespace first
            props = drug.findall("db:calculated-properties/db:property", ns) if ns else drug.findall("calculated-properties/property")
            if not props:
                # fallback to property nodes without ns
                props = drug.findall("calculated-properties/property") or drug.findall("property")

            for prop in props:
                kind_tag = prop.find('db:kind', ns) if ns and prop.find('db:kind', ns) is not None else prop.find('kind')
                val_tag = prop.find('db:value', ns) if ns and prop.find('db:value', ns) is not None else prop.find('value')
                if kind_tag is None or val_tag is None:
                    continue
                kind = kind_tag.text.strip().lower()
                if kind == 'smiles':
                    smiles = val_tag.text.strip()
                    break
            if not smiles:
                continue

            # ADR text from toxicity
            toxicity_tag = drug.find('db:toxicity', ns) if ns and drug.find('db:toxicity', ns) is not None else drug.find('toxicity')
            adr_text = toxicity_tag.text.strip() if (toxicity_tag is not None and toxicity_tag.text) else ""

            adr_labels = [1.0 if adr in adr_text.lower() else 0.0 for adr in common_adrs]

            data.append({'smiles': smiles, 'labels': adr_labels, 'dataset': 'drugbank', 'adr_text': adr_text, 'name': name})

        df_drug = pd.DataFrame(data)
        print(f"DrugBank parsed: {len(df_drug)} samples, {len(common_adrs)} ADR labels")
    except Exception as e:
        print("Error parsing DrugBank XML:", e)
else:
    # fallback in case user supplied a csv instead of XML
    DRUGBANK_CSV = "drugbank.csv"
    if os.path.exists(DRUGBANK_CSV):
        df_drug = pd.read_csv(DRUGBANK_CSV)
        # assume 'smiles' and ADR columns or provide ADR text
        if 'smiles' in df_drug.columns:
            if 'adr_text' in df_drug.columns:
                df_drug['adr_text'] = df_drug['adr_text'].fillna("")
                df_drug['labels'] = df_drug['adr_text'].apply(lambda t: [1.0 if adr in str(t).lower() else 0.0 for adr in common_adrs])
            else:
                # if drugbank csv has explicit ADR columns
                adr_cols = [c for c in df_drug.columns if c.lower() in common_adrs]
                if adr_cols:
                    df_drug[adr_cols] = df_drug[adr_cols].fillna(0).astype(float)
                    df_drug['labels'] = df_drug[adr_cols].values.tolist()
                else:
                    # fallback: all zeros
                    df_drug['labels'] = [[0.0]*len(common_adrs) for _ in range(len(df_drug))]
        else:
            df_drug = pd.DataFrame()  # couldn't parse fallback
    if df_drug.empty:
        print("No DrugBank data found (XML or drugbank.csv). DrugBank will be skipped.")



LOADING DATASETS...
Tox21: 8014 samples, 12 labels
SIDER: 1427 samples, 27 labels
DrugBank parsed: 12313 samples, 10 ADR labels


In [24]:
print(df_drug.head())
print(df_drug.shape)

                                              smiles  \
0  CC[C@H](C)[C@H](NC(=O)[C@H](CCC(O)=O)NC(=O)[C@...   
1  CC(C)C[C@H](NC(=O)[C@@H](COC(C)(C)C)NC(=O)[C@H...   
2  CC(C)C[C@@H](NC(=O)CNC(=O)[C@@H](NC=O)C(C)C)C(...   
3  NC(=O)CC[C@@H]1NC(=O)[C@H](CC2=CC=CC=C2)NC(=O)...   
4  CC(C)C[C@H](NC(=O)[C@@H](CCCNC(N)=O)NC(=O)[C@H...   

                                              labels   dataset  \
0  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  drugbank   
1  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  drugbank   
2  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  drugbank   
3  [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  drugbank   
4  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  drugbank   

                                            adr_text          name  
0  Based on a study by Gleason et al., the no-obs...   Bivalirudin  
1  No experience of overdosage from clinical trials.     Goserelin  
2                                                     Grami

## 2. CREATE INPUT TEXT WITH [SEP]

In [25]:
def create_input(row):
    smiles = row.get("smiles") if isinstance(row, dict) else row["smiles"]
    active = [i for i, v in enumerate(row["labels"]) if float(v) == 1.0]
    if not active:
        text = "No activity"
    else:
        if row["dataset"] == "tox21":
            text = "Tox21 activity: " + ", ".join([tox_labels[i] for i in active])
        elif row["dataset"] == "sider":
            text = "Side effects: " + ", ".join([sider_labels[i] for i in active])
        else:
            text = "ADR: " + ", ".join([common_adrs[i] for i in active])
    return str(smiles) + " [SEP] " + text

# Only create input_text for non-empty dfs
if not df_tox.empty:
    df_tox["input_text"] = df_tox.apply(create_input, axis=1)
if not df_sider.empty:
    df_sider["input_text"] = df_sider.apply(create_input, axis=1)
if not df_drug.empty:
    df_drug["input_text"] = df_drug.apply(create_input, axis=1)

In [26]:
# Gộp (skip any empty)
parts = []
for d in [df_tox, df_sider, df_drug]:
    if not d.empty:
        # ensure columns consistent
        if 'smiles' not in d.columns and 'SMILES' in d.columns:
            d = d.rename(columns={'SMILES':'smiles'})
        parts.append(d[['smiles', 'labels', 'dataset', 'input_text']])
df_all = pd.concat(parts, ignore_index=True) if parts else pd.DataFrame(columns=['smiles','labels','dataset','input_text'])
print(f"Total: {len(df_all)} samples")

Total: 21754 samples


In [27]:
print(df_all.head())
print("Label lengths example:", [len(x) for x in df_all['labels'].head().tolist()])

                                              smiles  \
0                       CCOc1ccc2nc(S(N)(=O)=O)sc2c1   
1                          CCN1C(=O)NC(c2ccccc2)C1=O   
2  CC[C@]1(O)CC[C@H]2[C@@H]3CCC4=CCCC[C@@H]4[C@H]...   
3                    CCCN(CC)C(CC)C(=O)Nc1c(C)cccc1C   
4                          CC(O)(P(=O)(O)O)P(=O)(O)O   

                                              labels dataset  \
0  [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ...   tox21   
1  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   tox21   
2  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   tox21   
3  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   tox21   
4  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   tox21   

                                          input_text  
0  CCOc1ccc2nc(S(N)(=O)=O)sc2c1 [SEP] Tox21 activ...  
1        CCN1C(=O)NC(c2ccccc2)C1=O [SEP] No activity  
2  CC[C@]1(O)CC[C@H]2[C@@H]3CCC4=CCCC[C@@H]4[C@H]...  
3  CCCN(CC)C(CC)C(=O)Nc1c(C)cccc1C [SEP] No activity  
4  

## 3. ToxBERT

In [28]:
PREFERRED = ["Exscientia/ToxBERT", "seyonec/ChemBERTa-zinc-base-v1", "bert-base-uncased"]
MODEL_NAME = None
tokenizer = None
backbone = None

for cand in PREFERRED:
    try:
        print(f"Trying to load model: {cand} ...")
        tokenizer = AutoTokenizer.from_pretrained(cand)
        backbone = AutoModel.from_pretrained(cand)
        MODEL_NAME = cand
        print(f"Loaded model: {cand}")
        break
    except Exception as e:
        print(f"Could not load {cand}: {str(e)[:200]}")

if MODEL_NAME is None:
    raise RuntimeError("Failed to load any model. Please provide a local pretrained model or set INTERNET access to HuggingFace.")


Trying to load model: Exscientia/ToxBERT ...
Could not load Exscientia/ToxBERT: Exscientia/ToxBERT is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to t
Trying to load model: seyonec/ChemBERTa-zinc-base-v1 ...
Loaded model: seyonec/ChemBERTa-zinc-base-v1


## 4. Dataset Class

In [29]:
class MultiTaskDataset:
    def __init__(self, df, tokenizer, max_len=256):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        encoding = self.tokenizer(
            row["input_text"],
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(row["labels"], dtype=torch.float),
            'dataset': row["dataset"]
        }

# If df_all is empty -> abort early
if df_all.empty:
    raise RuntimeError("No data in df_all. Check your input CSV/XML files.")

# Split (stratify by dataset)
train_df, temp_df = train_test_split(df_all, test_size=0.3, random_state=42, stratify=df_all["dataset"])
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df["dataset"])

train_ds = MultiTaskDataset(train_df, tokenizer)
val_ds = MultiTaskDataset(val_df, tokenizer)
test_ds = MultiTaskDataset(test_df, tokenizer)

print(f"Train/Val/Test: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}")


Train/Val/Test: 15227/3263/3264


In [30]:
# Split
train_df, temp_df = train_test_split(df_all, test_size=0.3, random_state=42, stratify=df_all["dataset"])
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df["dataset"])

train_ds = MultiTaskDataset(train_df, tokenizer)
val_ds = MultiTaskDataset(val_df, tokenizer)
test_ds = MultiTaskDataset(test_df, tokenizer)

print(f"Train/Val/Test: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}")

Train/Val/Test: 15227/3263/3264


## 5. Multi-Task Model

In [31]:
class ToxBERT_MultiTask(nn.Module):
    def __init__(self, backbone, n_tox=12, n_sider=27, n_adr=10):
        super().__init__()
        self.backbone = backbone
        self.dropout = nn.Dropout(0.1)
        self.head_tox = nn.Linear(self.backbone.config.hidden_size, n_tox)
        self.head_sider = nn.Linear(self.backbone.config.hidden_size, n_sider)
        self.head_adr = nn.Linear(self.backbone.config.hidden_size, n_adr)

    def forward(self, input_ids, attention_mask):
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        pooled = self.dropout(out.last_hidden_state[:, 0])
        return {
            'tox21': self.head_tox(pooled),
            'sider': self.head_sider(pooled),
            'adr': self.head_adr(pooled),
            'embedding': pooled
        }

model = ToxBERT_MultiTask(backbone, n_tox=12, n_sider=len(sider_labels), n_adr=len(common_adrs)).to(device)

## 6. Training Loop

In [32]:
def collate_fn(batch):
    input_ids = torch.stack([b['input_ids'] for b in batch])
    attention_mask = torch.stack([b['attention_mask'] for b in batch])
    labels = torch.nn.utils.rnn.pad_sequence([b['labels'] for b in batch], batch_first=True, padding_value=0.0)
    datasets = [b['dataset'] for b in batch]
    return input_ids, attention_mask, labels, datasets

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=32, collate_fn=collate_fn)
test_loader = DataLoader(test_ds, batch_size=32, collate_fn=collate_fn)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

print("\nTRAINING...")
model.train()
NUM_EPOCHS = 5  # thử 5 epoch rồi tăng nếu cần
for epoch in range(NUM_EPOCHS):
    total_loss = 0.0
    for input_ids, attn_mask, labels, datasets in tqdm(train_loader):
        input_ids, attn_mask, labels = input_ids.to(device), attn_mask.to(device), labels.to(device)
        outputs = model(input_ids, attn_mask)

        loss = 0.0
        # compute loss per dataset in batch
        for ds, head in [('tox21', 'tox21'), ('sider', 'sider'), ('drugbank', 'adr')]:
            idx = [i for i, d in enumerate(datasets) if d == ds]
            if len(idx) == 0:
                continue
            batch_labels = labels[idx]
            batch_logits = outputs[head][idx]
            # slice labels to match head size (B x num_labels)
            num_labels = batch_logits.shape[1]
            batch_labels = batch_labels[:, :num_labels]
            loss += criterion(batch_logits, batch_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Loss: {avg_loss:.4f}")



TRAINING...


100%|██████████| 952/952 [05:10<00:00,  3.06it/s]


Epoch 1/5 - Loss: 0.5472


100%|██████████| 952/952 [05:17<00:00,  3.00it/s]


Epoch 2/5 - Loss: 0.3667


100%|██████████| 952/952 [05:17<00:00,  3.00it/s]


Epoch 4/5 - Loss: 0.2615


100%|██████████| 952/952 [05:17<00:00,  3.00it/s]

Epoch 5/5 - Loss: 0.2411





## 7. Evaluation + Threshold Tuning

In [33]:
def evaluate(loader, name):
    model.eval()
    all_probs = {'tox21': [], 'sider': [], 'drugbank': []}
    all_labels = {'tox21': [], 'sider': [], 'drugbank': []}
    all_embs = []

    with torch.no_grad():
        for input_ids, attn_mask, labels, datasets in loader:
            input_ids, attn_mask = input_ids.to(device), attn_mask.to(device)
            outputs = model(input_ids, attn_mask)
            all_embs.append(outputs['embedding'].cpu().numpy())

            for ds in set(datasets):
                idx = [i for i, d in enumerate(datasets) if d == ds]
                if not idx:
                    continue
                if ds == 'tox21':
                    probs = torch.sigmoid(outputs['tox21'][idx]).cpu().numpy()
                    batch_labels = labels[idx, :12].cpu().numpy()
                elif ds == 'sider':
                    probs = torch.sigmoid(outputs['sider'][idx]).cpu().numpy()
                    batch_labels = labels[idx, :len(sider_labels)].cpu().numpy()
                else:
                    probs = torch.sigmoid(outputs['adr'][idx]).cpu().numpy()
                    batch_labels = labels[idx, :len(common_adrs)].cpu().numpy()
                all_probs[ds].append(probs)
                all_labels[ds].append(batch_labels)

    embs = np.vstack(all_embs) if len(all_embs) > 0 else np.zeros((0, model.backbone.config.hidden_size))
    results = {}

    for ds in all_probs:
        if not all_probs[ds]:
            continue
        probs = np.vstack(all_probs[ds])
        labels = np.vstack(all_labels[ds])

        # Threshold tuning per class
        thresholds = []
        for i in range(probs.shape[1]):
            try:
                p, r, t = precision_recall_curve(labels[:, i], probs[:, i])
                f1 = 2 * p * r / (p + r + 1e-8)
                best_thr = t[np.argmax(f1)] if len(t) > 0 else 0.5
            except Exception:
                best_thr = 0.5
            thresholds.append(float(best_thr))
        thresholds = np.array(thresholds)

        preds = (probs > thresholds).astype(int)
        micro = f1_score(labels, preds, average='micro') if preds.sum() > 0 else 0.0
        macro = f1_score(labels, preds, average='macro') if preds.sum() > 0 else 0.0
        hamming = hamming_loss(labels, preds)

        results[ds] = {
            'micro_f1': micro,
            'macro_f1': macro,
            'hamming': hamming,
            'thresholds': thresholds,
            'probs': probs,
            'labels': labels,
            'preds': preds
        }
        print(f"{ds.upper()}: Micro F1={micro:.4f}, Macro F1={macro:.4f}, Hamming={hamming:.4f}")

    return results, embs

print("\nEVALUATING VAL...")
val_results, val_embs = evaluate(val_loader, "val")

print("\nEVALUATING TEST...")
test_results, test_embs = evaluate(test_loader, "test")


EVALUATING VAL...
TOX21: Micro F1=0.9551, Macro F1=0.9428, Hamming=0.0057
SIDER: Micro F1=0.8684, Macro F1=0.7981, Hamming=0.1687
DRUGBANK: Micro F1=0.8821, Macro F1=0.6261, Hamming=0.0034

EVALUATING TEST...
TOX21: Micro F1=0.9641, Macro F1=0.9577, Hamming=0.0046
SIDER: Micro F1=0.8695, Macro F1=0.7912, Hamming=0.1525
DRUGBANK: Micro F1=0.1056, Macro F1=0.5918, Hamming=0.2035


## 8. Extract Embeddings

In [34]:
os.makedirs("embeddings", exist_ok=True)

# Train embeddings
train_embs = []
with torch.no_grad():
    for input_ids, attn_mask, _, _ in train_loader:
        input_ids, attn_mask = input_ids.to(device), attn_mask.to(device)
        outputs = model(input_ids, attn_mask)
        train_embs.append(outputs['embedding'].cpu().numpy())
train_embs = np.vstack(train_embs) if train_embs else np.zeros((0, model.backbone.config.hidden_size))

torch.save(train_embs, "embeddings/train_embeddings.pt")
torch.save(val_embs, "embeddings/val_embeddings.pt")
torch.save(test_embs, "embeddings/test_embeddings.pt")

# Save unified .pt like toxbert_sep_embeddings.pt
torch.save({
    "train.embeddings": train_embs,
    "val.embeddings": val_embs,
    "test.embeddings": test_embs,
    "val.results": val_results,
    "test.results": test_results
}, "embeddings/toxbert_sep_embeddings.pt")

print("\nSAVED EMBEDDINGS:")
print("  embeddings/train_embeddings.pt")
print("  embeddings/val_embeddings.pt")
print("  embeddings/test_embeddings.pt")
print("  embeddings/toxbert_sep_embeddings.pt")

print("\nALL DONE! ToxBERT + [SEP] for 3 datasets.")


SAVED EMBEDDINGS:
  embeddings/train_embeddings.pt
  embeddings/val_embeddings.pt
  embeddings/test_embeddings.pt
  embeddings/toxbert_sep_embeddings.pt

ALL DONE! ToxBERT + [SEP] for 3 datasets.


In [36]:
import torch

# Load từng embeddings
train_embs = torch.load("/kaggle/working/embeddings/train_embeddings.pt", weights_only=False)
val_embs   = torch.load("/kaggle/working/embeddings/val_embeddings.pt", weights_only=False)
test_embs  = torch.load("/kaggle/working/embeddings/test_embeddings.pt", weights_only=False)

print("Train shape:", train_embs.shape)
print("Val shape:", val_embs.shape)
print("Test shape:", test_embs.shape)

# Load unified embeddings
toxbert_data = torch.load("/kaggle/working/embeddings/toxbert_sep_embeddings.pt", weights_only=False)

train_embs2 = toxbert_data["train.embeddings"]
val_embs2   = toxbert_data["val.embeddings"]
test_embs2  = toxbert_data["test.embeddings"]
val_results = toxbert_data["val.results"]
test_results = toxbert_data["test.results"]

print("\nUnified embeddings loaded:")
print("Train:", train_embs2.shape)
print("Val:", val_embs2.shape)
print("Test:", test_embs2.shape)
print("Val results:", val_results)
print("Test results:", test_results)


Train shape: (15227, 768)
Val shape: (3263, 768)
Test shape: (3264, 768)

Unified embeddings loaded:
Train: (15227, 768)
Val: (3263, 768)
Test: (3264, 768)
Val results: {'tox21': {'micro_f1': 0.9551422319474836, 'macro_f1': 0.9427803789830574, 'hamming': 0.005684969495285635, 'thresholds': array([0.29493877, 0.74007112, 0.90575749, 0.93684632, 0.33415473,
       0.33080855, 0.57326126, 0.78087133, 0.19239934, 0.75388807,
       0.88252515, 0.19842649]), 'probs': array([[5.4218085e-03, 2.8782657e-01, 4.9252347e-03, ..., 8.6479999e-02,
        1.3363129e-02, 8.4319063e-02],
       [5.3331151e-04, 4.0536796e-04, 8.0530671e-04, ..., 3.7391108e-04,
        8.4800116e-04, 5.6708563e-04],
       [1.0848003e-02, 7.0237080e-03, 9.0575749e-01, ..., 1.3857560e-02,
        9.8450953e-01, 2.3978163e-02],
       ...,
       [1.8545841e-01, 4.6415888e-02, 9.7472501e-01, ..., 1.1081588e-02,
        1.9614173e-02, 6.6311754e-02],
       [9.1101445e-02, 1.2810989e-01, 3.6322575e-03, ..., 5.3570112e-03,


In [37]:
toxbert_data

{'train.embeddings': array([[ 0.24542175,  0.4627934 ,  0.42310736, ..., -0.5836917 ,
         -0.6105582 , -0.5732417 ],
        [ 0.38368216, -1.0461295 , -1.4433533 , ...,  0.73406404,
          1.7207364 , -1.0521164 ],
        [-0.15846848,  0.768257  ,  0.3528445 , ..., -0.8126923 ,
         -0.78711903, -0.479071  ],
        ...,
        [ 0.20116866,  1.9362442 ,  0.834899  , ..., -0.7820928 ,
         -0.4003265 ,  0.03537117],
        [-0.31408077,  0.70126486,  0.47756952, ..., -0.8560412 ,
         -0.8567976 , -0.4752864 ],
        [-0.38004762, -0.53846455,  0.26190087, ..., -2.0576403 ,
          0.02896108, -0.25180188]], dtype=float32),
 'val.embeddings': array([[-0.41426033,  0.7869829 ,  0.43471563, ..., -0.82132787,
         -0.9143279 , -0.4941279 ],
        [ 1.006086  , -1.3784261 ,  0.1215761 , ...,  0.45979875,
         -1.4841064 , -0.27623513],
        [ 1.4632021 ,  0.94343305,  0.75765914, ..., -0.2837115 ,
          0.39640746, -0.38931254],
        ...,
 

In [None]:
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch

def test_bert_embeddings(
    bert_model_path,
    embedding_file,
    smiles_file,
    num_samples=20,
    tolerance=1e-5
):
    print("=== LOADING MODEL & TOKENIZER ===")
    tokenizer = AutoTokenizer.from_pretrained(bert_model_path)
    model = AutoModel.from_pretrained(bert_model_path)
    model.eval()

    print("=== LOADING EMBEDDING FILE ===")
    emb = np.load(embedding_file)
    print("Embedding shape:", emb.shape)

    print("=== LOADING SMILES FILE ===")
    with open(smiles_file, "r") as f:
        smiles_list = [line.strip() for line in f]

    if len(smiles_list) != emb.shape[0]:
        print(f"⚠️ WARNING: SMILES count = {len(smiles_list)}, embedding count = {emb.shape[0]} → Không khớp!")
    else:
        print(f"OK: {len(smiles_list)} SMILES khớp với embedding count")

    print("\n=== TESTING FIRST SAMPLES ===")

    for i, smi in enumerate(smiles_list[:num_samples]):
        tokens = tokenizer(
            smi,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )

        with torch.no_grad():
            outputs = model(**tokens)

            # Lấy embedding CLS token (chuẩn nhất)
            cls_emb = outputs.last_hidden_state[:, 0, :].squeeze(0).numpy()

        # Tính độ lệch
        diff = np.abs(cls_emb - emb[i]).mean()

        print(f"Sample {i} | SMILES: {smi[:50]}... | mean abs diff = {diff}")

        if diff > tolerance:
            print("❌ FAIL: embedding không khớp – file này KHÔNG PHẢI embedding từ mô hình này hoặc không phải từ SMILES này!")
            return False

    print("\n✅ SUCCESS: Toàn bộ embedding trùng khớp → file embedding là từ đúng mô hình + đúng dữ liệu.")
    return True


# ============================
# ==== CHẠY TEST TẠI ĐÂY =====
# ============================

bert_model_path = "../ToxBERT/checkpoint"           # path model BERT của bạn
embedding_file = "../ToxBERT/embedding/train_embeddings.npy"
smiles_file = "../Tox21/smiles_train.txt"          # file SMILES tương ứng với embeddings

test_bert_embeddings(
    bert_model_path=bert_model_path,
    embedding_file=embedding_file,
    smiles_file=smiles_file,
    num_samples=20,      # test 20 mẫu đầu tiên
    tolerance=1e-5       # độ sai lệch cho phép
)
 