In [2]:
import pylibjpeg

In [15]:
import glob
import os

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from timm import create_model
from tqdm.notebook import tqdm
from joblib import Parallel, delayed
import cv2
import dicomsdl
from PIL import Image
import torchvision
import evaluate


pd.set_option('display.max_rows', 1000)
pd.set_option('display.max_columns', 1000)
plt.rcParams['figure.figsize'] = (20, 5)



RSNA_2022_PATH = '/scratch/eecs545w23_class_root/eecs545w23_class/yngmkim'
PNG_TEST_IMAGES_PATH = '/scratch/eecs545w23_class_root/eecs545w23_class/yngmkim/png_bcd_roi_1024x_split'

###########################################################
# specify the directory that your best model file in here #
###########################################################
#MODELS_PATH = '/home/yngmkim/EECS545/codes/0418_org_model'
MODELS_PATH = '/scratch/eecs545w23_class_root/eecs545w23_class/yngmkim/vit_model'

#DCM_TEST_IMAGES_PATH = f'/kaggle/input/rsna-breast-cancer-detection/test_images'

AUX_TARGET_NCLASSES = [2, 2, 6, 2, 2, 2, 4, 5, 2, 10, 10]
#AUX_TARGET_NCLASSES = [2, 2, 5, 2, 2, 2, 4, 5, 2, 10, 10]

MODEL_MEAN = (0.485, 0.456, 0.406)
MODEL_STD = (0.229, 0.224, 0.225)

# seresnext50_32x4d
# x,y = (1024,512)
# mean=0.2179, std=0.0529

# deit3_base_patch16_384
# deit3_small_patch16_384_in21ft1k
#  'mean': (0.485, 0.456, 0.406),
#  'std': (0.229, 0.224, 0.225),

# eva02_large_patch14_448.mim_m38m_ft_in22k_in1k
# mean=(0.48145466, 0.4578275, 0.40821073)
# std=(0.26862954, 0.26130258, 0.27577711)

DEBUG = True
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
if DEVICE == 'cuda':
    BATCH_SIZE = 16
else:
    BATCH_SIZE = 36
     
    
class CFG:
    resize_dim = 1024
    aspect_ratio = True
    img_size = [1024, 512]

In [16]:
def load_df_test():
    df_test = pd.read_csv(f'{RSNA_2022_PATH}/test_data_split.csv')
    return df_test

df_test = load_df_test()

df_test

Unnamed: 0,site_id,patient_id,image_id,laterality,view,age,cancer,biopsy,invasive,BIRADS,implant,density,machine_id,difficult_negative_case
0,2,10050,588678397,L,MLO,67.0,0,0,0,,0,,29,False
1,2,10050,1749389520,L,CC,67.0,0,0,0,,0,,29,False
2,2,10050,1428987847,R,MLO,67.0,0,0,0,,0,,29,False
3,2,10050,1614607569,R,CC,67.0,0,0,0,,0,,29,False
4,1,10097,664674273,L,MLO,44.0,0,0,0,1.0,0,C,49,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10896,2,9954,846263495,R,CC,48.0,0,0,0,,0,,48,False
10897,2,9965,69676487,L,MLO,67.0,0,0,0,,0,,21,False
10898,2,9965,1523807232,L,CC,67.0,0,0,0,,0,,21,False
10899,2,9965,1229053771,R,CC,67.0,0,0,0,,0,,21,False


### Preprocess images ROI

In [17]:
def get_transforms():

    def transforms(img):
        img = img.convert('RGB')#.resize((512, 512))
        tfm = [torchvision.transforms.Resize((384, 384))]
        img = torchvision.transforms.Compose(tfm + [            
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=MODEL_MEAN, std=MODEL_STD),
            # mean=0.2179, std=0.0529
        ])(img)
        return img

    return lambda img: transforms(img)


In [18]:
class BreastCancerDataSet(torch.utils.data.Dataset):
    def __init__(self, df, path, transforms=None):
        super().__init__()
        self.df = df
        self.path = path
        self.transforms = transforms

    def __getitem__(self, i):

        path = f'{self.path}/test_images/{self.df.iloc[i].patient_id}/{self.df.iloc[i].image_id}.png'
        try:
            img = Image.open(path).convert('RGB')
        except Exception as ex:
            print(path, ex)
            return None

        if self.transforms is not None:
            img = self.transforms(img)
        
        #if TARGET in self.df.columns:
        cancer_target = torch.as_tensor(self.df.iloc[i].cancer)
        img_id = torch.as_tensor(self.df.iloc[i].image_id)
        #cat_aux_targets = torch.as_tensor(self.df.iloc[i][CATEGORY_AUX_TARGETS])
        return img, cancer_target, img_id

        #return img

    def __len__(self):
        return len(self.df)

ds_test = BreastCancerDataSet(df_test, PNG_TEST_IMAGES_PATH, get_transforms())

In [19]:
class BreastCancerModel(torch.nn.Module):
    def __init__(self, aux_classes, model_type='deit3_small_patch16_384_in21ft1k', dropout=0.):
        super().__init__()
        self.model = create_model(model_type, pretrained=False, num_classes=0, drop_rate=dropout)

        self.backbone_dim = self.model(torch.randn(1, 3, 384, 384)).shape[-1]

        self.nn_cancer = torch.nn.Sequential(
            torch.nn.Linear(self.backbone_dim, 1),
        )
        self.nn_aux = torch.nn.ModuleList([
            torch.nn.Linear(self.backbone_dim, n) for n in aux_classes
        ])

    def forward(self, x):
        # returns logits
        x = self.model(x)

        cancer = self.nn_cancer(x).squeeze()
        aux = []
        for nn in self.nn_aux:
            aux.append(nn(x).squeeze())
        return cancer, aux

    def predict(self, x):
        cancer, aux = self.forward(x)
        sigaux = []
        for a in aux:
            sigaux.append(torch.softmax(a, dim=-1))
        return torch.sigmoid(cancer), sigaux

In [20]:
def load_model(name, dir='.', model=None):
    data = torch.load(os.path.join(dir, f'{name}'), map_location=DEVICE)
    if model is None:
        model = BreastCancerModel(AUX_TARGET_NCLASSES, data['model_type'])
    model.load_state_dict(data['model'])
    # print(data['threshold'], data['model_type'])
    return model, data['threshold'], data['model_type']


In [21]:
##########################################
# specify your best model file name here #
##########################################
model_file_name = 'nawat-model-f0'
model, thres, model_type = load_model(model_file_name, MODELS_PATH)
model = model.to(DEVICE)
model_thres = (model, thres)

In [22]:
def models_predict(model_thres, ds, max_batches=1e9):
    dl_test = torch.utils.data.DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=os.cpu_count())
    model, thres= model_thres
    
    model.eval()

    with torch.no_grad():
        predictions = []
        ground_truth = []
        img_ids = []
        for idx, (X, target, img_id) in enumerate(tqdm(dl_test, mininterval=30)):
            #pred = torch.zeros(len(X))
            preds = model.predict(X.to(DEVICE))[0].squeeze().cpu().numpy()
            #print("pred: ", preds)
            pred_out = (preds > thres).astype(float)
            #print("pred_out: ", pred_out)
            #pred[:, idx] = preds.cpu()
            predictions.extend(pred_out)
            ground_truth.extend(target)
            img_ids.extend(img_id)
            #print(predictions)
            if idx >= max_batches:
                break
        #print(predictions)
        return predictions, ground_truth, img_ids

In [23]:
models_pred, ground_truth, img_ids = models_predict(model_thres, ds_test)



  0%|          | 0/682 [00:00<?, ?it/s]

In [51]:
#models_pred.count(1.0)

169

In [24]:
#all_references = df_test['cancer'].values
all_references = ground_truth
all_predictions = models_pred

In [25]:
metric_f1 = evaluate.load("f1")
metric_acc = evaluate.load("accuracy")
metric_pre = evaluate.load("precision")
metric_rec = evaluate.load("recall")

In [26]:
score_acc = metric_acc.compute(references=all_references, predictions=all_predictions)
score_f1 = metric_f1.compute(references=all_references, predictions=all_predictions)
score_pre = metric_pre.compute(references=all_references, predictions=all_predictions)
score_rec = metric_rec.compute(references=all_references, predictions=all_predictions)

In [27]:
score_acc, score_f1 , score_pre, score_rec

({'accuracy': 0.9214750940280708},
 {'f1': 0.06345733041575492},
 {'precision': 0.0424597364568082},
 {'recall': 0.12554112554112554})

In [34]:
wel_pred_img_ids = []
wrong_pred_img_ids = []
for refer, predic, img_id in zip(all_references, all_predictions, img_ids):
    if (refer == 1) and (predic == 1):
        wel_pred_img_ids.append(int(img_id))
    if refer != predic:
        wrong_pred_img_ids.append(int(img_id))

In [35]:
len(wel_pred_img_ids), len(wrong_pred_img_ids)

(29, 856)

In [36]:
wel_pred_img_ids

[1998384452,
 1060299310,
 1313318609,
 349554513,
 1600425758,
 511030638,
 104106029,
 1833201492,
 1507275789,
 1925841866,
 1593856707,
 406764921,
 601361926,
 802662148,
 1732965874,
 1919147280,
 1524699538,
 1900457997,
 1632383378,
 774155798,
 728779079,
 297086361,
 385806898,
 115057946,
 1213747044,
 592525150,
 1880660195,
 1559227517,
 18384498]

In [52]:
#######################################################
# predict using another model!! 
# for finding well predicted image over many model
######################################################
MODELS_PATH = '/home/yngmkim/EECS545/codes/0418_org_model'

In [53]:
def get_transforms():

    def transforms(img):
        img = img.convert('RGB')#.resize((512, 512))
        tfm = [torchvision.transforms.Resize((1024, 512))]
        img = torchvision.transforms.Compose(tfm + [            
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=0.2179, std=0.0529),
            # mean=0.2179, std=0.0529
        ])(img)
        return img

    return lambda img: transforms(img)


In [44]:
class BreastCancerModel(torch.nn.Module):
    def __init__(self, aux_classes, model_type='seresnext50_32x4d', dropout=0.):
        super().__init__()
        self.model = create_model(model_type, pretrained=False, num_classes=0, drop_rate=dropout)

        self.backbone_dim = self.model(torch.randn(1, 3, 512, 512)).shape[-1]

        self.nn_cancer = torch.nn.Sequential(
            torch.nn.Linear(self.backbone_dim, 1),
        )
        self.nn_aux = torch.nn.ModuleList([
            torch.nn.Linear(self.backbone_dim, n) for n in aux_classes
        ])

    def forward(self, x):
        # returns logits
        x = self.model(x)

        cancer = self.nn_cancer(x).squeeze()
        aux = []
        for nn in self.nn_aux:
            aux.append(nn(x).squeeze())
        return cancer, aux

    def predict(self, x):
        cancer, aux = self.forward(x)
        sigaux = []
        for a in aux:
            sigaux.append(torch.softmax(a, dim=-1))
        return torch.sigmoid(cancer), sigaux

In [45]:
ds_test = BreastCancerDataSet(df_test, PNG_TEST_IMAGES_PATH, get_transforms())

In [46]:
model_file_name = 'model-f4'
model, thres, model_type = load_model(model_file_name, MODELS_PATH)
model = model.to(DEVICE)
model_thres = (model, thres)

In [47]:
models_pred, ground_truth, img_ids = models_predict(model_thres, ds_test)
all_references = ground_truth
all_predictions = models_pred
score_acc = metric_acc.compute(references=all_references, predictions=all_predictions)
score_f1 = metric_f1.compute(references=all_references, predictions=all_predictions)
score_pre = metric_pre.compute(references=all_references, predictions=all_predictions)
score_rec = metric_rec.compute(references=all_references, predictions=all_predictions)
score_acc, score_f1 , score_pre, score_rec



  0%|          | 0/682 [00:00<?, ?it/s]

({'accuracy': 0.9731217319511971},
 {'f1': 0.3138173302107728},
 {'precision': 0.34183673469387754},
 {'recall': 0.29004329004329005})

In [48]:
wel_pred_img_ids2 = []
wrong_pred_img_ids2 = []
for refer, predic, img_id in zip(all_references, all_predictions, img_ids):
    if (refer == 1) and (predic == 1):
        wel_pred_img_ids2.append(int(img_id))
    if refer != predic:
        wrong_pred_img_ids2.append(int(img_id))

In [49]:
len(wel_pred_img_ids2), len(wrong_pred_img_ids2)

(67, 293)

In [50]:
set(wel_pred_img_ids2).intersection(set(wel_pred_img_ids))

{104106029,
 115057946,
 349554513,
 406764921,
 1507275789,
 1524699538,
 1593856707,
 1833201492,
 1880660195,
 1900457997,
 1998384452}

In [None]:
len(set(wrong_pred_img_ids).intersection(set(wrong_pred_img_ids2)))

163