In [None]:
!pip install fair-esm torch

Defaulting to user installation because normal site-packages is not writeable


In [None]:
import torch
print(torch.cuda.is_available())

True


In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    matthews_corrcoef, cohen_kappa_score
)
import matplotlib.pyplot as plt
import esm


# Load ESM model
esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model.eval()

# Dataset Class
class PeptideDataset(Dataset):
    def __init__(self, csv_file):
        df = pd.read_csv(csv_file)
        self.sequences = df['sequence'].tolist()
        self.labels = df['toxin'].tolist()

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx], self.labels[idx]


# ESM Embedding Extractor
@torch.no_grad()
def extract_esm_embeddings(sequences):
    # Convert sequences to uppercase to match ESM alphabet expectations
    sequences_upper = [seq.upper() for seq in sequences]
    data = [("seq", seq) for seq in sequences_upper] # Use uppercase sequences
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    results = esm_model(batch_tokens, repr_layers=[6], return_contacts=False)
    token_representations = results["representations"][6]

    # Mean pooling (ignore padding and BOS/EOS)
    embeddings = []
    for i, seq in enumerate(sequences_upper): # Use uppercase sequences here as well
        emb = token_representations[i, 1:len(seq)+1].mean(0)
        embeddings.append(emb)
    return torch.stack(embeddings)


# TCN Block
class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super().__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(n_inputs, n_outputs, kernel_size, stride=stride,
                      padding=padding, dilation=dilation),
            Chomp1d(padding),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Conv1d(n_outputs, n_outputs, kernel_size, stride=stride,
                      padding=padding, dilation=dilation),
            Chomp1d(padding),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return F.relu(out + res)


class TCN(nn.Module):
    def __init__(self, input_size, num_channels, kernel_size=3, dropout=0.2):
        super().__init__()
        layers = []
        for i in range(len(num_channels)):
            dilation_size = 2 ** i
            in_channels = input_size if i == 0 else num_channels[i - 1]
            out_channels = num_channels[i]
            layers.append(
                TemporalBlock(in_channels, out_channels, kernel_size, stride=1,
                              dilation=dilation_size, padding=(kernel_size - 1) * dilation_size,
                              dropout=dropout)
            )
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


# Adaptive Feature Fusion
class AdaptiveFusion(nn.Module):
    def __init__(self, esm_dim, tcn_dim):
        super().__init__()
        self.fc_esm = nn.Linear(esm_dim, esm_dim)
        self.fc_tcn = nn.Linear(tcn_dim, esm_dim)
        self.gate = nn.Sigmoid()

    def forward(self, esm_feat, tcn_feat):
        esm_proj = self.fc_esm(esm_feat)
        tcn_proj = self.fc_tcn(tcn_feat)
        gate = self.gate(esm_proj + tcn_proj)
        return gate * esm_proj + (1 - gate) * tcn_proj

# Complete Classifier Model
class MultimodalClassifier(nn.Module):
    # Changed tcn_input from 20 to 21 to match the one-hot encoding dimension
    def __init__(self, esm_dim=320, tcn_input=21, tcn_channels=[64, 128], lstm_hidden=128, num_classes=2):
        super().__init__()
        self.tcn = TCN(tcn_input, tcn_channels)
        self.fusion = AdaptiveFusion(esm_dim, tcn_channels[-1])
        self.lstm = nn.LSTM(input_size=esm_dim, hidden_size=lstm_hidden,
                            num_layers=5, batch_first=True, dropout=0.3)
        self.classifier = nn.Linear(lstm_hidden, num_classes)

    def forward(self, esm_feats, onehot_seqs):
        tcn_out = self.tcn(onehot_seqs.permute(0, 2, 1))
        tcn_summary = torch.mean(tcn_out, dim=2)  # Global pooling
        fused = self.fusion(esm_feats, tcn_summary)
        lstm_input = fused.unsqueeze(1).repeat(1, 10, 1)  # Repeat to simulate sequence
        lstm_out, _ = self.lstm(lstm_input)
        out = self.classifier(lstm_out[:, -1])
        return out

# Utility Functions
# Also update the sequence_to_onehot function to use uppercase sequences for consistency
def sequence_to_onehot(sequences, max_len=100):
    amino_acids = 'RHKDESTNQCUGPAVILMFYW'
    aa_to_idx = {aa: i for i, aa in enumerate(amino_acids)}
    onehot = torch.zeros(len(sequences), max_len, len(amino_acids))
    for i, seq in enumerate(sequences):
        # Convert sequence to uppercase before processing
        seq_upper = seq.upper()
        for j, aa in enumerate(seq_upper[:max_len]):
            if aa in aa_to_idx:
                onehot[i, j, aa_to_idx[aa]] = 1.0
    return onehot


def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=20):
    history = {"train_loss": [], "val_loss": []}
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for sequences, labels in tqdm(train_loader):
            esm_feats = extract_esm_embeddings(sequences)
            onehot_seqs = sequence_to_onehot(sequences).float()
            labels = torch.tensor(labels)

            outputs = model(esm_feats, onehot_seqs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)
        history["train_loss"].append(train_loss)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for sequences, labels in val_loader:
                esm_feats = extract_esm_embeddings(sequences)
                onehot_seqs = sequence_to_onehot(sequences).float()
                labels = torch.tensor(labels)

                outputs = model(esm_feats, onehot_seqs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        val_loss /= len(val_loader)
        history["val_loss"].append(val_loss)

        print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    return history


def evaluate_model(model, data_loader):
    all_preds, all_labels = [], []
    with torch.no_grad():
        for sequences, labels in data_loader:
            esm_feats = extract_esm_embeddings(sequences)
            onehot_seqs = sequence_to_onehot(sequences).float()
            outputs = model(esm_feats, onehot_seqs)
            preds = torch.argmax(outputs, dim=1).numpy()
            all_preds.extend(preds)
            all_labels.extend(labels)

    print("Accuracy:", accuracy_score(all_labels, all_preds))
    print("Precision:", precision_score(all_labels, all_preds))
    print("Recall:", recall_score(all_labels, all_preds))
    print("F1-score:", f1_score(all_labels, all_preds))
    print("MCC:", matthews_corrcoef(all_labels, all_preds))
    print("Cohen’s Kappa:", cohen_kappa_score(all_labels, all_preds))


# Plotting
def plot_history(history):
    plt.plot(history["train_loss"], label="Train Loss")
    plt.plot(history["val_loss"], label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Training/Validation Loss")
    plt.show()

In [1]:
# Main Execution
if __name__ == "__main__":
    dataset = PeptideDataset("toxinpred_augmented_data.csv")
    val_split = 0.2
    n_val = int(len(dataset) * val_split)
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [len(dataset) - n_val, n_val])

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8)

    model = MultimodalClassifier()
    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    history = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=50)
    plot_history(history)

    print("Train Set Evaluation:")
    evaluate_model(model, train_loader)

    print("Validation Set Evaluation:")
    evaluate_model(model, val_loader)

NameError: name 'PeptideDataset' is not defined

In [2]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

def plot_confusion_matrix(model, data_loader, title):
    all_preds, all_labels = [], []
    with torch.no_grad():
        for sequences, labels in data_loader:
            esm_feats = extract_esm_embeddings(sequences)
            onehot_seqs = sequence_to_onehot(sequences).float()
            outputs = model(esm_feats, onehot_seqs)
            preds = torch.argmax(outputs, dim=1).numpy()
            all_preds.extend(preds)
            all_labels.extend(labels)

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(6, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title(title)
    plt.show()

# After training and evaluation
plot_confusion_matrix(model, train_loader, "Confusion Matrix - Training Set")
plot_confusion_matrix(model, val_loader, "Confusion Matrix - Validation Set")

NameError: name 'model' is not defined

In [None]:
import pandas as df
df = pd.read_csv("toxinpred_augmented_data.csv")
toxin_counts = df['toxin'].value_counts()
print(toxin_counts)

toxin
0    27590
1    27590
Name: count, dtype: int64


In [None]:
# Save the trained model
torch.save(model.state_dict(), 'toxinpred_peptide_classifier.pth')

In [None]:
# To open the trained model, you would load the state dictionary back into a model instance
# First, instantiate a new model with the same architecture
loaded_model = MultimodalClassifier()

# Then, load the saved state dictionary
loaded_model.load_state_dict(torch.load('toxinpred_peptide_classifier.pth'))

# Set the model to evaluation mode
loaded_model.eval()

print("Model loaded successfully!")

# You can now use loaded_model for inference on new data.
# Example (assuming you have new_sequences):
new_sequences = ["ARGLAKL", "AAVVRR"]
new_esm_feats = extract_esm_embeddings(new_sequences)
new_onehot_seqs = sequence_to_onehot(new_sequences).float()
with torch.no_grad():
     predictions = loaded_model(new_esm_feats, new_onehot_seqs)
     predicted_classes = torch.argmax(predictions, dim=1)
     print("Predictions:", predicted_classes)

Model loaded successfully!
Predictions: tensor([1, 0])


In [None]:
#!pip install tensorflow
#import tensorflow as tf
#print(tf.config.list_physical_devices("GPU"))

In [None]:
!pip install lime

Defaulting to user installation because normal site-packages is not writeable


In [None]:
# To open the trained model, you would load the state dictionary back into a model instance
# First, instantiate a new model with the same architecture
model = MultimodalClassifier()

# Then, load the saved state dictionary
model.load_state_dict(torch.load('toxinpred_peptide_classifier.pth'))

# Set the model to evaluation mode
model.eval()

print("Model loaded successfully!")

# You can now use loaded_model for inference on new data.
# Example (assuming you have new_sequences):
new_sequences = ["ARGLAKL", "AAVVRR"]
new_esm_feats = extract_esm_embeddings(new_sequences)
new_onehot_seqs = sequence_to_onehot(new_sequences).float()
with torch.no_grad():
     predictions = model(new_esm_feats, new_onehot_seqs)
     predicted_classes = torch.argmax(predictions, dim=1)
     print("Predictions:", predicted_classes)

Model loaded successfully!
Predictions: tensor([1, 0])


In [3]:
from lime.lime_text import LimeTextExplainer
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# === Load dataset ===
df = pd.read_csv("toxinpred_augmented_data.csv")
df['sequence'] = df['sequence'].astype(str)
df['label'] = df['toxin']

# === Tokenization into k-mers ===
def seq_to_kmers(seq, k=3):
    return ' '.join([seq[i:i + k] for i in range(len(seq) - k + 1)])

k = 3
df['kmers'] = df['sequence'].apply(lambda seq: seq_to_kmers(seq, k))

# === Helper to reconstruct full sequence from k-mers ===
def kmers_to_seq(kmer_str, k=3):
    kmers = kmer_str.split()
    if not kmers:
        return ''
    return kmers[0] + ''.join([k[-1] for k in kmers[1:]])

# === Prediction wrapper for LIME (accepts k-mer strings) ===
class_names = ['non-toxin', 'toxin']

def lime_predict_kmers(kmer_texts):
    # Convert k-mer text back to raw sequences
    sequences = [kmers_to_seq(text, k) for text in kmer_texts]

    model.eval()
    with torch.no_grad():
        # Filter out invalid inputs
        valid_sequences = [seq for seq in sequences if isinstance(seq, str) and len(seq) > 0]
        if not valid_sequences:
            return np.array([[0.5, 0.5]] * len(sequences))

        # ESM and one-hot encodings
        esm_feats = extract_esm_embeddings(valid_sequences)  # tensor [B, D]
        onehot_seqs = sequence_to_onehot(valid_sequences).float()  # tensor [B, L, 4]
        outputs = model(esm_feats, onehot_seqs)
        probs = F.softmax(outputs, dim=1).cpu().numpy()

        # Reassign to full array
        full_probs = np.zeros((len(sequences), len(class_names)))
        valid_idx = 0
        for i, seq in enumerate(sequences):
            if isinstance(seq, str) and len(seq) > 0:
                full_probs[i] = probs[valid_idx]
                valid_idx += 1
            else:
                full_probs[i] = np.array([0.5, 0.5])
        return full_probs

# === LIME explainer with whitespace as k-mer separator ===
explainer = LimeTextExplainer(class_names=class_names, split_expression='\\s+')

# === Select top-N high-confidence toxic sequences ===
raw_sequences = df['sequence'].tolist()
labels = df['label'].tolist()
kmer_texts = df['kmers'].tolist()

probs = lime_predict_kmers(kmer_texts)
high_conf_ids = np.where((np.array(labels) == 1) & (probs[:, 1] > 0.8))[0]
selected_ids = high_conf_ids[:10]

# === Run LIME explanation ===
for i, idx in enumerate(selected_ids):
    kmer_input = kmer_texts[idx]
    original_seq = df.iloc[idx]['sequence']

    explanation = explainer.explain_instance(kmer_input, lime_predict_kmers, num_features=10, labels=[1], num_samples=1000)

    print(f"\n🧬 Sequence {i+1}: {original_seq}")
    print("Top influential k-mers (toward toxin):")
    for token, weight in explanation.as_list(label=1):
        print(f"  {token}: {weight:.4f}")

    # Highlight influential regions in the original sequence
    highlighted = original_seq
    sorted_kmers = sorted(explanation.as_list(label=1), key=lambda x: abs(x[1]), reverse=True)
    for kmer, _ in sorted_kmers:
        highlighted = highlighted.replace(kmer, f"<{kmer}>")
    print("Highlighted:", highlighted)

    try:
        explanation.show_in_notebook()
        fig = explanation.as_pyplot_figure(label=1)
        plt.title(f"LIME Explanation for Sequence {i+1}")
        plt.show()
    except Exception as e:
        print(f"Could not generate plot for sequence {i+1}: {e}")

ModuleNotFoundError: No module named 'lime'

In [4]:
from lime.lime_text import LimeTextExplainer
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# === Load dataset ===
df = pd.read_csv("toxinpred_augmented_data.csv")
df['sequence'] = df['sequence'].astype(str)
df['label'] = df['toxin']

# === Tokenization into k-mers ===
def seq_to_kmers(seq, k=3):
    return ' '.join([seq[i:i + k] for i in range(len(seq) - k + 1)])

k = 3
df['kmers'] = df['sequence'].apply(lambda seq: seq_to_kmers(seq, k))

# === Helper to reconstruct full sequence from k-mers ===
def kmers_to_seq(kmer_str, k=3):
    kmers = kmer_str.split()
    if not kmers:
        return ''
    return kmers[0] + ''.join([k[-1] for k in kmers[1:]])

# === Prediction wrapper for LIME (accepts k-mer strings) ===
class_names = ['toxin', 'non-toxin']

def lime_predict_kmers(kmer_texts):
    # Convert k-mer text back to raw sequences
    sequences = [kmers_to_seq(text, k) for text in kmer_texts]

    model.eval()
    with torch.no_grad():
        # Filter out invalid inputs
        valid_sequences = [seq for seq in sequences if isinstance(seq, str) and len(seq) > 0]
        if not valid_sequences:
            return np.array([[0.5, 0.5]] * len(sequences))

        # ESM and one-hot encodings
        esm_feats = extract_esm_embeddings(valid_sequences)  # tensor [B, D]
        onehot_seqs = sequence_to_onehot(valid_sequences).float()  # tensor [B, L, 4]
        outputs = model(esm_feats, onehot_seqs)
        probs = F.softmax(outputs, dim=1).cpu().numpy()

        # Reassign to full array
        full_probs = np.zeros((len(sequences), len(class_names)))
        valid_idx = 0
        for i, seq in enumerate(sequences):
            if isinstance(seq, str) and len(seq) > 0:
                full_probs[i] = probs[valid_idx]
                valid_idx += 1
            else:
                full_probs[i] = np.array([0.5, 0.5])
        return full_probs

# === LIME explainer with whitespace as k-mer separator ===
explainer = LimeTextExplainer(class_names=class_names, split_expression='\\s+')

# === Select top-N high-confidence toxic sequences ===
raw_sequences = df['sequence'].tolist()
labels = df['label'].tolist()
kmer_texts = df['kmers'].tolist()

probs = lime_predict_kmers(kmer_texts)
high_conf_ids = np.where((np.array(labels) == 0) & (probs[:, 1] > 0.8))[0]
selected_ids = high_conf_ids[:10]

# === Run LIME explanation ===
for i, idx in enumerate(selected_ids):
    kmer_input = kmer_texts[idx]
    original_seq = df.iloc[idx]['sequence']

    explanation = explainer.explain_instance(kmer_input, lime_predict_kmers, num_features=10, labels=[0], num_samples=1000)

    print(f"\n🧬 Sequence {i+1}: {original_seq}")
    print("Top influential k-mers (toward non-toxin):")
    for token, weight in explanation.as_list(label=0):
        print(f"  {token}: {weight:.4f}")

    # Highlight influential regions in the original sequence
    highlighted = original_seq
    sorted_kmers = sorted(explanation.as_list(label=0), key=lambda x: abs(x[1]), reverse=True)
    for kmer, _ in sorted_kmers:
        highlighted = highlighted.replace(kmer, f"<{kmer}>")
    print("Highlighted:", highlighted)

    try:
        explanation.show_in_notebook()
        fig = explanation.as_pyplot_figure(label=0)
        plt.title(f"LIME Explanation for Sequence {i+1}")
        plt.show()
    except Exception as e:
        print(f"Could not generate plot for sequence {i+1}: {e}")

ModuleNotFoundError: No module named 'lime'

In [None]:
!pip install anchor-exp

Defaulting to user installation because normal site-packages is not writeable




In [5]:
from anchor import anchor_text
import numpy as np
import torch
import pandas as pd

# === Load dataset ===
df = pd.read_csv("toxinpred_augmented_data.csv")
df['sequence'] = df['sequence'].astype(str)
df['label'] = df['toxin']

# === Convert sequence to k-mers ===
def seq_to_kmers(seq, k=3):
    return ' '.join([seq[i:i + k] for i in range(len(seq) - k + 1)])

# === Convert k-mers back to sequence ===
def kmers_to_seq(kmer_str, k=3):
    kmers = kmer_str.split()
    if not kmers: return ''
    return kmers[0] + ''.join([k[-1] for k in kmers[1:]])

k = 3
df['kmers'] = df['sequence'].apply(lambda seq: seq_to_kmers(seq, k))
kmer_texts = df['kmers'].tolist()
raw_sequences = df['sequence'].tolist()
labels = df['label'].tolist()

# === Prediction using ESM + one-hot ===
def predict_fn_esm(sequences):
    model.eval()
    with torch.no_grad():
        valid_sequences = [s for s in sequences if isinstance(s, str) and len(s) > 0]
        if not valid_sequences:
            return np.array([[0.5, 0.5]] * len(sequences))

        esm_feats = extract_esm_embeddings(valid_sequences)
        onehot = sequence_to_onehot(valid_sequences).float()
        outputs = model(esm_feats, onehot)
        probs = torch.softmax(outputs, dim=1).cpu().numpy()

        full_probs = np.zeros((len(sequences), 2))
        valid_idx = 0
        for i, seq in enumerate(sequences):
            if isinstance(seq, str) and len(seq) > 0:
                full_probs[i] = probs[valid_idx]
                valid_idx += 1
            else:
                full_probs[i] = np.array([0.5, 0.5])
        return full_probs

# === Wrapper for AnchorText (returns class predictions) ===
def predict_probs_kmers_anchor(kmer_seqs):
    recovered_seqs = [kmers_to_seq(text, k=k) for text in kmer_seqs]
    return predict_fn_esm(recovered_seqs)

def predict_class_kmers_anchor(kmer_seqs):
    probs = predict_probs_kmers_anchor(kmer_seqs)
    return np.argmax(probs, axis=1)

# === Dummy tokenizer for AnchorText ===
class DummyToken:
    def __init__(self, text, idx):
        self.text = text
        self.idx = idx

class DummyTokenizer:
    def __call__(self, text):
        tokens = text.split()
        return [DummyToken(token, i) for i, token in enumerate(tokens)]

# === Find top high-confidence toxic predictions ===
probs = predict_fn_esm(raw_sequences)
high_conf_ids = np.where((np.array(labels) == 1) & (probs[:, 1] > 0.8))[0]
selected_ids = high_conf_ids[:10]

# === Create AnchorText explainer ===
class_names = ['non-toxin', 'toxin']
explainer = anchor_text.AnchorText(nlp=DummyTokenizer(), class_names=class_names)

# === Explain selected instances ===
for i, idx in enumerate(selected_ids):
    print(f"\n🧬 Explaining Sequence {i+1} (Index {idx})")
    print("Original Sequence:", raw_sequences[idx])

    explanation = explainer.explain_instance(
        kmer_texts[idx],
        classifier_fn=predict_class_kmers_anchor,
        threshold=0.95
    )

    print("\n🔍 Anchor Explanation:")
    print('Anchor (if these k-mers present → toxin):', ' AND '.join(explanation.names()))
    print('Precision:', explanation.precision())
    print('Coverage:', explanation.coverage())
    explanation.show_in_notebook()

ModuleNotFoundError: No module named 'anchor'

In [6]:
from anchor import anchor_text
import numpy as np
import torch
import pandas as pd

# === Load dataset ===
df = pd.read_csv("toxinpred_augmented_data.csv")
df['sequence'] = df['sequence'].astype(str)
df['label'] = df['toxin']

# === Convert sequence to k-mers ===
def seq_to_kmers(seq, k=3):
    return ' '.join([seq[i:i + k] for i in range(len(seq) - k + 1)])

# === Convert k-mers back to sequence ===
def kmers_to_seq(kmer_str, k=3):
    kmers = kmer_str.split()
    if not kmers: return ''
    return kmers[0] + ''.join([k[-1] for k in kmers[1:]])

k = 3
df['kmers'] = df['sequence'].apply(lambda seq: seq_to_kmers(seq, k))
kmer_texts = df['kmers'].tolist()
raw_sequences = df['sequence'].tolist()
labels = df['label'].tolist()

# === Prediction using ESM + one-hot ===
def predict_fn_esm(sequences):
    model.eval()
    with torch.no_grad():
        valid_sequences = [s for s in sequences if isinstance(s, str) and len(s) > 0]
        if not valid_sequences:
            return np.array([[0.5, 0.5]] * len(sequences))

        esm_feats = extract_esm_embeddings(valid_sequences)
        onehot = sequence_to_onehot(valid_sequences).float()
        outputs = model(esm_feats, onehot)
        probs = torch.softmax(outputs, dim=1).cpu().numpy()

        full_probs = np.zeros((len(sequences), 2))
        valid_idx = 0
        for i, seq in enumerate(sequences):
            if isinstance(seq, str) and len(seq) > 0:
                full_probs[i] = probs[valid_idx]
                valid_idx += 1
            else:
                full_probs[i] = np.array([0.5, 0.5])
        return full_probs

# === Wrapper for AnchorText (returns class predictions) ===
def predict_probs_kmers_anchor(kmer_seqs):
    recovered_seqs = [kmers_to_seq(text, k=k) for text in kmer_seqs]
    return predict_fn_esm(recovered_seqs)

def predict_class_kmers_anchor(kmer_seqs):
    probs = predict_probs_kmers_anchor(kmer_seqs)
    return np.argmax(probs, axis=1)

# === Dummy tokenizer for AnchorText ===
class DummyToken:
    def __init__(self, text, idx):
        self.text = text
        self.idx = idx

class DummyTokenizer:
    def __call__(self, text):
        tokens = text.split()
        return [DummyToken(token, i) for i, token in enumerate(tokens)]

# === Find top high-confidence toxic predictions ===
probs = predict_fn_esm(raw_sequences)
high_conf_ids = np.where((np.array(labels) == 0) & (probs[:, 1] > 0.8))[0]
selected_ids = high_conf_ids[:10]

# === Create AnchorText explainer ===
class_names = ['non-toxin', 'toxin']
explainer = anchor_text.AnchorText(nlp=DummyTokenizer(), class_names=class_names)

# === Explain selected instances ===
for i, idx in enumerate(selected_ids):
    print(f"\n🧬 Explaining Sequence {i+1} (Index {idx})")
    print("Original Sequence:", raw_sequences[idx])

    explanation = explainer.explain_instance(
        kmer_texts[idx],
        classifier_fn=predict_class_kmers_anchor,
        threshold=0.95
    )

    print("\n🔍 Anchor Explanation:")
    print('Anchor (if these k-mers present → non-toxin):', ' AND '.join(explanation.names()))
    print('Precision:', explanation.precision())
    print('Coverage:', explanation.coverage())
    explanation.show_in_notebook()


ModuleNotFoundError: No module named 'anchor'

In [None]:
import shap
import torch
import numpy as np
import pandas as pd
from torch.nn import functional as F

# Load dataset
df = pd.read_csv("toxinpred_augmented_data.csv")
df['sequence'] = df['sequence'].astype(str).tolist()
df['label'] = df['toxin'].tolist()

# Define a fixed max_len for consistent feature representation
max_len_for_shap = 512

# Corrected tokenizer functions
def sequence_to_features(seq, max_len):
    """Convert sequence to flattened one-hot features with padding"""
    # Assuming amino acid sequences (RHKDESTNQCUGPAVILMFYW)
    bases = 'RHKDESTNQCUGPAVILMFYW'
    onehot = np.zeros((max_len, len(bases)))

    for i in range(min(len(seq), max_len)):  # Ensure we don't go beyond sequence length
        base = seq[i]
        if base in bases:
            onehot[i, bases.index(base)] = 1
    return onehot.flatten()

def features_to_sequence(flattened_onehot, max_len):
    """Convert flattened one-hot vector back to sequence"""
    bases = 'RHKDESTNQCUGPAVILMFYW'
    onehot = flattened_onehot.reshape((max_len, len(bases)))
    sequence = []
    for pos in onehot:
        if np.any(pos != 0):
            sequence.append(bases[np.argmax(pos)])
    return ''.join(sequence)

# Convert sequences to feature space
X = np.array([sequence_to_features(seq, max_len=max_len_for_shap) for seq in sequences])

# Use larger background for better SHAP approximation
background_indices = np.random.choice(len(X), size=min(100, len(X)), replace=False)
X_background = X[background_indices]

# Corrected prediction function
def shap_predict_fn(feature_vectors):
    # Convert feature vectors back to sequences
    reconstructed_seqs = []
    for fv in feature_vectors:
        try:
            seq = features_to_sequence(fv, max_len=max_len_for_shap)
            if len(seq) > 0:
                reconstructed_seqs.append(seq)
        except:
            pass

    if not reconstructed_seqs:
        return np.zeros((len(feature_vectors), 2)) + 0.5  # Neutral prediction for invalid inputs

    # Get model predictions
    model.eval()
    with torch.no_grad():
        esm_feats = extract_esm_embeddings(reconstructed_seqs)
        onehot_seqs = sequence_to_onehot(reconstructed_seqs).float()
        outputs = model(esm_feats, onehot_seqs)
        probs = F.softmax(outputs, dim=1).cpu().numpy()

    # Match output size to input size (some sequences might have been filtered)
    final_probs = np.zeros((len(feature_vectors), 2)) + 0.5
    valid_count = 0
    for i, fv in enumerate(feature_vectors):
        try:
            if features_to_sequence(fv, max_len_for_shap):
                final_probs[i] = probs[valid_count]
                valid_count += 1
        except:
            pass

    return final_probs

# Initialize explainer
#explainer = shap.KernelExplainer(shap_predict_fn, X_background)

# Initialize explainer with the k-mer prediction function
explainer = shap.KernelExplainer(shap_predict_ids, X_padded_kmer_ids[background_indices])

# Select examples to explain
predicted_probs = shap_predict_fn(X)  # Use feature vectors, not raw sequences
selected_ids = np.where((np.array(labels) == 1) & (predicted_probs[:, 1] > 0.8))[0][:5]

In [None]:
import pandas as pd
import numpy as np
import torch
import shap
from collections import defaultdict

# Load dataset
df = pd.read_csv("toxinpred_augmented_data.csv")
df['sequence'] = df['sequence'].astype(str).tolist()
df['label'] = df['toxin'].tolist()

# Function to tokenize peptide sequences into k-mers
def seq_to_kmers(seq, k=3):
    return ' '.join([seq[i:i+k] for i in range(len(seq) - k + 1)])

k = 3
df['kmers'] = df['sequence'].apply(lambda seq: seq_to_kmers(seq, k))
kmer_texts = df['kmers'].tolist()

# Function to convert k-mer text back to full sequence
def kmers_to_seq(kmer_str, k=3):
    kmers = kmer_str.split()
    if not kmers:
        return ''
    return kmers[0] + ''.join([kmer[-1] for kmer in kmers[1:]])

# 1. Build k-mer vocabulary and mapping to integer IDs
all_kmers = [kmer for text in kmer_texts for kmer in text.split()]
kmer_vocab = sorted(list(set(all_kmers)))
# Add a padding token to the vocabulary
padding_token = "<PAD>"
if padding_token not in kmer_vocab:
    kmer_vocab.append(padding_token)
kmer_to_id = {kmer: i for i, kmer in enumerate(kmer_vocab)}
id_to_kmer = {i: kmer for kmer, i in kmer_to_id.items()}
padding_id = kmer_to_id[padding_token]


# 2. Convert k-mer text sequences to integer ID sequences
kmer_id_sequences = []
for text in kmer_texts:
    kmer_ids = [kmer_to_id[kmer] for kmer in text.split()]
    kmer_id_sequences.append(kmer_ids)

# Determine maximum k-mer sequence length for padding
max_kmer_len = max(len(ids) for ids in kmer_id_sequences)
# After defining max_kmer_len, add a check:
print(f"Maximum k-mer sequence length: {max_kmer_len}")
if max_kmer_len > 53:  # Or whatever your limit is
    max_kmer_len = 53  # Set a safe upper limit
    print(f"Adjusted max k-mer sequence length to: {max_kmer_len}")

# Then pad with this adjusted length
X_padded_kmer_ids = np.full((len(kmer_id_sequences), max_kmer_len), padding_id, dtype=int)
for i, ids in enumerate(kmer_id_sequences):
    length_to_copy = min(len(ids), max_kmer_len)
    X_padded_kmer_ids[i, :length_to_copy] = ids[:length_to_copy]


# 3. Pad integer ID sequences
X_padded_kmer_ids = np.full((len(kmer_id_sequences), max_kmer_len), padding_id, dtype=int)
for i, ids in enumerate(kmer_id_sequences):
    X_padded_kmer_ids[i, :len(ids)] = ids


# Define the missing predict_fn function
def predict_fn(sequences):
    """
    Takes list of sequences and returns predicted probabilities as numpy array (n_samples, 2)
    Handles both ESM embeddings and one-hot encoded sequences
    """
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        # Generate ESM embeddings
        esm_feats = extract_esm_embeddings(sequences)  # Shape: [batch_size, esm_dim]

        # Generate one-hot encoded sequences
        onehot_seqs = sequence_to_onehot(sequences).float()  # Shape: [batch_size, seq_len, 4]

        # Get model predictions
        outputs = model(esm_feats, onehot_seqs)  # Now passing both required arguments
        probs = torch.softmax(outputs, dim=1).numpy()

    return probs  # Shape (n_samples, 2)


# Now the shap_predict_ids function will work since predict_fn is defined
def shap_predict_ids(padded_kmer_id_arrays):
    original_seqs = []
    for padded_row_ids in padded_kmer_id_arrays:
        kmers_without_padding = [id_to_kmer[id] for id in padded_row_ids if id != padding_id]
        kmer_text_str = ' '.join(kmers_without_padding)
        original_seqs.append(kmers_to_seq(kmer_text_str, k))
    return predict_fn(original_seqs)


# Fix the labels reference (was using 'labels' which wasn't defined)
labels = df['label'].tolist()
sequences = df['sequence'].tolist()

# Select high-confidence toxic predictions (using the now-defined predict_fn)
probs = predict_fn(sequences)
high_conf_ids = np.where((np.array(labels) == 1) & (probs[:, 1] > 0.8))[0]
selected_ids = high_conf_ids[:10]

# Explain each selected instance

for idx in selected_ids:
    if idx >= len(X_padded_kmer_ids):
        continue  # Skip if index is out of bounds
    print(f"Explaining sequence ID: {idx}")
    # Get the explanation for this instance
    try:
        shap_values = explainer.shap_values(X_padded_kmer_ids[idx:idx+1])
        # Rest of your explanation code...
    except Exception as e:
        print(f"Failed to explain sequence {idx}: {str(e)}")

    # Get the base value (expected value) - different for PermutationExplainer
    if hasattr(explainer, 'expected_value'):
        base_value = explainer.expected_value[1]  # For class 1
    else:
        # For PermutationExplainer, we need to calculate it differently
        base_value = np.mean([shap_predict_ids(X_padded_kmer_ids[idx:idx+1])[0][1] for _ in range(10)])

    # Extract SHAP values for class 1
    if hasattr(shap_values, 'values'):
        # Newer SHAP versions
        shap_values_instance = shap_values.values[0, :, 1]
    else:
        # Older SHAP versions
        shap_values_instance = shap_values[0, :, 1]

    # Get the actual k-mer strings for this padded sequence
    kmer_strings_for_viz = [id_to_kmer[id] for id in X_padded_kmer_ids[idx, :]]

    # Create a SHAP Explanation object for text visualization
    text_exp = shap.Explanation(
        values=shap_values_instance,
        base_values=base_value,  # Use our calculated base value
        data=kmer_strings_for_viz,
        feature_names=[str(i) for i in range(max_kmer_len)]
    )

    # Visualize
    shap.plots.text(text_exp)

Maximum k-mer sequence length: 36
Explaining sequence ID: 4415


  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
import pandas as pd
import numpy as np
import torch
import shap
from collections import defaultdict

# Function to tokenize peptide sequences into k-mers
def seq_to_kmers(seq, k=3):
    return ' '.join([seq[i:i+k] for i in range(len(seq) - k + 1)])

k = 3
df['kmers'] = df['sequence'].apply(lambda seq: seq_to_kmers(seq, k))
kmer_texts = df['kmers'].tolist()

# Function to convert k-mer text back to full sequence
def kmers_to_seq(kmer_str, k=3):
    kmers = kmer_str.split()
    if not kmers:
        return ''
    return kmers[0] + ''.join([kmer[-1] for kmer in kmers[1:]])

# 1. Build k-mer vocabulary and mapping to integer IDs
all_kmers = [kmer for text in kmer_texts for kmer in text.split()]
kmer_vocab = sorted(list(set(all_kmers)))
# Add a padding token to the vocabulary
padding_token = "<PAD>"
if padding_token not in kmer_vocab:
    kmer_vocab.append(padding_token)
kmer_to_id = {kmer: i for i, kmer in enumerate(kmer_vocab)}
id_to_kmer = {i: kmer for kmer, i in kmer_to_id.items()}
padding_id = kmer_to_id[padding_token]


# 2. Convert k-mer text sequences to integer ID sequences
kmer_id_sequences = []
for text in kmer_texts:
    kmer_ids = [kmer_to_id[kmer] for kmer in text.split()]
    kmer_id_sequences.append(kmer_ids)

# Determine maximum k-mer sequence length for padding
max_kmer_len = max(len(ids) for ids in kmer_id_sequences)
# After defining max_kmer_len, add a check:
print(f"Maximum k-mer sequence length: {max_kmer_len}")
if max_kmer_len > 53:  # Or whatever your limit is
    max_kmer_len = 53  # Set a safe upper limit
    print(f"Adjusted max k-mer sequence length to: {max_kmer_len}")

# Then pad with this adjusted length
X_padded_kmer_ids = np.full((len(kmer_id_sequences), max_kmer_len), padding_id, dtype=int)
for i, ids in enumerate(kmer_id_sequences):
    length_to_copy = min(len(ids), max_kmer_len)
    X_padded_kmer_ids[i, :length_to_copy] = ids[:length_to_copy]


# 3. Pad integer ID sequences
X_padded_kmer_ids = np.full((len(kmer_id_sequences), max_kmer_len), padding_id, dtype=int)
for i, ids in enumerate(kmer_id_sequences):
    X_padded_kmer_ids[i, :len(ids)] = ids


# Define the missing predict_fn function
def predict_fn(sequences):
    """
    Takes list of sequences and returns predicted probabilities as numpy array (n_samples, 2)
    Handles both ESM embeddings and one-hot encoded sequences
    """
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        # Generate ESM embeddings
        esm_feats = extract_esm_embeddings(sequences)  # Shape: [batch_size, esm_dim]

        # Generate one-hot encoded sequences
        onehot_seqs = sequence_to_onehot(sequences).float()  # Shape: [batch_size, seq_len, 4]

        # Get model predictions
        outputs = model(esm_feats, onehot_seqs)  # Now passing both required arguments
        probs = torch.softmax(outputs, dim=1).numpy()

    return probs  # Shape (n_samples, 2)


# Now the shap_predict_ids function will work since predict_fn is defined
def shap_predict_ids(padded_kmer_id_arrays):
    original_seqs = []
    for padded_row_ids in padded_kmer_id_arrays:
        kmers_without_padding = [id_to_kmer[id] for id in padded_row_ids if id != padding_id]
        kmer_text_str = ' '.join(kmers_without_padding)
        original_seqs.append(kmers_to_seq(kmer_text_str, k))
    return predict_fn(original_seqs)


# Fix the labels reference (was using 'labels' which wasn't defined)
labels = df['label'].tolist()
sequences = df['sequence'].tolist()

# Select high-confidence toxic predictions (using the now-defined predict_fn)
probs = predict_fn(sequences)
high_conf_ids = np.where((np.array(labels) == 0) & (probs[:, 1] > 0.8))[0]
selected_ids = high_conf_ids[:10]

# Explain each selected instance

for idx in selected_ids:
    if idx >= len(X_padded_kmer_ids):
        continue  # Skip if index is out of bounds
    print(f"Explaining sequence ID: {idx}")
    # Get the explanation for this instance
    try:
        shap_values = explainer.shap_values(X_padded_kmer_ids[idx:idx+1])
        # Rest of your explanation code...
    except Exception as e:
        print(f"Failed to explain sequence {idx}: {str(e)}")

    # Get the base value (expected value) - different for PermutationExplainer
    if hasattr(explainer, 'expected_value'):
        base_value = explainer.expected_value[0]  # For class 0
    else:
        # For PermutationExplainer, we need to calculate it differently
        base_value = np.mean([shap_predict_ids(X_padded_kmer_ids[idx:idx+1])[0][1] for _ in range(10)])

    # Extract SHAP values for class 1
    if hasattr(shap_values, 'values'):
        # Newer SHAP versions
        shap_values_instance = shap_values.values[0, :, 0]
    else:
        # Older SHAP versions
        shap_values_instance = shap_values[0, :, 0]

    # Get the actual k-mer strings for this padded sequence
    kmer_strings_for_viz = [id_to_kmer[id] for id in X_padded_kmer_ids[idx, :]]

    # Create a SHAP Explanation object for text visualization
    text_exp = shap.Explanation(
        values=shap_values_instance,
        base_values=base_value,  # Use our calculated base value
        data=kmer_strings_for_viz,
        feature_names=[str(i) for i in range(max_kmer_len)]
    )

    # Visualize
    shap.plots.text(text_exp)