In [4]:
import numpy as np
import glob

import torch
import torch.nn as nn

import matplotlib.pyplot as plt

from Utils.PTModel.Models import VGLCLSTMModel

MODELNAME = "VGCLLSTM"

In [84]:
import Utils.fastnumpyio as fnp

def TrainModelFromFiles(batchPaths, epochs, batchSize):

    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = VGLCLSTMModel()
    optimizer = torch.optim.RMSprop(model.parameters(), lr=0.1, eps=1e-7)
    #optimizer = torch.optim.Adam(model.parameters(), lr=0.01, eps=1e-7)

    criterion = nn.MSELoss()

    model.to(device)
    model.train()

    losses = []

    for i in range(epochs):

        losses.append([])

        for t in range(0, len(batchPaths)//2, 4):
            print(t)
            print(batchPaths[t:t+4])

            xTrain = fnp.load(batchPaths[t+1])
            xTrainTargetIn = fnp.load(batchPaths[t+2])
            yTrain = fnp.load(batchPaths[t+3])
            columnRef = fnp.load(batchPaths[t])

            for j in range(0, xTrain.shape[0], batchSize):

                xTrain = xTrain.reshape(xTrain.shape[0], xTrain.shape[1], 1)
                xTrainTargetIn = xTrainTargetIn.reshape(xTrainTargetIn.shape[0], xTrainTargetIn.shape[1], 1)
                yTrain = yTrain.reshape(yTrain.shape[0], yTrain.shape[1], 1)
                
                xTrainTensor = torch.tensor(xTrain[j:j+batchSize], dtype=torch.float32).to(device)
                xTrainTargetInTensor = torch.tensor(xTrainTargetIn[j:j+batchSize], dtype=torch.float32).to(device)

                yTrainTensor = torch.tensor(yTrain[j:j+batchSize], dtype=torch.float32).to(device)

                columnRefTensor = torch.tensor(columnRef[j:j+batchSize], dtype=torch.float32).to(device)

                # print(xTrainTensor.shape)
                # print(xTrainTargetInTensor.shape)
                # print(yTrainTensor.shape)
                # print(columnRefTensor.shape)

                #print(f"xTrain size: {xTrainTensor.size()}")

                yPred = model(xTrainTensor, xTrainTargetInTensor, columnRefTensor)
                
                #print(f"yPred size: {yPred.size()}")
                #print(f"yTruth size: {yTrainTensor.size()}")
                loss = criterion(yPred, yTrainTensor)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                losses[i].append(loss.cpu().detach().item())

            print(f"Epoch {i} Batch {t}: loss {losses[i][-1]}")

        print(f"Epoch {i}: loss {sum(losses[i])/len(losses[i])}")
    
    return model

In [85]:
model = TrainModelFromFiles(sorted(glob.glob(f"lrVGCLLSTMData/*")), 20, 32)

0
['lrVGCLLSTMData/batch0columnRef.npy', 'lrVGCLLSTMData/batch0xTrain.npy', 'lrVGCLLSTMData/batch0xTrainTargetIn.npy', 'lrVGCLLSTMData/batch0yTrain.npy']
Epoch 0 Batch 0: loss 1.9039191007614136
4
['lrVGCLLSTMData/batch10columnRef.npy', 'lrVGCLLSTMData/batch10xTrain.npy', 'lrVGCLLSTMData/batch10xTrainTargetIn.npy', 'lrVGCLLSTMData/batch10yTrain.npy']
Epoch 0 Batch 4: loss 10.220907211303711
8
['lrVGCLLSTMData/batch11columnRef.npy', 'lrVGCLLSTMData/batch11xTrain.npy', 'lrVGCLLSTMData/batch11xTrainTargetIn.npy', 'lrVGCLLSTMData/batch11yTrain.npy']
Epoch 0 Batch 8: loss 8.978261947631836
12
['lrVGCLLSTMData/batch12columnRef.npy', 'lrVGCLLSTMData/batch12xTrain.npy', 'lrVGCLLSTMData/batch12xTrainTargetIn.npy', 'lrVGCLLSTMData/batch12yTrain.npy']
Epoch 0 Batch 12: loss 8.088552474975586
16
['lrVGCLLSTMData/batch13columnRef.npy', 'lrVGCLLSTMData/batch13xTrain.npy', 'lrVGCLLSTMData/batch13xTrainTargetIn.npy', 'lrVGCLLSTMData/batch13yTrain.npy']
Epoch 0 Batch 16: loss 10.442082405090332
20
['lr

KeyboardInterrupt: 

In [63]:
torch.save(model, f"Models/{MODELNAME}/LodeRunnerLSTM.pt")

In [64]:
total = 0
for name, param in model.named_parameters():
    print(f"{name:<24}: {param.numel():5}")
    total += param.numel()

print(f"Total Params: {total}")

histLSTM.weight_ih_l0   :   512
histLSTM.weight_hh_l0   : 65536
histLSTM.bias_ih_l0     :   512
histLSTM.bias_hh_l0     :   512
colLSTM.weight_ih_l0    : 131072
colLSTM.weight_hh_l0    : 65536
colLSTM.bias_ih_l0      :   512
colLSTM.bias_hh_l0      :   512
textLSTM.weight_ih_l0   :   512
textLSTM.weight_hh_l0   : 65536
textLSTM.bias_ih_l0     :   512
textLSTM.bias_hh_l0     :   512
outputLayer.weight      :  1152
outputLayer.bias        :     9
Total Params: 332937
