In [1]:
# Enable autoreload
%load_ext autoreload

# Set autoreload mode to automatically reload all modules
%autoreload 2

In [2]:
import logging
import ankh
import clearml
import numpy as np
import pandas as pd
import torch
import yaml
from clearml import Logger, Task
from sklearn.metrics import roc_curve, auc
from data_prepare import get_embed_clustered_df, make_folds
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torch_utils import (CustomBatchSampler, SequenceDataset,
                         custom_collate_fn, train_fn, validate_fn, load_models)

In [3]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
models = load_models(prefix_name="checkpoints/pdb2272_")

RuntimeError: Error(s) in loading state_dict for ConvBertForBinaryClassification:
	Unexpected key(s) in state_dict: "transformer_encoder.1.attention.self.query.weight", "transformer_encoder.1.attention.self.query.bias", "transformer_encoder.1.attention.self.key.weight", "transformer_encoder.1.attention.self.key.bias", "transformer_encoder.1.attention.self.value.weight", "transformer_encoder.1.attention.self.value.bias", "transformer_encoder.1.attention.self.key_conv_attn_layer.bias", "transformer_encoder.1.attention.self.key_conv_attn_layer.depthwise.weight", "transformer_encoder.1.attention.self.key_conv_attn_layer.pointwise.weight", "transformer_encoder.1.attention.self.conv_kernel_layer.weight", "transformer_encoder.1.attention.self.conv_kernel_layer.bias", "transformer_encoder.1.attention.self.conv_out_layer.weight", "transformer_encoder.1.attention.self.conv_out_layer.bias", "transformer_encoder.1.attention.output.dense.weight", "transformer_encoder.1.attention.output.dense.bias", "transformer_encoder.1.attention.output.LayerNorm.weight", "transformer_encoder.1.attention.output.LayerNorm.bias", "transformer_encoder.1.intermediate.dense.weight", "transformer_encoder.1.intermediate.dense.bias", "transformer_encoder.1.output.dense.weight", "transformer_encoder.1.output.dense.bias", "transformer_encoder.1.output.LayerNorm.weight", "transformer_encoder.1.output.LayerNorm.bias". 
	size mismatch for transformer_encoder.0.intermediate.dense.weight: copying a param with shape torch.Size([728, 1536]) from checkpoint, the shape in current model is torch.Size([1536, 1536]).
	size mismatch for transformer_encoder.0.intermediate.dense.bias: copying a param with shape torch.Size([728]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for transformer_encoder.0.output.dense.weight: copying a param with shape torch.Size([1536, 728]) from checkpoint, the shape in current model is torch.Size([1536, 1536]).

In [29]:
input_data = "pdb2272"
batch_size = 10

In [30]:
test_df = get_embed_clustered_df(
    embedding_path=f"../../../../ssd2/dbp_finder/ankh_embeddings/{input_data}_2d.h5",
    csv_path=f"../data/embeddings/input_csv/{input_data}.csv",
)

In [31]:
dataset = SequenceDataset(test_df)
# sampler = CustomBatchSampler(dataset, batch_size)
dataloader = DataLoader(
        dataset,
        num_workers=1,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=custom_collate_fn,
    )

In [37]:
prev_dim = 0
unique_dim = []
for x, y in dataloader:
    dim_l = x.shape[1]
    if dim_l != prev_dim:
        unique_dim.append(dim_l)
        prev_dim = dim_l


In [38]:
unique_dim

[979,
 924,
 785,
 847,
 857,
 943,
 694,
 977,
 790,
 765,
 956,
 1017,
 732,
 960,
 1012,
 956,
 886,
 865,
 645,
 457,
 925,
 829,
 974,
 865,
 646,
 910,
 714,
 4911,
 1507,
 1627,
 985,
 1244,
 1412,
 1368,
 1530,
 1380,
 1622,
 648,
 1223,
 1129,
 1250,
 1226,
 1718,
 1128,
 2286,
 1250,
 1137,
 2649,
 4128,
 1351,
 1969,
 1235,
 829,
 1528,
 2716,
 811,
 1025,
 1129,
 1188,
 2717,
 1107,
 886,
 733,
 1132,
 1314,
 1333,
 1196,
 2517,
 3726,
 766,
 1756,
 2492,
 1613,
 1589,
 1461,
 982,
 1148,
 1581,
 1044,
 722,
 1380,
 742,
 880,
 1209,
 1193,
 1455,
 1487,
 866,
 3664,
 2452,
 2248,
 3130,
 1853,
 944,
 2175,
 1151,
 2150,
 1104,
 1403,
 1281,
 725,
 602,
 2130,
 2715,
 1174,
 1487,
 1432,
 1542,
 1323,
 1109,
 1048,
 1226,
 1895,
 1878,
 1044,
 1358,
 1383,
 1267,
 1513,
 1510,
 1311,
 1002,
 1089,
 970,
 941,
 999,
 712,
 563,
 1426,
 1074,
 804,
 2249,
 1137,
 2442,
 5183,
 607,
 893,
 811,
 2971,
 1420,
 738,
 667,
 1123,
 552,
 955,
 1034,
 1178,
 2034,
 1744,
 525,
 607

In [14]:
a.shape

torch.Size([32, 979, 1536])

In [77]:
valid_dataset = SequenceDataset(test_df)
valid_sampler = CustomBatchSampler(valid_dataset, batch_size=5)
valid_dataloader = DataLoader(
                              valid_dataset,
                              num_workers=1,
                              batch_sampler=valid_sampler,
                              collate_fn=custom_collate_fn)

In [79]:
a, b = next(iter(valid_dataloader))

In [81]:
a.shape

torch.Size([5, 354, 1536])

In [95]:
def validate_fn(model, valid_dataloader, DEVICE):
    model.eval()
    loss = 0.0
    all_preds = []
    all_labels = []
    all_logits = []

    with torch.no_grad():
        for x, y in valid_dataloader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)

            y = y.unsqueeze(1)

            output = model(x, y)
            loss += output.loss.item()

            prob = torch.sigmoid(output.logits)
            preds = (prob > 0.5).float()

            all_logits.extend(output.logits.cpu().numpy().flatten())
            all_preds.extend(preds.cpu().numpy().flatten())
            all_labels.extend(y.cpu().numpy().flatten())
            return all_preds

In [96]:
model = models[0]

In [97]:
all_logits = validate_fn(model, valid_dataloader, DEVICE)

In [98]:
all_logits

[1.0, 0.0, 1.0, 1.0, 1.0]

In [52]:
all_logits, all_labels = ens_logits

In [57]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

In [58]:
all_logits = np.array(all_logits)
prob = sigmoid(all_logits)

In [60]:
fpr, tpr, thresholds = roc_curve(all_labels, prob)
youden_j = tpr - fpr
optimal_idx = np.argmax(youden_j)
optimal_threshold = thresholds[optimal_idx]

print("Optimal Threshold:", optimal_threshold)

Optimal Threshold: 0.3747879175469912


In [64]:
prob = torch.tensor(prob)

In [66]:
prob.shape

torch.Size([2272])

In [67]:
all_preds = (prob >= optimal_threshold).float().tolist()

In [32]:

from sklearn.metrics import recall_score, roc_auc_score, accuracy_score, matthews_corrcoef, precision_score, f1_score


def calculate_metrics(
    all_logits: list[float], all_labels: list[float], all_preds: list[float]
) -> dict[str, float]:
    auc = roc_auc_score(all_labels, all_logits)
    accuracy = accuracy_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    specificity = recall_score(all_labels, all_preds, pos_label=0)
    f1 = f1_score(all_labels, all_preds)

    metrics_dict = {
        "Accuracy": accuracy,
        "Sensitivity": recall,
        "Specificity": specificity,
        "Precision": precision,
        "AUC": auc,
        "F1": f1,
        "MCC": mcc,
    }
    return metrics_dict

In [74]:
metrics_dict = calculate_metrics(all_logits, all_labels, all_preds)

In [75]:
metrics_dict

{'Accuracy': 0.8543133802816901,
 'Sensitivity': 0.8256721595836947,
 'Specificity': 0.8838248436103664,
 'Precision': 0.8798521256931608,
 'AUC': 0.9170869480633728,
 'F1': 0.8519015659955257,
 'MCC': 0.7102204141321365}

In [35]:
metrics_dict

{'Accuracy': 0.8516725352112676,
 'Sensitivity': 0.805724197745013,
 'Specificity': 0.8990169794459338,
 'Precision': 0.8915547024952015,
 'AUC': 0.9170869480633728,
 'F1': 0.8464692482915718,
 'MCC': 0.7070871246705799}

In [26]:
metrics_dict

{'Accuracy': 0.8516725352112676,
 'Sensitivity': 0.805724197745013,
 'Specificity': 0.8990169794459338,
 'Precision': 0.8915547024952015,
 'AUC': 0.9170869480633728,
 'F1': 0.8464692482915718,
 'MCC': 0.7070871246705799}

In [12]:

# Example data
y_true = [0.0, 0.0, 1.0, 1.0]
y_scores = [0.1, 0.4, 0.35, 0.8]
fpr, tpr, thresholds = roc_curve(y_true, y_scores)

# Calculate the AUC
roc_auc = auc(fpr, tpr)
print(f"ROC AUC: {roc_auc}")

# Find the optimal threshold
optimal_idx = np.argmax(tpr - fpr)
optimal_threshold = thresholds[optimal_idx]
print(f"Optimal Threshold: {optimal_threshold}")

ROC AUC: 0.75
Optimal Threshold: 0.8


In [5]:
test_df = get_embed_clustered_df(
    embedding_path="../data/embeddings/ankh_embeddings/pdb20000_2d.h5",
    csv_path="../data/embeddings/input_csv/pdb20000.csv",
)

In [7]:
all_logits = torch.tensor([13, -2, 14, 31])
prob = torch.sigmoid(all_logits)
all_preds = (prob > 0.5).float()

In [12]:
all_preds.tolist()

[1.0, 0.0, 1.0, 1.0]

In [10]:
a = []
a.extend(all_preds.numpy())
a

[1.0, 0.0, 1.0, 1.0]

Частота классов по батчам

In [3]:
df = get_embed_clustered_df(
    embedding_path="../data/embeddings/ankh_embeddings/train_p3_2d.h5",
    csv_path="../data/splits/train_p3.csv",
)
train_folds, valid_folds = make_folds(df)

In [None]:
train_folds[0]

In [37]:
dataset = SequenceDataset(df)
sampler = CustomBatchSampler(dataset, batch_size=64)
dataloader = DataLoader(
        dataset,
        num_workers=1,
        batch_sampler=sampler,
        collate_fn=custom_collate_fn,
    )

In [67]:
len(df)

69872

In [53]:
freq_ratio = {}
for i, (_, y) in enumerate(dataloader):
    zero_freq = torch.sum(y == 0).item()
    ones_freq = torch.sum(y == 1).item()

    freq_ratio[i] = zero_freq / ones_freq


In [58]:
ratio_values = list(freq_ratio.values())

In [74]:
np.median(ratio_values)

1.0317460317460319

In [68]:
count = 0
for i, (_, y) in enumerate(dataloader):
    count += len(y)

In [76]:
count / 64

1091.75

In [66]:
models = load_models()
model = models[0]

In [7]:
total_params = sum(p.numel() for p in model.parameters())

In [9]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

In [10]:
total_params, trainable_params

(13013775, 13013775)

In [1]:
13013775 / 1e6

13.013775

In [2]:
embed_df = load_embeddings_to_df("../data/embeddings/ankh_embeddings/not_annotated_seqs_v1_2d.h5")

In [3]:
inderence_dataset = InferenceDataset(embed_df)

In [5]:
from torch.utils.data import DataLoader

In [20]:
inference_dataloader = DataLoader(
        inderence_dataset,
        num_workers=1,
        shuffle=False,
        batch_size=1,
    )

In [9]:
id_, x = next(iter(inference_dataloader))

In [None]:
id_

In [19]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [42]:
identifiers = []
scores = []
with torch.no_grad():
    for id_, x in inference_dataloader:
        x = x.to(DEVICE)
        ens_logits = []

        for model in models:
            model.eval()
            model = model.to(DEVICE)
            output = model(x)

            logits = output.logits
            ens_logits.append(logits)

        ens_logits = torch.stack(ens_logits, dim=0)
        ens_logits = torch.mean(ens_logits, dim=0)
        prob_score = torch.sigmoid(ens_logits)

        identifiers.extend(id_)
        scores.append(prob_score.cpu().numpy().item())
        break

Предсказание на неанноитрованных данных

In [3]:
predictions_df = pd.read_csv("../data/not_annotated/predictions.csv")

In [11]:
predictions_df

Unnamed: 0,identifier,score
57449,O07917,9.999994e-01
38847,F4HRT7,9.999990e-01
29644,C9J381,9.999990e-01
11514,A0A1P8AWJ0,9.999987e-01
58007,O31637,9.999986e-01
...,...,...
8895,A0A1I9LM80,2.624521e-09
57646,O22882,2.242381e-09
81755,Q9LYX3,2.171691e-09
73864,Q8LCX3,2.169861e-09


In [12]:
predictions_df = predictions_df.sort_values(by="score", ascending=False)

In [13]:
df = pd.read_csv("../data/not_annotated/not_annotated_seqs_with_org.csv")

In [14]:
df

Unnamed: 0,identifier,sequence,Organism
0,A0A178U7Y7,MITGDSIDGMDFISSLPDEILHHILSSVPTKSAIRTSLLSKRWRYV...,Arabidopsis thaliana
1,A0A178U8Q8,MRNAVVALRHHRRVSLLRRIVAFRDDNTICLSPSLKNNELSRQTNS...,Arabidopsis thaliana
2,A0A178U9N5,MRAKGKKQSEEAPEHAIAMALLCPSLPSPNSRLFRCRSSNISSKYH...,Arabidopsis thaliana
3,A0A178UBI5,MEMDPLLRSYPDIGYKAFGNTGRVIVSIFMNLELYLVATSFLILEG...,Arabidopsis thaliana
4,A0A178UCR2,MMKMSIGTTTSGDGEMELRPGGMVVQKRTDHSSSVPRGIRVRVKYG...,Arabidopsis thaliana
...,...,...,...
86911,A0A1B2JJA1,MNGYVLVTGGAGYIGSHTVVELLNNDYLVIVVDNLSNSSYHVIKRI...,Komagataella pastoris
86912,Q96X16,MRLTNLLSLTTLVALAVAVPDFYQKREAVSSKEAALLRRDASAECV...,Komagataella pastoris
86913,A0A1B2J6E6,MSYSAEDIEVLDKKFPSLKDEFHIPTFKSLGIAPPQSKDENDDSIY...,Komagataella pastoris
86914,A0A1B2J9A8,MPSPHGGVLQDLIKRDASIKEDLLKEVPQLQSIVLTGRQLCDLELI...,Komagataella pastoris


In [15]:
df = predictions_df.merge(df, on="identifier")

In [18]:
df

Unnamed: 0,identifier,score,sequence,Organism
0,O07917,9.999994e-01,MAYVKATAILPEKLISEIQKYVQGKTIYIPKPESSHQKWGACSGTR...,Bacillus subtilis (strain 168)
1,F4HRT7,9.999990e-01,MSSSTNDYNDGNNNGVYPLSLYLSSLSGHQDIIHNPYNHQLKASPG...,Arabidopsis thaliana
2,C9J381,9.999990e-01,MADYLISGGTGYVPEDGLTAQQLFASADGLTYNDFLILPGFIDFIA...,Homo sapiens
3,A0A1P8AWJ0,9.999987e-01,MFPSLDTNGYDLFDPFIPHQTTMFPSFITHIQSPNSHHHYSSPSFP...,Arabidopsis thaliana
4,O31637,9.999986e-01,MTDQMIAWEIEEWIRDYKFMLREIKRLNRVLNKVDFISTKLTATYG...,Bacillus subtilis (strain 168)
...,...,...,...,...
86911,A0A1I9LM80,2.624521e-09,MKKSLALIVLLYFLLLMVVHIPANEALRYLPNERLGNLQFLQKGEV...,Arabidopsis thaliana
86912,O22882,2.242381e-09,MDATKIKFDVILLSFLLIISGIPSNLGLSTSVRGTTRSEPEAFHGG...,Arabidopsis thaliana
86913,Q9LYX3,2.171691e-09,MKSTVMMIFLIIYLLIAVPCFAKGSEQTDSEVYEIDYRGPETHNSR...,Arabidopsis thaliana
86914,Q8LCX3,2.169861e-09,MKRQVMIFVMLVAFFVVFLDVKQVEAMRPFPTAADEIRFVFQALQR...,Arabidopsis thaliana


In [19]:
grouped = df.groupby("Organism")

In [20]:
thresholds = [0.5, 0.6, 0.7, 0.8, 0.9]
stats = {}
for name, sub_df in grouped:
    for threshold in thresholds:
        count = len(sub_df[sub_df["score"] > threshold])
        group_name = name + "_" + str(threshold)
        stats[group_name] = count

In [23]:
stats

{'Arabidopsis thaliana_0.5': 1073,
 'Arabidopsis thaliana_0.6': 951,
 'Arabidopsis thaliana_0.7': 854,
 'Arabidopsis thaliana_0.8': 747,
 'Arabidopsis thaliana_0.9': 591,
 'Bacillus subtilis (strain 168)_0.5': 217,
 'Bacillus subtilis (strain 168)_0.6': 177,
 'Bacillus subtilis (strain 168)_0.7': 155,
 'Bacillus subtilis (strain 168)_0.8': 129,
 'Bacillus subtilis (strain 168)_0.9': 93,
 'Escherichia coli (strain K12)_0.5': 151,
 'Escherichia coli (strain K12)_0.6': 131,
 'Escherichia coli (strain K12)_0.7': 114,
 'Escherichia coli (strain K12)_0.8': 100,
 'Escherichia coli (strain K12)_0.9': 74,
 'Homo sapiens_0.5': 1597,
 'Homo sapiens_0.6': 1389,
 'Homo sapiens_0.7': 1205,
 'Homo sapiens_0.8': 1012,
 'Homo sapiens_0.9': 792,
 'Komagataella pastoris_0.5': 238,
 'Komagataella pastoris_0.6': 214,
 'Komagataella pastoris_0.7': 181,
 'Komagataella pastoris_0.8': 162,
 'Komagataella pastoris_0.9': 131,
 'Organism name not found_0.5': 0,
 'Organism name not found_0.6': 0,
 'Organism name n

In [13]:
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier

# Generate a synthetic dataset
X, y = make_classification(n_samples=1000, n_features=20, n_informative=15, n_redundant=5, random_state=42)

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Instantiate the base models
lr = LogisticRegression()
svc = SVC(probability=True)
dt = DecisionTreeClassifier()

# Instantiate the voting classifier for hard voting
voting = VotingClassifier(estimators=[
    ('lr', lr),
    ('svc', svc),
    ('dt', dt)
], voting='hard')

# Fit the model
voting.fit(X_train, y_train)

# Predict classes using hard voting
y_pred = voting.predict(X_test)

# Instantiate the voting classifier for soft voting
voting_soft = VotingClassifier(estimators=[
    ('lr', lr),
    ('svc', svc),
    ('dt', dt)
], voting='soft')

# Fit the model
voting_soft.fit(X_train, y_train)

# Predict classes using soft voting
y_pred_soft = voting_soft.predict(X_test)


In [17]:
import numpy as np
from scipy.stats import mode

# Assuming these are your predicted labels
y_true = np.array([0, 1, 1, 0, 1])
y_pred_model1 = [0, 1, 0, 0, 1]
y_pred_model2 = [0, 0, 1, 0, 1]
y_pred_model3 = [1, 1, 1, 0, 0]

# Stack predictions for majority voting
predictions = np.vstack([y_pred_model1, y_pred_model2, y_pred_model3])

# Perform majority voting
y_pred_voted, _ = mode(predictions, axis=0)

# Flatten the result
y_pred_voted = y_pred_voted.flatten()

print("Ground Truth:", y_true)
print("Predicted Labels Model 1:", y_pred_model1)
print("Predicted Labels Model 2:", y_pred_model2)
print("Predicted Labels Model 3:", y_pred_model3)
print("Majority Voted Predictions:", y_pred_voted)

Ground Truth: [0 1 1 0 1]
Predicted Labels Model 1: [0, 1, 0, 0, 1]
Predicted Labels Model 2: [0, 0, 1, 0, 1]
Predicted Labels Model 3: [1, 1, 1, 0, 0]
Majority Voted Predictions: [0 1 1 0 1]


In [26]:
y_pred_voted.tolist()

[0, 1, 1, 0, 1]

In [22]:
from collections import defaultdict

In [23]:
pred_models = defaultdict(list)

In [24]:
pred_models[0]

[]