# Improved balance data


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
from torch.utils.data.sampler import SubsetRandomSampler

import numpy as np
import data
from dataset import rdDataset
from model import rdcnn_2
from math import log10


# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:1" if use_cuda else "cpu")
# cudnn.benchmark = True
path = './data'

# Parameters

params = {'test_split': .25,
          'shuffle_dataset': True,
          'batchsize': 32,
          'testBatchsize': 10,
          'random_seed': 42,
          'numworkers':32,
          'pinmemory':True}
max_epoches = 100
learning_rate = 1e-3
drop_rate = 0.0

print('===> Loading datasets')
# Load All Dataset
dataset = rdDataset(path)

# Creating data indices for training and validation splits:
training_data_loader, testing_data_loader = data.DatasetSplit(dataset, **params)

print('===> Building model')
model = rdcnn_2(drop_rate).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,weight_decay=1e-5)



def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(training_data_loader, 1):
        input, target = batch[0].to(device, torch.float), batch[1].to(device, torch.float)
        optimizer.zero_grad()
        loss = criterion(model(input), target)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()

#         print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.item()))

    print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader)))
    return epoch, epoch_loss / len(training_data_loader)
    
def test():
    avg_error = 0
    avg_loss = 0
    with torch.no_grad():
        for batch in testing_data_loader:
            input, target = batch[0].to(device, torch.float), batch[1].to(device, torch.float)

            prediction = model(input)
            tmp_error = 0
#             print(len(prediction))
            for j in range(len(prediction)):
                tmp_error += torch.mean((prediction[j]-target[j])**2/torch.max(target[j]))
            avg_error += tmp_error / len(prediction)
            mse = criterion(prediction, target)
            avg_loss += mse
    print("===> Avg. Loss: {:.4f} ".format(avg_loss / len(testing_data_loader)))
    print("===> Avg. Error: {:.4f} ".format(avg_error / len(testing_data_loader)))
    return avg_loss / len(testing_data_loader),avg_error / len(testing_data_loader)

def checkpoint(epoch):
    model_out_path = "./checkpoint_largedata5_removetanh-4layerinput/model_epoch_{}.pth".format(epoch)
    torch.save(model, model_out_path)
    print("Checkpoint saved to {}".format(model_out_path))
    
    

In [None]:
L_train_loss = []
L_test_loss = []
L_test_error = []
for epoch in range(1, max_epoches + 1):
    train_loss = train(epoch)
    test_loss,test_error = test()
    checkpoint(epoch)
#     data.TestErrorPlot(model,device, testing_data_loader)
    L_train_loss.append(train_loss)
    L_test_loss.append(test_loss)
    L_test_error.append(test_error)

In [None]:
import importlib 
importlib.reload(data)
# data.TestErrorPlot(model,device, testing_data_loader)
test_error_epoch = data.ComputeErrorVsEpoch("./checkpoint_largedata5_removetanh-4layerinput/", device, testing_data_loader)
np.savetxt('relu_4layer_test_error_epoch.out',  np.transpose(test_error_epoch))

In [None]:
import data
import importlib 
importlib.reload(data)
data.TestErrorPlot(model,device, testing_data_loader)

In [None]:
model = torch.load('./checkpoint_largedata5_removetanh-4layerinput/model_epoch_100.pth')
model.eval()

In [None]:
from matplotlib import pyplot as plt

with torch.no_grad():
    for batch in testing_data_loader:
        input, target = batch[0].to(device, torch.float), batch[1].to(device, torch.float)
        prediction = model(input)

In [None]:
from matplotlib import pyplot as plt
prediction_L = []
input_L = []
target_L = []
i=0

with torch.no_grad():
    for batch in testing_data_loader:
        if ++i ==
        input, target = batch[0].to(device, torch.float), batch[1].to(device, torch.float)
        input_L.append(input)
        target_L.append(target)
        prediction = model(input)
        prediction_L.append(prediction)
        i = i+1
        if i==10:
            break

In [None]:
for i in range(10):
    input = input_L[i].cpu().numpy()
    target = target_L[i]
    fig, ax = plt.subplots(1,5, figsize=(20,5))
    for t in range(5):
        im = ax[t].imshow(target[t][0].cpu(),cmap = "jet")
        ax[t].axis('off')
        ax[t].set_title("D = "+str(input[t][3][0][0])+"  K = "+str(input[t][2][0][0])+"  t = "+str(input[t][1][0][0]),size=10)
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.84, 0.27, 0.01, 0.47])
    fig.colorbar(im, cax=cbar_ax)
    
    fig, ax = plt.subplots(1,5, figsize=(20,5))
    for t in range(5,10):
        im = ax[t-5].imshow(target[t][0].cpu(),cmap = "jet")
        ax[t-5].axis('off')
        ax[t-5].set_title("D = "+str(input[t][3][0][0])+"  K = "+str(input[t][2][0][0])+"  t = "+str(input[t][1][0][0]),size=10)
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.84, 0.27, 0.01, 0.47])
    fig.colorbar(im, cax=cbar_ax)


plt.show()

In [None]:
for i in range(10):
    input = input_L[i].cpu().numpy()
    prediction = prediction_L[i]
    fig, ax = plt.subplots(1,5, figsize=(20,5))
    for t in range(5):
        im = ax[t].imshow(prediction[t][0].cpu(),cmap = "jet")
        ax[t].axis('off')
        ax[t].set_title("D = "+str(input[t][3][0][0])+"  K = "+str(input[t][2][0][0])+"  t = "+str(input[t][1][0][0]),size=10)
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.84, 0.27, 0.01, 0.47])
    fig.colorbar(im, cax=cbar_ax)
    
    fig, ax = plt.subplots(1,5, figsize=(20,5))
    for t in range(5,10):
        im = ax[t-5].imshow(prediction[t][0].cpu(),cmap = "jet")
        ax[t-5].axis('off')
        ax[t-5].set_title("D = "+str(input[t][3][0][0])+"  K = "+str(input[t][2][0][0])+"  t = "+str(input[t][1][0][0]),size=10)
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.84, 0.27, 0.01, 0.47])
    fig.colorbar(im, cax=cbar_ax)


plt.show()

In [None]:
for i in range(10):
    input = input_L[i].cpu().numpy()
    target = target_L[i]
    prediction = prediction_L[i]
    for t in range(len(prediction)):
        fig, ax = plt.subplots(1,2, figsize=(10,5))

        im = ax[0].imshow(prediction[t][0].cpu(),cmap = "jet")
        ax[0].axis('off')
        ax[0].set_title("Prediction")
        im = ax[1].imshow(target[t][0].cpu(),cmap = "jet")
        ax[1].axis('off')
        ax[1].set_title("Ground Truth data")

        fig.subplots_adjust(right=0.8)
        cbar_ax = fig.add_axes([0.84, 0.27, 0.01, 0.47])
        fig.colorbar(im, cax=cbar_ax)
        fig.text(0.35, 0.1,"D = "+str(input[t][3][0][0])+"  K = "+str(input[t][2][0][0])+"  t = "+str(input[t][1][0][0]), fontsize=10)
    plt.show()

In [None]:
from matplotlib import pyplot as plt
prediction_L = []
input_L = []
target_L = []
i=0

with torch.no_grad():
    for batch in testing_data_loader:
        input, target = batch[0].to(device, torch.float), batch[1].to(device, torch.float)
        predction = model(input)
        tmp_error = 0
        for j in range(len(prediction)):
            tmp_error = ComputeTestError(prediction[j], target[j])
        

In [None]:
for i in range(10):
    input = input_L[i].cpu().numpy()
    target = target_L[i]
    fig, ax = plt.subplots(1,5, figsize=(20,5))
    for t in range(5):
        im = ax[t].imshow(target[t][0].cpu(),cmap = "jet")
        ax[t].axis('off')
        ax[t].set_title("D = "+str(input[t][3][0][0])+"  K = "+str(input[t][2][0][0])+"  t = "+str(input[t][1][0][0]),size=10)
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.84, 0.27, 0.01, 0.47])
    fig.colorbar(im, cax=cbar_ax)
    
    fig, ax = plt.subplots(1,5, figsize=(20,5))
    for t in range(5,10):
        im = ax[t-5].imshow(target[t][0].cpu(),cmap = "jet")
        ax[t-5].axis('off')
        ax[t-5].set_title("D = "+str(input[t][3][0][0])+"  K = "+str(input[t][2][0][0])+"  t = "+str(input[t][1][0][0]),size=10)
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.84, 0.27, 0.01, 0.47])
    fig.colorbar(im, cax=cbar_ax)


plt.show()#%%
