In [1]:
import os
import sys
import torch
import random
import argparse
import numpy as np
import torch.nn as nn
from glob import glob
from multiprocessing import Pool
from torch.optim import optimizer
from collections import OrderedDict
from torch.utils.data import DataLoader
from scipy.optimize import linear_sum_assignment
from torch.utils.tensorboard import SummaryWriter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from src.common_utils.dataset_building import split_multi_individuals_datasets
from src.common_utils.prints import get_checkpoint_timestamp, print_log_message
from src.preprocessing.training.utils import extract_annos
from src.recognition.training.configs import e1 as cfg
from src.recognition.training.extract_training_data import extract_preprocessing_json, select_ids
from src.recognition.inference.feature_maker.load_features import make_one_volume_neuronal_features
from src.common_utils.prints import print_info_message
from src.common_utils.metric.rec import top_1_accuracy_score, top_k_accuracy, top_1_accuracy_score_torch
from src.recognition.training.extract_training_data import neurons2data
from src.recognition.training.exps.e1 import *
from src.recognition.inference.dataset import RecFeatureDataset
from src.recognition.inference.network import RecFuseNetworkLinear, RecMarginalCosLossNetwork

In [3]:
parser = argparse.ArgumentParser(description = '')
parser.add_argument('--name-reg', type = str, default = r"[iI]ma?ge?_?[sS]t(?:ac)?k_?\d+_dk?\d+.*[wW]\d+_?Dt\d{6}")
parser.add_argument('--load-preprocess-result-root', type = str, default = "data/dataset/proofreading")
parser.add_argument('--data-root', type = str, default = "")
parser.add_argument('--random-seed', type = int, default = 520)
parser.add_argument('--checkpoint-timestamp', type = str, default = f"{get_checkpoint_timestamp()}_{np.random.randint(100000)}")
parser.add_argument('--label-root', type = str, default = "data/dataset/label")
# neural points setting
parser.add_argument('--rec-tflag', default = 130, type = int)
parser.add_argument('--rec-fea-mode', default = 0, type = int)
parser.add_argument('--rec-others-class', default = 1, type = int)
parser.add_argument('--rec-xoy-unit', type = float, default = 0.3, help = "um/pixel")
parser.add_argument('--rec-z-unit', type = float, default = 1.5, help = "um/pixel")
parser.add_argument('--rec-worm-diagonal-line', type = float, default = 400.0)
# knn feature
parser.add_argument('--rec-knn-k', type = int, default = 25)
# neural density feature
parser.add_argument('--rec-des-len', type = int, default = 20)
# neuron recognition (train)
parser.add_argument('--rec-fp16', action = "store_true")
parser.add_argument('--rec-epoch', default = 300, type = int)
parser.add_argument('--rec-num-workers', default = 8, type = int)
parser.add_argument('--rec-batch-size', default = 256, type = int)
parser.add_argument('--rec-model-load-path', type = str, default = "")
parser.add_argument('--rec-model-save-path', type = str, default = "models/supp/e1")
parser.add_argument('--rec-shuffle', default = 1, type = int)

# embedding method
parser.add_argument('--rec-channel-base', type = int, default = 32)
parser.add_argument('--rec-group-base', type = int, default = 4)
parser.add_argument('--rec-len-embedding', type = int, default = 56)
parser.add_argument('--rec-hypersphere-radius', type = int, default = 32)
parser.add_argument('--rec-loss-coefficients', type = float, nargs = "+", default = [1.05, 0.0, 0.05])
# tensorboard
parser.add_argument('--rec-tensorboard-root', type = str, default = "tb_log/supp/e1")

args = parser.parse_args("")
args.rec_z_scale = args.rec_z_unit / args.rec_xoy_unit

random.seed(args.random_seed)
np.random.seed(args.random_seed)
torch.manual_seed(args.random_seed)

<torch._C.Generator at 0x7f344c09a0d0>

In [4]:
# ----- data preparation --------------------
# vols_ccords: {vol_name: [[xmin, ymin], mass_of_center, anterior_y, posterior_y, ventral_x, dorsal_x]}
vols_xymin, vols_ccords = extract_preprocessing_json(os.path.join(args.load_preprocess_result_root, "*/*.json"))

labels = [{k: v for file_name, idxes in idv_label for k, v in extract_annos(os.path.join(args.label_root, file_name), idxes, args.name_reg).items()}
          for idv_name, idv_label in cfg.dataset['animals']['label'].items()]
labels = [{k: {i: [[p[1] - vols_xymin[k][0], p[2] - vols_xymin[k][1], p[3] - vols_xymin[k][0], p[4] - vols_xymin[k][1], p[0]] for p in pp] for i, pp in vol.items()}
           for k, vol in idv_labels.items()} for idv_labels in labels]

In [5]:
# ==========  within dataset  ====================
t_idx = args.rec_tflag
val_idx = t_idx + 20
processing_ids = select_ids(labels[0], 225)[0]
within_labels, dataset_names = split_multi_individuals_datasets(labels = labels[:1], indexes = [[(0, t_idx), (t_idx, val_idx), (val_idx, len(labels[0]))]], shuffle_type = args.rec_shuffle)
# ----- samples making --------------------
within_engineering_feature = make_fea_multiprocess(within_labels, vols_ccords, args, mode = args.rec_fea_mode)
trains, vals, tests, num_ids, id_map, processing_ids = neurons2data(within_engineering_feature.copy(), dataset_names = dataset_names,
                                                                    include_others_class = args.rec_others_class, given_ids = sorted(processing_ids), verbose = False)
# ----- dataloader --------------------
train_dataset = RecFeatureDataset(Xs = trains[0], ys = trains[1], names = trains[2], is_train = True, is_fp16 = args.rec_fp16)
train_dataloader = DataLoader(train_dataset, batch_size = args.rec_batch_size, drop_last = True, shuffle = True, pin_memory = False, num_workers = 1)
val_dataset = RecFeatureDataset(Xs = vals[0], ys = vals[1], names = vals[2], is_train = False, is_fp16 = args.rec_fp16)
val_dataloader = DataLoader(val_dataset, batch_size = args.rec_batch_size, drop_last = False, shuffle = True, pin_memory = False, num_workers = 1)
within_test_dataset = RecFeatureDataset(tests[0], tests[1], tests[2], is_train = False, is_fp16 = args.rec_fp16)
within_test_dataloader = DataLoader(within_test_dataset, batch_size = args.rec_batch_size, drop_last = False, shuffle = True, pin_memory = False, num_workers = 1)

# ==========  across dataset  ====================
across_test_infos = [make_across_testset(make_fea_multiprocess(idv_label, vols_ccords, args)) for idv_label in labels[1:]]
across_test_datasets = [RecFeatureDataset(info[0], info[1], info[2], is_train = False, is_fp16 = args.rec_fp16) for info in across_test_infos]
across_test_vol_names = [info[3] for info in across_test_infos]
across_test_dataloaders = [DataLoader(testset, batch_size = args.rec_batch_size) for testset in across_test_datasets]

In [6]:
# ----- network --------------------
model = RecFuseNetworkLinear(input_dim = (args.rec_knn_k * 3, args.rec_des_len * 4), output_dim = args.rec_len_embedding, num_ids = num_ids,
                             channel_base = args.rec_channel_base, group_base = args.rec_group_base,
                             dropout_ratio = 0.2, activation_method = "celu").cuda()
model = model.half() if args.rec_fp16 else model
if os.path.isfile(args.rec_model_load_path):
    model.load_state_dict(torch.load(args.rec_model_load_path, map_location = 'cuda:0')['network'])
    
criterion = RecMarginalCosLossNetwork(len_embedding = args.rec_len_embedding, coefficients = args.rec_loss_coefficients, hypersphere_radius = args.rec_hypersphere_radius).cuda()
criterion = criterion.half() if args.rec_fp16 else criterion

In [7]:
train_val_procedure(model = model,
                    criterion = criterion,
                    train_dataloader = train_dataloader,
                    val_dataloader = val_dataloader,
                    within_test_dataloader = within_test_dataloader,
                    across_test_dataloaders = across_test_dataloaders,
                    across_test_vol_names = across_test_vol_names,
                    cfg = cfg,
                    test_vols = tests[3],
                    batch_size = args.rec_batch_size,
                    optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3),
                    is_fp16 = args.rec_fp16,
                    num_ids = num_ids,
                    processing_ids = processing_ids,
                    hr = args.rec_hypersphere_radius,
                    num_epochs = args.rec_epoch,
                    checkpoint_timestamp = args.checkpoint_timestamp,
                    tb_writer = SummaryWriter(os.path.join(args.rec_tensorboard_root, args.checkpoint_timestamp)),
                    model_save_path = os.path.join(args.rec_model_save_path, f"rec_{args.checkpoint_timestamp}.ckpt"),
                    )

2022-07-29 09:25:50 - [32m[1mINFO   [0m - ------------------------------------------- >>>>>>>>>>>>>>>>>>>>>>>>>>>> Training procedure starts! 

2022-07-29 09:35:29 - [32m[1mINFO   [0m - ------------------------------------------- >>>>>>>>>>>>>>>>>>>>>>>>>>>> Training procedure finished! 
The best val epoch of 2022_07_29_09_25_26_8335 model is 260 	 Top-1 accuracy: 87.68 % 	 
2022-07-29 09:35:30 - [32m[1mINFO   [0m - Within testset: accuracy: 95.48 	 num_hits: 150.67
2022-07-29 09:35:31 - [32m[1mINFO   [0m - Across testset: 2 animals 	 accuracy: 66.70 	 num_gts: 139.19 	 num_hits: 92.84


In [8]:
# ==========  evaluate benchmarks  ====================
from src.benchmarks.test2 import evaluate_benchmark
from src.benchmarks.datasets.CeNDeR import Dataset_CeNDeR

model.eval()
test_batch_size = 32  # keep the same with fDNC

# benchmark leifer 2017
dataset = Dataset_CeNDeR(glob(os.path.join(args.data_root, "data/benchmarks/supp/e1/test_tracking", "*.npy")), is_fp16 = args.rec_fp16, num_pool = args.rec_num_workers)
dataloader = DataLoader(dataset, args.rec_batch_size)
accs = evaluate_benchmark(dataloader, model, refer_idx = 16, verbose = False, test_batch_size = test_batch_size)
print_info_message(f"NeRVE: accuracy:  {accs[0] * 100:.2f} \t num_gts: {accs[2]:.2f} \t num_hits: {(accs[4]):.2f}")

# benchmark NeuroPAL
dataset = Dataset_CeNDeR(glob(os.path.join(args.data_root, "data/benchmarks/supp/e1/test_neuropal_our", "*.npy")), is_fp16 = args.rec_fp16, num_pool = args.rec_num_workers)
dataloader = DataLoader(dataset, args.rec_batch_size)
accs = np.array([evaluate_benchmark(dataloader, model, refer_idx = ref_idx, test_batch_size = test_batch_size) for ref_idx in range(dataset.num_vols)])
print_info_message(f"NeuroPAL Yu: accuracy: {np.mean(accs[:, 0] * 100):.2f} ± {np.std(accs[:, 0] * 100):.2f} \t num_gts: {np.mean(accs[:, 2]):.2f} ± {np.std(accs[:, 2]):.2f} \t "
                   f"num_hits: {np.mean(accs[:, 4]):.2f} ± {np.std(accs[:, 4]):.2f}")

# benchmark NeuroPAL Chaudhary
dataset = Dataset_CeNDeR(glob(os.path.join(args.data_root, "data/benchmarks/supp/e1/test_neuropal_Chaudhary", "*.npy")), is_fp16 = args.rec_fp16, num_pool = args.rec_num_workers)
dataloader = DataLoader(dataset, args.rec_batch_size)
accs = np.array([evaluate_benchmark(dataloader, model, refer_idx = ref_idx, test_batch_size = test_batch_size) for ref_idx in range(dataset.num_vols)])
print_info_message(f"NeuroPAL Chaudhary: accuracy: {np.mean(accs[:, 0] * 100):.2f} ± {np.std(accs[:, 0] * 100):.2f} \t num_gts: {np.mean(accs[:, 2]):.2f} ± {np.std(accs[:, 2]):.2f} \t "
                   f"num_hits: {np.mean(accs[:, 4]):.2f} ± {np.std(accs[:, 4]):.2f}")


2022-07-29 09:35:36 - [32m[1mINFO   [0m - NeRVE: accuracy:  50.08 	 num_gts: 61.41 	 num_hits: 30.76
2022-07-29 09:35:37 - [32m[1mINFO   [0m - NeuroPAL Yu: accuracy: 27.79 ± 1.11 	 num_gts: 38.81 ± 2.87 	 num_hits: 10.82 ± 1.32
2022-07-29 09:35:38 - [32m[1mINFO   [0m - NeuroPAL Chaudhary: accuracy: 42.97 ± 2.35 	 num_gts: 49.62 ± 0.71 	 num_hits: 21.31 ± 0.90
