**About** : This notebook is used to validate models.

In [None]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src/

## Initialization

### Imports

In [None]:
import os
import torch

print(torch.__version__)
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
device = torch.cuda.get_device_name(0)
print(device)

In [None]:
import os
import sys
import glob
import json
import timm
import torch
import operator
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm import tqdm
from sklearn.metrics import *
from numerize.numerize import numerize

pd.set_option('display.width', 500)
pd.set_option('max_colwidth', 100)

In [None]:
from params import *
from util.plots import *
from data.dataset import *
from data.transforms import get_transfos
from data.preparation import *
from util.torch import init_distributed, count_parameters, load_model_weights, count_parameters
from util.plots import plot_sample
from model_zoo.models import define_model

## Data

In [None]:
EXP_FOLDER = "../logs/2023-07-03/35/"
# EXP_FOLDER = "../logs/2023-07-05/35/"


In [None]:

dices = json.load(open(EXP_FOLDER + "dices.json", "r"))

th, dice = max(dices.items(), key=operator.itemgetter(1))
th = float(th)

In [None]:
plt.plot(np.array(list(dices.keys())).astype(float), dices.values())
plt.axvline(th, c="salmon")
plt.xlim(th - 0.1, th + 0.1)
plt.ylim(dice - 0.01, dice + 0.002)
plt.title(f'dice={dice:.3f}, th={th:.2f}')
plt.show()

In [None]:
df = prepare_data(DATA_PATH, processed_folder="false_color/")

In [None]:
if "fold" not in df.columns:
    folds = pd.read_csv(DATA_PATH + "folds_4.csv")
    df = df.merge(folds)

In [None]:
df_val = df[df['fold'] == 0].reset_index(drop=True)

### Dataset

In [None]:
import torch
import json
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader

from params import NUM_WORKERS
from data.dataset import ContrailDataset
from data.transforms import get_transfos
from model_zoo.models import define_model
from util.torch import load_model_weights


NUM_WORKERS = 8


class Config:
    """
    Placeholder to load a config from a saved json
    """

    def __init__(self, dic):
        for k, v in dic.items():
            setattr(self, k, v)


def predict(model, dataset, loss_config, batch_size=64, device="cuda", use_fp16=False):
    """
    Perform model inference on a dataset.

    Args:
        model (nn.Module): Trained model for inference.
        dataset (Dataset): Dataset to perform inference on.
        loss_config (dict): Loss configuration.
        batch_size (int, optional): Batch size for inference. Defaults to 64.
        device (str, optional): Device to use for inference. Defaults to "cuda".
        use_fp16 (bool, optional): Whether to use mixed precision inference. Defaults to False.

    Returns:
        preds (numpy.ndarray): Predicted probabilities of shape (num_samples, num_classes).
        preds_aux (numpy.ndarray): Auxiliary predictions of shape (num_samples, num_aux_classes).
    """
    model.eval()
    preds, preds_aux = [], []

    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=NUM_WORKERS
    )

    with torch.no_grad():
        for img, _, _ in tqdm(loader):
            with torch.cuda.amp.autocast(enabled=use_fp16):
                pred, _ = model(img.cuda())

            # Get probabilities
            if loss_config["activation"] == "sigmoid":
                pred = pred.sigmoid()
            elif loss_config["activation"] == "softmax":
                pred = pred.softmax(-1)

#             if loss_config.get("activation_aux", "softmax") == "sigmoid":
#                 pred_aux = pred_aux.sigmoid()
#             elif loss_config.get("activation_aux", "softmax") == "softmax":
#                 pred_aux = pred_aux.softmax(-1)

            preds.append(pred.detach().cpu().numpy())
#             preds_aux.append(pred_aux.cpu().numpy())

    return np.concatenate(preds), []  #np.concatenate(preds_aux)


def kfold_inference(
    df,
    exp_folder,
    debug=False,
    use_fp16=False,
    save=False,
):
    """
    Perform k-fold cross-validation for model inference on the validation set.

    Args:
        df (pd.DataFrame): DataFrame containing the data.
        exp_folder (str): Path to the experiment folder.
        debug (bool, optional): Whether to run in debug mode. Defaults to False.
        use_fp16 (bool, optional): Whether to use mixed precision inference. Defaults to False.

    Returns:
        np.ndarray: Array containing the predicted probabilities for each class.
    """
    config = Config(json.load(open(exp_folder + "config.json", "r")))

    model = define_model(
        config.decoder_name,
        config.encoder_name,
        num_classes=config.num_classes,
        n_channels=config.n_channels,
        reduce_stride=config.reduce_stride,
        use_pixel_shuffle=config.use_pixel_shuffle,
        use_hypercolumns=config.use_hypercolumns,
        center=config.center,
        use_cls=config.loss_config['aux_loss_weight'] > 0,
        pretrained=False,
    ).cuda()
    model = model.cuda().eval()

    preds = []
    for fold in config.selected_folds:
        print(f"\n- Fold {fold + 1}")

        weights = exp_folder + f"{config.decoder_name}_{config.encoder_name}_{fold}.pt"
        model = load_model_weights(model, weights, verbose=1)

        df_val = df[df['fold'] == fold].reset_index(drop=True)

        dataset = ContrailDataset(
            df_val,
            transforms=get_transfos(augment=False),
        )

        pred, _ = predict(
            model,
            dataset,
            config.loss_config,
            batch_size=config.data_config["val_bs"],
            use_fp16=use_fp16,
        )
        
        if save:
            pass

        preds.append(pred)

    return np.mean(preds, 0)

In [None]:
preds = kfold_inference(df, EXP_FOLDER, use_fp16=True, save=False)

In [None]:
dataset = ContrailDataset(df_val)

In [None]:
for i in range(len(dataset)):
    img, mask, _ = dataset[i]
    
    

In [None]:
preds.shape

Done ! 