In [1]:
device='cuda:2'

In [2]:
run config

In [3]:
import numpy as np
from tqdm.notebook import tqdm
import os

from include import *

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from tensorflow import summary
import tensorflow as tf

In [4]:
##Load paths 
clean_train= data_path + 'clean_train/'
noisy_train= data_path + 'noisy_train/'

clean_val= data_path + 'clean_val/'
noisy_val= data_path + 'noisy_val/'

In [5]:
##Prepare data
class data():
    
    def __init__(self, path_clean, path_noisy):
        self.path_clean = path_clean
        self.path_noisy = path_noisy
        
    def __len__(self):
        return len(os.listdir(self.path_clean))
    
    def __getitem__(self, idx):
        
        data= dict()
        data['clean']= torch.load(self.path_clean + '{0:05}'.format(idx))
        data['noisy']= torch.load(self.path_noisy + '{0:05}'.format(idx))

        return data


In [6]:
training_set=data(clean_train, noisy_train)
validation_set=data(clean_val, noisy_val)

t_split= 100
t_diff= len(training_set)-t_split

v_split= 100
v_diff= len(validation_set)-v_split


train_set, nth= torch.utils.data.random_split(training_set,[t_split,t_diff])
val_set,   nth= torch.utils.data.random_split(validation_set,[v_split,v_diff])

#Dataloader
train_dl = DataLoader(train_set, batch_size=32, shuffle=True)
val_dl = DataLoader(val_set, batch_size=1, shuffle=False)


In [7]:
model = Unet_Mixer().to(device)

In [8]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

6451136

In [9]:
##Functions

##Loss
def mse(gt: torch.Tensor, pred:torch.Tensor)-> torch.Tensor:
    loss = torch.nn.MSELoss()
    return loss(gt,pred)

#train
def train(model, optimizer, sample):
    model.train()
    
    # reset optimizer's gradient
    optimizer.zero_grad()

    # define input and output
    clean = sample['clean'].to(device)
    noisy = sample['noisy'].to(device)
      
    # get the prediction
    pred = model(noisy)
    img = torch.clamp(noisy-pred, 0, 1)
    pred_loss = mse(img, clean)
    
    #one step of training
    pred_loss.backward()
    optimizer.step()

    return pred_loss.item()

#test function
def test(model, sample): 
    model.eval()
    
    with torch.no_grad(): 
        
        # define input and output
        clean = sample['clean'].to(device)
        noisy = sample['noisy'].to(device)
        
        # get the prediction
        pred = model(noisy)
        img = torch.clamp(noisy-pred, 0, 1)
        pred_loss = mse(img, clean)

    return pred_loss.item()

In [10]:
#prepare tensorboard
train_log_dir = logs_path + 'mixer_100/train'
train_summary_writer = summary.create_file_writer(train_log_dir)
val_log_dir = logs_path + 'mixer_100/validate'
val_summary_writer = summary.create_file_writer(val_log_dir)

In [11]:
optimizer = optim.Adam(model.parameters(), lr=0.002, weight_decay=0)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

In [None]:
max_epoch = 150
least_loss = 9999
tmp_path= models_path  +'temp_mixer_100.pth'
best_path= models_path +'best_mixer_100.pth'

for epoch in tqdm(range(max_epoch)):
    # Initialize Loss and Accuracy
    train_loss = val_loss= 0.0
    
###############################################################################################################  
                                             ###Train Phase
    
    ## Iterate over the train_dataloader
    with tqdm(total=len(train_dl)) as pbar:
        for sample in train_dl:            
            curr_loss = train(model, optimizer, sample)
            train_loss += curr_loss / len(train_dl) 
            pbar.update(1)
    
    scheduler.step()
    
    ## Write the current loss to Tensorboard
    with train_summary_writer.as_default():
        tf.summary.scalar('loss', train_loss, step=epoch)                
        
###############################################################################################################
   
    ## save the model and optimizer's information as a checkpoint
    checkpoint = {
         'model_state_dict': model.state_dict(),
         'optimizer_state_dict': optimizer.state_dict()}
    torch.save(checkpoint, tmp_path)

###############################################################################################################
                                             ###Validate Phase 
    
    ## Iterate over the test_dataloader
    with tqdm(total=len(val_dl)) as pbar:
        for sample in val_dl: 
            curr_loss= test(model, sample)
            val_loss += curr_loss / len(val_dl)
            pbar.update(1)
        
    ## Write the current loss and accuracy to the Tensorboard
    with val_summary_writer.as_default():
        tf.summary.scalar('loss', val_loss, step=epoch)
           
############################################################################################################### 
    
    print(epoch, train_loss, val_loss)
    
    ##Save best model
    least_loss = min(val_loss, least_loss)
    if least_loss == val_loss:
        torch.save(checkpoint, best_path)

HBox(children=(FloatProgress(value=0.0, max=150.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))