In [1]:
import glob

import torch
import torch.nn as nn

from Utils.PTModel.Models import VGLCLSTMModel

MODELNAME = "VGCLLSTM"

In [8]:
import Utils.fastnumpyio as fnp

def TrainModelFromFiles(batchPaths, epochs, batchSize, continueTraining=None, learningRate=0.001):

    device = "cuda" if torch.cuda.is_available() else "cpu"

    if continueTraining == None:
        model = VGLCLSTMModel()
    else:
        model = continueTraining

    optimizer = torch.optim.RMSprop(model.parameters(), lr=learningRate, eps=1e-7)
    #optimizer = torch.optim.Adam(model.parameters(), lr=learningRate, eps=1e-7)

    criterion = nn.MSELoss()

    model.to(device)
    model.train()

    losses = []

    for i in range(epochs):

        losses.append([])

        for t in range(0, len(batchPaths), 4):
            print(t)
            #print(batchPaths[t:t+4])

            xTrain = fnp.load(batchPaths[t+1])
            xTrainTargetIn = fnp.load(batchPaths[t+2])
            yTrain = fnp.load(batchPaths[t+3])
            columnRef = fnp.load(batchPaths[t])

            for j in range(0, xTrain.shape[0], batchSize):

                xTrain = xTrain.reshape(xTrain.shape[0], xTrain.shape[1], 1)
                xTrainTargetIn = xTrainTargetIn.reshape(xTrainTargetIn.shape[0], xTrainTargetIn.shape[1], 1)
                yTrain = yTrain.reshape(yTrain.shape[0], yTrain.shape[1], 1)
                
                xTrainTensor = torch.tensor(xTrain[j:j+batchSize], dtype=torch.float32).to(device)
                xTrainTargetInTensor = torch.tensor(xTrainTargetIn[j:j+batchSize], dtype=torch.float32).to(device)

                yTrainTensor = torch.tensor(yTrain[j:j+batchSize], dtype=torch.float32).to(device)

                columnRefTensor = torch.tensor(columnRef[j:j+batchSize], dtype=torch.float32).to(device)

                # print(xTrainTensor.shape)
                # print(xTrainTargetInTensor.shape)
                # print(yTrainTensor.shape)
                # print(columnRefTensor.shape)

                #print(f"xTrain size: {xTrainTensor.size()}")

                yPred = model(xTrainTensor, xTrainTargetInTensor, columnRefTensor)
                
                #print(f"yPred size: {yPred.size()}")
                #print(f"yTruth size: {yTrainTensor.size()}")
                loss = criterion(yPred, yTrainTensor)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                losses[i].append(loss.cpu().detach().item())

            print(f"Epoch {i} Batch {t}: loss {losses[i][-1]}")

        print(f"Epoch {i}: loss {sum(losses[i])/len(losses[i])}")
    
    return model

In [15]:
model = TrainModelFromFiles(sorted(glob.glob(f"Models/{MODELNAME}/LRLSTMData/*")), 20, 32, learningRate=0.0001, continueTraining=model)

0


  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 0 Batch 0: loss 1.8350787162780762
4
Epoch 0 Batch 4: loss 10.182848930358887
8
Epoch 0 Batch 8: loss 8.909875869750977
12
Epoch 0 Batch 12: loss 8.01806640625
16
Epoch 0 Batch 16: loss 10.411195755004883
20
Epoch 0 Batch 20: loss 19.39141082763672
24
Epoch 0 Batch 24: loss 9.534256935119629
28
Epoch 0 Batch 28: loss 10.586026191711426
32
Epoch 0 Batch 32: loss 9.234169960021973
36
Epoch 0 Batch 36: loss 6.598904609680176
40
Epoch 0 Batch 40: loss 10.497014999389648
44
Epoch 0 Batch 44: loss 14.231578826904297
48
Epoch 0 Batch 48: loss 5.560441970825195
52
Epoch 0 Batch 52: loss 11.041385650634766
56
Epoch 0 Batch 56: loss 7.158257961273193
60
Epoch 0 Batch 60: loss 12.300336837768555
64
Epoch 0 Batch 64: loss 9.481467247009277
68
Epoch 0 Batch 68: loss 7.6795172691345215
72
Epoch 0 Batch 72: loss 10.083869934082031
76
Epoch 0 Batch 76: loss 8.255311965942383
80
Epoch 0 Batch 80: loss 5.891663551330566
84
Epoch 0 Batch 84: loss 2.5608203411102295
88
Epoch 0 Batch 88: loss 6.63900

In [14]:
torch.save(model, f"Models/{MODELNAME}/LodeRunnerLSTM.pt")

In [64]:
total = 0
for name, param in model.named_parameters():
    print(f"{name:<24}: {param.numel():5}")
    total += param.numel()

print(f"Total Params: {total}")

histLSTM.weight_ih_l0   :   512
histLSTM.weight_hh_l0   : 65536
histLSTM.bias_ih_l0     :   512
histLSTM.bias_hh_l0     :   512
colLSTM.weight_ih_l0    : 131072
colLSTM.weight_hh_l0    : 65536
colLSTM.bias_ih_l0      :   512
colLSTM.bias_hh_l0      :   512
textLSTM.weight_ih_l0   :   512
textLSTM.weight_hh_l0   : 65536
textLSTM.bias_ih_l0     :   512
textLSTM.bias_hh_l0     :   512
outputLayer.weight      :  1152
outputLayer.bias        :     9
Total Params: 332937
