In [1]:
!pip install deepchem transformers peft onnxruntime onnx

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import deepchem as dc
from peft import get_peft_model, LoraConfig, TaskType
import time
from sklearn.metrics import accuracy_score
from rdkit import Chem
import onnxruntime as ort
from onnxruntime.quantization import quantize_dynamic, QuantType

In [3]:
# Define paths and parameters
MODEL_NAME = "seyonec/ChemBERTa-zinc-base-v1"
MAX_LENGTH = 128
BATCH_SIZE = 32
EPOCHS = 5
LEARNING_RATE = 2e-5

# LoRA Configuration
LORA_R = 8  # Rank of LoRA
LORA_ALPHA = 16
LORA_DROPOUT = 0.1

In [12]:
class ClinToxDataset(Dataset):
    def __init__(self, data_path, tokenizer, split='train', max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.split = split

        # Load ClinTox dataset from DeepChem
        tasks, datasets, transformers = dc.molnet.load_clintox()
        train_dataset, valid_dataset, test_dataset = datasets

        # Convert to SMILES and labels
        self.smiles_train, self.labels_train = self.remove_invalid_smiles(train_dataset.ids, train_dataset.y)
        self.smiles_valid, self.labels_valid = self.remove_invalid_smiles(valid_dataset.ids, valid_dataset.y)
        self.smiles_test, self.labels_test = self.remove_invalid_smiles(test_dataset.ids, test_dataset.y)

        # Set active split
        if split == 'train':
            self.smiles = self.smiles_train
            self.labels = self.labels_train
        elif split == 'valid':
            self.smiles = self.smiles_valid
            self.labels = self.labels_valid
        elif split == 'test':
            self.smiles = self.smiles_test
            self.labels = self.labels_test
        else:
            raise ValueError("Invalid split. Use 'train', 'valid', or 'test'.")

    def __len__(self):
        return len(self.smiles)

    def __getitem__(self, idx):
        smiles = self.smiles[idx]
        label = self.labels[idx]

        encoding = self.tokenizer(
            smiles,
            return_tensors="pt",
            max_length=self.max_length,
            padding="max_length",
            truncation=True
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.float)
        }

    def remove_invalid_smiles(self, smiles, labels):
        valid_indices = []
        for i, smile in enumerate(smiles):
            try:
                mol = Chem.MolFromSmiles(smile)
                if mol is not None:
                    valid_indices.append(i)
            except:
                pass

        return smiles[valid_indices], labels[valid_indices]

In [5]:
def setup_lora_model(model_name):
    """
    Set up a model with LoRA configuration
    """
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=2,  # Binary classification for ClinTox
        return_dict=True
    )

    peft_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        inference_mode=False,
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        lora_dropout=LORA_DROPOUT,
        target_modules=["query", "value"]  # Target attention modules
    )

    model = get_peft_model(model, peft_config)
    return model

In [6]:
def train(model, tokenizer):
    # Create dataset for train split
    train_dataset = ClinToxDataset("clintox", tokenizer, split="train", max_length=MAX_LENGTH)

    # Create dataloader
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.BCEWithLogitsLoss()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(EPOCHS):
        model.train()
        total_train_loss = 0
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")

        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            loss = criterion(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item()})

    return model

In [14]:
def export(model, tokenizer, save_path, quantize=False):
    """
    Export a HuggingFace model to ONNX format. Optionally apply dynamic quantization.

    Args:
        model: Trained model to export.
        tokenizer: Associated tokenizer.
        save_path: Path to save the ONNX (or quantized) model.
        quantize: Whether to apply dynamic quantization (QInt8).
    """

    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Example input for tracing
    encoded = tokenizer(
        "CC(=O)Oc1ccccc1C(=O)O",  # SMILES example
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=128
    )
    input_ids = encoded["input_ids"].to(device)
    attention_mask = encoded["attention_mask"].to(device)

    # Use temp path if quantizing
    export_path = "temp_model.onnx" if quantize else save_path

    # Export to ONNX
    torch.onnx.export(
        model,
        (input_ids, attention_mask),
        export_path,
        input_names=["input_ids", "attention_mask"],
        output_names=["logits"],
        dynamic_axes={"input_ids": {0: "batch"}, "attention_mask": {0: "batch"}},
        opset_version=14
    )

    # Quantize if requested
    if quantize:
        quantize_dynamic(export_path, save_path, weight_type=QuantType.QInt8)
        os.remove(export_path)
        print(f"Quantized model saved to: {save_path}")
    else:
        print(f"Baseline ONNX model saved to: {save_path}")

    # Report model size
    size_mb = os.path.getsize(save_path) / (1024 * 1024)
    print(f"Model size: {size_mb:.2f} MB")

In [23]:
def evaluate_onnx_model(onnx_path):
    session = ort.InferenceSession(onnx_path)

    test_dataset = ClinToxDataset("clintox", tokenizer, split="test", max_length=MAX_LENGTH)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    all_labels = []
    all_preds = []

    start_time = time.time()

    for batch in test_dataloader:
        input_ids = batch['input_ids'].numpy()
        attention_mask = batch['attention_mask'].numpy()
        labels = batch['labels'].numpy()

        # Run inference with ONNX
        ort_inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }
        ort_outs = session.run(None, ort_inputs)  # Returns a list; usually one element: logits

        logits = ort_outs[0]
        probs = 1 / (1 + np.exp(-logits))  # Sigmoid manually
        preds = (probs > 0.5).astype(int)

        all_preds.extend(preds)
        all_labels.extend(labels)

    end_time = time.time()
    total_time = end_time - start_time

    acc = accuracy_score(all_labels, all_preds)

    print(f"Test Accuracy: {acc:.4f}")
    print(f"Total Inference Time: {total_time:.2f} s")

In [20]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = setup_lora_model(MODEL_NAME)

trained_model = train(model, tokenizer)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at seyonec/ChemBERTa-zinc-base-v1 and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/5: 100%|██████████| 37/37 [00:02<00:00, 15.51it/s, loss=0.198]
Epoch 2/5: 100%|██████████| 37/37 [00:02<00:00, 15.88it/s, loss=0.228]
Epoch 3/5: 100%|██████████| 37/37 [00:02<00:00, 15.85it/s, loss=0.269]
Epoch 4/5: 100%|██████████| 37/37 [00:02<00:00, 15.90it/s, loss=0.125]
Epoch 5/5: 100%|██████████| 37/37 [00:02<00:00, 15.92it/s, loss=0.154]


In [24]:
export(trained_model, tokenizer, save_path="model-baseline.onnx", quantize=False)

evaluate_onnx_model("model-baseline.onnx")

Baseline ONNX model saved to: model-baseline.onnx
Model size: 169.02 MB
Test Accuracy: 0.9459
Total Inference Time: 5.82 s


In [25]:
export(trained_model, tokenizer, save_path="model-quant.onnx", quantize=True)

evaluate_onnx_model("model-quant.onnx")



Quantized model saved to: model-quant.onnx
Model size: 42.75 MB
Test Accuracy: 0.9459
Total Inference Time: 2.03 s
