In [None]:
import numpy as np
import torch
import utility as util
import network as net

# data var
xLength             = 64
yLength             = 512
device              = 'cuda:0'
version             = '_1'

dataSimulFID        = np.load('simul_Data.npz')['FID']

# train var
trainNum            = 10000
validNum            = 1000
minibatchSize       = 128
epochNum            = 30
minibatchSize_eval  = 1000

learningRate        = 0.001
hiddenDim           = 150
layerNum            = 1
teacherForce        = 0.5

######################################################################################################################################
# train data func
def getDataloader(xLen,data,batchSize,shuf):
    class RNN_Dataset(torch.utils.data.Dataset):
        def __init__(self):
            self.x_data = torch.FloatTensor(util.Normalize(data[:,:xLen,:],data[:,:xLen,:])).to(device)
            self.y_data = torch.FloatTensor(util.Normalize(data[:,xLen:,:],data[:,:xLen,:])).to(device)
        def __len__(self):
            return len(self.x_data)

        def __getitem__(self, idx):
            x = self.x_data[idx]
            y = self.y_data[idx]
            return x, y

    Train_dataset    = RNN_Dataset()
    if shuf == 'shufT':
        RNN_dataloader = torch.utils.data.DataLoader(Train_dataset, batch_size = batchSize, shuffle = True)
    elif shuf == 'shufF':
        RNN_dataloader = torch.utils.data.DataLoader(Train_dataset, batch_size = batchSize, shuffle = False)
    return RNN_dataloader

In [2]:
######################################################################################################################################
# train func
def getRNN(xLen,yLen,LearnRate,hiddenNum,layerNum):
    # data load
    dataloader_train = getDataloader(xLen,dataSimulFID[:trainNum,:,:],minibatchSize,'shufT')
    # model load
    encoder          = net.rRNN_encoder(hiddenNum, layerNum)
    decoder          = net.rRNN_decoder(hiddenNum, layerNum)
    model            = net.rRNN(encoder, decoder).to(device)
    # train
    optimizer        = torch.optim.Adam(model.parameters(), lr=LearnRate)
    
    for epoch in range(epochNum):
        model.train()
        for _, samples in enumerate(dataloader_train):
            x_train, y_train  = samples
            x_predict         = model(x_train, y_train, xLen, yLen, teacherForce, device)
            cost_train        = torch.nn.MSELoss()
            cost              = cost_train(x_predict, y_train)
            optimizer.zero_grad()
            cost.backward()
            optimizer.step()
        
        print(epoch+1,'/',epochNum,
                'Train_cost: ',round(cost.item(),6),
                'Valid_cost: ',round(evalRNN(xLen,yLen,model,dataSimulFID[trainNum:trainNum+validNum,:,:],minibatchSize_eval).item(),6)
                )
    return model

def evalRNN(xLen,yLen,model_eval,data_eval,batchSize_eval):
    dataloader_eval    = getDataloader(xLen,data_eval,batchSize_eval,'shufF')
    model_eval.eval()
    with torch.no_grad():
        costSum_eval = 0.0
        for _, samples in enumerate(dataloader_eval):
            x_eval, y_eval = samples
            x_predict      = model_eval(x_eval, y_eval, xLen, yLen,  0, device)
            cost_eval      = torch.nn.MSELoss()
            costSum_eval  += cost_eval(x_predict, y_eval)
    return costSum_eval/(len(data_eval)/batchSize_eval)
##################################################################################################################

Running

In [None]:
model   = getRNN(xLength, yLength, learningRate, hiddenDim, layerNum)

Reconstruction

In [None]:
def getFID_RNN(xLen,yLen,model_eval,data_eval,batchSize_eval):
    dataloader_eval    = getDataloader(xLen,data_eval,batchSize_eval,'shufF')
    model_eval.eval()
    with torch.no_grad():
        predFID_eval_normal = torch.Tensor([]).to(device)
        for _, samples in enumerate(dataloader_eval):
            x_eval, y_eval      = samples
            x_predict           = model_eval(x_eval, y_eval, xLen, yLen,  0, device)
            predFID_eval_normal = torch.cat([predFID_eval_normal,torch.cat([x_eval,x_predict], 1)],0)
    preFID_eval = util.returnNormalize(predFID_eval_normal.data.cpu().numpy(),data_eval[:,:xLen,:])
    return preFID_eval

In [4]:
test_FID  = dataSimulFID[-10:,:,:]
recon_FID = getFID_RNN(xLength, yLength,model,test_FID,10)

In [None]:
import matplotlib.pyplot as plt
plt.plot(test_FID[0,:,0])
plt.plot(recon_FID[0,:].real)
plt.show()

In [None]:
import matplotlib.pyplot as plt