In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
cd ../src

/home/theo/kaggle/pulmonary_embolism/src


## Imports

In [3]:
import os
import torch
import warnings
import pydicom
import datetime
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from sklearn.model_selection import *

In [4]:
from params import *
from utils.logger import *
from data.dataset import *
from data.transforms import get_transfos

from model_zoo.models import define_model

## Data

In [5]:
df = pd.read_csv(DATA_PATH + "train.csv")

In [6]:
gkf = GroupKFold(n_splits=5)
splits = list(gkf.split(X=df, y=df, groups=df['StudyInstanceUID']))


fold_idx = np.zeros(len(df))
for i, (train_idx, val_idx) in enumerate(splits):
    fold_idx[val_idx] = i
df['fold'] = fold_idx

## Dataset

In [7]:
class PatientDataset(Dataset):
    """
    Dataset for feature extraction
    """
    def __init__(self, path, transforms=None): 
        self.path = path
        self.img_paths = sorted(os.listdir(path))
        
        self.transforms = transforms

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.path + self.img_paths[idx])

        if self.transforms:
            image = self.transforms(image=image)["image"]

        return image, idx

In [8]:
jpg_path = IMG_PATH + "6897fa9de148/2bfbb7fd2e8b/"

In [9]:
transforms = get_transfos(augment=False)
# transforms = None

In [10]:
dataset = PatientDataset(jpg_path, transforms=transforms)

In [11]:
if transforms is None:
    plt.subplot(1,2,1)
    plt.imshow(dataset_jpg[0][0])
    plt.subplot(1,2,2)
    plt.imshow(dataset_jpg[1][0])

In [12]:
paths = [IMG_PATH + f"{study}/{series}/" for study, series in df[['StudyInstanceUID', 'SeriesInstanceUID']].values]
df['path'] = paths
unique_df = df[['path', 'StudyInstanceUID', 'SeriesInstanceUID', 'fold']].drop_duplicates()

## Model

In [13]:
CP_PATH = "../logs/weights/"

In [14]:
from model_zoo.models import define_model

In [15]:
from utils.torch_utils import load_model_weights

In [19]:
weights = [f for f in sorted(os.listdir(CP_PATH)) if "efficientnet" in f]

In [27]:
models = []

for weight in weights:
    model = define_model('efficientnet-b3').cuda()
    model = load_model_weights(model, CP_PATH + weight)
    models.append(model)
    
# models = []

# for weight in weights:
#     model = define_model('resnext50_32x4d').cuda()
#     model = load_model_weights(model, CP_PATH + weight)
#     models.append(model)

Loaded pretrained weights for efficientnet-b3

 -> Loading weights from ../logs/weights/efficientnet-b3__0.pt

Loaded pretrained weights for efficientnet-b3

 -> Loading weights from ../logs/weights/efficientnet-b3__1.pt

Loaded pretrained weights for efficientnet-b3

 -> Loading weights from ../logs/weights/efficientnet-b3__2.pt

Loaded pretrained weights for efficientnet-b3

 -> Loading weights from ../logs/weights/efficientnet-b3__3.pt

Loaded pretrained weights for efficientnet-b3

 -> Loading weights from ../logs/weights/efficientnet-b3__4.pt



## Extract fts

In [28]:
from torch.utils.data import DataLoader

def extract_features(model, dataset, batch_size=4):
    model.eval()
    fts = np.empty((0, model.nb_ft))
    preds = np.empty(0)

    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=NUM_WORKERS, drop_last=False
    )
    
    with torch.no_grad():
        for x, _ in loader:
            y, ft = model.extract_ft(x.cuda())
            fts = np.concatenate([fts, ft.detach().cpu().numpy()])
            preds = np.concatenate([preds, torch.sigmoid(y).detach().cpu().numpy()])

    return preds, fts

In [29]:
SAVE_PATH = FEATURES_PATH + "b3/"

In [30]:
for path, study, series, fold in tqdm(unique_df.values):
    dataset = PatientDataset(path, transforms)
    
    preds, features = extract_features(models[int(fold)], dataset, batch_size=32)
    
    np.save(f"{SAVE_PATH}/features_{'_'.join(path.split('/')[-3:-1])}.npy" , features)
    np.save(f"{SAVE_PATH}/preds_{'_'.join(path.split('/')[-3:-1])}.npy" , preds)

#     break

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=7279.0), HTML(value='')))


