In [1]:
import numpy as np
import glob

import torch
import torch.nn as nn
import Utils.fastnumpyio as fnp

from Utils.PTModel.Models import LSTMModel

MODELNAME = "AutoEncoderwATT"

In [2]:
gameName = "LodeRunner"
rowLength = 32
numOfRows = 22

lrEmbeddingPath = f"Models/{MODELNAME}/LevelUnifiedRep/{gameName}"
lrEmbeddingPaths = sorted(glob.glob(f"{lrEmbeddingPath}/Level*.npy"))

columnRefArray = np.array([np.arange(0, 32) for i in range(numOfRows+5)]).flatten()

In [3]:
def TrainModelFromFiles(batchPaths, epochs, batchSize, continueTraining=None, learningRate=0.001):

    device = "cuda" if torch.cuda.is_available() else "cpu"

    if continueTraining == None:
        model = LSTMModel()
    else:
        model = continueTraining
    
    optimizer = torch.optim.RMSprop(model.parameters(), lr=learningRate, eps=1e-7)
    #optimizer = torch.optim.Adam(model.parameters(), lr=learningRate, eps=1e-7)

    criterion = nn.MSELoss()

    model.to(device)
    model.train()

    losses = []

    for i in range(epochs):

        losses.append([])

        for t in range(0, len(batchPaths), 4):
            print(t)
            # print(batchPaths[t:t+4])

            xTrain = fnp.load(batchPaths[t+1])
            xTrainTargetIn = fnp.load(batchPaths[t+2])
            yTrain = fnp.load(batchPaths[t+3])
            columnRef = fnp.load(batchPaths[t])

            for j in range(0, xTrain.shape[0], batchSize):
                
                xTrainTensor = torch.tensor(xTrain[j:j+batchSize], dtype=torch.float32).to(device)
                xTrainTargetInTensor = torch.tensor(xTrainTargetIn[j:j+batchSize], dtype=torch.float32).to(device)

                yTrainTensor = torch.tensor(yTrain[j:j+batchSize], dtype=torch.float32).to(device)

                columnRefTensor = torch.tensor(columnRef[j:j+batchSize], dtype=torch.float32).to(device)

                # print(xTrainTensor.shape)
                # print(xTrainTargetInTensor.shape)
                # print(yTrainTensor.shape)
                # print(columnRefTensor.shape)

                #print(f"xTrain size: {xTrainTensor.size()}")

                yPred = model(xTrainTensor, xTrainTargetInTensor, columnRefTensor)
                
                #print(f"yPred size: {yPred.size()}")
                #print(f"yTruth size: {yTrainTensor.size()}")
                loss = criterion(yPred, yTrainTensor)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                losses[i].append(loss.cpu().detach().item())

            print(f"Epoch {i} Batch {t}: loss {losses[i][-1]}")

        print(f"Epoch {i}: loss {sum(losses[i])/len(losses[i])}")
    
    return model

In [5]:
model = TrainModelFromFiles(sorted(glob.glob(f"Models/{MODELNAME}/LRLSTMData/*")), 1, 32, learningRate=0.001)

0
Epoch 0 Batch 0: loss 0.7136383056640625
4
Epoch 0 Batch 4: loss 0.4626832902431488
8
Epoch 0 Batch 8: loss 0.4953828752040863
12
Epoch 0 Batch 12: loss 0.40381479263305664
16
Epoch 0 Batch 16: loss 0.6392042636871338
20
Epoch 0 Batch 20: loss 0.5832244753837585
24
Epoch 0 Batch 24: loss 0.7669863104820251
28
Epoch 0 Batch 28: loss 0.41166049242019653
32
Epoch 0 Batch 32: loss 0.4107339084148407
36
Epoch 0 Batch 36: loss 0.6719116568565369
40
Epoch 0 Batch 40: loss 0.5179440975189209
44
Epoch 0 Batch 44: loss 0.5352583527565002
48
Epoch 0 Batch 48: loss 0.4715504050254822
52
Epoch 0 Batch 52: loss 0.5633155703544617
56
Epoch 0 Batch 56: loss 0.6713377833366394
60
Epoch 0 Batch 60: loss 0.6178293228149414
64
Epoch 0 Batch 64: loss 0.6185478568077087
68
Epoch 0 Batch 68: loss 0.6466265916824341
72
Epoch 0 Batch 72: loss 0.6233478784561157
76
Epoch 0 Batch 76: loss 0.45325231552124023
80
Epoch 0 Batch 80: loss 0.5382699966430664
84
Epoch 0 Batch 84: loss 0.5210961103439331
88
Epoch 0 Ba

In [6]:
torch.save(model, f"Models/{MODELNAME}/LodeRunnerLSTM.pt")

In [7]:
total = 0
for name, param in model.named_parameters():
    print(f"{name:<24}: {param.numel():5}")
    total += param.numel()

print(f"Total Params: {total}")

histLSTM.weight_ih_l0   : 131072
histLSTM.weight_hh_l0   : 65536
histLSTM.bias_ih_l0     :   512
histLSTM.bias_hh_l0     :   512
colLSTM.weight_ih_l0    : 131072
colLSTM.weight_hh_l0    : 65536
colLSTM.bias_ih_l0      :   512
colLSTM.bias_hh_l0      :   512
textLSTM.weight_ih_l0   : 131072
textLSTM.weight_hh_l0   : 65536
textLSTM.bias_ih_l0     :   512
textLSTM.bias_hh_l0     :   512
outputLayer.weight      : 32768
outputLayer.bias        :   256
Total Params: 625920
