In [1]:
# import torch

# a = torch.randn((2,41,96,96))
# b = torch.randn((2,250,96,96))
# #
# c = torch.square(a-b)


In [2]:
# official modules
import argparse
import json
import os
import time
import numpy as np
from math import floor
from shutil import copy2
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import torch.distributed as dist
# self-defined module
from utils.helper import init_DDP, Logger, print_log, load_labels, build_model
from utils.data import dataloader
from utils.train_model import train_model
from utils.test_model import test_model

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
    parser = argparse.ArgumentParser(description='3d localization')
    # phase
    parser.add_argument('--train_or_test', type=str, default='other', help='train or test')
    parser.add_argument('--resume', action='store_true', default=False)
    parser.add_argument('--gpu_number', type=str, default=None, help='assign gpu')
    # data info
    parser.add_argument('--num_im', type=int, default=None, help='Number of samples used, train:val=9:1')
    parser.add_argument('--H', type=int, default=96, help='Height of image')
    parser.add_argument('--W', type=int, default=96, help='Width of image')
    parser.add_argument('--zmin', type=int, default=-20, help='min zeta')
    parser.add_argument('--zmax', type=int, default=20, help='max zeta')
    parser.add_argument('--clear_dist', type=int, default=1, help='safe margin for z axis')
    parser.add_argument('--D', type=int, default=250, help='num grid of zeta axis')
    parser.add_argument('--scaling_factor', type=int, default=800, help='entry value for existence of pts')
    parser.add_argument('--upsampling_factor', type=int, default=2, help='grid dim=[H,W]*upsampling_factor')
    # train info
    parser.add_argument('--model_use', type=str, default='LocNet')
    parser.add_argument('--postpro',  action='store_true', default=False, help='whether do post processing in dnn')
    parser.add_argument('--batch_size', type=int, default=1, help='when training on multi GPU, is the batch size on each GPU')
    parser.add_argument('--initial_learning_rate', type=float, default=None, help='initial learning rate for adam')
    parser.add_argument('--lr_decay_per_epoch', type=int, default=None, help='number of epochs learning rate decay')
    parser.add_argument('--lr_decay_factor', type=float, default=None, help='lr decay factor')
    parser.add_argument('--max_epoch', type=int,   default=None, help='number of training epoches')
    parser.add_argument('--save_epoch', type=int, default=None, help='save model per save_epoch')
    # test info
    parser.add_argument('--test_id_loc', type=str, default=None)
    # path
    parser.add_argument('--checkpoint_path', type=str,  default=None,  help='checkpoint to resume from')
    parser.add_argument('--data_path', type=str, default='/home/lingjia/Documents/microscope/Data/training_images_zrange20', help='path for train and val data')
    parser.add_argument('--save_path', type=str, default=None, help='path for save models and results')
    # output
    parser.add_argument('--name_time', type=str, default=None, help='string of running time')
    # for nonconvex loss
    parser.add_argument('--port', type=str, default=None, help='DDP master port')
    parser.add_argument('--weight', type=str, default=None, help='lambda CEL0')
    parser.add_argument('--extra_loss', type=str, default=None, help='indicate whether use cel0 for gaussian or nc for possion')
    # for extra losses
    parser.add_argument('--cel0_mu', type=float, default=None, help='mu in cel0 loss')
    parser.add_argument('--klnc_a', type=float, default=None, help='a for nonconvex loss in KLNC')
    parser.add_argument('--log_comment', type=str, default=None)
    
    # opt = parser.parse_args()
    opt,_=parser.parse_known_args()

In [4]:
opt.gpu_number='1'
opt.train_or_test='train'
opt.name_time="2022-99-99-99-99-99"
opt.num_im=10000
opt.H=96
opt.W=96
opt.zmin=-20
opt.zmax=20
opt.clear_dist=1
opt.D=250
opt.scaling_factor=800
opt.upsampling_factor=2
opt.model_use='LocNet'
opt.batch_size=16
opt.initial_learning_rate=1e-3
opt.lr_decay_per_epoch=7
opt.lr_decay_factor=0.5
opt.max_epoch=2
opt.save_epoch=10
opt.data_path='/media/hdd/lingjia/hdd_rpsf/20220917_nonconvex_loss/data/gaussian_10k_pt50L5'
opt.save_path='./temp'
opt.port='123789'
opt.weight='1_1_1_1'
opt.extra_loss='mse3d_cel0_klnc_forward'
opt.cel0_mu=1
opt.klnc_a=10
opt.rank = 0
opt.world_size = 1

rank = opt.rank

In [5]:
# calculate zoom ratio of z-axis
opt.pixel_size_axial = (opt.zmax - opt.zmin + 1 + 2*opt.clear_dist) / opt.D

# split dataset to train, validation 9:1
train_IDs = np.arange(1,floor(opt.num_im*0.9)+1,1).tolist()
val_IDs = np.arange(floor(opt.num_im*0.9)+1,opt.num_im+1).tolist()

opt.partition = {'train': train_IDs, 'valid': val_IDs}
opt.ntrain, opt.nval = len(train_IDs), len(val_IDs)

# output folder name
name_time = opt.name_time if opt.name_time else time.strftime('%Y-%m-%d-%H-%M-%S')
save_name = name_time + '-lr'+str(opt.initial_learning_rate) + \
    '-bs'+str(opt.batch_size) + \
    '-D'+str(opt.D) + \
    '-Ep'+str(opt.max_epoch) + \
    '-nT'+str(opt.ntrain)
if opt.extra_loss:
    save_name = save_name + '-w' + str(opt.weight) + '-' + str(opt.extra_loss)
if opt.cel0_mu:
    save_name = save_name + '-' + str(opt.cel0_mu)
if opt.klnc_a:
    save_name = save_name + '-' + str(opt.klnc_a)
save_name = save_name + '-' + str(opt.model_use)

if opt.resume:
    save_name = save_name + '-resume'
if opt.postpro:
    save_name = save_name + '-postpro'
opt.save_path = os.path.join(opt.save_path,save_name)
os.makedirs(opt.save_path, exist_ok=True)

if rank == 0:
    log = open(os.path.join(opt.save_path, '{}_log.txt'.format(time.strftime('%H-%M-%S'))), 'w')
    logger = Logger(os.path.join(opt.save_path, '{}_tensorboard'.format(time.strftime('%H-%M-%S'))))

    print_log('[INFO==>] setup_params:',log)
    for key,value in opt._get_kwargs():
        if not key == 'partition':
            print_log('{}: {}'.format(key,value),log)
    print_log(f'[INFO==>] Dataset: Train {len(train_IDs)} Val {len(val_IDs)}',log)

device = torch.device('cuda')
torch.backends.cudnn.benchmark = True

if opt.rank==0:
    # save setup parameters in result folder as well
    with open(os.path.join(opt.save_path,'setup_params.json'),'w') as handle:
        json.dump(opt.__dict__, handle, indent=2)

# Load labels and generate dataset
labels = load_labels(os.path.join(opt.data_path,'label.txt'))

# Parameters for dataloaders
params_train = {'batch_size': opt.batch_size, 'shuffle': True,  'partition': opt.partition['train']}
params_val = {'batch_size': opt.batch_size, 'shuffle': False, 'partition': opt.partition['valid']}

training_generator = dataloader(opt.data_path, labels, params_train, opt, num_workers=0)
validation_generator = dataloader(opt.data_path, labels, params_val, opt, num_workers=0)

# model
model = build_model(opt)
model.to(device)

[INFO==>] setup_params:
D: 250
H: 96
W: 96
batch_size: 16
cel0_mu: 1
checkpoint_path: None
clear_dist: 1
data_path: /media/hdd/lingjia/hdd_rpsf/20220917_nonconvex_loss/data/gaussian_10k_pt50L5
extra_loss: mse3d_cel0_klnc_forward
gpu_number: 1
initial_learning_rate: 0.001
klnc_a: 10
log_comment: None
lr_decay_factor: 0.5
lr_decay_per_epoch: 7
max_epoch: 2
model_use: LocNet
name_time: 2022-99-99-99-99-99
ntrain: 9000
num_im: 10000
nval: 1000
pixel_size_axial: 0.172
port: 123789
postpro: False
rank: 0
resume: False
save_epoch: 10
save_path: ./temp/2022-99-99-99-99-99-lr0.001-bs16-D250-Ep2-nT9000-w1_1_1_1-mse3d_cel0_klnc_forward-1-10-LocNet
scaling_factor: 800
test_id_loc: None
train_or_test: train
upsampling_factor: 2
weight: 1_1_1_1
world_size: 1
zmax: 20
zmin: -20
[INFO==>] Dataset: Train 9000 Val 1000


ResLocalizationCNN(
  (norm): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Conv2DLeakyReLUBN(
    (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (lrelu): LeakyReLU(negative_slope=0.2, inplace=True)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (layer2): ResConv2DLeakyReLUBN(
    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (lrelu): LeakyReLU(negative_slope=0.2, inplace=True)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (layer3): ResConv2DLeakyReLUBN(
    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
    (lrelu): LeakyReLU(negative_slope=0.2, inplace=True)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (layer4): ResConv2DLeakyReLUBN(
    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(

In [6]:

optimizer = Adam(list(model.parameters()), lr=opt.initial_learning_rate)

# opt.scheduler_type = 'ReduceLROnPlateau'
# if opt.scheduler_type == 'StepLR':
# scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_per_epoch, gamma=opt.lr_decay_factor)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=opt.lr_decay_factor, patience=opt.lr_decay_per_epoch)

if opt.rank == 0:
    # print_log(model, log)
    print_log("[INFO==>] Number of parameters: {}".format(sum(param.numel() for param in model.parameters())),log)

[INFO==>] Number of parameters: 1594162


In [7]:
# train_model(model,optimizer,scheduler,device,training_generator,validation_generator,log,logger,opt)
# official modules
import os
import time
from time import localtime, strftime
from math import ceil
import pickle5 as pickle
import numpy as np
from collections import defaultdict
import torch
from torch.cuda.amp import autocast, GradScaler
import torch.distributed as dist
# self-defined modules
from utils.loss import calculate_loss
from utils.helper import print_log, print_metrics, print_time, print_metric_format


In [8]:
learning_results = defaultdict(list)
max_epoch = opt.max_epoch

steps_train = ceil(opt.ntrain / opt.batch_size / opt.world_size)
steps_val = ceil(opt.nval / opt.batch_size / opt.world_size)
params_val = {'batch_size': opt.batch_size, 'shuffle': False}

# loss function
loss_type = ['loss']
extra_weight = []
if opt.extra_loss: # None or string
    extra_loss = opt.extra_loss.split('_')
    extra_weight = [float(n) for n in opt.weight.split('_')]
    loss_type = loss_type + extra_loss
    if not len(extra_loss) == len(extra_weight):
        raise Exception(f'Input {len(extra_loss)} weight with {len(extra_weight)} extra loss')
print_log(f'[INFO==>] Loss types: {loss_type}',log)
calc_loss = calculate_loss(opt,loss_type,extra_weight)

scaler = GradScaler()

[INFO==>] Loss types: ['loss', 'mse3d', 'cel0', 'klnc', 'forward']
(6, 2, 2)


In [9]:
# start from scratch
start_epoch, end_epoch = 0, max_epoch
learning_results = {'val_max': [], 'val_sum': [], 'steps_per_epoch': steps_train}
for loss in loss_type:
    learning_results['train_'+loss] = []
    learning_results['val_'+loss] = []
best_val_loss = float('Inf')

# starting time of training
train_start_time = time.time()
not_improve = 0

print_log(f'[INFO==>] Start training from {start_epoch} to {end_epoch} rank {opt.rank}\n',log)

epoch = 0
# starting time of current epoch
epoch_start_time = time.time()

if opt.rank == 0:
    print_log(f'Epoch {epoch+1}/{end_epoch} | {strftime("%Y-%m-%d %H:%M:%S", localtime())} | lr {optimizer.param_groups[0]["lr"]}', log, arrow=True)

# training phase
model.train()
metric, metrics = defaultdict(float), defaultdict(float)

[INFO==>] Start training from 0 to 2 rank 0

---> Epoch 1/2 | 2022-12-24 15:25:18 | lr 0.001


In [10]:
with torch.set_grad_enabled(True):
    batch_ind = 0
    (inputs, targets, target_ims, fileids) = next(iter(training_generator))
    print(f"{inputs.shape} \n{targets.shape} \n{target_ims.shape} \n{fileids}")
    # for batch_ind, (inputs, targets, target_ims, fileids) in enumerate(training_generator):

    inputs = inputs.to(device)
    targets = targets.to(device)
    target_ims = target_ims.to(device)

    optimizer.zero_grad()
    with autocast():
        outputs = model(inputs)
    print(f"Output shape:\n{outputs.shape}")
    loss = calc_loss(upgrid=outputs.float(), gt_upgrid=targets, gt_im=target_ims, metric=metric, metrics=metrics)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    if opt.rank==0:
        print(f'Epoch {epoch}/{end_epoch-1} Train {batch_ind}/{steps_train-1} MaxOut {outputs.max():.2f} {print_metric_format(metric)}')



torch.Size([16, 1, 96, 96]) 
torch.Size([16, 250, 192, 192]) 
torch.Size([16, 96, 96]) 
['45', '8188', '4690', '5863', '7076', '1660', '7640', '8977', '7346', '559', '4862', '5116', '5005', '3935', '6155', '6027']
Output shape:
torch.Size([16, 250, 192, 192])
upgrid:
torch.Size([16, 250, 192, 192])
spikes_pred:
torch.Size([16, 250, 96, 96])
pool info: (2, 2)
spikes_pred:
torch.Size([16, 250, 96, 96])
norm_ai2:
torch.Size([250, 96, 96])
abs_heat:
torch.Size([16, 250, 96, 96])
Epoch 0/1 Train 0/562 MaxOut 31.69 mse3d 55.7752  cel0 nan  klnc 4513019.0000  forward 752305984.0000  loss nan  


In [7]:
import numpy as np
a = np.nan
print(a)
print(a*0)
print(18*1+a*0)

nan
nan
nan
