# Training of the U-net to impliment the unflooding 


## Loading the modules 

In [None]:


import matplotlib.pyplot as plt
import matplotlib as mlp
from sklearn.preprocessing import StandardScaler, MinMaxScaler,MaxAbsScaler, RobustScaler
from sklearn.model_selection import train_test_split
from skimage.transform import resize
import numpy as np
from prep_data import Dataset
import torch
from torch import autograd
from torch.utils.data import DataLoader
from unet import UNet
import torch.nn as  nn
import random
from scipy.ndimage.filters import gaussian_filter1d
from util import plot_models1D, plot_models1D, plot_history, plot_models1D2, plot_r2, r2_score
from noise_layer import GaussianNoise
import time 




## Define the training function 

In [None]:




def run_training(model,opt,criterion,training_data,valid_data,num_epoch=100,batchsz=32):
	ntrain = training_data.x_data.shape[0]
	nvalid = valid_data.y_data.shape[0]
	train_loader= DataLoader (dataset=training_data,batch_size=batchsz,shuffle=True)
	valid_loader= DataLoader(dataset=valid_data,batch_size=batchsz,shuffle=True)
	loss_train = []
	loss_valid = []
	R2_train = []
	R2_valid = []
	t_start = time.time()
	add_noise = GaussianNoise(0.01) # This is a noise layer, used for regularization only for training, changing the standard deviation might change the output
	for epoch in range (num_epoch):
		epoch_train_loss=0
		R2_train_running=0
		# loop over batches
		for batch,data in enumerate (train_loader,0):
			inputs,targets = data
			# inputs = inputs.view(inputs.shape[0],1,inputs.shape[1])
			targets = targets.view(targets.shape[0],1,targets.shape[1])
			inputs = add_noise(inputs)
			pred = model(inputs)
			loss = criterion(pred,targets) 
			R2_train_running += r2_score(targets,pred)
			epoch_train_loss += loss.item()
			loss.backward()
			optimizer.step()
			optimizer.zero_grad()
		epoch_train_loss = epoch_train_loss/(ntrain/batchsz)
		R2_train_running = R2_train_running/(ntrain/batchsz) 

		with torch.no_grad():
			epoch_valid_loss=0 
			R2_valid_running=0
			for batch,data in enumerate (valid_loader,0):
				inputs,targets = data
				targets = targets.view(targets.shape[0],1,targets.shape[1])
				pred = model(inputs)
				loss = criterion(pred,targets)
				R2_valid_running += r2_score(targets,pred)
				epoch_valid_loss += loss.item()
			epoch_valid_loss = epoch_valid_loss/(nvalid/batchsz)
			R2_valid_running = R2_valid_running/(nvalid/batchsz) 

		loss_train.append(epoch_train_loss)
		loss_valid.append(epoch_valid_loss)	
		R2_train.append(R2_train_running)
		R2_valid.append(R2_valid_running)
		print(f'''epoch: {epoch+1:3}/{num_epoch:3}  Training_loss: {epoch_train_loss:.5e}  Validation_loss: {epoch_valid_loss:.5e}
		 R2_Training: {R2_train_running:.5}  R2_Validation: {R2_valid_running:.5}''')
	t_end = time.time()
	print("=================================================")
	print(f"Training time is {(t_end-t_start)/60} minutes.")
	print("=================================================")

	return model,np.array(loss_train),np.array(loss_valid),np.array(R2_train),np.array(R2_valid)





## Define training hyperparameter

In [None]:
batchsz= 32
LR = 0.001
num_epoch =  100
chanl = 2
feat=16 

netname = 'unet'  



## Reading the data 

In [None]:

	# reading the inputs and targets
path = '<data path>' 

ifile = 1  # initial file index
endfile= 8000 # final file index

# Read the first file 
inp2=np.load(path+'inv_m'+str(ifile)+'.npy')   # inverted models 
oup2 =np.load(path+'true_m'+str(ifile)+'.npy') # true models
init2 =np.load(path+'init_m'+str(ifile)+'.npy') # initial models used for FWI
print('shape of the first reading file ',init2.shape)

# Allocate the arrays 
nm = inp2.shape[0]
nt = inp2.shape[1]
inp = np.zeros((endfile,nt))
oup = np.zeros((endfile,nt))
init = np.zeros((endfile,nt))
print('number of allocating array  for data is ',init.shape)
inp[0,:] = inp2
oup[0,:] = oup2
init[0,:] = init2
for k in range(ifile,endfile):
	inp_tmp = np.load(path+'inv_m'+str(k)+'.npy')
	oup_tmp = np.load(path+'true_m'+str(k)+'.npy')
	init_tmp = np.load(path+'init_m'+str(k)+'.npy')
	inp[nm,:] = inp_tmp
	oup[nm,:] = oup_tmp
	init[nm,:] = init_tmp
	nm +=1
	if k%1000==0:print('number of models', nm)       
	
print ('loaded shapes for input and output',inp.shape, oup.shape)	
path = './output/'  
plot_models1D2(inp,oup,init,inp.shape[0],5,5)

## Reshape the data and concatenate to create the input for the network 

In [None]:
inp = inp.reshape((inp.shape[0],1,inp.shape[1]))
init = init.reshape((init.shape[0],1,init.shape[1]))
inp = np.concatenate((inp,init),axis=1)

In [None]:
# Normalize the data bt the salt velocity 
inp = inp/4.5
oup = oup/4.5
init = init/4.5


# split training and validation
x_train,x_valid, y_train, y_valid = train_test_split(inp,oup, test_size=0.2)
ntrain = x_train.shape[0]
nvalid = x_valid.shape[0]


np.save(path+'NNmodel/xtrain_%s'%netname,x_train)
np.save(path+'NNmodel/ytrain_%s'%netname,y_train)
np.save(path+'NNmodel/xvalid_%s'%netname,x_valid)
np.save(path+'NNmodel/yvalid_%s'%netname,y_valid)




## Prepare data for pytorch, define the network

In [None]:
print('shape of the training data is ', x_train.shape)
print('number of training: ',ntrain )
print('number of validation: ',nvalid )
print('Batch: ',batchsz )

# prepare data for pytorch loader
training_data=Dataset(x_train,y_train)
valid_data =Dataset(x_valid,y_valid)


# define NN model
model = UNet(in_channels=chanl,out_channels=1,init_features=feat)
# model = UNet(in_channels=10,out_channels=1,init_features=chanl) # For this use unference 2
model.cuda()
# this should be outer loop
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.MSELoss() 


## Run training 

In [None]:

# Run trainnig 
model, loss_train, loss_valid, r2_train,r2_valid = run_training(model,optimizer,criterion,training_data,valid_data,num_epoch)


print('Done training yaaaaaay --------')


## Save model, losses, 

In [None]:

torch.save(model.state_dict(),path+'NNmodel/'+netname)

np.save(path+'NNmodel/Training_loss'+netname,np.array(loss_train))     
np.save(path+'NNmodel/Validation_loss'+netname,np.array(loss_valid))
np.save(path+'NNmodel/Training_R2'+netname,np.array(r2_train))
np.save(path+'NNmodel/Validation_R2'+netname,np.array(r2_valid))


plot_history(loss_train,loss_valid,netname) 
plot_history(r2_train,r2_valid,netname) 

