In [5]:
import yaml
import ankh

In [2]:
with open("config.yml", "r") as f:
    config = yaml.safe_load(f)

In [4]:
input_dim = config["model_config"]["input_dim"]
nhead = config["model_config"]["nhead"]
hidden_dim = config["model_config"]["hidden_dim"]
num_hidden_layers = config["model_config"]["num_hidden_layers"]
num_layers = config["model_config"]["num_layers"]
kernel_size = config["model_config"]["kernel_size"]
dropout = config["model_config"]["dropout"]
pooling = config["model_config"]["pooling"]


epochs = config["training_config"]["epochs"]
lr = config["training_config"]["lr"]
factor = config["training_config"]["factor"]
patience = config["training_config"]["patience"]
min_lr = config["training_config"]["min_lr"]
batch_size = config["training_config"]["batch_size"]
seed = config["training_config"]["seed"]
num_workers = config["training_config"]["num_workers"]

In [9]:
import torch

In [10]:
models = []
for i in range(3):  # Assuming we have 5 models
    binary_classification_model = ankh.ConvBertForBinaryClassification(
        input_dim=input_dim,
        nhead=nhead,
        hidden_dim=hidden_dim,
        num_hidden_layers=num_hidden_layers,
        num_layers=num_layers,
        kernel_size=kernel_size,
        dropout=dropout,
        pooling=pooling,
    )

    path_model = f"checkpoints/pdb2272_best_model_{i}.pth"
    binary_classification_model.load_state_dict(torch.load(path_model))
    binary_classification_model.eval()  # Set the model to evaluation mode
    models.append(binary_classification_model)

In [12]:
from data_prepare import prepare_embed_df

In [13]:
test_df = prepare_embed_df(
    embedding_path="../../../../ssd2/dbp_finder/ankh_embeddings/pdb2272_2d.h5",
    csv_path="../data/embeddings/input_csv/pdb2272.csv",
)

In [14]:
from torch.utils.data import DataLoader
from torch_utils import SequenceDataset

In [15]:
testing_set = SequenceDataset(test_df)
testing_dataloader = DataLoader(
    testing_set,
    num_workers=num_workers,
    shuffle=False,
    batch_size=1,
)

In [19]:
x, y = next(iter(testing_dataloader))

In [24]:
x.shape

torch.Size([1, 65, 1536])

In [21]:
stacked_x = x.unsqueeze(0).repeat(len(models), 1, 1, 1)

In [23]:
stacked_x.shape

torch.Size([3, 1, 65, 1536])