In [1]:
%reload_ext autoreload
%autoreload 2
## sys package
import os, sys
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"  # specify which GPU(s) to be used
sys.path.append("./prediction_models/input/prostate-cancer-grade-assessment/")
## warning off
import warnings
warnings.filterwarnings("ignore")

## general package
import random
from tqdm import tqdm_notebook as tqdm
import numpy as np
import pandas as pd
import torch
from torch.utils.data import *
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
# print(device)

## customized package
from input.inputPipeline import *
from model.resnext_ssl import Model_Infer as Model

In [2]:
DATA = './input/prostate-cancer-grade-assessment/train_images'
TEST = './input/prostate-cancer-grade-assessment/train.csv'
SAMPLE = './input/prostate-cancer-grade-assessment/sample_submission.csv'

In [3]:
mean = torch.tensor([0.90949707, 0.8188697, 0.87795304])
std = torch.tensor([0.36357649, 0.49984502, 0.40477625])
tsfm = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean=mean,
                              std=std)])

In [4]:
models = []
weights = [f'./train/weights/Resnext50_reg_medreso_12patch/Resnext50_reg_medreso_12patch_{i}_best.pth.tar' for i in range(4)]
for path in weights:
#     state_dict = torch.load(path,map_location=torch.device('cpu'))
    state_dict = torch.load(path)
    model = Model(n = 1)
    model.load_state_dict(state_dict)
    model.float()
    model.eval()
    model.to(device)
    models.append(model)

del state_dict

In [5]:
sub_df = pd.read_csv(SAMPLE)
## if there are data in "test_images", (only happens when you submit your notebook, do inference)
if os.path.exists(DATA):
    sz = 256
    bs = 4
    dataset = PandaPatchDatasetInfer(TEST, DATA, transform=tsfm, sz = sz)
    dataloader = DataLoader(dataset, batch_size=bs,
                            shuffle=False, num_workers=0, collate_fn=dataloader_collte_fn_infer)
    names,preds = [],[] ## record image names and predictions
    ## Model inference
    with torch.no_grad():
        for idx, data in enumerate(tqdm(dataloader), start = 0):
            if idx > 50:
                break
            img, name = data
            img = img.float().to(device)
            bs,N,C,h,w = img.shape
#             print(bs, N, C, h, w)
            ## dihedral TTA
            img = torch.stack([img,img.flip(-1),img.flip(-2),img.flip(-1,-2),
                  img.transpose(-1,-2),img.transpose(-1,-2).flip(-1),
                  img.transpose(-1,-2).flip(-2),img.transpose(-1,-2).flip(-1,-2)],1)
            img = img.view(-1, N, C, h, w)
            p = [model(img) for model in models] # [4, bs * 8, 6]
            p = torch.stack(p,1) # [bs * 8, 4, 1]
            p = p.view(bs,8*len(models),-1) # [bs, 8(augmentation) * 4 (model), 1]
#             p = p.mean(1).argmax(-1).cpu() #[bs]
            p = p.mean(1).round().squeeze().to(torch.int64).clamp_(min=0, max=5).cpu() #[bs]
            names.append(name)
            preds.append(p)
        names = np.concatenate(names)
        preds = torch.cat(preds).numpy()
        sub_df = pd.DataFrame({'image_id': names, 'isup_grade': preds})

sub_df.to_csv('submission.csv', index=False)
sub_df.head()

HBox(children=(FloatProgress(value=0.0, max=2654.0), HTML(value='')))




Unnamed: 0,image_id,isup_grade
0,0005f7aaab2800f6170c399693a96917,0
1,000920ad0b612851f8e01bcc880d9b3d,0
2,0018ae58b01bdadc8e347995b69f99aa,5
3,001c62abd11fa4b57bf7a6c603a11bb9,4
4,001d865e65ef5d2579c190a0e0350d8f,0


In [6]:
sub_df['isup_grade'].dtype

dtype('int64')