In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import random
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast

import cv2

import albumentations
from albumentations.pytorch import ToTensorV2

from sklearn.metrics import confusion_matrix, roc_auc_score, average_precision_score

from tqdm import tqdm
import argparse
import os, sys, yaml

sys.path.append('/workspace/siim-rsna-2021')
from src.logger import setup_logger, LOGGER
from src.meter import mAPMeter, AUCMeter, APMeter, AverageValueMeter
from src.utils import plot_sample_images
from src.segloss import SymmetricLovaszLoss


# import neptune.new as neptune
import wandb
import pydicom

import time
from contextlib import contextmanager

import timm

import warnings

target_columns = [
    "Negative for Pneumonia", "Typical Appearance", "Indeterminate Appearance", "Atypical Appearance", "is_none"
]


@contextmanager
def timer(name):
    t0 = time.time()
    yield
    print(f'[{name}] done in {time.time() - t0:.0f} s')

In [2]:
from exp.exp414.train import Net as Net414
from exp.exp415.train import Net as Net415
from exp.exp416.train import Net as Net416
from exp.exp417.train import Net as Net417

from exp.exp418.train import Net as Net418
from exp.exp419.train import Net as Net419
from exp.exp420.train import Net as Net420

from exp.exp520.train import Net as Net520
from exp.exp551.train import Net as Net551
from exp.exp552.train import Net as Net552
from exp.exp553.train import Net as Net553

from exp.exp604.train import Net as Net604
from exp.exp605.train import Net as Net605
from exp.exp606.train import Net as Net606

In [3]:
class CustomDataset(Dataset):
    def __init__(self,
                 df,
                 image_size=512,
                 transform=None,
                 ):
        self.df = df
        self.image_size = image_size
        self.transform = transform
        self.cols = target_columns

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]
        images = cv2.imread(row.npy_path)

        # original image size
        original_h = images.shape[0]
        original_w = images.shape[1]
        images = cv2.resize(images, (512, 512))

        if self.transform is not None:
            aug = self.transform(image=images)
            images_only = aug['image'].astype(np.float32).transpose(2, 0, 1) / 255
        return {
            "image": torch.tensor(images_only, dtype=torch.float),
        }


In [4]:
def get_val_transforms(image_size=512):
    return albumentations.Compose([
        albumentations.Resize(image_size, image_size),
])

In [5]:
study_sub_list = [

    # prediction set, b6 mask
    [
        # backbone
        Net414("tf_efficientnetv2_m_in21k"),
        "tf_efficientnetv2_m_in21k",
        # img_size
        512,
        # weight list
        [
            "/workspace/output/exp414/model/cv0_weight_checkpoint_best.pth",
            "/workspace/output/exp414/model/cv1_weight_checkpoint_best.pth",
            "/workspace/output/exp414/model/cv2_weight_checkpoint_best.pth",
            "/workspace/output/exp414/model/cv3_weight_checkpoint_best.pth",
            "/workspace/output/exp414/model/cv4_weight_checkpoint_best.pth",
        ],
        "exp414_hflip"
    ],
    
    # prediction set, b6 mask
    [
        # backbone
        Net415("tf_efficientnetv2_m_in21k"),
        "tf_efficientnetv2_m_in21k",
        # img_size
        512,
        # weight list
        [
            "/workspace/output/exp415/model/cv0_weight_checkpoint_best.pth",
            "/workspace/output/exp415/model/cv1_weight_checkpoint_best.pth",
            "/workspace/output/exp415/model/cv2_weight_checkpoint_best.pth",
            "/workspace/output/exp415/model/cv3_weight_checkpoint_best.pth",
            "/workspace/output/exp415/model/cv4_weight_checkpoint_best.pth",
        ],
        "exp415_hflip"
    ],

    # prediction set, b7 mask
    [
        # backbone
        Net416("tf_efficientnetv2_m_in21k"),
        "tf_efficientnetv2_m_in21k",
        # img_size
        512,
        # weight list
        [
            "/workspace/output/exp416/model/cv0_weight_checkpoint_best.pth",
            "/workspace/output/exp416/model/cv1_weight_checkpoint_best.pth",
            "/workspace/output/exp416/model/cv2_weight_checkpoint_best.pth",
            "/workspace/output/exp416/model/cv3_weight_checkpoint_best.pth",
            "/workspace/output/exp416/model/cv4_weight_checkpoint_best.pth",
        ],
        "exp416_hflip"
    ],
    
    # prediction set, b6 mask
    [
        # backbone
        Net417("tf_efficientnetv2_m_in21k"),
        "tf_efficientnetv2_m_in21k",
        # img_size
        512,
        # weight list
        [
            "/workspace/output/exp417/model/cv0_weight_checkpoint_best.pth",
            "/workspace/output/exp417/model/cv1_weight_checkpoint_best.pth",
            "/workspace/output/exp417/model/cv2_weight_checkpoint_best.pth",
            "/workspace/output/exp417/model/cv3_weight_checkpoint_best.pth",
            "/workspace/output/exp417/model/cv4_weight_checkpoint_best.pth",
        ],
        "exp417_hflip"
    ],

    # ===========================================
    # Eff v2 L
    # ===========================================


    # prediction set, b7 map
    [
        # backbone
        Net419("tf_efficientnetv2_l_in21k"),
        "tf_efficientnetv2_l_in21k",
        # img_size
        512,
        # weight list
        [
            "/workspace/output/exp419/model/cv0_weight_checkpoint_best.pth",
            "/workspace/output/exp419/model/cv1_weight_checkpoint_best.pth",
            "/workspace/output/exp419/model/cv2_weight_checkpoint_best.pth",
            "/workspace/output/exp419/model/cv3_weight_checkpoint_best.pth",
            "/workspace/output/exp419/model/cv4_weight_checkpoint_best.pth",
        ],
        "exp419_hflip"
    ],

    # prediction set, b6, b7
    [
        # backbone
        Net420("tf_efficientnetv2_l_in21k"),
        "tf_efficientnetv2_l_in21k",
        # img_size
        512,
        # weight list
        [
            "/workspace/output/exp420/model/cv0_weight_checkpoint_best.pth",
            "/workspace/output/exp420/model/cv1_weight_checkpoint_best.pth",
            "/workspace/output/exp420/model/cv2_weight_checkpoint_best.pth",
            "/workspace/output/exp420/model/cv3_weight_checkpoint_best.pth",
            "/workspace/output/exp420/model/cv4_weight_checkpoint_best.pth",
        ],
        "exp420_hflip"
    ],
    
    # prediction set, b6, b7
    [
        # backbone
        Net420("tf_efficientnetv2_l_in21k"),
        "tf_efficientnetv2_l_in21k",
        # img_size
        512,
        # weight list
        [
            "/workspace/output/exp520/model/cv0_weight_checkpoint_best.pth",
            "/workspace/output/exp520/model/cv1_weight_checkpoint_best.pth",
            "/workspace/output/exp520/model/cv2_weight_checkpoint_best.pth",
            "/workspace/output/exp520/model/cv3_weight_checkpoint_best.pth",
            "/workspace/output/exp520/model/cv4_weight_checkpoint_best.pth",
        ],
        "exp520_hflip"
    ],
    
    # prediction set, b6
    [
        # backbone
        Net418("tf_efficientnetv2_l_in21k"),
        "tf_efficientnetv2_l_in21k",
        # img_size
        512,
        # weight list
        [
            "/workspace/output/exp551/model/cv0_weight_checkpoint_best.pth",
            "/workspace/output/exp551/model/cv1_weight_checkpoint_best.pth",
            "/workspace/output/exp551/model/cv2_weight_checkpoint_best.pth",
            "/workspace/output/exp551/model/cv3_weight_checkpoint_best.pth",
            "/workspace/output/exp551/model/cv4_weight_checkpoint_best.pth",
        ],
        "exp551_hflip"
    ],
    
    # prediction set, b6
    [
        # backbone
        Net552("tf_efficientnetv2_l_in21k"),
        "tf_efficientnetv2_l_in21k",
        # img_size
        512,
        # weight list
        [
            "/workspace/output/exp552/model/cv0_weight_checkpoint_best.pth",
            "/workspace/output/exp552/model/cv1_weight_checkpoint_best.pth",
            "/workspace/output/exp552/model/cv2_weight_checkpoint_best.pth",
            "/workspace/output/exp552/model/cv3_weight_checkpoint_best.pth",
            "/workspace/output/exp552/model/cv4_weight_checkpoint_best.pth",
        ],
        "exp552_hflip"
    ],
    
    # prediction set, b6
    [
        # backbone
        Net553("tf_efficientnetv2_l_in21k"),
        "tf_efficientnetv2_l_in21k",
        # img_size
        512,
        # weight list
        [
            "/workspace/output/exp553/model/cv0_weight_checkpoint_best.pth",
            "/workspace/output/exp553/model/cv1_weight_checkpoint_best.pth",
            "/workspace/output/exp553/model/cv2_weight_checkpoint_best.pth",
            "/workspace/output/exp553/model/cv3_weight_checkpoint_best.pth",
            "/workspace/output/exp553/model/cv4_weight_checkpoint_best.pth",
        ],
        "exp553_hflip"
    ],

    # ===========================================
    # Swin transformer
    # ===========================================

    # prediction set
    [
        # backbone
        Net604("swin_base_patch4_window12_384"),
        "swin_base_patch4_window12_384",
        # img_size
        384,
        # weight list
        [
            "/workspace/output/exp604/model/cv0_weight_checkpoint_best.pth",
            "/workspace/output/exp604/model/cv1_weight_checkpoint_best.pth",
            "/workspace/output/exp604/model/cv2_weight_checkpoint_best.pth",
            "/workspace/output/exp604/model/cv3_weight_checkpoint_best.pth",
            "/workspace/output/exp604/model/cv4_weight_checkpoint_best.pth",
        ],
        "exp604_hflip"
    ],
    
    # prediction set
    [
        # backbone
        Net605("swin_base_patch4_window12_384"),
        "swin_base_patch4_window12_384",
        # img_size
        384,
        # weight list
        [
            "/workspace/output/exp605/model/cv0_weight_checkpoint_best.pth",
            "/workspace/output/exp605/model/cv1_weight_checkpoint_best.pth",
            "/workspace/output/exp605/model/cv2_weight_checkpoint_best.pth",
            "/workspace/output/exp605/model/cv3_weight_checkpoint_best.pth",
            "/workspace/output/exp605/model/cv4_weight_checkpoint_best.pth",
        ],
        "exp605_hflip"
    ],
    
    # prediction set
    [
        # backbone
        Net606("swin_base_patch4_window12_384"),
        "swin_base_patch4_window12_384",
        # img_size
        384,
        # weight list
        [
            "/workspace/output/exp606/model/cv0_weight_checkpoint_best.pth",
            "/workspace/output/exp606/model/cv1_weight_checkpoint_best.pth",
            "/workspace/output/exp606/model/cv2_weight_checkpoint_best.pth",
            "/workspace/output/exp606/model/cv3_weight_checkpoint_best.pth",
            "/workspace/output/exp606/model/cv4_weight_checkpoint_best.pth",
        ],
        "exp606_hflip"
    ],
    
    

]

In [6]:
device = "cuda:1"
hflip = True

In [7]:
df_original = pd.read_csv("/workspace/data/df_train_study_level_npy640_3_w_bbox.csv")
df = df_original.groupby('image_id').first().reset_index()

In [9]:
from copy import deepcopy

# key: image size, value: model
model_dict = {}

image_size_list = []

for model_set in tqdm(study_sub_list):
    # 画像サイズをkey, modelのリストをvalueにする
    # keyがまだない場合はからのリストを登録
    model_dict.setdefault(model_set[2], [])
    model_list = []
    exp_name = model_set[-1]
    for cv, ckpt in tqdm(enumerate(model_set[3])):
        model = model_set[0].to(device)
        weight = torch.load(ckpt, map_location=device)
        model.load_state_dict(weight["state_dict"])
        model.eval()
        
        df_val = df[df.cv == cv].reset_index(drop=True)
        
        dataset = CustomDataset(df=df_val, transform=get_val_transforms(model_set[2]))
        test_loader = DataLoader(
            dataset,
            shuffle=False,
            batch_size=32,
            num_workers=0,
            pin_memory=True,
        )
        
        pred_list1 = []
        for i, image in enumerate(test_loader):
            pred_list2 = []
            pred_mask2 = []
            image = image["image"].to(device)
            with torch.no_grad():
                preds, *_ = model(image)
                preds = preds.cpu().detach()
                pred_list2.append(preds.sigmoid())
                
                if hflip:
                    preds, *_ = model(image.flip(-1))
                    pred_list2.append(preds.cpu().detach().sigmoid())

                # average prediction
                pred_list1.append(torch.stack(pred_list2, 0).mean(0))

        preds = torch.cat(pred_list1).numpy()
        print(f"preds.shape: {preds.shape}")
        np.save(f'/workspace/output/oof/{exp_name}_cv{cv}.npy', preds)

  0%|          | 0/13 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
1it [01:57, 117.82s/it][A

preds.shape: (1223, 5)



2it [03:53, 117.31s/it][A

preds.shape: (1220, 5)



3it [05:51, 117.28s/it][A

preds.shape: (1221, 5)



4it [07:54, 119.14s/it][A

preds.shape: (1228, 5)



5it [09:56, 119.30s/it][A
  8%|▊         | 1/13 [09:56<1:59:17, 596.49s/it]
0it [00:00, ?it/s][A

preds.shape: (1225, 5)



1it [02:02, 122.40s/it][A

preds.shape: (1223, 5)



2it [03:56, 119.77s/it][A

preds.shape: (1220, 5)



3it [05:49, 117.94s/it][A

preds.shape: (1221, 5)



4it [07:46, 117.74s/it][A

preds.shape: (1228, 5)



5it [09:47, 117.51s/it][A
 15%|█▌        | 2/13 [19:44<1:48:51, 593.80s/it]
0it [00:00, ?it/s][A

preds.shape: (1225, 5)



1it [01:55, 115.67s/it][A

preds.shape: (1223, 5)



2it [03:51, 115.76s/it][A

preds.shape: (1220, 5)



3it [05:48, 116.04s/it][A

preds.shape: (1221, 5)



4it [07:49, 117.72s/it][A

preds.shape: (1228, 5)



5it [09:50, 118.13s/it][A
 23%|██▎       | 3/13 [29:34<1:38:48, 592.86s/it]
0it [00:00, ?it/s][A

preds.shape: (1225, 5)



1it [01:57, 117.12s/it][A

preds.shape: (1223, 5)



2it [03:58, 118.42s/it][A

preds.shape: (1220, 5)



3it [05:56, 118.16s/it][A

preds.shape: (1221, 5)



4it [07:55, 118.41s/it][A

preds.shape: (1228, 5)



5it [09:53, 118.80s/it][A
 31%|███       | 4/13 [39:28<1:28:58, 593.20s/it]
0it [00:00, ?it/s][A

preds.shape: (1225, 5)



1it [02:17, 137.82s/it][A

preds.shape: (1223, 5)



2it [04:41, 139.57s/it][A

preds.shape: (1220, 5)



3it [06:58, 138.69s/it][A

preds.shape: (1221, 5)



4it [09:18, 139.29s/it][A

preds.shape: (1228, 5)



5it [11:43, 140.67s/it][A
 38%|███▊      | 5/13 [51:11<1:23:29, 626.24s/it]
0it [00:00, ?it/s][A

preds.shape: (1225, 5)



1it [02:21, 141.52s/it][A

preds.shape: (1223, 5)



2it [04:45, 142.14s/it][A

preds.shape: (1220, 5)



3it [07:08, 142.52s/it][A

preds.shape: (1221, 5)



4it [09:33, 143.31s/it][A

preds.shape: (1228, 5)



5it [11:54, 142.81s/it][A
 46%|████▌     | 6/13 [1:03:06<1:16:08, 652.58s/it]
0it [00:00, ?it/s][A

preds.shape: (1225, 5)



1it [02:21, 141.20s/it][A

preds.shape: (1223, 5)



2it [04:39, 140.35s/it][A

preds.shape: (1220, 5)



3it [06:58, 140.03s/it][A

preds.shape: (1221, 5)



4it [09:19, 140.14s/it][A

preds.shape: (1228, 5)



5it [11:41, 140.23s/it][A
 54%|█████▍    | 7/13 [1:14:47<1:06:42, 667.15s/it]
0it [00:00, ?it/s][A

preds.shape: (1225, 5)



1it [02:20, 140.98s/it][A

preds.shape: (1223, 5)



2it [04:37, 139.75s/it][A

preds.shape: (1220, 5)



3it [06:55, 139.01s/it][A

preds.shape: (1221, 5)



4it [09:16, 139.56s/it][A

preds.shape: (1228, 5)



5it [11:42, 140.43s/it][A
 62%|██████▏   | 8/13 [1:26:29<56:28, 677.64s/it]  
0it [00:00, ?it/s][A

preds.shape: (1225, 5)



1it [02:18, 138.82s/it][A

preds.shape: (1223, 5)



2it [04:41, 140.03s/it][A

preds.shape: (1220, 5)



3it [07:00, 139.74s/it][A

preds.shape: (1221, 5)



4it [09:20, 139.73s/it][A

preds.shape: (1228, 5)



5it [11:41, 140.32s/it][A
 69%|██████▉   | 9/13 [1:38:10<45:39, 684.83s/it]
0it [00:00, ?it/s][A

preds.shape: (1225, 5)



1it [02:18, 138.81s/it][A

preds.shape: (1223, 5)



2it [04:38, 139.08s/it][A

preds.shape: (1220, 5)



3it [06:57, 138.93s/it][A

preds.shape: (1221, 5)



4it [09:20, 140.22s/it][A

preds.shape: (1228, 5)



5it [11:43, 140.64s/it][A
 77%|███████▋  | 10/13 [1:49:54<34:31, 690.34s/it]
0it [00:00, ?it/s][A

preds.shape: (1225, 5)



1it [02:06, 126.12s/it][A

preds.shape: (1223, 5)



2it [04:12, 126.12s/it][A

preds.shape: (1220, 5)



3it [06:18, 126.17s/it][A

preds.shape: (1221, 5)



4it [08:29, 127.74s/it][A

preds.shape: (1228, 5)



5it [10:40, 128.16s/it][A
 85%|████████▍ | 11/13 [2:00:34<22:30, 675.49s/it]
0it [00:00, ?it/s][A

preds.shape: (1225, 5)



1it [02:10, 130.93s/it][A

preds.shape: (1223, 5)



2it [04:19, 130.29s/it][A

preds.shape: (1220, 5)



3it [06:25, 128.88s/it][A

preds.shape: (1221, 5)



4it [08:34, 128.84s/it][A

preds.shape: (1228, 5)



5it [10:46, 129.31s/it][A
 92%|█████████▏| 12/13 [2:11:21<11:06, 666.80s/it]
0it [00:00, ?it/s][A

preds.shape: (1225, 5)


0it [00:00, ?it/s]
 92%|█████████▏| 12/13 [2:11:22<10:56, 656.84s/it]


RuntimeError: Error(s) in loading state_dict for Net:
	Unexpected key(s) in state_dict: "mask1.0.weight", "mask1.0.bias", "mask1.2.weight", "mask1.2.bias", "mask1.4.weight", "mask1.4.bias", "mask2.0.weight", "mask2.0.bias", "mask2.2.weight", "mask2.2.bias", "mask2.4.weight", "mask2.4.bias". 