## handle parameter use argparser object

In [None]:
import argparse
def get_args():
    parser =argparse.ArgumentParser() #get ArgumentParser object
    parser.add_argument('--exp', type=str)
    parser.add_argument('--lr', type=int)
    ...
    args = parser.parser_args()# handle the parser as an object with argument attribute
    return args

#use the args object, FEP, --exp parameter point to a json config file
...
opts = json.load(open(args.exp, 'r')) # get the --exp argument value by '.' op like args.exp
    
    

## load model

In [None]:
## pytorch加载模型的两种方式
# way 1
saved_stated_dict = torch.load('model_path')
model_dict = model.state_dict() #get the model_dic object
model_dict = model_dict.update(saved_stated_dict) #update the net model parameter dict
model.load_state_dict(model_dict) #load the updated state dict
# way 2
saved_state_dict = torch.load(args['--snapshots'])
model.load_state_dict(saved_state_dict)
## load all parameter to CPU or GPU
torch.load('model_path', map_location=lambda:storage, loc: storage) # 到cpu
torch.load('model_path', map_location=lambda:storage, loc: storage.cuda(1)) #onto GPU 1
torch.load('model_path', map_location={'cuda:1':'cuda:0'}) #from GPU 1 to GPU 2
## another way to load model
with open('model_.pth') as f:
    buffer = io.BytesIO(f.read()) #convert readed stream to Bytes buffer
torch.load(buffer)

##  save checkpoint

In [None]:
def save_checkpoint(params):
    save_sate={
        'epoch':epoch,
        'global_step':self.global_step,
        'state_dict':self.model.state_dict(),
        'optimizer':self.optimizer.state_dict(), #state_dict for optimizer
        'lr_decay':self.lr_decay.state_dict() #state_dict for lr_decay 
    }
    save_name = os.path.join(self.opts['exp_dir'], 'checkpoints', 'epoch%d_step%d.pth'%(epoch, self.global_step))
    torch.save(save_state, save_name) #save model
    

# pytorch 的数据准备和预处理(to pytorch data format)

## 基础知识
DataProvider.py  作用是将一张张图像和GT处理成torch能够处理的 [original_image.tensor, gt.tensor], 然后只需在train.py中导入即可.
1. torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sample=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
注意4个参数:
    batch_size:批处理数目; shuffle:是否每个epoch都打乱； num_workers: 采用几线程来load数据; collate_fn: 将sample到的数据组织成mini-batch（ merges a list of samples to form a mini-batch） #reference: https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader 
    
2. 需要的库：
    import torch.utils.data #子类化自己的数据dataset
    import torch 
    from torchvision import transform #对数据做预处理，比如旋转，高斯模糊等Data Augmentation
    
3.子类化(class: 属性+方法 )自己的数据：
    class DataProvider(torch.utils.data.DataLoader):
   
4.添加必要的override函数
    def __init__(self): #初始化定义的属性
    
    def __len__(self): #提供数据集大小
    
    def __getitem__(self): # 提供下标index索引

## DataProvider template 模板

In [None]:
import torch.utils.data as data
import torch
from torchvision import transforms

class DataProvider(torch.utils.data.Dataset): #子类化
    def __init__(self, img_root, transform=None, train=True):
        self.img_root=img_root
        self.train=train
    
    def __getitem__(self, idx): #装在图片数据， 返回[image, GT], idx：索引一张一张的图片
        img=imread(img_path) #读取数据
        img=torch.from_numpy(img).float() #将ndarray convert to tensor, must be float
        ...
        return img, gt
    
    def __len__(self):
        return len(self.imagenumber) #返回加载的数据长度


##  具体例子

In [None]:
import torch.utils.data as data # for subclass of DataSet
from scipy.ndimage import imread # to read RGB image; from scipy import misc, or misc.imread() / misc.imsave()
import torch
import os
import glob # for search the files in directory
from torchvision import transforms # for data augmentation op

## 读取自己的数据集，返回[original_img_path, GT_img_path] pair for training
## if one need transform before return image, gt pair
## just add '''img = transforms.ToTensor() gt = transforms.ToTensor()''' code before return.
def handle_dataset_train(root_path, train=True):
    dataset = []
    if train:
        img_path=os.path.join(root_path, 'train/image/')
        gt_path=os.path.join(root_path, 'train/gt/')
        
        for imgGT in glob.glob(os.path.join(gt_path, '*.jpg')): #regular expression search
            image_name = os.path.basename(imgGT) #get the image name
            
            dataset.append([os.path.join(img_path, image_name), os.path.join(gt_path, image_name)]) #add the absolute image or GT image path
    return dataset

class DataProvider(data.DataSet): #subclassing
    ## initialize attribute func
    def __init__(self, img_root, transform=None, Train=True): #for getting each img_gt pair for train dataset. 
        self.img_root = img_root
        self.transfor=transform
        self.train=train
        if self.train:
            self.train_set_pair = handle_dataset_train(img_root, train):
                
    ## getitem func for indexing
    def __getitem__(self, idx):
        img_path, gt_path = self.train_set_pair[idx]
        
        #for image
        image = imread(img_path)# from scipy import misc; misc.imread()
        image = np.atleast_3d(image).transpose(2,0,1).astype(npg.float32)
        image = (image-image.min())/(image.max()-image.min()) # normalize op
        image = torch.from_numpy(image).float32 #convert to torch tensor
        
        #for gt
        gt = imread(img_path) 
        gt = atleast_3d(gt).transpose(2, 0, 1).astype(np.float32)
        gt = gt/255.0 # for gray GT image normalize op
        gt = torch.from_numpy(gt).float32 
        
        return image, gt
    
    ## return length func
    def __len__(self):
        return len(self.train_set_pair)

## Pytorch 的并行

In [None]:
    import multiprocessing.dummy as multiprocessing  ##python 线程池
    ## 数据的多进程处理
    def read_dataset(self):
        data_list = glob.glob(osp.join(self.data_dir, '*/*.json'))
        data_list = [[d, self.opts] for d in data_list]
        pool = multiprocessing.Pool(self.opts['num_workers'])  #设置进程数量
        data = pool.map(process_info, data_list) #Ordered results using pool.map(); process_info 为函数名，data_list为数据
        pool.close()
        pool.join()
    