In [10]:
import torch
import torch.nn as nn
import numpy as np
import os
import sys
sys.path.append("./src/models/")
sys.path.append("./src/")
import utils as utils
from torch.utils.data import DataLoader
from MWU_CNN import MW_Unet
from tqdm import tqdm
import importlib


In [30]:
model_save_path  = "./experiments/baseline/baseline.pt"
def backprop(optimizer,model_output,target):
    optimizer.zero_grad()
    loss_fn = nn.MSELoss()
    loss = loss_fn(model_output,target)
    loss.backward()
    optimizer.step()
    return loss
def get_PSNR(model_output,target):
    I_hat = model_output.cpu().detach().numpy()
    I = target.cpu().detach().numpy()
    mse = (np.square(I-I_hat)).mean(axis=None)
    PSNR = 10 * np.log10(1.0/mse)
    return PSNR
def train(args):
    """
    train model
    """
    
    ####################################### Initializing Model #######################################
    
    step = 0.01
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print_every = int(args['--print_every'])
    num_epochs = int(args['--num_epochs'])
    save_every = int(args['--save_every'])
    save_path = model_save_path
    num_batches = int(args['--num_batches'])
    batch_size = int(args['--batch_size'])
    
    model = MW_Unet(num_conv=0,in_ch=1)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(),lr=step)
    
    ######################################### Loading Data ##########################################
    train_data_path = args['--train_data_path']
    n  = 3000
    dataset_total = utils.patchesDataset(patches_path=train_data_path,n=n)
    train_split,val_split,test_split = 0.8,0.1,0.1
    train_max_idx = int(train_split*len(dataset_total))
#     print(train_max_idx)
    dataset_train = torch.utils.data.Subset(dataset_total,range(0,train_max_idx))
    val_max_idx = train_max_idx+int(val_split*len(dataset_total))
    dataset_val = torch.utils.data.Subset(dataset_total,range(train_max_idx,val_max_idx))
    test_max_idx = val_max_idx + int(test_split*len(dataset_total))
#     print(test_max_idx)
    
#     dataset_test = torch.utils.data.Subset(dataset_total,range(val_max_idx,test_max_idx))
    
    dataloader_train = DataLoader(dataset_train,batch_size=batch_size)
    dataloader_val = DataLoader(dataset_val,batch_size=batch_size)
    
    
    print("length of train set: ",len(dataset_train))
    print("length of val set: ",len(dataset_val))
#     print("length of test set: ",len(dataset_test))
    
    train_PSNRs = []
    train_losses = []
    val_PSNRs = []
    val_losses = []
    init_epoch = 0 
    
    best_val_PSNR = 0.0
    try:
        for epoch in range(1, num_epochs + 1):
                #INITIATE dataloader_train
                print("epoch: ",epoch)
                with tqdm(total = len(dataloader_train)) as pbar:
                    for index, sample in enumerate(dataloader_train):

                        model.train()

                        target,model_input = sample['target'],sample['input']
                        target = target.to(device)
                        model_input = model_input.to(device)

                        output = model.forward(model_input)

                        train_loss = backprop(optimizer,output,target)

                        train_PSNR = get_PSNR(output,target)

                        avg_val_PSNR = []
                        avg_val_loss = []
                        model.eval()
                        with torch.no_grad():
                            for index,sample in enumerate(dataloader_val):
                                target,model_input = sample['target'],sample['input']
                                
                                target = target.to(device)
                                model_input = model_input.to(device)
                                
                                output = model.forward(model_input)
                                loss_fn = nn.MSELoss()
                                loss_val = loss_fn(output,target)
                                PSNR = get_PSNR(output,target)
                                avg_val_PSNR.append(PSNR)
                                avg_val_loss.append(loss_val.cpu().detach().numpy())
                        avg_val_PSNR = np.mean(avg_val_PSNR)
                        avg_val_loss = np.mean(avg_val_loss)
                        val_PSNRs.append(avg_val_PSNR)
                        val_losses.append(avg_val_loss)

                        train_losses.append(train_loss.cpu().detach().numpy())
                        train_PSNRs.append(train_PSNR)
                    if epoch % print_every == 0: 
                        print ("Epoch: {}, Loss: {}, Training PSNR: {}".format(epoch, train_loss, train_PSNR))
                        print ("Epoch: {}, Avg Val Loss: {},Avg Val PSNR: {}".format(epoch, avg_val_loss, avg_val_PSNR))
                    if epoch % save_every == 0 and best_val_PSNR < avg_val_PSNR:
                        best_val_PSNR = avg_val_PSNR
                        print("new best Avg Val PSNR: {}".format(best_val_PSNR))
                        print ("Saving model to {}".format(save_path))
                        torch.save({'epoch': epoch,
                            'model_state_dict': model.state_dict(), 
                            'optimizer_state_dict': optimizer.state_dict(), 
                            'loss': train_loss}, 
                             save_path)
                        print ("Saved successfully to {}".format(save_path))
    except KeyboardInterrupt:
        print("Training interupted...")
        print ("Saving model to {}".format(save_path))
        torch.save({'epoch': epoch,
                    'model_state_dict': model.state_dict(), 
                    'optimizer_state_dict': optimizer.state_dict(), 
                    'loss': train_loss}, 
                     save_path)
        print ("Saved successfully to {}".format(save_path))          
        

    print("Training completed.")


        
            

                    
                            
                            
                            
                            
                    
                    
                    

    
    
    
    


In [31]:
args = {'--print_every':5,
        '--num_epochs':10,
        '--save_every':1,
        '--num_batches':32,
        '--train_data_path':'./data/patches_Train/',
        '--batch_size':128}
train(args)

channel_1: 16, channel_2: 32
loading patches from patches directory


100%|██████████| 3000/3000 [00:05<00:00, 592.62it/s]


completed loading patches from directory!


  0%|          | 0/19 [00:00<?, ?it/s]

(3000, 240, 240)
shape of target: (3000, 1, 240, 240) shape of noisy: (3000, 1, 240, 240)
length of train set:  2400
length of val set:  300
epoch:  1



  0%|          | 0/19 [00:00<?, ?it/s]

new best Avg Val PSNR: 11.907683155434512
Saving model to ./experiments/baseline/baseline.pt
Saved successfully to ./experiments/baseline/baseline.pt
epoch:  2



  0%|          | 0/19 [00:00<?, ?it/s]

new best Avg Val PSNR: 17.660998640712528
Saving model to ./experiments/baseline/baseline.pt
Saved successfully to ./experiments/baseline/baseline.pt
epoch:  3



  0%|          | 0/19 [00:00<?, ?it/s]

new best Avg Val PSNR: 19.848217989370642
Saving model to ./experiments/baseline/baseline.pt
Saved successfully to ./experiments/baseline/baseline.pt
epoch:  4



  0%|          | 0/19 [00:00<?, ?it/s]

new best Avg Val PSNR: 21.37140048914917
Saving model to ./experiments/baseline/baseline.pt
Saved successfully to ./experiments/baseline/baseline.pt
epoch:  5



  0%|          | 0/19 [00:00<?, ?it/s]

Epoch: 5, Loss: 0.004983397666364908, Training PSNR: 23.02474293184393
Epoch: 5, Avg Val Loss: 0.004608292132616043,Avg Val PSNR: 23.37104088612215
new best Avg Val PSNR: 23.37104088612215
Saving model to ./experiments/baseline/baseline.pt
Saved successfully to ./experiments/baseline/baseline.pt
epoch:  6



  0%|          | 0/19 [00:00<?, ?it/s]

new best Avg Val PSNR: 24.321259420411906
Saving model to ./experiments/baseline/baseline.pt
Saved successfully to ./experiments/baseline/baseline.pt
epoch:  7



  0%|          | 0/19 [00:00<?, ?it/s]

new best Avg Val PSNR: 25.394704574084795
Saving model to ./experiments/baseline/baseline.pt
Saved successfully to ./experiments/baseline/baseline.pt
epoch:  8



  0%|          | 0/19 [00:00<?, ?it/s]

new best Avg Val PSNR: 25.85827734586638
Saving model to ./experiments/baseline/baseline.pt
Saved successfully to ./experiments/baseline/baseline.pt
epoch:  9



  0%|          | 0/19 [00:00<?, ?it/s]

new best Avg Val PSNR: 26.24341347220103
Saving model to ./experiments/baseline/baseline.pt
Saved successfully to ./experiments/baseline/baseline.pt
epoch:  10
Epoch: 10, Loss: 0.0029898989014327526, Training PSNR: 25.243434963886774
Epoch: 10, Avg Val Loss: 0.0022299785632640123,Avg Val PSNR: 26.51986384918006
new best Avg Val PSNR: 26.51986384918006
Saving model to ./experiments/baseline/baseline.pt
Saved successfully to ./experiments/baseline/baseline.pt
Training completed.



