In [None]:
import pandas as pd
from Bio import SeqIO
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import matplotlib.cm as cm

from sklearn.preprocessing import MultiLabelBinarizer

from pymmseqs.commands import easy_cluster
import os
import re

# Toxins
(taxonomy_id:33208) AND (cc_tissue_specificity:venom) AND (reviewed:true) AND (keyword:KW-0800) AND (fragment:false)

In [None]:
tox = pd.read_csv('../data/raw/tox.tsv', sep='\t')
tox = tox.dropna(subset=["Protein families"])

tox

In [None]:
tox['Protein families'] = tox['Protein families'].str.split(',').str[0]
tox['Protein families'] = tox['Protein families'].str.split(';').str[0]

In [None]:
tox['Protein families'] = tox['Protein families'].replace('I1 superfamily', 'Conotoxin I1 superfamily')
tox['Protein families'] = tox['Protein families'].replace('O1 superfamily', 'Conotoxin O1 superfamily')
tox['Protein families'] = tox['Protein families'].replace('O2 superfamily', 'Conotoxin O2 superfamily')
tox['Protein families'] = tox['Protein families'].replace('E superfamily', 'Conotoxin E superfamily')
tox['Protein families'] = tox['Protein families'].replace('F superfamily', 'Conotoxin F superfamily')
tox['Protein families'] = tox['Protein families'].replace('Conotoxin M family', 'Conotoxin M superfamily')
tox['Protein families'] = tox['Protein families'].replace('Conotoxin B2 family', 'Conotoxin B2 superfamily')
tox['Protein families'] = tox['Protein families'].replace('Conotoxin O1 family', 'Conotoxin O1 superfamily')
tox['Protein families'] = tox['Protein families'].replace('Conotoxin O2 family', 'Conotoxin O2 superfamily')

In [None]:
mapping = {
    r'Conotoxin.*': 'Conotoxin family',
    r'Neurotoxin.*': 'Neurotoxin family',
    r'Scoloptoxin.*|Scolopendra.*': 'Scoloptoxin family',
    r'Caterpillar.*': 'Caterpillar family',
    r'Teretoxin.*': 'Teretoxin family',
    r'Limacoditoxin.*': 'Limacoditoxin family',
    r'Scutigerotoxin.*': 'Scutigerotoxin family',
    r'Cationic peptide.*': 'Cationic peptide family',
    r'Formicidae venom.*': 'Formicidae venom family',
    r'Bradykinin-potentiating peptide family|Natriuretic peptide family': 'Natriuretic, Bradykinin potentiating peptide family',
    r'.*phospholipase.*|.*Phospholipase.*': 'Phospholipase family'
}

# Apply mapping
for pattern, replacement in mapping.items():
    tox['Protein families'] = tox['Protein families'].str.replace(pattern, replacement, regex=True)

# everything with less than 3 samples is "other"
tox["Protein families"] = tox["Protein families"].where(tox["Protein families"].map(tox["Protein families"].value_counts()) >= 10, "other")

tox['Protein families'].value_counts()

# Non-Toxins
(taxonomy_id:33208) AND (reviewed:true) AND (fragment:false) NOT (keyword:KW-0800) AND ((existence:1) OR (existence:2))

In [None]:
nontox = pd.read_csv('../data/raw/nontox.tsv', sep='\t')

mask = nontox["Sequence"].str.len() <= 2000
removed = (~mask).sum()

nontox = nontox[mask].reset_index(drop=True)

nontox

### Fasta Generation

In [None]:
def write_fasta(df, filename):
    """Writes a DataFrame to a FASTA file."""
    with open(filename, "w") as f:
        for _, row in df.iterrows():
            f.write(f">{row['Entry']}\n{row['Sequence']}\n")

write_fasta(tox, "../data/raw/tox.fasta")
write_fasta(nontox, "../data/raw/nontox.fasta")

## Remove SPs

In [None]:
!signalp6 --fastafile ../data/raw/tox.fasta --output_dir ../data/sp6/tox/ --organism eukarya --mode fast --model_dir /Users/selin/Desktop/Uni/signalp6/signalp-6-package/models/

In [None]:
!signalp6 --fastafile ../data/nontox.fasta --output_dir ../data/sp6/nontox/ --organism eukarya --mode fast --model_dir /Users/selin/Desktop/Uni/signalp6/signalp-6-package/models/

In [None]:
def fasta_to_dataframe(fasta_file):
    records = SeqIO.parse(fasta_file, "fasta")
    data = []

    for record in records:
        id_part = record.id.split('|')[-1]
        data.append({"identifier": id_part, "Sequence": str(record.seq)})

    df = pd.DataFrame(data)
    return df

# SignalP6 all (processed) sequences
proc_tox = fasta_to_dataframe("../data/sp6/tox/processed_entries.fasta")
proc_nontox = fasta_to_dataframe("../data/sp6/nontox/processed_entries.fasta")
#proc_tox = proc_tox.rename(columns={'Sequence': 'Sequence'})
#proc

In [None]:
proc_tox

In [None]:
gff3_tox = pd.read_csv('../data/sp6/tox/output.gff3', sep='\t', comment='#', header=None)
gff3_nontox = pd.read_csv('../data/sp6/nontox/output.gff3', sep='\t', comment='#', header=None)

cols = [
    'identifier', 'source', 'feature_type', 'start', 'end',
    'score', 'strand', 'phase', 'attributes'
]
gff3_tox.columns = cols
gff3_nontox.columns = cols

def extract_seqid(full_seqid):
    return full_seqid.split('|')[-1].split(' ')[0]

gff3_tox['identifier'] = gff3_tox['identifier'].apply(extract_seqid)
gff3_nontox['identifier'] = gff3_nontox['identifier'].apply(extract_seqid)

gff3_tox = pd.merge(gff3_tox, proc_tox, on='identifier')
gff3_nontox = pd.merge(gff3_nontox, proc_nontox, on='identifier')

In [None]:
gff3_tox[gff3_tox['score'] < 0.8]

### merge with SP6 predictions

In [None]:
# Merge with tox, replacing 'Sequence' where Entry matches

filtered = gff3_tox[gff3_tox['score'] > 0.8][['identifier', 'Sequence']]
filtered = filtered.rename(columns={'identifier': 'Entry'})
tox.update(filtered.set_index('Entry'))

filtered = gff3_nontox[gff3_nontox['score'] > 0.8][['identifier', 'Sequence']]
filtered = filtered.rename(columns={'identifier': 'Entry'})
nontox.update(filtered.set_index('Entry'))

In [None]:
write_fasta(tox, "../data/interm/tox_noSP.fasta")
write_fasta(nontox, "../data/interm/nontox_noSP.fasta")

## Clustering
### run mmseqs2 90% sequence similarity clustering per protein family

In [None]:
out_dir = "../data/families/"
os.makedirs(out_dir, exist_ok=True)

def sanitize_filename(name):
    return re.sub(r"[^a-zA-Z0-9_-]", "_", name)

failed = []

for family, group in tox.groupby("Protein families"):
    safe_family = sanitize_filename(family)

    fasta_path = os.path.join(out_dir, f"{safe_family}.fasta")
    write_fasta(group, fasta_path)

    # Create family-specific mmseqs directory
    family_mmseqs_dir = os.path.join("/Users/selin/PycharmProjects/ToxFam/data/mmseqs", safe_family)
    os.makedirs(family_mmseqs_dir, exist_ok=True)

    cluster_prefix = os.path.join(family_mmseqs_dir, "cluster")
    tmp_dir = os.path.join(family_mmseqs_dir, "tmp")
    os.makedirs(tmp_dir, exist_ok=True)

    try:
        easy_cluster(
            fasta_files=fasta_path,
            cluster_prefix=cluster_prefix,
            tmp_dir=tmp_dir,
            min_seq_id=0.9
        )
    except Exception as e:
        print(f"⚠️ Skipping {safe_family} due to error: {e}")
        failed.append((fasta_path, cluster_prefix, tmp_dir))

# Print mmseqs commands for failures
if failed:
    print("\n🔁 Manual mmseqs2 commands for failed entries:\n")
    for fasta, out, tmp in failed:
        print(f"mmseqs easy-cluster {fasta} {out} {tmp} --min-seq-id 0.9")

In [None]:
mmseqs_base_dir = "/Users/selin/PycharmProjects/ToxFam/data/mmseqs"
rep_seqs = []

# Go through each family subdirectory (excluding "nontox")
for family_dir in os.listdir(mmseqs_base_dir):
    if family_dir == "nontox":
        continue

    full_path = os.path.join(mmseqs_base_dir, family_dir)
    if not os.path.isdir(full_path):
        continue

    rep_fasta = os.path.join(full_path, "cluster_rep_seq.fasta")
    if not os.path.exists(rep_fasta):
        continue

    # Parse FASTA and collect entries
    for record in SeqIO.parse(rep_fasta, "fasta"):
        rep_seqs.append({
            "Entry": record.id,
            "Sequence": str(record.seq),
        })

# Create DataFrame
rep_df = pd.DataFrame(rep_seqs).merge(tox[["Entry", "Protein families"]], on="Entry", how="left")
rep_df

In [None]:
rep_df["Protein families"] = rep_df["Protein families"].where(rep_df["Protein families"].map(rep_df["Protein families"].value_counts()) >= 3, "other")
rep_df["Protein families"].value_counts()

### Train-Val-Test sets with 70:15:15 split

In [None]:
# Ensure 'Protein families' is split only if not already a list
rep_df['Protein families'] = rep_df['Protein families'].apply(lambda x: x.split(',') if isinstance(x, str) else x)

# Create binary indicator matrix
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(rep_df['Protein families'])

# Save label classes (optional)
label_classes = mlb.classes_

# Train+val vs test split (test=15%)
msss1 = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.15, random_state=42)
idx_train_val, idx_test = next(msss1.split(rep_df, y))

rep_df_train_val = rep_df.iloc[idx_train_val].copy()  # add .copy() here
y_train_val = y[idx_train_val]

# Train vs val split (val = 15% / (1-0.15) ~0.176 of train_val)
val_size = 0.15 / (1 - 0.15)
msss2 = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=val_size, random_state=42)
idx_train, idx_val = next(msss2.split(rep_df_train_val, y_train_val))

train_df = rep_df_train_val.iloc[idx_train].copy()  # add .copy() here
val_df = rep_df_train_val.iloc[idx_val].copy()      # add .copy() here
test_df = rep_df.iloc[idx_test].copy()              # add .copy() here

# Join lists back to strings with .loc
train_df.loc[:, 'Protein families'] = train_df['Protein families'].apply(lambda x: ','.join(x) if isinstance(x, list) else x)
val_df.loc[:, 'Protein families'] = val_df['Protein families'].apply(lambda x: ','.join(x) if isinstance(x, list) else x)
test_df.loc[:, 'Protein families'] = test_df['Protein families'].apply(lambda x: ','.join(x) if isinstance(x, list) else x)

In [None]:
def print_split_sizes(train_df, val_df, test_df, total_df):
    print(f"Train size: {len(train_df)} ({len(train_df)/len(total_df)*100:.2f}%)")
    print(f"Validation size: {len(val_df)} ({len(val_df)/len(total_df)*100:.2f}%)")
    print(f"Test size: {len(test_df)} ({len(test_df)/len(total_df)*100:.2f}%)")
print_split_sizes(train_df, val_df, test_df, rep_df)

In [None]:
def plot_protein_family_distribution(train_df, val_df, test_df):
    def get_family_percentages(df):
        return df['Protein families'].str.split(',').explode().value_counts(normalize=True).sort_index()

    train_pct = get_family_percentages(train_df)
    val_pct = get_family_percentages(val_df)
    test_pct = get_family_percentages(test_df)

    all_families = sorted(set(train_pct.index) | set(val_pct.index) | set(test_pct.index))

    train_pct = train_pct.reindex(all_families, fill_value=0)
    val_pct = val_pct.reindex(all_families, fill_value=0)
    test_pct = test_pct.reindex(all_families, fill_value=0)

    df_pct = pd.DataFrame({
        'Train': train_pct,
        'Validation': val_pct,
        'Test': test_pct
    }).T

    # Generate rainbow colormap with number of families colors
    num_colors = len(all_families)
    cmap = cm.get_cmap('rainbow', num_colors)
    colors = [cmap(i) for i in range(num_colors)]

    plt.figure(figsize=(10,7), dpi=300)
    ax = df_pct.plot(kind='bar', stacked=True, color=colors)

    plt.ylabel('Percentage within split')
    plt.title('Protein Families Distribution Across Splits')

    plt.legend(title='Protein Family', loc='upper center',
           bbox_to_anchor=(0.49, -0.55),  # below plot
           fontsize=4, title_fontsize='small', ncol=4)
    plt.subplots_adjust(bottom=0.55)  # add space below plot for legend

    plt.savefig('../data/train_distribution.png', dpi=300)
    plt.show()

# Usage
plot_protein_family_distribution(train_df, val_df, test_df)

In [None]:
train_df

## redundancy reduction nontox

In [None]:
!mmseqs easy-cluster ../data/raw/nontox.fasta ../data/mmseqs/nontox/cluster ../data/mmseqs/nontox/tmp --min-seq-id 0.9

In [None]:
rep_seqs = []

nontox_dir = os.path.join(mmseqs_base_dir, "nontox")
if os.path.isdir(nontox_dir):
    rep_fasta = os.path.join(nontox_dir, "cluster_rep_seq.fasta")
    if os.path.exists(rep_fasta):
        for record in SeqIO.parse(rep_fasta, "fasta"):
            rep_seqs.append({
                "Entry": record.id,
                "Sequence": str(record.seq),
            })

# Create DataFrame
rep_df = pd.DataFrame(rep_seqs)
rep_df

In [None]:
# Split into 70% train and 30% temp
nontox_train_df, nontox_temp_df = train_test_split(rep_df, test_size=0.30, shuffle=True, random_state=42)

# Split temp into 15% val and 15% test
nontox_val_df, nontox_test_df = train_test_split(nontox_temp_df, test_size=0.50, shuffle=True, random_state=42)
print_split_sizes(nontox_train_df,nontox_val_df,nontox_test_df,rep_df)

## generate final training data
### merge tox and nontox

In [None]:
nontox_train_df['Protein families'] = 'nontox'
nontox_val_df['Protein families'] = 'nontox'
nontox_test_df['Protein families'] = 'nontox'

# Add 'Split' column to each dataframe
train_df['Split'] = 'train'
val_df['Split'] = 'val'
test_df['Split'] = 'test'

nontox_train_df['Split'] = 'train'
nontox_val_df['Split'] = 'val'
nontox_test_df['Split'] = 'test'

# Concatenate all dataframes
training_data = pd.concat([
    train_df, val_df, test_df,
    nontox_train_df, nontox_val_df, nontox_test_df
], ignore_index=True)

training_data

In [None]:
training_data = pd.read_csv("../data/interm/training_data.csv")
len(training_data["Protein families"].unique())

In [None]:
training_data.to_csv("../data/interm/training_data.csv", index=False)

In [None]:
import matplotlib.pyplot as plt

test = training_data.copy()
test["Length"] = test["Sequence"].str.len()

long_sequences = test[test["Length"] > 10000]

print(long_sequences[["Entry", "Length"]])

# Plot der Längenverteilung
plt.figure(figsize=(10, 6))
plt.hist(test["Length"], bins=100, color="skyblue", edgecolor="black")
plt.xlabel("Sequence Length")
plt.ylabel("Count")
plt.title("Distribution of Protein Sequence Lengths")
plt.grid(True)
plt.show()


In [None]:
import subprocess

NUM_CHUNKS = 20
total      = len(training_data)
chunk_size = (total + NUM_CHUNKS - 1) // NUM_CHUNKS   # ceiling division

for part in range(NUM_CHUNKS):
    start = part * chunk_size
    end   = min(start + chunk_size, total)
    if start >= total:            # nothing left
        break

    chunk = training_data.iloc[start:end]
    fasta_path = f"../data/interm/training_data_part_{part+1}.fasta"
    write_fasta(chunk, fasta_path)
    print(f"✓ part {part+1:2d}: {len(chunk):>5} seq → {fasta_path} → {embed_path}")

    embed_path = f"../data/interm/training_embeds_{part+1}.h5"
    subprocess.run(
        [
            "python", "../generate_embeds.py",
            "-i", fasta_path,
            "-o", embed_path,
            "--per-protein",
            "--max-batch", "1",
            "--max-residues", "1000",
            "--max-seq-len", "500",
        ],
        check=True,
    )


## generate embeddings