In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import cv2
import torch
import torchvision

import pandas as pd
import numpy as np
import seaborn as sns 
import albumentations as A
import matplotlib.pyplot as plt

from pathlib import Path
from tqdm.cli import tqdm
from pytorch_toolbelt.inference import tta

from src.data import EyeDataset
from src.utils import load_splits
from src.utils import custom_to_std_tensor
from src.utils import denormalize
from src.utils import eye_blend
from src.config import load_config_from_ckpt
from src.models import ModelEnsemble
from src.train import validation

In [3]:
data_folder = Path("data")
folds_folder = Path("folds")
models_folder = Path("models")
configs_folder = Path("configs")
figs_folder = Path("figs")
output_folder = Path("output")
submissions_folder = Path("submissions")

test_df = pd.read_csv(data_folder / "test.csv")

## Single Model validation

In [39]:
config_filename = configs_folder / "baseline.yml"
checkpoint_filename = models_folder / "baseline-epoch-38-ckpt.pt"
config = load_config_from_ckpt(config_filename, checkpoint_filename)

model = config.model
model = tta.TTAWrapper(config.model, tta.fliplr_image2mask) 

validation(
    model,
    config.device,
    config.criterion,
    config.metrics,
    config.dataloaders.val,
)

Overwriting device = gpu (was gpu)
Overwriting non-existing attribute checkpoint.filename = models/baseline-epoch-38-ckpt.pt
Overwriting non-existing attribute checkpoint.filename = models/baseline-epoch-38-ckpt.pt
Overwriting non-existing attribute checkpoint.model = True
{'model': {'name': 'models.smp.Unet', 'params': {'encoder_name': 'efficientnet-b2', 'encoder_weights': 'imagenet', 'in_channels': 1, 'classes': 4, 'activation': None}}, 'device': 'gpu', 'criterion': {'name': 'L.JaccardLoss', 'params': {'mode': 'multiclass'}}, 'optimizer': {'name': 'torch.optim.Adam', 'params': {'lr': 0.001}}, 'scheduler': {'name': 'torch.optim.lr_scheduler.ReduceLROnPlateau', 'params': {'patience': 10, 'factor': 0.5, 'mode': 'max'}}, 'dataloaders': {'train': {'name': 'torch.utils.data.DataLoader', 'params': {'dataset': {'name': 'data.EyeDataset', 'params': {'df': {'name': 'utils.load_splits', 'params': {'folds_folder': {'path': 'folds'}, 'val_folds': [0], 'only_train': True}}, 'mode': 'train', 'trans

validation phase:   0%|          | 0/101 [00:00<?, ?it/s]

Load model from state dict


validation phase: 100%|██████████| 101/101 [00:11<00:00,  8.98it/s]


{'val_loss': 0.05054121527505751,
 'mean_iou': 0.9136595248582944,
 'mean_with_bg_iou': 0.9335828500600597,
 'sclera_iou': 0.8811745228458993,
 'iris_iou': 0.950562122449353,
 'pupil_iou': 0.9092418634476354}

## Ensemble validation

In [5]:
configs = [
    load_config_from_ckpt(
        configs_folder / "baseline.yml",
        models_folder / "baseline-epoch-38-ckpt.pt",
        verbose=False,
    ),
    load_config_from_ckpt(
        configs_folder / "baseline-crossentoryloss.yml",
        models_folder / "baseline-crossentoryloss-epoch-30-ckpt.pt",
        verbose=False,
    ),   
]

model = ModelEnsemble([cnfg.model for cnfg in configs])
model = tta.TTAWrapper(model, tta.fliplr_image2mask) 

validation(
    model,
    configs[0].device,
    configs[0].criterion,
    configs[0].metrics,
    configs[0].dataloaders.val,
)

Overwriting device = gpu (was gpu)
Overwriting non-existing attribute checkpoint.filename = models/baseline-epoch-38-ckpt.pt
Overwriting non-existing attribute checkpoint.filename = models/baseline-epoch-38-ckpt.pt
Overwriting non-existing attribute checkpoint.model = True
Load model from state dict
Overwriting device = gpu (was gpu)
Overwriting non-existing attribute checkpoint.filename = models/baseline-crossentoryloss-epoch-30-ckpt.pt
Overwriting non-existing attribute checkpoint.filename = models/baseline-crossentoryloss-epoch-30-ckpt.pt
Overwriting non-existing attribute checkpoint.model = True


validation phase:   0%|          | 0/101 [00:00<?, ?it/s]

Load model from state dict


validation phase: 100%|██████████| 101/101 [00:23<00:00,  4.24it/s]


{'val_loss': 0.051355925291331844,
 'mean_iou': 0.912806399426057,
 'mean_with_bg_iou': 0.9329317407821541,
 'sclera_iou': 0.8798373810094388,
 'iris_iou': 0.9499838897837928,
 'pupil_iou': 0.9085978637287273}