In [1]:
from fastai.vision.all import *

In [2]:
pd.options.display.max_columns = 100

In [3]:
datapath = Path("/../rsna_data/")
train_df = pd.read_csv(datapath/'train.csv')

In [4]:
train_df.head()

Unnamed: 0,StudyInstanceUID,SeriesInstanceUID,SOPInstanceUID,pe_present_on_image,negative_exam_for_pe,qa_motion,qa_contrast,flow_artifact,rv_lv_ratio_gte_1,rv_lv_ratio_lt_1,leftsided_pe,chronic_pe,true_filling_defect_not_pe,rightsided_pe,acute_and_chronic_pe,central_pe,indeterminate
0,6897fa9de148,2bfbb7fd2e8b,c0f3cb036d06,0,0,0,0,0,0,1,1,0,0,1,0,0,0
1,6897fa9de148,2bfbb7fd2e8b,f57ffd3883b6,0,0,0,0,0,0,1,1,0,0,1,0,0,0
2,6897fa9de148,2bfbb7fd2e8b,41220fda34a3,0,0,0,0,0,0,1,1,0,0,1,0,0,0
3,6897fa9de148,2bfbb7fd2e8b,13b685b4b14f,0,0,0,0,0,0,1,1,0,0,1,0,0,0
4,6897fa9de148,2bfbb7fd2e8b,be0b7524ffb4,0,0,0,0,0,0,1,1,0,0,1,0,0,0


In [5]:
do_cv = True
FOLD = 1

if do_cv: 
    cv_pids_dir = (datapath/'cv_pids')
    if not cv_pids_dir.exists(): cv_pids_dir.mkdir()
    cv_df = train_df[['StudyInstanceUID', 'negative_exam_for_pe']].drop_duplicates().reset_index(drop=True)
    all_pids = cv_df['StudyInstanceUID'].values
    
    valid_pids = pd.read_pickle(datapath/f'cv_pids/pids_fold{FOLD}.pkl')
    train_pids = list(set(all_pids).difference(valid_pids))

In [6]:
len(train_pids), len(valid_pids), len(train_pids+valid_pids)

(5823, 1456, 7279)

In [7]:
metadata_path = datapath/'metadata'
train_metadata_paths = [o for o in metadata_path.ls() if o.stem in train_pids]
valid_metadata_paths = [o for o in metadata_path.ls() if o.stem in valid_pids]

In [8]:
len(train_metadata_paths), len(valid_metadata_paths)

(5823, 1456)

### Get files

In [9]:
labels_dict = dict(zip(train_df['SOPInstanceUID'], train_df['pe_present_on_image']))

In [10]:
len(labels_dict), len(train_df)

(1790594, 1790594)

In [11]:
imgdatapath = (datapath/'full_raw_512')

In [12]:
files = get_image_files(imgdatapath)

In [13]:
train_files = [o for o in files if o.parent.name in train_pids]
valid_files = [o for o in files if o.parent.name in valid_pids]

In [14]:
len(train_files), len(valid_files)

(1431401, 359193)

In [15]:
files = train_files + valid_files

In [16]:
len(files)

1790594

### Load Model

In [17]:
def get_label(o): return labels_dict[o.stem.split("_")[1]]

In [18]:
resize = 512

In [19]:
cnn_learner = load_learner(f"./models/xresnet34-{resize}-PR-fold{FOLD}-export.pkl", cpu=False)

In [20]:
valid_dl = cnn_learner.dls.test_dl(train_files+valid_files, with_labels=True)

### Generate Embeddings

In [21]:
embs_dir = Path(datapath/f"cnn_embs")
if not embs_dir.exists(): embs_dir.mkdir()

In [22]:
len(valid_dl.dataset)

1790594

In [23]:
class EmbeddingHook:
    def __init__(self, m, subdir, csz=5000000, n_init=0):
        store_attr("m,subdir,csz")
        self.embeddings = tensor([])
        if len(m._forward_hooks) > 0: self.reset()
        self.hook = Hook(m, self.hook_fn, cpu=True)
        self.save_iter = n_init        
        if not (embs_dir/self.subdir).exists(): (embs_dir/self.subdir).mkdir()
    
    def hook_fn(self, m, inp, out): 
        "Stack and save computed embeddings"
        self.embeddings = torch.cat([self.embeddings, out])
        if self.embeddings.shape[0] > self.csz:
            self.save()
            self.embeddings = tensor([])
    
    def reset(self): 
        self.m._forward_hooks = OrderedDict()
        
    def save(self): 
        torch.save(self.embeddings, embs_dir/self.subdir/f"train_embs-{self.save_iter}.pkl")
        self.save_iter += 1

In [24]:
emb_hook = EmbeddingHook(cnn_learner.model[1][1], subdir=f"full_sz{resize}_FOLD{FOLD}", n_init=0)

In [25]:
cnn_learner.model[1][1]._forward_hooks, emb_hook.embeddings

(OrderedDict([(0,
               <bound method Hook.hook_fn of <fastai.callback.hook.Hook object at 0x7f53c74315d0>>)]),
 tensor([]))

In [None]:
preds, targs = cnn_learner.get_preds(dl=valid_dl, act=noop)

### Save

In [None]:
torch.save(emb_hook.embeddings, embs_dir/f"full_sz{resize}_FOLD{FOLD}"/"embeddings.pkl")
pd.to_pickle(train_files+valid_files, embs_dir/f"full_sz{resize}_FOLD{FOLD}"/"files.pkl")