#### Code to train models

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src

## Imports

In [None]:
import os
import cv2
import json
import glob
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.metrics import *
from collections import Counter
from tqdm.notebook import tqdm

In [None]:
from params import *

In [None]:
from data.dataset import CovidInfDataset
from data.transforms import get_tranfos_inference

from model_zoo.models import get_model

from utils.logger import Config

## Data

### Load

## Dataset

In [None]:
from data.dataset import CovidInfDataset
from data.transforms import get_tranfos_inference

In [None]:
root = DATA_PATH + f"test_{SIZE}/"
df = pd.read_csv(DATA_PATH + f'df_test_{SIZE}.csv')

In [None]:
transforms = get_tranfos_inference()
dataset = CovidInfDataset(df, root, transforms=transforms)

In [None]:
img = dataset[0]
plt.imshow((img.numpy().transpose(1, 2, 0) + 1) / 2)
plt.show()

In [None]:
df.head()

In [None]:
df_bim = pd.read_csv(DATA_PATH + "meta_bim.csv")
df_covidx = pd.read_csv(DATA_PATH + "meta_covidx.csv")
df_ricord = pd.read_csv(DATA_PATH + "meta_ricord.csv")

In [None]:
df_bim["save_name"] = df_bim["path"]
df_covidx["save_name"] = df_covidx["path"]
df_ricord["save_name"] = df_ricord["path"]

In [None]:
df_ext = pd.concat([df_bim, df_covidx, df_ricord]).reset_index(drop=True).dropna(axis=1)

# Main test

In [None]:
from inference.predict import predict
from utils.torch import load_model_weights

def inference(
    config,
    weights,
    df,
    root_dir="",
    log_folder=None,
    flip_tta=False,
    scale_tta=False,
    suffix="test",
    save_all_folds=False,
):
    """
    Inference on the test data.
    Args:
        config (Config): Parameters.
        weights (list of strings): Model weights.
        log_folder (None or str, optional): Folder to load the weights from. Defaults to None.
        flip_tta (bool, optional): Whether to use hflip tta. Defaults to False.
        scale_tta (bool, optional): Whether to use scale tta. Defaults to False.
    """

    model = get_model(
        config.selected_model,
        use_unet=config.use_unet,
        num_classes=config.num_classes,
    ).to(config.device)
    model.zero_grad()
    
    dataset = CovidInfDataset(
        df,
        root_dir=root_dir,
        transforms=get_tranfos_inference(mean=model.mean, std=model.std),
    )
    
    preds_study, preds_img = [], []
    for i, weight in enumerate(weights):
        load_model_weights(model, weight)

        pred_study, pred_img = predict(
            model,
            dataset,
            batch_size=config.val_bs,
            num_classes=config.num_classes,
            flip_tta=flip_tta,
            scale_tta=scale_tta,
            device=config.device,
        )
        
        if log_folder is not None and save_all_folds:
            tta_suffix = "_flip" * flip_tta + "_scale" * scale_tta
            np.save(log_folder + f'preds_{suffix}_study{tta_suffix}_{i}.npy', pred_study)
            np.save(log_folder + f'preds_{suffix}_img{tta_suffix}_{i}.npy', pred_img)

        preds_study.append(pred_study)
        preds_img.append(pred_img)

    preds_study = np.mean(preds_study, 0)
    preds_img = np.mean(preds_img, 0)

    if log_folder is not None and not save_all_folds:
        tta_suffix = "_flip" * flip_tta + "_scale" * scale_tta
        np.save(log_folder + f'preds_{suffix}_study{tta_suffix}.npy', preds_study)
        np.save(log_folder + f'preds_{suffix}_img{tta_suffix}.npy', preds_img)

    return preds_study, preds_img

In [None]:
# EXP_FOLDER = LOG_PATH + "2021-07-30/4/"
# EXP_FOLDER = LOG_PATH + "2021-07-31/0/"
# EXP_FOLDER = LOG_PATH +  "aphrodeep_v2s_lung/"

config = Config(json.load(open(EXP_FOLDER + "config.json", 'r')))
weights = sorted(glob.glob(EXP_FOLDER + "*.pt"))

In [None]:
# ext data
suffix = "ext"
df = df_ext
root = ""

In [None]:
# test data
suffix = "test"
root = DATA_PATH + f"test_{SIZE}/"
df = pd.read_csv(DATA_PATH + f'df_test_{SIZE}.csv')

In [None]:
preds_study, preds_img = inference(
    config,
    weights,
    df,
    root_dir=root,
    log_folder=EXP_FOLDER,
    flip_tta=True,
    scale_tta=False,
    suffix=suffix,
    save_all_folds=True,
)