In [None]:
!git clone https://github.com/DarkAngel007-design/llmwrapper.git

# Tox21 Multitask Classiifcation with ChemBERTa (Ligthning + QLoRA)

This notebook demonstrates training a HuggingFace encoder model (ChemBERTa) on the DeepChem Tox21 dataset using PyTorch Lightning.

Features:
- DeepChem-style multitak masking (w > 0)
- Encoder-based HuggingFace backbone
- Supports frozen / full finetuning / QLoRA
- ROC-AUC and PR-AUC evaluation

Dataset:
- Tox21 (12 toxicity tasks)

In [None]:
%cd llmwrapper

In [None]:
!pip install -q \
  torch \
  transformers \
  scikit-learn \
  rdkit 
!pip install --pre deepchem

In [None]:
!pip install -U bitsandbytes peft accelerate

In [None]:
!pip uninstall numpy -y
!pip install "numpy<2"

In [None]:
import numpy as np, torch
print("NumPy:", np.__version__)
print("Torch:", torch.__version__)
print("CUDA:", torch.cuda.is_available())


In [None]:
import numpy as np
import torch
import pytorch_lightning as pl
import deepchem as dc

from llmwrapper.model import DeepChemLLM
from llmwrapper.datamodule import Tox21DataModule
from llmwrapper.lightning_module import Tox21LightningModule


## Load Tox21 dataset (DeepChem)

In [None]:
tasks, datasets, transformers = dc.molnet.load_tox21(
    featuizer="ECFP",
    splitter = "scaffold"
)

train_ds, valid_ds, _ = datasets

## Extract SMILES, labels, and task weights

In [None]:
from llmwrapper.utils import extract_smiles_and_labels

train_smiles, y_train, w_train = extract_smiles_and_labels(train_ds)
valid_smiles, y_valid, w_valid = extract_smiles_and_labels(valid_ds)

## Model configuration
Using ChemBERTa with QLoRA for efficient finetuning.


In [None]:
model = DeepChemLLM(
    model_name="seyonec/ChemBERTa-zinc-base-v1",
    n_tasks=12,
    qlora=True
)

## Lightning DataModule
Handles tokenization and batching.


In [None]:
dm = Tox21DataModule(
    model_name="seyonec/ChemBERTa-zinc-base-v1",
    train_data=(train_smiles, y_train, w_train),
    valid_data=(valid_smiles, y_valid, w_valid),
    batch_size=16,
)

## Training with PyTorch Lightning


In [None]:
lit_model = Tox21LightningModule(model)

trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=5,
    precision=16,
    log_every_n_steps=10,
)

trainer.fit(lit_model, dm)

## Final validation metrics


In [None]:
metrics = trainer.callback_metrics

print("\nFinal validation metrics:")
for k, v in metrics.items():
    if "val_roc_auc" in k or "val_pr_auc" in k or "val_n_tasks" in k:
        print(f"{k}: {v.item()}")
