In [70]:
# Enable autoreload
%load_ext autoreload

# Set autoreload mode to automatically reload all modules
%autoreload 2

In [72]:
import logging

import ankh
import clearml
import numpy as np
import torch
import yaml
from clearml import Logger, Task
from data_prepare import get_embed_clustered_df, make_folds
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torch_utils import (CustomBatchSampler, SequenceDataset,
                         custom_collate_fn, train_fn, validate_fn, load_models)

Частота классов по батчам

In [30]:
df = get_embed_clustered_df(
    embedding_path="../data/embeddings/ankh_embeddings/train_p2_2d.h5",
    csv_path="../data/splits/train_p2.csv",
)
train_folds, valid_folds = make_folds(df)

In [37]:
dataset = SequenceDataset(df)
sampler = CustomBatchSampler(dataset, batch_size=64)
dataloader = DataLoader(
        dataset,
        num_workers=1,
        batch_sampler=sampler,
        collate_fn=custom_collate_fn,
    )

In [67]:
len(df)

69872

In [53]:
freq_ratio = {}
for i, (_, y) in enumerate(dataloader):
    zero_freq = torch.sum(y == 0).item()
    ones_freq = torch.sum(y == 1).item()

    freq_ratio[i] = zero_freq / ones_freq


In [58]:
ratio_values = list(freq_ratio.values())

In [74]:
np.median(ratio_values)

1.0317460317460319

In [68]:
count = 0
for i, (_, y) in enumerate(dataloader):
    count += len(y)

In [76]:
count / 64

1091.75

In [66]:
models = load_models()
model = models[0]

In [7]:
total_params = sum(p.numel() for p in model.parameters())

In [9]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

In [10]:
total_params, trainable_params

(13013775, 13013775)

In [2]:
embed_df = load_embeddings_to_df("../data/embeddings/ankh_embeddings/not_annotated_seqs_v1_2d.h5")

In [3]:
inderence_dataset = InferenceDataset(embed_df)

In [5]:
from torch.utils.data import DataLoader

In [20]:
inference_dataloader = DataLoader(
        inderence_dataset,
        num_workers=1,
        shuffle=False,
        batch_size=1,
    )

In [9]:
id_, x = next(iter(inference_dataloader))

In [None]:
id_

In [19]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [42]:
identifiers = []
scores = []
with torch.no_grad():
    for id_, x in inference_dataloader:
        x = x.to(DEVICE)
        ens_logits = []

        for model in models:
            model.eval()
            model = model.to(DEVICE)
            output = model(x)

            logits = output.logits
            ens_logits.append(logits)

        ens_logits = torch.stack(ens_logits, dim=0)
        ens_logits = torch.mean(ens_logits, dim=0)
        prob_score = torch.sigmoid(ens_logits)

        identifiers.extend(id_)
        scores.append(prob_score.cpu().numpy().item())
        break

Предсказание на неанноитрованных данных

In [15]:
predictions_df = pd.read_csv("../data/not_annotated/predictions.csv")

In [16]:
predictions_df = predictions_df.sort_values(by="score", ascending=False)

In [17]:
df = pd.read_csv("../data/not_annotated/not_annotated_seqs_with_org.csv")

In [18]:
df["Organism"].unique()

array(['Arabidopsis thaliana', 'Bacillus subtilis (strain 168)',
       'Saccharomyces cerevisiae (strain ATCC 204508 / S288c)',
       'Escherichia coli (strain K12)', 'Homo sapiens',
       'Organism name not found', 'Komagataella pastoris'], dtype=object)

In [21]:
df = predictions_df.merge(df, on="identifier")

In [23]:
grouped = df.groupby("Organism")

In [26]:
thresholds = [0.5, 0.6, 0.7, 0.8, 0.9]
stats = {}
for name, sub_df in grouped:
    for threshold in thresholds:
        count = len(sub_df[sub_df["score"] > threshold])
        group_name = name + "_" + str(threshold)
        stats[group_name] = count