In [None]:
# default_exp inference.general

# General inference methods

> Various general methods to be used during inference

In [None]:
# hide
from nbdev.showdoc import *
from fastcore.test import *

%load_ext autoreload
%autoreload 2

In [None]:
# export 
import pandas as pd
import numpy as np
from pytorch_lightning.callbacks import BasePredictionWriter
import torch

In [None]:
# export
def convert_test_df(df):
    '''Converts a df designed for testing to a train/val df as this easier to use for fastai batch inference'''
    dummy_train = df.copy()
    val = df.copy()
    
    dummy_train["is_valid"] = False
    val["is_valid"] = True
    
    return val.append(dummy_train, ignore_index=True)

In [None]:
# export
class PredictionWriter(BasePredictionWriter):
    '''Blabla
    
    Parameters
    ----------
    
    splits : dict; optional
        Contains 
    
    '''
    def __init__(self, output_dir, output_file, write_interval, splits=None):
        super().__init__(write_interval)
        self.output_dir = output_dir
        self.output_file = output_file
        self.splits = splits

    def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):        
        predictions = predictions[0]
        probs, gts = zip(*predictions)
        probs, gts = torch.concat(probs, dim=0), torch.concat(gts, dim=0)
        
        preds = (probs, gts)
        if self.splits:
            preds = self._split_preds(preds)
            
        torch.save(preds, os.path.join(self.output_dir, f"{self.output_file}.pt"))
    
    def _split_preds(self, preds):
        group_names = list(self.splits.keys())
        group_sizes = list(self.splits.values())
        
        probs = torch.split(preds[0], group_sizes)
        gts = torch.split(preds[1], group_sizes)

        split_preds = dict()
        for group_name, prob, gt in zip(group_names, probs, gts):
            split_preds[group_name] = (prob, gt)
        
        return split_preds 