In [8]:
import numpy as np
import glob

import torch
import torch.nn as nn

from Utils.PTModel.Models import LSTMModel

MODELNAME = "AutoEncoderwATT"

In [13]:
bbEmbeddingPath = f"Models/{MODELNAME}/LevelUnifiedRep/BubbleBobble"

columnRefArray = np.array([i+1 for i in range(13) for j in range(12)])

bbEmbeddingPaths = sorted(glob.glob(f"{bbEmbeddingPath}/Level*.npy"))

In [14]:
N = 78

sosArray = np.ones(shape=(1, 256)) * 9
eosArray = np.ones(shape=(1, 256)) * 5

xTrain = []
xTrainHotColumnRef = []
yTrain = []
xTrainTargetIn = []

columnRef = []
for i, levelEmbeddingPath in enumerate(bbEmbeddingPaths):

    levelEmbeddingArray = np.load(levelEmbeddingPath)

    for j in range(len(levelEmbeddingArray) - N):

        padLength = N - j

        dataI = np.concatenate((np.zeros(shape=(padLength, 256)), levelEmbeddingArray[:j]), axis=0)

        dataT = levelEmbeddingArray[j:j+N]
        targetIn = np.concatenate((sosArray, dataT))
        targetOut = np.concatenate((dataT, eosArray))

        levelIdx = np.concatenate((np.zeros(shape=(padLength)), columnRefArray[:j]), axis=0)
        dataC = np.zeros(shape=(N, 256))
        for j in range(N): dataC[j][int(levelIdx[j])] = 1

        columnRef.append(dataC)
        xTrain.append(dataI)
        xTrainTargetIn.append(targetIn)
        yTrain.append(targetOut)

xTrain = np.array(xTrain)
xTrainTargetIn = np.array(xTrainTargetIn)
yTrain = np.array(yTrain)
columnRef = np.array(columnRef)

In [15]:
def TrainModel(xTrain, xTrainTargetIn, yTrain, columnRef, epochs, batchSize, continueTraining=None, learningRate=0.001):

    device = "cuda" if torch.cuda.is_available() else "cpu"

    if continueTraining == None:
        model = LSTMModel()
    else:
        model = continueTraining
    # optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001, eps=1e-7)
    optimizer = torch.optim.Adam(model.parameters(), lr=learningRate, eps=1e-7)

    criterion = nn.MSELoss()

    model.to(device)
    model.train()

    losses = []

    # print(f"xTrain b4 tensor shape: {xTrain.shape}")
    
    #xTrain = xTrain.reshape(xTrain.shape[0], xTrain.shape[2], xTrain.shape[1])
    #xTrainTargetIn = xTrainTargetIn.reshape(xTrainTargetIn.shape[0], xTrainTargetIn.shape[2], xTrainTargetIn.shape[1])
    #yTrain = yTrain.reshape(yTrain.shape[0], yTrain.shape[2], yTrain.shape[1])
    #columnRef = columnRef.reshape(columnRef.shape[0], columnRef.shape[2], columnRef.shape[1])

    for i in range(epochs):

        losses.append([])

        for j in range(0, xTrain.shape[0], batchSize):
            
            xTrainTensor = torch.tensor(xTrain[j:j+batchSize], dtype=torch.float32).to(device)
            xTrainTargetInTensor = torch.tensor(xTrainTargetIn[j:j+batchSize], dtype=torch.float32).to(device)

            yTrainTensor = torch.tensor(yTrain[j:j+batchSize], dtype=torch.float32).to(device)

            columnRefTensor = torch.tensor(columnRef[j:j+batchSize], dtype=torch.float32).to(device)

            # print(xTrainTensor.shape)
            # print(xTrainTargetInTensor.shape)
            # print(yTrainTensor.shape)
            # print(columnRefTensor.shape)

            #print(f"xTrain size: {xTrainTensor.size()}")

            yPred = model(xTrainTensor, xTrainTargetInTensor, columnRefTensor)
            
            #print(f"yPred size: {yPred.size()}")
            #print(f"yTruth size: {yTrainTensor.size()}")
            loss = criterion(yPred, yTrainTensor)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses[i].append(loss.cpu().detach().item())

        print(f"Epoch {i}: loss {sum(losses[i])/len(losses[i])}")
    
    return model

In [16]:
model = TrainModel(xTrain, xTrainTargetIn, yTrain, columnRef, 1, 32, learningRate=0.001)

Epoch 0: loss 1.052663097864595


In [17]:
torch.save(model, f"Models/{MODELNAME}/BubbleBobbleLSTM.pt")

In [18]:
total = 0
for name, param in model.named_parameters():
    print(f"{name:<24}: {param.numel():5}")
    total += param.numel()

print(f"Total Params: {total}")

histLSTM.weight_ih_l0   : 131072
histLSTM.weight_hh_l0   : 65536
histLSTM.bias_ih_l0     :   512
histLSTM.bias_hh_l0     :   512
colLSTM.weight_ih_l0    : 131072
colLSTM.weight_hh_l0    : 65536
colLSTM.bias_ih_l0      :   512
colLSTM.bias_hh_l0      :   512
textLSTM.weight_ih_l0   : 131072
textLSTM.weight_hh_l0   : 65536
textLSTM.bias_ih_l0     :   512
textLSTM.bias_hh_l0     :   512
outputLayer.weight      : 32768
outputLayer.bias        :   256
Total Params: 625920
