In [33]:
import sys
sys.path.append('..')

import json
import torch
from pathlib import Path
import numpy as np
from sklearn.metrics import confusion_matrix

#import pytorch_lightning as pl
from multiprocessing import cpu_count
from libraries.lightningDMBACNN import ImageFolderLightningDataModule, WhoiDataModule, ZooscanDataModule, LenslessDataModule, ZooLakeDataModule#from libraries.helper_functions import class_histogram, sample_images, imshow
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import seaborn as sns
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassRecall, MulticlassPrecision

#torch.cuda.empty_cache()


NAME_OF_EXPERIMENT = 'test_hierar_Zooscan_v6_storage'
#'test_hierar_Zooscan_v4'#'test_hierar_Zooscan_v5_storage'

WORKING_DIR = Path.cwd()
DATA_DIR =  Path.cwd().parent / 'data/Zooscan_hierarchical_imagefolder'

EXPERIMENT_DIR = WORKING_DIR / f'experiments_folder/{NAME_OF_EXPERIMENT}'
MODELS_DIR = EXPERIMENT_DIR / 'models'

RESULTS_DIR = EXPERIMENT_DIR / 'results'

MEAN_STD_PATH = Path.cwd() / 'libraries/saved_mean_std/Zooscan_mean_std_224'


In [44]:
NUM_CLASSES = [[2, 13, 93], [2, 4, 9, 12, 81]]

In [45]:
datamodule = ImageFolderLightningDataModule(
                                                        data_dir=DATA_DIR,
                                                        mean_std_path = MEAN_STD_PATH,
                                                        image_size = 224,
                                                        batch_size = 8,
                                                        split_seed = 42,
                                                        num_workers=cpu_count(), 
                                                        sampler = True,                                             # Use True for Weighted Sampler or False to not use any
                                                        pin_memory = True,   # Default is False
                                                    )

datamodule.setup()

num_classes = datamodule.num_classes

In [46]:
datamodule.dataset.classes[:5]

['IRRELEVANT___ARTEFACT___artefact',
 'IRRELEVANT___ARTEFACT___badfocus__Copepoda',
 'IRRELEVANT___ARTEFACT___badfocus__artefact',
 'IRRELEVANT___ARTEFACT___bubble',
 'IRRELEVANT___DETRITUS___detritus']

# Transform 5 results in 3 results

In [52]:
def calculate_final_distrib(c_pred_binary, c_pred_1_1, c_pred_1_2, c_pred_2_1, c_pred_2_2, NUM_CLASSES):
    NUM_CLASSES_5 = NUM_CLASSES[1]
    NUM_CLASSES_3 = NUM_CLASSES[0]
    softmax = torch.nn.Softmax(dim=-1)
    length = c_pred_binary.size()[0]
    
    # first layer
    c_pred_1 = softmax(c_pred_binary.float())
   
    # second layer
    c_1_1 = torch.argmax(c_pred_1_1.float(), dim=1)
    c_2_1 = torch.argmax(c_pred_2_1.float(), dim=1)
    c_1_2 = torch.argmax(c_pred_1_2.float(), dim=1)
    c_2_2 = torch.argmax(c_pred_2_2.float(), dim=1)
    
    c_pred_2 = torch.zeros(length)
    c_pred_3 = torch.zeros(length)
    for i in range(length):
        if c_pred_1[i, 0] >= 0.5:
            c_pred_2[i] = c_1_1[i]
            c_pred_3[i] = c_1_2[i]
        else:#if c_pred_1[i, 0] < 0.5:
            c_pred_2[i] = c_2_1[i] + NUM_CLASSES_5[1]
            c_pred_3[i] = c_2_2[i] + NUM_CLASSES_5[3]
        
    c_pred_1 = torch.argmax(c_pred_1, dim=1)
    
    return c_pred_1, c_pred_2, c_pred_3

In [53]:
def update_true_labels(c_true, NUM_CLASSES):
    length = len(c_true)
    c_true_binary = c_true[:, 0]
    c_true_1_1 = c_true[:, 1]
    c_true_1_2 = c_true[:, 2]
    c_true_2_1 = c_true[:, 3]
    c_true_2_2 = c_true[:, 4]
    
    
    c_true_test = torch.Tensor([[c_true_binary[i], c_true_1_1[i], c_true_1_2[i]] if c_true_binary[i]<=0.5 else [c_true_binary[i], c_true_2_1[i] + NUM_CLASSES[1][1], c_true_2_2[i] + NUM_CLASSES[1][3]] for i in range(length)])
    
    return c_true_test

In [54]:
c_pred_binary = torch.load(RESULTS_DIR/'c_pred_binary.pt', map_location=torch.device('cpu'))
c_pred_1_1 = torch.load(RESULTS_DIR/'c_pred_1_1.pt', map_location=torch.device('cpu'))
c_pred_1_2 = torch.load(RESULTS_DIR/'c_pred_1_2.pt', map_location=torch.device('cpu'))
c_pred_2_1 = torch.load(RESULTS_DIR/'c_pred_2_1.pt', map_location=torch.device('cpu'))
c_pred_2_2 = torch.load(RESULTS_DIR/'c_pred_2_2.pt', map_location=torch.device('cpu'))

c_true = torch.load(RESULTS_DIR/'c_true_test.pt', map_location=torch.device('cpu'))
torch.save(c_true, RESULTS_DIR/'c_true_5.pt')

In [55]:
c_pred_1, c_pred_2, c_pred_3 = calculate_final_distrib(c_pred_binary, c_pred_1_1, c_pred_1_2, c_pred_2_1, c_pred_2_2, NUM_CLASSES)

c_true_test = update_true_labels(c_true, NUM_CLASSES)

In [56]:
torch.save(c_pred_1, RESULTS_DIR/'c_pred_1_test.pt')
torch.save(c_pred_2, RESULTS_DIR/'c_pred_2_test.pt')
torch.save(c_pred_3, RESULTS_DIR/'c_pred_3_test.pt')

torch.save(c_true_test, RESULTS_DIR/'c_true_test.pt')