In [57]:
import argparse
import torch

from histopatho.metric import auc
from histopatho.trainer import (
    slide_level_train_step,
    slide_level_val_step,
)

from histopatho.prediction import predict
from histopatho.utils import generate_output_csv
from histopatho.utils import (
    load_npy_from_dir,
    load_data_from_csv,
    shuffle_data,
    split_dataset_in_subset,
)

from HistoSSLscaling.rl_benchmarks.trainers import TorchTrainer
from HistoSSLscaling.rl_benchmarks.datasets import SlideFeaturesDataset
from HistoSSLscaling.rl_benchmarks.models import Chowder

In [58]:
import os

list_path = []
count = 0

for file in os.listdir('weights/'):
    if file.endswith('pth'):
        list_path.append(file)
        count += 1

In [59]:
device = "cuda" if torch.cuda.is_available() else "cpu"

models = [Chowder(    
    in_features=2048,
    out_features=1,
    n_top=5,
    n_bottom=5,
    mlp_hidden=[200, 100],
    mlp_activation=torch.nn.Sigmoid(),
    bias=True
    ).to(device) for _ in range(count)]
    

In [63]:
len(models)

3

In [60]:
for i in range(len(list_path)):
    models[i].load_state_dict(torch.load(os.path.join('weights', list_path[i])))


In [64]:
test_path = 'data/test_input/moco_features'
test_values = load_npy_from_dir(test_path)

In [65]:
test_values_tensor = torch.tensor(test_values)
test_values_tensor = test_values_tensor.to(device)

In [66]:
predictions = []
for model in models:
    prediction = predict(model, test_values_tensor)
    predictions.append(prediction)

In [None]:
predictions

In [68]:
ensemble_predictions = torch.stack(predictions).mean(dim=0)

In [69]:
ensemble_predictions

tensor([[0.3282],
        [0.4197],
        [0.1639],
        [0.1781],
        [0.2260],
        [0.5941],
        [0.1563],
        [0.3996],
        [0.1608],
        [0.4339],
        [0.2127],
        [0.4533],
        [0.1540],
        [0.2278],
        [0.2969],
        [0.2132],
        [0.4732],
        [0.6561],
        [0.1897],
        [0.3552],
        [0.3142],
        [0.1092],
        [0.2848],
        [0.7089],
        [0.4402],
        [0.2426],
        [0.6935],
        [0.5248],
        [0.2501],
        [0.1814],
        [0.1832],
        [0.3965],
        [0.1401],
        [0.5884],
        [0.5670],
        [0.3459],
        [0.1255],
        [0.2548],
        [0.2300],
        [0.5567],
        [0.1250],
        [0.2786],
        [0.3809],
        [0.1472],
        [0.0548],
        [0.1832],
        [0.1260],
        [0.3098],
        [0.5598],
        [0.2536],
        [0.5031],
        [0.3440],
        [0.2042],
        [0.2832],
        [0.4239],
        [0

In [70]:
len(ensemble_predictions)

149