In [1]:
from tqdm import tqdm
import importlib
from pathlib import Path
import pandas as pd

import numpy as np
from collections import defaultdict

from torch.utils.data import DataLoader
import albumentations as albu 
import torch

from dataset import PneumoDataset
from helpers import load_yaml, mask2rle

import os

%load_ext autoreload
%autoreload 2

In [2]:
def inference_model(model, loader, device):
    model.eval()
    with torch.no_grad():
        mask_dict = {}
        for image_ids, images in tqdm(loader):
            images = images.to(device)
            predicted = model(images)
            masks = torch.sigmoid(predicted) 
            masks = masks.squeeze(1).cpu().detach().numpy()
            for name, mask in zip(image_ids, masks):
                mask_dict[name] = mask.astype(np.float32)
    return mask_dict

def run_binarizer(mask_dict, binarizer_fn, result_path, device):
    used_thresholds = binarizer_fn.thresholds
    for name, mask in tqdm(mask_dict.items()):
        mask = torch.tensor(mask).unsqueeze(0).unsqueeze(0).to(torch.float32)
        mask = mask.to(device)

        mask_generator = binarizer_fn.transform(mask)
        for current_thr, current_mask in zip(used_thresholds, mask_generator):
            csv_name = os.path.join(result_path, f"{current_thr}.csv")
            current_mask = current_mask.squeeze(0).squeeze(0).cpu().detach().numpy()
            build_csv(name, current_mask, csv_name)

def build_csv(name, mask, out_path):
    rle_mask = mask2rle(mask)
    mask_df = pd.DataFrame({"ImageId": [name], "EncodedPixels": [rle_mask]})
    if os.path.exists(out_path):
        df = pd.read_csv(out_path)
        df = pd.concat([df, mask_df], ignore_index=True)
        df.to_csv(out_path, index=False)
    else:
        mask_df.to_csv(out_path, index=False)

In [3]:
experiment_folder = Path("experiments")
config_folder = experiment_folder / "configs" / "Inference.yaml"
inference_config = load_yaml(config_folder)
print(inference_config)

{'SEED': 42, 'NUM_WORKERS': 4, 'DEVICE': 'cuda', 'BATCH_SIZE': 2, 'MODEL': {'PY': 'model', 'CLASS': 'ResUNet', 'ARGS': {'pretrained': False}}, 'CHECKPOINTS': {'FULL_FOLDER': 'resunet_1024_3', 'PIPELINE_PATH': 'experiments/resunet', 'PIPELINE_NAME': 'resunet_1024'}, 'USEFOLDS': [0, 1, 2, 3, 4], 'MASK_BINARIZER': {'PY': 'binarizer', 'CLASS': 'TripletMaskBinarization', 'ARGS': {'triplets': [[0.6, 3000, 0.25], [0.7, 3000, 0.3], [0.7, 2000, 0.3]]}}, 'RESULT_PATH': 'submission'}


In [4]:
batch_size = inference_config['BATCH_SIZE']
device = inference_config['DEVICE']

module = importlib.import_module(inference_config['MODEL']['PY'])
model_class = getattr(module, inference_config['MODEL']['CLASS'])
model = model_class(**inference_config['MODEL'].get('ARGS', None)).to(device)
model.eval()

pipeline_path = Path(inference_config['CHECKPOINTS']['PIPELINE_PATH'])
pipeline_name = inference_config['CHECKPOINTS']['PIPELINE_NAME']
checkpoints_list = []
checkpoints_folder = Path(pipeline_path, inference_config['CHECKPOINTS']['FULL_FOLDER'])
usefolds = inference_config['USEFOLDS']
for fold_id in usefolds:
    filename = '{}_fold{}.pth'.format(pipeline_name, fold_id)
    checkpoints_list.append(Path(checkpoints_folder, filename))

binarizer_module = importlib.import_module(inference_config['MASK_BINARIZER']['PY'])
binarizer_class = getattr(binarizer_module, inference_config['MASK_BINARIZER']['CLASS'])
binarizer_fn = binarizer_class(**inference_config['MASK_BINARIZER']['ARGS'])

result_path = Path(experiment_folder, inference_config['RESULT_PATH'])
os.makedirs(result_path, exist_ok=True)

test_transform = albu.Compose([
    albu.Resize(1024, 1024, always_apply=True),
    albu.Normalize()
])

num_workers = inference_config['NUM_WORKERS']

test_names = np.load("data/4test_imgs_npy/test_imgs_names.npy") + ".png"
fold_labels = np.load("data/4test_imgs_npy/fold_labels_test.npy")

In [5]:
for fold_id in range(np.max(fold_labels) + 1):

    print(f"Fold {fold_id}")

    dataset = PneumoDataset(
        mode='test', 
        fold_index=fold_id,
        test_names=test_names,
        fold_labels=fold_labels,
        transform=test_transform,
    )
    dataloader = DataLoader(
        dataset=dataset, 
        batch_size=batch_size, 
        num_workers=num_workers, 
        shuffle=False
    )

    fold_size = len(dataset)
    print(f"Data amount: {fold_size}")

    mask_dict = defaultdict(int)
    for pred_idx, checkpoint_path in enumerate(checkpoints_list):
        print(f"Loaded {checkpoint_path}")
        model.load_state_dict(torch.load(checkpoint_path))
        current_mask_dict = inference_model(model, dataloader, device)
        for name, mask in tqdm(current_mask_dict.items()):
            mask_dict[name] = (mask_dict[name] * pred_idx + mask) / (pred_idx + 1)

    run_binarizer(mask_dict, binarizer_fn, result_path, device)
    del mask_dict

  model.load_state_dict(torch.load(checkpoint_path))


Fold 0
Data amount: 641
Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold0.pth


100%|██████████| 321/321 [01:27<00:00,  3.66it/s]
100%|██████████| 641/641 [00:01<00:00, 489.70it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold1.pth


100%|██████████| 321/321 [01:08<00:00,  4.66it/s]
100%|██████████| 641/641 [00:02<00:00, 272.67it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold2.pth


100%|██████████| 321/321 [01:07<00:00,  4.79it/s]
100%|██████████| 641/641 [00:02<00:00, 303.55it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold3.pth


100%|██████████| 321/321 [01:05<00:00,  4.92it/s]
100%|██████████| 641/641 [00:02<00:00, 311.10it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold4.pth


100%|██████████| 321/321 [01:06<00:00,  4.79it/s]
100%|██████████| 641/641 [00:02<00:00, 277.97it/s]
100%|██████████| 641/641 [00:39<00:00, 16.26it/s]
  model.load_state_dict(torch.load(checkpoint_path))


Fold 1
Data amount: 641
Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold0.pth


100%|██████████| 321/321 [01:06<00:00,  4.85it/s]
100%|██████████| 641/641 [00:01<00:00, 506.30it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold1.pth


100%|██████████| 321/321 [01:09<00:00,  4.65it/s]
100%|██████████| 641/641 [00:02<00:00, 281.25it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold2.pth


100%|██████████| 321/321 [01:07<00:00,  4.79it/s]
100%|██████████| 641/641 [00:02<00:00, 290.10it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold3.pth


100%|██████████| 321/321 [01:06<00:00,  4.80it/s]
100%|██████████| 641/641 [00:02<00:00, 294.27it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold4.pth


100%|██████████| 321/321 [01:05<00:00,  4.87it/s]
100%|██████████| 641/641 [00:02<00:00, 292.51it/s]
100%|██████████| 641/641 [00:52<00:00, 12.22it/s]
  model.load_state_dict(torch.load(checkpoint_path))


Fold 2
Data amount: 641
Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold0.pth


100%|██████████| 321/321 [01:05<00:00,  4.93it/s]
100%|██████████| 641/641 [00:01<00:00, 482.49it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold1.pth


100%|██████████| 321/321 [01:05<00:00,  4.88it/s]
100%|██████████| 641/641 [00:02<00:00, 302.34it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold2.pth


100%|██████████| 321/321 [01:06<00:00,  4.83it/s]
100%|██████████| 641/641 [00:02<00:00, 286.81it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold3.pth


100%|██████████| 321/321 [01:05<00:00,  4.88it/s]
100%|██████████| 641/641 [00:02<00:00, 307.93it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold4.pth


100%|██████████| 321/321 [01:05<00:00,  4.91it/s]
100%|██████████| 641/641 [00:02<00:00, 303.78it/s]
100%|██████████| 641/641 [00:56<00:00, 11.42it/s]
  model.load_state_dict(torch.load(checkpoint_path))


Fold 3
Data amount: 641
Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold0.pth


100%|██████████| 321/321 [01:05<00:00,  4.92it/s]
100%|██████████| 641/641 [00:01<00:00, 506.90it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold1.pth


100%|██████████| 321/321 [01:05<00:00,  4.91it/s]
100%|██████████| 641/641 [00:02<00:00, 230.84it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold2.pth


100%|██████████| 321/321 [01:05<00:00,  4.91it/s]
100%|██████████| 641/641 [00:02<00:00, 319.55it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold3.pth


100%|██████████| 321/321 [01:05<00:00,  4.92it/s]
100%|██████████| 641/641 [00:02<00:00, 296.26it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold4.pth


100%|██████████| 321/321 [01:05<00:00,  4.91it/s]
100%|██████████| 641/641 [00:02<00:00, 288.99it/s]
100%|██████████| 641/641 [00:59<00:00, 10.84it/s]
  model.load_state_dict(torch.load(checkpoint_path))


Fold 4
Data amount: 641
Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold0.pth


100%|██████████| 321/321 [01:05<00:00,  4.93it/s]
100%|██████████| 641/641 [00:01<00:00, 517.75it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold1.pth


100%|██████████| 321/321 [01:05<00:00,  4.91it/s]
100%|██████████| 641/641 [00:02<00:00, 320.45it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold2.pth


100%|██████████| 321/321 [01:05<00:00,  4.91it/s]
100%|██████████| 641/641 [00:02<00:00, 317.63it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold3.pth


100%|██████████| 321/321 [01:05<00:00,  4.93it/s]
100%|██████████| 641/641 [00:02<00:00, 319.30it/s]


Loaded experiments\resunet\resunet_1024_3\resunet_1024_fold4.pth


100%|██████████| 321/321 [01:05<00:00,  4.92it/s]
100%|██████████| 641/641 [00:01<00:00, 322.47it/s]
100%|██████████| 641/641 [01:06<00:00,  9.67it/s]
