In [2]:
import torch
import torch.nn.functional as F
import numpy as np
from transformers import AutoTokenizer
from captum.attr import LayerIntegratedGradients, TokenReferenceBase
from captum.attr import visualization as viz

# Import your model
from Models.chemlt_f_model import DebertaMultiTaskModel

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [3]:
TASKS = {
    0: ("BACE", 1, "classification"),
    1: ("HIV", 1, "classification"),
    2: ("BBBP", 1, "classification"),
    3: ("ClinTox", 2, "classification"),
    4: ("Tox21", 12, "classification"),
    5: ("MUV", 17, "classification"),
    6: ("SIDER", 27, "classification"),
    7: ("ToxCast", 617, "classification"),
    8: ("Delaney", 1, "regression"),
    9: ("FreeSolv", 1, "regression"),
    10: ("Lipo", 1, "regression"),
}

In [4]:
def load_model(model_dir):
    """Load CheMLT-F model from HuggingFace format directory."""
    
    print(f"Loading model from {model_dir}...")
    
    # Initialize model architecture
    model = DebertaMultiTaskModel(
        model_path1=model_dir,
        model_path2=model_dir,
        num_labels_list=[1, 1, 1, 2, 12, 17, 27, 617, 1, 1, 1],
        problem_type_list=[
            "classification", "classification", "classification",
            "classification", "classification", "classification",
            "classification", "classification",
            "regression", "regression", "regression"
        ]
    )
    
    # Load weights
    try:
        from safetensors.torch import load_file
        state_dict = load_file(f"{model_dir}/model.safetensors")
        print("✅ Loaded from model.safetensors")
    except:
        state_dict = torch.load(f"{model_dir}/pytorch_model.bin", map_location="cpu")
        print("✅ Loaded from pytorch_model.bin")
    
    # Filter out encoder2 (for SMILES-only)
    filtered = {k: v for k, v in state_dict.items() if not k.startswith("encoder2.")}
    model.load_state_dict(filtered, strict=False)
    model.to(device)
    model.eval()
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    print("✅ Tokenizer loaded\n")
    
    return model, tokenizer

# Load the model
MODEL_DIR = "Weights/Scaffold_CheMLT-F"
model, tokenizer = load_model(MODEL_DIR)

Loading model from Weights/Scaffold_CheMLT-F...
768 768


Some weights of DebertaV2Model were not initialized from the model checkpoint at Weights/Scaffold_CheMLT-F and are newly initialized: ['embeddings.LayerNorm.bias', 'embeddings.LayerNorm.weight', 'embeddings.position_embeddings.weight', 'embeddings.word_embeddings.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.self.key_proj.bias', 'encoder.layer.0.attention.self.key_proj.weight', 'encoder.layer.0.attention.self.query_proj.bias', 'encoder.layer.0.attention.self.query_proj.weight', 'encoder.layer.0.attention.self.value_proj.bias', 'encoder.layer.0.attention.self.value_proj.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0

✅ Loaded from model.safetensors
✅ Tokenizer loaded



In [5]:
class CheMLTWrapper(torch.nn.Module):
    """Wrapper for CheMLT model to work with Captum."""
    
    def __init__(self, model, task_index, label_index=0):
        """
        Args:
            model: The CheMLT multitask model
            task_index: Which task to interpret (0-10)
            label_index: For multi-label tasks, which label to interpret (default: 0)
        """
        super().__init__()
        self.model = model
        self.task_index = task_index
        self.label_index = label_index
        self.task_name, self.num_labels, self.task_type = TASKS[task_index]
    
    def forward(self, input_ids, attention_mask):
        """
        Forward pass returning scalar prediction for the target label.
        """
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            input_ids2=None,
            attention_mask2=None,
            input_ids3=None,
            attention_mask3=None,
            task_index=self.task_index
        )
        
        logits = outputs["logits"]
        
        # For classification, return probability
        if self.task_type == "classification":
            probs = torch.sigmoid(logits)
            if self.num_labels == 1:
                return probs.squeeze(-1)
            else:
                return probs[:, self.label_index]
        # For regression, return raw value
        else:
            return logits.squeeze(-1)

In [6]:
def predict_smiles(model, tokenizer, smiles, task_index, label_index=0):
    """
    Get prediction and confidence for a SMILES string.
    
    Returns:
        pred_value: The prediction (probability or regression value)
        pred_label: Human-readable prediction
        token_ids: Token IDs for interpretation
        tokens: Actual tokens
    """
    # Tokenize
    inputs = tokenizer(
        smiles,
        return_tensors="pt",
        max_length=512,
        padding="max_length",
        truncation=True
    )
    
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    
    # Get tokens
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    
    # Create wrapper
    wrapper = CheMLTWrapper(model, task_index, label_index)
    
    # Predict
    with torch.no_grad():
        pred_value = wrapper(input_ids, attention_mask).item()
    
    # Format prediction
    task_name, num_labels, task_type = TASKS[task_index]
    
    if task_type == "classification":
        pred_label = f"Active (prob={pred_value:.3f})" if pred_value > 0.5 else f"Inactive (prob={pred_value:.3f})"
    else:
        pred_label = f"Value={pred_value:.3f}"
    
    return pred_value, pred_label, input_ids, attention_mask, tokens


In [7]:
def compute_attributions(model, input_ids, attention_mask, task_index, label_index=0):
    """
    Compute token attributions using Layer Integrated Gradients.
    
    Returns:
        attributions: Attribution scores for each token
        delta: Approximation error
    """
    # Create wrapper
    wrapper = CheMLTWrapper(model, task_index, label_index)
    
    # Get embedding layer
    embeddings = model.encoder1.embeddings
    
    # Create Layer Integrated Gradients
    lig = LayerIntegratedGradients(wrapper, embeddings)
    
    # Create baseline (PAD token)
    baseline_ids = torch.zeros_like(input_ids).long()
    baseline_ids[:] = tokenizer.pad_token_id
    baseline_mask = torch.zeros_like(attention_mask)
    
    # Compute attributions
    attributions, delta = lig.attribute(
        inputs=(input_ids, attention_mask),
        baselines=(baseline_ids, baseline_mask),
        return_convergence_delta=True,
        n_steps=50
    )
    
    # Sum across embedding dimension
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()
    
    return attributions, delta

In [12]:
def visualize_attributions(tokens, attributions, pred_label, smiles):
    """
    Create HTML visualization of token attributions.
    """
    # Filter out padding tokens
    valid_indices = [i for i, token in enumerate(tokens) if token != tokenizer.pad_token]
    tokens = [tokens[i] for i in valid_indices]
    attributions = attributions[valid_indices]
    
    # Normalize attributions for visualization
    attr_sum = np.abs(attributions).sum()
    normalized_attrs = attributions / attr_sum if attr_sum != 0 else attributions
    
    # Display header
    print(f"\n{'='*80}")
    print(f"SMILES: {smiles}")
    print(f"Prediction: {pred_label}")
    print(f"{'='*80}\n")
    
    # Create a single visualization record with all tokens
    vis_record = viz.VisualizationDataRecord(
        word_attributions=normalized_attrs.tolist(),
        pred_prob=0.0,  # Scalar float
        pred_class=pred_label,
        true_class="",
        attr_class=smiles,
        attr_score=float(np.abs(normalized_attrs).sum()),  # Scalar sum of attributions
        raw_input_ids=tokens,
        convergence_score=0.0
    )
    
    # Visualize
    html = viz.visualize_text([vis_record])
    return html


In [13]:
# Test molecules
test_molecules = {
    "Caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",  # Expected: High BBB permeability
    "Aspirin": "CC(=O)Oc1ccccc1C(=O)O",  # Expected: Low BBB permeability
    "Diazepam": "CN1C(=O)CN=C(c2ccccc2)c3cc(Cl)ccc13",  # Expected: High BBB permeability
}

task_idx = 2  # BBBP task
task_name = TASKS[task_idx][0]

print(f"\n{'#'*80}")
print(f"# Task: {task_name} - Blood-Brain Barrier Permeability")
print(f"{'#'*80}\n")

for drug_name, smiles in test_molecules.items():
    print(f"\n--- Analyzing: {drug_name} ---")
    
    # Get prediction
    pred_value, pred_label, input_ids, attention_mask, tokens = predict_smiles(
        model, tokenizer, smiles, task_idx
    )
    
    # Compute attributions
    attributions, delta = compute_attributions(
        model, input_ids, attention_mask, task_idx
    )
    
    # Visualize
    html = visualize_attributions(tokens, attributions, pred_label, smiles)
    
    # Show top contributing tokens
    valid_indices = [i for i, token in enumerate(tokens) if token != tokenizer.pad_token]
    valid_tokens = [tokens[i] for i in valid_indices]
    valid_attrs = attributions[valid_indices]
    
    # Sort by absolute attribution
    sorted_indices = np.argsort(np.abs(valid_attrs))[::-1]
    
    print("\nTop 5 Contributing Tokens:")
    for i in sorted_indices[:5]:
        token = valid_tokens[i]
        attr = valid_attrs[i]
        direction = "→ Active" if attr > 0 else "→ Inactive"
        print(f"  {token:10s} | Attribution: {attr:+.4f} {direction}")
    
    print(f"\nConvergence Delta: {delta.item():.6f}")



################################################################################
# Task: BBBP - Blood-Brain Barrier Permeability
################################################################################


--- Analyzing: Caffeine ---

SMILES: CN1C=NC2=C1C(=O)N(C(=O)N2C)C
Prediction: Active (prob=0.982)



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
,Active (prob=0.982) (0.00),CN1C=NC2=C1C(=O)N(C(=O)N2C)C,1.0,[CLS] CN 1 C = NC 2 = C 1 C (= O ) N ( C (= O ) N 2 C ) C [SEP]
,,,,



Top 5 Contributing Tokens:
  CN         | Attribution: -0.4153 → Inactive
  O          | Attribution: -0.3787 → Inactive
  2          | Attribution: +0.3639 → Active
  =          | Attribution: +0.2674 → Active
  1          | Attribution: +0.2663 → Active

Convergence Delta: -0.011888

--- Analyzing: Aspirin ---

SMILES: CC(=O)Oc1ccccc1C(=O)O
Prediction: Inactive (prob=0.369)



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
,Inactive (prob=0.369) (0.00),CC(=O)Oc1ccccc1C(=O)O,1.0,[CLS] CC (= O ) Oc 1 ccccc 1 C (= O ) O [SEP]
,,,,



Top 5 Contributing Tokens:
  O          | Attribution: -0.4621 → Inactive
  (=         | Attribution: -0.4348 → Inactive
  O          | Attribution: -0.3719 → Inactive
  1          | Attribution: +0.3693 → Active
  1          | Attribution: +0.2869 → Active

Convergence Delta: -0.009743

--- Analyzing: Diazepam ---

SMILES: CN1C(=O)CN=C(c2ccccc2)c3cc(Cl)ccc13
Prediction: Active (prob=0.981)



True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
,Active (prob=0.981) (0.00),CN1C(=O)CN=C(c2ccccc2)c3cc(Cl)ccc13,1.0,[CLS] CN 1 C (= O ) CN = C ( c 2 ccccc 2 ) c 3 cc ( Cl ) ccc 13 [SEP]
,,,,



Top 5 Contributing Tokens:
  [CLS]      | Attribution: +0.4549 → Active
  ccccc      | Attribution: +0.3913 → Active
  2          | Attribution: +0.2779 → Active
  1          | Attribution: +0.2774 → Active
  O          | Attribution: -0.2770 → Inactive

Convergence Delta: -0.024323
