In [16]:
import os
import torch
import numpy as np
import pandas as pd
from CellClass import CNN 
import matplotlib.pyplot as plt
import CellClass.CNN.dataset as D
import CellClass.CNN.training as T
from torch.utils.data import DataLoader

In [17]:
save_dir = "/home/simon_g/src/MICCAI/trained_models"

setups = os.listdir(save_dir)

models = []
for setup in setups:
    files = [os.path.join(save_dir, setup, x) for x in os.listdir(os.path.join(save_dir, setup))]
    models.extend([x for x in files if ".pt" in x])
    
losses = []
for model in models:
    dict_mod = torch.load(model)
    losses.append({"model": model, "loss": dict_mod["validation_loss"]})

In [18]:
model_df = pd.DataFrame(losses)

In [19]:
sorted = model_df.sort_values("loss")
best = sorted.iloc[0,:]
print(best)

model    /home/simon_g/src/MICCAI/trained_models/202206...
loss                                                0.0002
Name: 3, dtype: object


In [20]:
best_setup = best.model.split("/")[-2]
with open(os.path.join(save_dir, best_setup, "training_parameters.txt"), "r") as fin:
    lines = fin.readlines()
    lines = [x.strip().split(":") for x in lines]
    lines = np.array(lines)
    
setup_df = pd.DataFrame(lines.T[1,:], index=lines.T[0,:], columns=[0]).T
print(setup_df.down_steps)

0     2
Name: down_steps, dtype: object


In [21]:
best = torch.load(best.model, map_location="cpu")
model = CNN.ClassificationCNN(int(setup_df.down_steps.item()))
model.load_state_dict(best["model_state_dict"])

<All keys matched successfully>

In [22]:
patches_dir = "/data_isilon_main/isilon_images/10_MetaSystems/MetaSystemsData/MYCN_SpikeIn/results/patches"
dilutions = np.unique([x.split("_")[0] for x in os.listdir(patches_dir)])
print(dilutions)

['S11' 'S12' 'S19' 'S1b' 'S2' 'S20' 'S29' 'S3' 'S30' 'S31' 'S32' 'S33'
 'S34' 'S4' 'S6' 'S7' 'S8']


In [23]:
dils = []
for dilution in dilutions: 
    files = [x for x in os.listdir(patches_dir) if f"{dilution}_" in x]
    dils.append({"series": dilution, "files": files})
    print(dilution, len(files))

S11 113
S12 123
S19 597
S1b 40
S2 45
S20 116
S29 597
S3 37
S30 101
S31 20
S32 116
S33 94
S34 96
S4 28
S6 30
S7 36
S8 43


In [24]:
for dil in dilutions:
    dilution = D.load_dilution(patches_dir, f"{dil}_", n=10, verbose=False)
    dilution_loader = DataLoader(dilution, 128)
    ims, labels, percentage = T.predict_dilution(model, dilution_loader)
    
    # fig, axs = plt.subplots(5, 5, figsize=(10,10))
    # for ax, im, l in zip(axs.ravel(), ims[:25], labels[:25]):
    #     ax.imshow(im)
    #     ax.set_title(l)
    #     ax.set_xticks([])
    #     ax.set_yticks([])
    # plt.show()
        
    print(f"{dil} [{len(dilution)} Patches] with predicted dilution of: {percentage}")
    

S11 [3607 Patches] with predicted dilution of: 90.49
S12 [1975 Patches] with predicted dilution of: 1.06
S19 [1245 Patches] with predicted dilution of: 99.04
S1b [2551 Patches] with predicted dilution of: 80.24


KeyboardInterrupt: 