In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data
import os
import cv2
import glob
import torch
from Models.LargePNet import *
import tifffile
from Utils.TrainerSTED import TrainerSTED
import matplotlib.pyplot as plt
from Utils.AppendLoad import *
from Utils.DataAug2D import *

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# Load Raw .npz data
Data_dir = r'F:\Datasets\STED_deblurring\MitoInner\Training'
loaded_data = np.load(Data_dir + '\MitoDeblurdata.npz')
X_train = loaded_data['X_train']  
y_train = loaded_data['y_train']
X_val = loaded_data['X_val']
y_val = loaded_data['y_val']
log_dir = Data_dir + r'\logfile'
if not os.path.exists(log_dir):
    os.mkdir(log_dir)
# Create training and validation data
X_train = torch.tensor(X_train, dtype = torch.float32).unsqueeze(1)
y_train = torch.tensor(y_train, dtype = torch.float32).unsqueeze(1)
X_val = torch.tensor(X_val, dtype = torch.float32).unsqueeze(1)
y_val = torch.tensor(y_val, dtype = torch.float32).unsqueeze(1)
print(X_train.shape)
print(y_train.shape)

''' Another loading method: load .tif image file
head_dir = r"D:\SISR_Full\CCP"
Training_GT_path = head_dir + '\\' + r'GT\*.tif'
Training_Raw_path = head_dir + '\\' + r'Noisy\*.tif'
Val_GT_path = head_dir + '\\' + r'ValGT\*.tif'
Val_Raw_path = head_dir + '\\' + r'ValNoisy\*.tif'

X_train = AppendLoad(Training_Raw_path)
y_train = AppendLoad(Training_GT_path)
X_val = AppendLoad(Val_Raw_path)
y_val = AppendLoad(Val_GT_path)

print("X_train.shape",X_train.shape)
print("y_train.shape",y_train.shape)
print("X_val.shape",X_val.shape)
print("y_val.shape",y_val.shape)

train_data = Data.TensorDataset(X_train,y_train)
val_data = Data.TensorDataset(X_val,y_val)
'''

torch.Size([288, 1, 1024, 1024])
torch.Size([288, 1, 1024, 1024])


' Another loading method: load .tif image file\nhead_dir = r"D:\\SISR_Full\\CCP"\nTraining_GT_path = head_dir + \'\\\' + r\'GT\\*.tif\'\nTraining_Raw_path = head_dir + \'\\\' + r\'Noisy\\*.tif\'\nVal_GT_path = head_dir + \'\\\' + r\'ValGT\\*.tif\'\nVal_Raw_path = head_dir + \'\\\' + r\'ValNoisy\\*.tif\'\n\nX_train = AppendLoad(Training_Raw_path)\ny_train = AppendLoad(Training_GT_path)\nX_val = AppendLoad(Val_Raw_path)\ny_val = AppendLoad(Val_GT_path)\n\nprint("X_train.shape",X_train.shape)\nprint("y_train.shape",y_train.shape)\nprint("X_val.shape",X_val.shape)\nprint("y_val.shape",y_val.shape)\n\ntrain_data = Data.TensorDataset(X_train,y_train)\nval_data = Data.TensorDataset(X_val,y_val)\n'

In [6]:
# Augment data
img_num = X_train.shape[0]
img_num_val = X_val.shape[0]
wanted_num = 2000
aimingsize = 512
background_control_max = 0.05
background_control_mean = 0.02 # typically 0.015-0.03

X_train_expand = []
y_train_expand = []
X_val_expand = []
y_val_expand = []
crop_num = wanted_num // img_num

count = 0
for i in tqdm(range(img_num), desc="Processing images"):  
    raw = X_train[i,0,...]
    gt = y_train[i,0,...]
    raw = np.expand_dims(raw,axis=2)
    gt = np.expand_dims(gt,axis=2)
    for j in range (crop_num):
        Augraw, Auggt = DataAug(raw,gt,aimingsize)
        if Auggt.max()>background_control_max:
            if Auggt.mean()>background_control_mean:
                X_train_expand.append(Augraw.squeeze())
                y_train_expand.append(Auggt.squeeze())
                count = count + 1
X_train = np.array(X_train_expand)
y_train = np.array(y_train_expand)

count = 0
for i in tqdm(range(img_num_val), desc="Processing images"):  
    raw = X_val[i,0,...]
    gt = y_val[i,0,...]
    raw = np.expand_dims(raw,axis=2)
    gt = np.expand_dims(gt,axis=2)
    for j in range (crop_num//2):
        Augraw, Auggt = DataAug(raw,gt,aimingsize)
        if Auggt.max()>background_control_max:
            if Auggt.mean()>background_control_mean:
                X_val_expand.append(Augraw.squeeze())
                y_val_expand.append(Auggt.squeeze())
                count = count + 1
X_val = np.array(X_val_expand)
y_val = np.array(y_val_expand)

X_train = torch.tensor(X_train, dtype = torch.float16).unsqueeze(1)
y_train = torch.tensor(y_train, dtype = torch.float16).unsqueeze(1)
X_val = torch.tensor(X_val, dtype = torch.float16).unsqueeze(1)
y_val = torch.tensor(y_val, dtype = torch.float16).unsqueeze(1)

print(X_train.shape)
print(y_train.shape)

train_data = Data.TensorDataset(X_train,y_train)
val_data = Data.TensorDataset(X_val,y_val)

Processing images: 100%|██████████| 288/288 [00:12<00:00, 23.61it/s]
Processing images: 100%|██████████| 48/48 [00:01<00:00, 44.77it/s]


torch.Size([1584, 1, 512, 512])
torch.Size([1584, 1, 512, 512])


In [9]:
 class Options:  
    def __init__(self):  
        self.LR = 0.001
        self.batchsize = 1 
        self.epoch_num = 50 
        self.MSE_weight = 1
        self.MAE_weight = 0.05 
        self.SSIM_weight = 0 
        self.epoch_critic = 10 
        self.useinit = 0
        self.pre_epoch = 0
        self.ModelType = 1   
        self.StepLR = 1
        self.ExpoLR = 0
        self.val = 1
        self.test = 0
        self.use_dir =0
        self.use_norm = 1
        self.test_data_dir = ''
        self.innerpoint = []
        self.instanceimage = ''
opt = Options()

In [None]:
model = LargePNet(1,1,1,25,4)
save_path = Data_dir +r'\logfile\MINet'
TrainerSTED(model, train_data, val_data, save_path, opt)

drop path: Identity()
drop path: Identity()
drop path: Identity()
0


Training epoch:  19%|█▉        | 301/1584 [01:59<07:34,  2.82it/s, Epoch: 1, Loss: 0.002253551036119461] 