In [78]:
import torch
from torch import nn
from sklearn.metrics import confusion_matrix
from torchvision.models import resnet50
from datasets.nih_cxr import NIHCXRDataset
from tasks.binary_classification import BinaryClassificationTask

saved_model_path = ...  # Path to PyTorch Lightning Bolts .ckpt file
data_dir = ...          # Path to NIX CXR data folder

In [37]:
# Re-create the PyTorch Lightning Module and load it from the saved dictionary
model = resnet50(pretrained=True)
model.fc = nn.Linear(in_features=model.fc.in_features, out_features=1)
task = BinaryClassificationTask.load_from_checkpoint(saved_model_path, model=model)
task.eval()
model = task.model

                not been set for this class (Accuracy). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                


In [38]:
val_dataset = NIHCXRDataset(data_dir, split='val', binary=True)

In [92]:
# todo: for improvement
#
# 1. Right now we're getting predictions
#    for one datapoint at a time, which is
#    not very efficient. We can use a DataLoader
#    to batch out Dataset and get predictions
#    for batches at a time.
#
# 2. PyTorch Lightning has a bunch of metrics,
#    and the confusion matrix is one of them.
#    We would be able to compute the confusion
#    matrix on-the-fly rather than at the end.

ypreds = []
labels = []
for image, label in val_dataset:
    with torch.no_grad():
        ypred = model(image.unsqueeze(dim=0))
    ypred = torch.sigmoid(ypred)
    ypreds.append(ypred)
    labels.append(label)
ypreds = torch.stack(ypreds).reshape(-1)
labels = torch.stack(labels)

In [93]:
confusion_matrix(labels, (ypreds > 0.5).float())

array([[2, 1],
       [0, 1]])