In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import os
import cv2
import glob
from Models.LargePNet import *
import tifffile
from Models.Discriminator import *
from Utils.TrainerRaGAN32 import *
from Utils.AppendLoad import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load Raw .npz data
Data_dir = r'F:\Datasets\STED_deblurring\MitoInner\Training'
loaded_data = np.load(Data_dir + '\MitoDeblurdata.npz')
X_train = loaded_data['X_train']  
y_train = loaded_data['y_train']
X_val = loaded_data['X_val']
y_val = loaded_data['y_val']
log_dir = Data_dir + r'\logfile'
if not os.path.exists(log_dir):
    os.mkdir(log_dir)
# Create training and validation data
X_train = torch.tensor(X_train, dtype = torch.float32).unsqueeze(1)
y_train = torch.tensor(y_train, dtype = torch.float32).unsqueeze(1)
X_val = torch.tensor(X_val, dtype = torch.float32).unsqueeze(1)
y_val = torch.tensor(y_val, dtype = torch.float32).unsqueeze(1)
print(X_train.shape)
print(y_train.shape)

''' Another loading method: load .tif image file
head_dir = r"E:\SMLM\MT\512data\Training"
Training_GT_path = head_dir + '\\' + r'GT\*.tif'
Training_Raw_path = head_dir + '\\' + r'Noisy\*.tif'
Val_GT_path = head_dir + '\\' + r'ValGT\*.tif'
Val_Raw_path = head_dir + '\\' + r'ValNoisy\*.tif'

X_train = AppendLoad(Training_Raw_path)
y_train = AppendLoad(Training_GT_path)
X_val = AppendLoad(Val_Raw_path)
y_val = AppendLoad(Val_GT_path)

print("X_train.shape",X_train.shape)
print("y_train.shape",y_train.shape)
print("X_val.shape",X_val.shape)
print("y_val.shape",y_val.shape)

train_data = Data.TensorDataset(X_train,y_train)
val_data = Data.TensorDataset(X_val,y_val)
'''

torch.Size([288, 1, 1024, 1024])
torch.Size([288, 1, 1024, 1024])


' Another loading method: load .tif image file\nhead_dir = r"E:\\SMLM\\MTŊdata\\Training"\nTraining_GT_path = head_dir + \'\\\' + r\'GT\\*.tif\'\nTraining_Raw_path = head_dir + \'\\\' + r\'Noisy\\*.tif\'\nVal_GT_path = head_dir + \'\\\' + r\'ValGT\\*.tif\'\nVal_Raw_path = head_dir + \'\\\' + r\'ValNoisy\\*.tif\'\n\nX_train = AppendLoad(Training_Raw_path)\ny_train = AppendLoad(Training_GT_path)\nX_val = AppendLoad(Val_Raw_path)\ny_val = AppendLoad(Val_GT_path)\n\nprint("X_train.shape",X_train.shape)\nprint("y_train.shape",y_train.shape)\nprint("X_val.shape",X_val.shape)\nprint("y_val.shape",y_val.shape)\n\ntrain_data = Data.TensorDataset(X_train,y_train)\nval_data = Data.TensorDataset(X_val,y_val)\n'

In [3]:
# Augment data
img_num = X_train.shape[0]
img_num_val = X_val.shape[0]
wanted_num = 2000
aimingsize = 512
background_control_max = 0.05
background_control_mean = 0.02 # typically 0.015-0.03

X_train_expand = []
y_train_expand = []
X_val_expand = []
y_val_expand = []
crop_num = wanted_num // img_num

count = 0
for i in tqdm(range(img_num), desc="Processing images"):  
    raw = X_train[i,0,...]
    gt = y_train[i,0,...]
    raw = np.expand_dims(raw,axis=2)
    gt = np.expand_dims(gt,axis=2)
    for j in range (crop_num):
        Augraw, Auggt = DataAug(raw,gt,aimingsize)
        if Auggt.max()>background_control_max:
            if Auggt.mean()>background_control_mean:
                X_train_expand.append(Augraw.squeeze())
                y_train_expand.append(Auggt.squeeze())
                count = count + 1
X_train = np.array(X_train_expand)
y_train = np.array(y_train_expand)

count = 0
for i in tqdm(range(img_num_val), desc="Processing images"):  
    raw = X_val[i,0,...]
    gt = y_val[i,0,...]
    raw = np.expand_dims(raw,axis=2)
    gt = np.expand_dims(gt,axis=2)
    for j in range (crop_num//2):
        Augraw, Auggt = DataAug(raw,gt,aimingsize)
        if Auggt.max()>background_control_max:
            if Auggt.mean()>background_control_mean:
                X_val_expand.append(Augraw.squeeze())
                y_val_expand.append(Auggt.squeeze())
                count = count + 1
X_val = np.array(X_val_expand)
y_val = np.array(y_val_expand)

X_train = torch.tensor(X_train, dtype = torch.float32).unsqueeze(1)
y_train = torch.tensor(y_train, dtype = torch.float32).unsqueeze(1)
X_val = torch.tensor(X_val, dtype = torch.float32).unsqueeze(1)
y_val = torch.tensor(y_val, dtype = torch.float32).unsqueeze(1)

print(X_train.shape)
print(y_train.shape)

train_data = Data.TensorDataset(X_train,y_train)
val_data = Data.TensorDataset(X_val,y_val)

Processing images: 100%|██████████| 288/288 [00:11<00:00, 24.14it/s]
Processing images: 100%|██████████| 48/48 [00:01<00:00, 42.07it/s]


torch.Size([1580, 1, 512, 512])
torch.Size([1580, 1, 512, 512])


In [5]:
class Options:  
    def __init__(self):  
        self.LR = 0.001
        self.batchsize = 1 
        self.epoch_num = 50 
        self.MSE_weight = 0
        self.MAE_weight = 1   
        self.SSIM_weight = 0 
        self.epoch_critic = 10
        self.useinit = 0
        self.ModelType = 1   
        self.loss_content = 0.006
        self.loss_GAN = 0.001
        self.D_decay = 500   # Decay of dicriminator leraning rate over generator
        self.warmupepoch = 0
        self.Dinterval = 1
        self.StepLR = 1
        self.ExpoLR = 0
        self.val = 1
        self.test = 0
        self.use_dir = 0
        self.use_norm = 1
        self.val_data_dir = ''
        self.innerpoint = []
        self.instanceimage = ''
        self.val_start = 10
opt = Options()

In [6]:
netG = LargePNet(1,1,1,25,4)
hr_shape = (512,512)
netD = MultiScaleDiscriminator(input_shape = (1, *hr_shape))
save_path = Data_dir+r'\logfile\LargePGAN'
TrainerRaGAN32(netG, netD, train_data, val_data, save_path, opt)

drop path: Identity()
drop path: Identity()
drop path: Identity()


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


0


Training epoch: 100%|██████████| 1580/1580 [09:17<00:00,  2.83it/s, Epoch: 1, Loss: 0.020300325006246567]


Epoch: 0, Loss_D: 0.3889, Loss_G: 0.0438, D(x): 0.3889, D(z): 0.3889
Epoch: 0, Val Loss: 0.001099, Val MSE: 0.001099, Val MAE: 0.018017, Val SSIM: 0.923970
1


Training epoch: 100%|██████████| 1580/1580 [09:45<00:00,  2.70it/s, Epoch: 2, Loss: 0.029411477968096733]


Epoch: 1, Loss_D: 0.3500, Loss_G: 0.0219, D(x): 0.3500, D(z): 0.3500
Epoch: 1, Val Loss: 0.001071, Val MSE: 0.001071, Val MAE: 0.017589, Val SSIM: 0.928700
2


Training epoch: 100%|██████████| 1580/1580 [09:39<00:00,  2.73it/s, Epoch: 3, Loss: 0.011040061712265015]


Epoch: 2, Loss_D: 0.3301, Loss_G: 0.0214, D(x): 0.3301, D(z): 0.3301
Epoch: 2, Val Loss: 0.000978, Val MSE: 0.000978, Val MAE: 0.016832, Val SSIM: 0.933691
3


Training epoch: 100%|██████████| 1580/1580 [09:32<00:00,  2.76it/s, Epoch: 4, Loss: 0.025355413556098938]


Epoch: 3, Loss_D: 0.3464, Loss_G: 0.0210, D(x): 0.3464, D(z): 0.3464
Epoch: 3, Val Loss: 0.001151, Val MSE: 0.001151, Val MAE: 0.018021, Val SSIM: 0.924868
4


Training epoch: 100%|██████████| 1580/1580 [10:09<00:00,  2.59it/s, Epoch: 5, Loss: 0.02618771605193615] 


Epoch: 4, Loss_D: 0.3654, Loss_G: 0.0209, D(x): 0.3654, D(z): 0.3654
Epoch: 4, Val Loss: 0.001697, Val MSE: 0.001697, Val MAE: 0.021280, Val SSIM: 0.914445
5


Training epoch: 100%|██████████| 1580/1580 [08:58<00:00,  2.93it/s, Epoch: 6, Loss: 0.02076788991689682] 


Epoch: 5, Loss_D: 0.3717, Loss_G: 0.0206, D(x): 0.3717, D(z): 0.3717
Epoch: 5, Val Loss: 0.000940, Val MSE: 0.000940, Val MAE: 0.016714, Val SSIM: 0.933892
6


Training epoch: 100%|██████████| 1580/1580 [08:53<00:00,  2.96it/s, Epoch: 7, Loss: 0.019438298419117928] 


Epoch: 6, Loss_D: 0.3842, Loss_G: 0.0204, D(x): 0.3842, D(z): 0.3842
Epoch: 6, Val Loss: 0.000893, Val MSE: 0.000893, Val MAE: 0.016277, Val SSIM: 0.937126
7


Training epoch: 100%|██████████| 1580/1580 [08:52<00:00,  2.97it/s, Epoch: 8, Loss: 0.012843376025557518] 


Epoch: 7, Loss_D: 0.3846, Loss_G: 0.0202, D(x): 0.3846, D(z): 0.3846
Epoch: 7, Val Loss: 0.000860, Val MSE: 0.000860, Val MAE: 0.016306, Val SSIM: 0.932409
8


Training epoch: 100%|██████████| 1580/1580 [08:53<00:00,  2.96it/s, Epoch: 9, Loss: 0.015912024304270744]


Epoch: 8, Loss_D: 0.3984, Loss_G: 0.0201, D(x): 0.3984, D(z): 0.3984
Epoch: 8, Val Loss: 0.001014, Val MSE: 0.001014, Val MAE: 0.016909, Val SSIM: 0.936306
9


Training epoch: 100%|██████████| 1580/1580 [08:52<00:00,  2.97it/s, Epoch: 10, Loss: 0.012492291629314423]


Epoch: 9, Loss_D: 0.4074, Loss_G: 0.0200, D(x): 0.4074, D(z): 0.4074
Epoch: 9, Val Loss: 0.001230, Val MSE: 0.001230, Val MAE: 0.018632, Val SSIM: 0.927455
10


Training epoch: 100%|██████████| 1580/1580 [08:52<00:00,  2.97it/s, Epoch: 11, Loss: 0.033062901347875595]


Epoch: 10, Loss_D: 0.4196, Loss_G: 0.0197, D(x): 0.4196, D(z): 0.4196
Epoch: 10, Val Loss: 0.000934, Val MSE: 0.000934, Val MAE: 0.017245, Val SSIM: 0.926208
11


Training epoch:  12%|█▏        | 195/1580 [01:10<08:17,  2.78it/s, Epoch: 12, Loss: 0.02089753746986389] 


KeyboardInterrupt: 