In [3]:
import numpy as np
import torch
import utility as util
import network as net

# data var
xLength             = 64
device              = 'cuda:0'
version             = '_1'

dataSet             = np.load('simul_Data.npz')
dataSimulFID        = dataSet['FID']
dataSimulConc       = dataSet['Conc']
dataSimulintactFID  = dataSet['FID_intact']
intactBasis         = dataSet['Basis']

# train var
trainNum            = 256
validNum            = 10
minibatchSize       = 128
epochNum            = 3
minibatchSize_eval  = 10

learningRate        = 0.001
hiddenDim           = 300
layerNum            = 1
teacherForce        = 0.5

######################################################################################################################################
# train data func
def getDataloader(xLen,dataX,dataY,batchSize,shuf):
    class RNN_Dataset(torch.utils.data.Dataset):
        def __init__(self):
            self.x_data = torch.FloatTensor(util.Normalize(dataX[:,:xLen,:],dataX[:,:xLen,:])).to(device)
            self.y_data = torch.FloatTensor(util.Normalize(dataY[:,:xLen,:],dataX[:,:xLen,:])).to(device)
        def __len__(self):
            return len(self.x_data)

        def __getitem__(self, idx):
            x = self.x_data[idx]
            y = self.y_data[idx]
            return x, y

    Train_dataset    = RNN_Dataset()
    if shuf == 'shufT':
        RNN_dataloader = torch.utils.data.DataLoader(Train_dataset, batch_size = batchSize, shuffle = True)
    elif shuf == 'shufF':
        RNN_dataloader = torch.utils.data.DataLoader(Train_dataset, batch_size = batchSize, shuffle = False)
    return RNN_dataloader
######################################################################################################################################

In [6]:
######################################################################################################################################
# train func
def getRNN(xLen,LearnRate,hiddenNum,layerNum):
    # data
    dataloader_train = getDataloader(xLen,dataSimulFID[:trainNum,:,:],dataSimulintactFID[:trainNum,:,:],minibatchSize,'shufT')
    # model
    encoder          = net.cRNN_encoder(hiddenNum, layerNum)
    decoder          = net.cRNN_decoder(hiddenNum, layerNum)
    map_first        = net.MappingFirstPoint(xLen)
    model            = net.cRNN(encoder, decoder, map_first).to(device)
    # train
    optimizer        = torch.optim.Adam(model.parameters(), lr=LearnRate)
    
    for epoch in range(epochNum):
        model.train()
        for _, samples in enumerate(dataloader_train):
            x_train, y_train  = samples
            x_predict         = model(x_train, y_train, xLen, teacherForce, device)
            cost_train        = torch.nn.MSELoss()
            cost              = cost_train(x_predict, y_train)
            optimizer.zero_grad()
            cost.backward()
            optimizer.step()
        
        print(epoch+1,'/',epochNum,
                'Train_cost: ',round(cost.item(),6),
                'Valid_cost: ',round(evalRNN(xLen,model,dataSimulFID[trainNum:trainNum+validNum,:,:],dataSimulintactFID[trainNum:trainNum+validNum,:,:],minibatchSize_eval).item(),6)
                )
    return model

def evalRNN(xLen,model_eval,data_evalX,data_evalY,batchSize_eval):
    dataloader_eval    = getDataloader(xLen,data_evalX,data_evalY,batchSize_eval,'shufF')
    model_eval.eval()
    with torch.no_grad():
        costSum_eval = 0.0
        for _, samples in enumerate(dataloader_eval):
            x_eval, y_eval = samples
            x_predict      = model_eval(x_eval, y_eval, xLen,  0, device)
            cost_eval      = torch.nn.MSELoss()
            costSum_eval  += cost_eval(x_predict, y_eval)
    return costSum_eval/(len(data_evalX)/batchSize_eval)

Running

In [None]:
model   = getRNN(xLength, learningRate, hiddenDim, layerNum)

Reconstruction

In [23]:
def getTotalConc(Conc):
    # [    0,    1,   2,     3,    4,    5,    6,    7,    8,    9,  10,   11,    12,  13,   14,  15,   16,   17,    18,   19,    20]
    # ['Ala','Asp','Cr','GABA','Glc','Gln','Glu','GPC','GSH','Lac','mI','NAA','NAAG','PC','PCr','PE','Tau','Glx','tCho','tCr','tNAA']
    TotalConc = np.zeros((len(Conc),21))
    TotalConc[:,:17]= Conc[:,:]
    TotalConc[:,17] = Conc[:,5] + Conc[:,6]
    TotalConc[:,18] = Conc[:,7] + Conc[:,13]
    TotalConc[:,19] = Conc[:,14] + Conc[:,2]
    TotalConc[:,20] = Conc[:,12] + Conc[:,11]
    return TotalConc

def getConc(X_intactFID,Y_conc,intactBasis):
    X_conc = getTotalConc(np.dot(X_intactFID.real,np.linalg.pinv(intactBasis.real)))
    Y_conc = getTotalConc(Y_conc)
    MAPE   = np.mean(((100*np.abs(X_conc-Y_conc))/Y_conc),axis=0)
    return X_conc, MAPE

def getintactFID_RNN(xLen,model_eval,data_evalX,batchSize_eval):
    dataloader_eval    = getDataloader(xLen,data_evalX,data_evalX,batchSize_eval,'shufF')
    model_eval.eval()
    with torch.no_grad():
        predFID_eval_normal = torch.Tensor([]).to(device)
        for _, samples in enumerate(dataloader_eval):
            x_eval, y_eval      = samples
            x_predict           = model_eval(x_eval, y_eval, xLen,  0, device)
            predFID_eval_normal = torch.cat([predFID_eval_normal,x_predict],0)
    preFID_eval = util.returnNormalize(predFID_eval_normal.data.cpu().numpy(),data_evalX[:,:xLen,:])
    return preFID_eval

In [24]:
test_FID            = dataSimulFID[-10:,:xLength,:]
test_Conc           = dataSimulConc[-10:,:]
recon_FID           = getintactFID_RNN(xLength,model,test_FID,10)
recon_Conc, MAPE    = getConc(recon_FID,test_Conc,intactBasis)

In [None]:
import matplotlib.pyplot as plt
plt.plot(test_FID[0,:,0])
plt.plot(recon_FID[0,:].real)
plt.show()