In [None]:
import torch
import pytorch_lightning as pl
from .deepsea import DeepSea, DeepSeaModule

In [None]:
torch.set_float32_matmul_precision('medium')

threads = snakemake.threads
devices = snakemake.params['devices']
batch_size = 512 #snakemake.params['batch_size']

In [None]:
datamodule = DeepSeaModule(
    snakemake.input['train'],
    snakemake.input['val'],
    snakemake.input['test'],
    num_workers=threads,
    batch_size=batch_size
)
dl = datamodule.test_dataloader()

In [None]:
import h5py

output_size = h5py.File(snakemake.input['train'])['traindata'].shape[0]
model = DeepSea(output_size=output_size)
model.load_state_dict(torch.load(snakemake.input['model']))

In [None]:
trainer = pl.Trainer(
    devices=devices,
    precision="bf16-mixed",
)
pred = trainer.predict(model=model, dataloaders=dl, return_predictions=True)
pred = torch.concat(pred).cpu().double().numpy()

In [None]:
_, y = datamodule._read_mat(snakemake.input['test'], 'test')

In [None]:
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

aucs = list()

for y_pred, y_true in tqdm(zip(pred.T, y.T)):
    fpr, tpr,_ = roc_curve(y_true, y_pred)    
    aucs.append(auc(fpr, tpr))
    plt.plot(fpr, tpr, c='black', lw=1, alpha=0.1)

plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(f'DeepSEA auc={np.mean(aucs)}Â±{np.std(aucs)}')
plt.savefig(snakemake.output['fig'], dpi=300)