In [None]:
run config

In [1]:
import numpy as np
from tqdm.notebook import tqdm
import os

from networks import *

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [2]:
device = 'cuda:2'  #Enter device name
tensorboard = 0    #Enter 1 if you wish to use tensorboard or 0 otherwise

In [3]:
t_split = 4000   # number of training images
v_split = 100    # number of validation images 
batch_size = 12  # batch size

max_epoch = 60   # training epochs
lr = 0.001       # learning rate
step_size = 50   # number of epochs at which learning rate decays
gamma = 0.5      # facetor by which learning rate decays

In [4]:
############################## Choose Network you wish to train ###################################################

In [5]:
model_name = "Img2Img_Mixer" #any arbitrary name

model = Img2Img_Mixer(
        
        img_size = 256,   #Image Size (assumed to be square image), here 256 x 256
        img_channels = 3, #Image Channels, 3 for RGB, 1 for greyscale
        patch_size = 4,   #Patch Size, P
        embed_dim = 128,  #Embedding Dimension, C
        num_layers = 16,  #Number of Mixer Layers, N
        f_hidden = 4,     #Multiplication Factor for Hidden Dimensions, f
)

In [6]:
#model_name = "Linear_Mixer"

#model = Linear_Mixer(
#
#        img_size = 256,   #Image Size (assumed to be square image), here 256 x 256
#        img_channels = 3, #Image Channels, 3 for RGB, 1 for greyscale
#        patch_size = 4,   #Patch Size, P
#        embed_dim = 140,  #Embedding Dimension, C
#        num_layers = 19,  #Number of Mixer Layers, N
#        f_hidden = 4,     #Multiplication Factor for Hidden Dimensions, f
#)

In [7]:
#model_name = "Original_Mixer"

#model = Original_Mixer(
#
#        image_size = 256,            #Image Size (assumed to be square image), here 256 x 256
#        channels = 3,                #Image Channels, 3 for RGB, 1 for greyscale
#        patch_size = 4,              #Patch Size, P
#        num_layers = 8,              #Number of Mixer Layers, N
#        hidden_dim = 128,            #Embedding Dimension, C
#        tokens_hidden_dim = 96,      #Hidden Dimension for Tokens
#        channels_hidden_dim = 256    #Hidden Dimension for Channels
#)

In [8]:
#model_name = "U_Mixer"

#model = U_Mixer(
#        
#        img_size = 256,    #Image Size (assumed to be square image), here 256 x 256
#        img_channels = 3,  #Image Channels, 3 for RGB, 1 for greyscale
#        embed_dim = 96,    #Embedding Dimension, C
#)

In [9]:
#model_name = "Unet"

#model = Unet (
    
#        in_chans = 3,  #Number of channels in the input to the U-Net model
#        out_chans = 3, #Number of channels in the output to the U-Net model
#        chans = 21,    #Number of output channels of the first convolution layer
#)    

In [10]:
# model_name = "ViT"

# net = VisionTransformer(
#     avrg_img_size=256, 
#     patch_size=10, 
#     in_chans=3, embed_dim=44, 
#     depth=4, num_heads=6, mlp_ratio=4., 
#     )

# model = ReconNet(net)

In [11]:
###################################################################################################################

In [12]:
model = model.to(device)
print("Model size is: ",  sum(p.numel() for p in model.parameters() if p.requires_grad))

Model size is:  3443456


In [None]:
##Load paths 
clean_train= data_path + 'clean_train/'
noisy_train= data_path + 'noisy_train/'

clean_val= data_path + 'clean_val/'
noisy_val= data_path + 'noisy_val/'

In [None]:
##Prepare data
class data():
    
    def __init__(self, path_clean, path_noisy):
        self.path_clean = path_clean
        self.path_noisy = path_noisy
        
    def __len__(self):
        return len(os.listdir(self.path_clean))
    
    def __getitem__(self, idx):
        
        data= dict()
        data['clean']= torch.load(self.path_clean + '{0:05}'.format(idx))
        data['noisy']= torch.load(self.path_noisy + '{0:05}'.format(idx))

        return data


In [None]:
training_set=data(clean_train, noisy_train)
validation_set=data(clean_val, noisy_val)


t_diff= len(training_set)-t_split
v_diff= len(validation_set)-v_split


train_set, nth= torch.utils.data.random_split(training_set,[t_split,t_diff])
val_set,   nth= torch.utils.data.random_split(validation_set,[v_split,v_diff])

#Dataloader
train_dl = DataLoader(train_set, batch_size= batch_size, shuffle=True)
val_dl = DataLoader(val_set, batch_size=1, shuffle=False)


In [None]:
##Functions

##Loss
def mse(gt: torch.Tensor, pred:torch.Tensor)-> torch.Tensor:
    loss = torch.nn.MSELoss()
    return loss(gt,pred)

#train
def train(model, optimizer, sample):
    model.train()
    
    # reset optimizer's gradient
    optimizer.zero_grad()

    # define input and output
    clean = sample['clean'].to(device)
    noisy = sample['noisy'].to(device)
      
    # get the prediction
    pred = model(noisy)
    img = torch.clamp(noisy-pred, 0, 1)
    pred_loss = mse(img, clean)
    
    #one step of training
    pred_loss.backward()
    optimizer.step()

    return pred_loss.item()

#test function
def test(model, sample): 
    model.eval()
    
    with torch.no_grad(): 
        
        # define input and output
        clean = sample['clean'].to(device)
        noisy = sample['noisy'].to(device)
        
        # get the prediction
        pred = model(noisy)
        img = torch.clamp(noisy-pred, 0, 1)
        pred_loss = mse(img, clean)

    return pred_loss.item()

In [None]:
if tensorboard:
    
    from tensorflow import summary
    import tensorflow as tf
    
    train_log_dir = logs_path + model_name + '/train'
    train_summary_writer = summary.create_file_writer(train_log_dir)
    val_log_dir = logs_path + model_name + '/validate'
    val_summary_writer = summary.create_file_writer(val_log_dir)

In [None]:
least_loss = 9999

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

tmp_path = models_path  +'temp_' + model_name + '.pth'
best_path = models_path +'best_' + model_name + '.pth'

for epoch in tqdm(range(max_epoch)):
    # Initialize Loss
    train_loss = val_loss= 0.0
    
###############################################################################################################  
                                             ###Train Phase
    
    ## Iterate over the train_dataloader
    with tqdm(total=len(train_dl)) as pbar:
        for sample in train_dl:            
            curr_loss = train(model, optimizer, sample)
            train_loss += curr_loss / len(train_dl) 
            pbar.update(1)
    
    scheduler.step()
    
    if tensorboard:
        ## Write the current loss to Tensorboard
        with train_summary_writer.as_default():
            tf.summary.scalar('loss', train_loss, step=epoch)                
        
###################################################################################################################
   
    ## save the model and optimizer's information as a checkpoint
    checkpoint = {
         'model_state_dict': model.state_dict(),
         'optimizer_state_dict': optimizer.state_dict()}
    torch.save(checkpoint, tmp_path)

###################################################################################################################
                                             ###Validate Phase 
    
    ## Iterate over the test_dataloader
    with tqdm(total=len(val_dl)) as pbar:
        for sample in val_dl: 
            curr_loss= test(model, sample)
            val_loss += curr_loss / len(val_dl)
            pbar.update(1)
        
    if tensorboard:
        ## Write the current loss and accuracy to the Tensorboard
        with val_summary_writer.as_default():
            tf.summary.scalar('loss', val_loss, step=epoch)

###################################################################################################################               
    print(epoch, train_loss, val_loss)
    
    ##Save best model
    least_loss = min(val_loss, least_loss)
    if least_loss == val_loss:
        torch.save(checkpoint, best_path)