In [70]:
# Enable autoreload
%load_ext autoreload

# Set autoreload mode to automatically reload all modules
%autoreload 2

In [2]:
import logging

import ankh
import clearml
import numpy as np
import pandas as pd
import torch
import yaml
from clearml import Logger, Task
from data_prepare import get_embed_clustered_df, make_folds
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torch_utils import (CustomBatchSampler, SequenceDataset,
                         custom_collate_fn, train_fn, validate_fn, load_models)

  from .autonotebook import tqdm as notebook_tqdm


Частота классов по батчам

In [30]:
df = get_embed_clustered_df(
    embedding_path="../data/embeddings/ankh_embeddings/train_p2_2d.h5",
    csv_path="../data/splits/train_p2.csv",
)
train_folds, valid_folds = make_folds(df)

In [37]:
dataset = SequenceDataset(df)
sampler = CustomBatchSampler(dataset, batch_size=64)
dataloader = DataLoader(
        dataset,
        num_workers=1,
        batch_sampler=sampler,
        collate_fn=custom_collate_fn,
    )

In [67]:
len(df)

69872

In [53]:
freq_ratio = {}
for i, (_, y) in enumerate(dataloader):
    zero_freq = torch.sum(y == 0).item()
    ones_freq = torch.sum(y == 1).item()

    freq_ratio[i] = zero_freq / ones_freq


In [58]:
ratio_values = list(freq_ratio.values())

In [74]:
np.median(ratio_values)

1.0317460317460319

In [68]:
count = 0
for i, (_, y) in enumerate(dataloader):
    count += len(y)

In [76]:
count / 64

1091.75

In [66]:
models = load_models()
model = models[0]

In [7]:
total_params = sum(p.numel() for p in model.parameters())

In [9]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

In [10]:
total_params, trainable_params

(13013775, 13013775)

In [1]:
13013775 / 1e6

13.013775

In [2]:
embed_df = load_embeddings_to_df("../data/embeddings/ankh_embeddings/not_annotated_seqs_v1_2d.h5")

In [3]:
inderence_dataset = InferenceDataset(embed_df)

In [5]:
from torch.utils.data import DataLoader

In [20]:
inference_dataloader = DataLoader(
        inderence_dataset,
        num_workers=1,
        shuffle=False,
        batch_size=1,
    )

In [9]:
id_, x = next(iter(inference_dataloader))

In [None]:
id_

In [19]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [42]:
identifiers = []
scores = []
with torch.no_grad():
    for id_, x in inference_dataloader:
        x = x.to(DEVICE)
        ens_logits = []

        for model in models:
            model.eval()
            model = model.to(DEVICE)
            output = model(x)

            logits = output.logits
            ens_logits.append(logits)

        ens_logits = torch.stack(ens_logits, dim=0)
        ens_logits = torch.mean(ens_logits, dim=0)
        prob_score = torch.sigmoid(ens_logits)

        identifiers.extend(id_)
        scores.append(prob_score.cpu().numpy().item())
        break

Предсказание на неанноитрованных данных

In [3]:
predictions_df = pd.read_csv("../data/not_annotated/predictions.csv")

In [11]:
predictions_df

Unnamed: 0,identifier,score
57449,O07917,9.999994e-01
38847,F4HRT7,9.999990e-01
29644,C9J381,9.999990e-01
11514,A0A1P8AWJ0,9.999987e-01
58007,O31637,9.999986e-01
...,...,...
8895,A0A1I9LM80,2.624521e-09
57646,O22882,2.242381e-09
81755,Q9LYX3,2.171691e-09
73864,Q8LCX3,2.169861e-09


In [12]:
predictions_df = predictions_df.sort_values(by="score", ascending=False)

In [13]:
df = pd.read_csv("../data/not_annotated/not_annotated_seqs_with_org.csv")

In [14]:
df

Unnamed: 0,identifier,sequence,Organism
0,A0A178U7Y7,MITGDSIDGMDFISSLPDEILHHILSSVPTKSAIRTSLLSKRWRYV...,Arabidopsis thaliana
1,A0A178U8Q8,MRNAVVALRHHRRVSLLRRIVAFRDDNTICLSPSLKNNELSRQTNS...,Arabidopsis thaliana
2,A0A178U9N5,MRAKGKKQSEEAPEHAIAMALLCPSLPSPNSRLFRCRSSNISSKYH...,Arabidopsis thaliana
3,A0A178UBI5,MEMDPLLRSYPDIGYKAFGNTGRVIVSIFMNLELYLVATSFLILEG...,Arabidopsis thaliana
4,A0A178UCR2,MMKMSIGTTTSGDGEMELRPGGMVVQKRTDHSSSVPRGIRVRVKYG...,Arabidopsis thaliana
...,...,...,...
86911,A0A1B2JJA1,MNGYVLVTGGAGYIGSHTVVELLNNDYLVIVVDNLSNSSYHVIKRI...,Komagataella pastoris
86912,Q96X16,MRLTNLLSLTTLVALAVAVPDFYQKREAVSSKEAALLRRDASAECV...,Komagataella pastoris
86913,A0A1B2J6E6,MSYSAEDIEVLDKKFPSLKDEFHIPTFKSLGIAPPQSKDENDDSIY...,Komagataella pastoris
86914,A0A1B2J9A8,MPSPHGGVLQDLIKRDASIKEDLLKEVPQLQSIVLTGRQLCDLELI...,Komagataella pastoris


In [15]:
df = predictions_df.merge(df, on="identifier")

In [18]:
df

Unnamed: 0,identifier,score,sequence,Organism
0,O07917,9.999994e-01,MAYVKATAILPEKLISEIQKYVQGKTIYIPKPESSHQKWGACSGTR...,Bacillus subtilis (strain 168)
1,F4HRT7,9.999990e-01,MSSSTNDYNDGNNNGVYPLSLYLSSLSGHQDIIHNPYNHQLKASPG...,Arabidopsis thaliana
2,C9J381,9.999990e-01,MADYLISGGTGYVPEDGLTAQQLFASADGLTYNDFLILPGFIDFIA...,Homo sapiens
3,A0A1P8AWJ0,9.999987e-01,MFPSLDTNGYDLFDPFIPHQTTMFPSFITHIQSPNSHHHYSSPSFP...,Arabidopsis thaliana
4,O31637,9.999986e-01,MTDQMIAWEIEEWIRDYKFMLREIKRLNRVLNKVDFISTKLTATYG...,Bacillus subtilis (strain 168)
...,...,...,...,...
86911,A0A1I9LM80,2.624521e-09,MKKSLALIVLLYFLLLMVVHIPANEALRYLPNERLGNLQFLQKGEV...,Arabidopsis thaliana
86912,O22882,2.242381e-09,MDATKIKFDVILLSFLLIISGIPSNLGLSTSVRGTTRSEPEAFHGG...,Arabidopsis thaliana
86913,Q9LYX3,2.171691e-09,MKSTVMMIFLIIYLLIAVPCFAKGSEQTDSEVYEIDYRGPETHNSR...,Arabidopsis thaliana
86914,Q8LCX3,2.169861e-09,MKRQVMIFVMLVAFFVVFLDVKQVEAMRPFPTAADEIRFVFQALQR...,Arabidopsis thaliana


In [19]:
grouped = df.groupby("Organism")

In [20]:
thresholds = [0.5, 0.6, 0.7, 0.8, 0.9]
stats = {}
for name, sub_df in grouped:
    for threshold in thresholds:
        count = len(sub_df[sub_df["score"] > threshold])
        group_name = name + "_" + str(threshold)
        stats[group_name] = count

In [23]:
stats

{'Arabidopsis thaliana_0.5': 1073,
 'Arabidopsis thaliana_0.6': 951,
 'Arabidopsis thaliana_0.7': 854,
 'Arabidopsis thaliana_0.8': 747,
 'Arabidopsis thaliana_0.9': 591,
 'Bacillus subtilis (strain 168)_0.5': 217,
 'Bacillus subtilis (strain 168)_0.6': 177,
 'Bacillus subtilis (strain 168)_0.7': 155,
 'Bacillus subtilis (strain 168)_0.8': 129,
 'Bacillus subtilis (strain 168)_0.9': 93,
 'Escherichia coli (strain K12)_0.5': 151,
 'Escherichia coli (strain K12)_0.6': 131,
 'Escherichia coli (strain K12)_0.7': 114,
 'Escherichia coli (strain K12)_0.8': 100,
 'Escherichia coli (strain K12)_0.9': 74,
 'Homo sapiens_0.5': 1597,
 'Homo sapiens_0.6': 1389,
 'Homo sapiens_0.7': 1205,
 'Homo sapiens_0.8': 1012,
 'Homo sapiens_0.9': 792,
 'Komagataella pastoris_0.5': 238,
 'Komagataella pastoris_0.6': 214,
 'Komagataella pastoris_0.7': 181,
 'Komagataella pastoris_0.8': 162,
 'Komagataella pastoris_0.9': 131,
 'Organism name not found_0.5': 0,
 'Organism name not found_0.6': 0,
 'Organism name n

In [None]:
stats[]