In [1]:
from dataset.FootballDataset import FootballDataset
from dataset.FootballDisplay import FootballDisplay
from dataset.Transforms import *
from nn.BasicNN import BasicNN

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, utils

In [3]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# The path to save and load model
MODEL_SAVE_PATH = '/projects/research/football/pytorch_nn/models/basicNN_model.ckpt'

# Hyper-parameters
batch_size = 4
num_epochs = 2
learning_rate = 0.01
input_size = 44
hidden_size = 500
output_size = 2

print(device)

cuda


In [2]:
train_dataset = FootballDataset(train=True, 
                                transform=transforms.Compose([ToTensor(),
                                                            YTo2D()]))
test_dataset = FootballDataset(train=False,
                               transform=transforms.Compose([ToTensor(),
                                                            YTo2D()]))

In [4]:
# Data loader
train_loader = DataLoader(dataset=train_dataset,
                         batch_size=batch_size,
                         shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=False)

In [5]:
model = BasicNN(input_size, hidden_size, output_size).to(device)
model.load_state_dict(torch.load(MODEL_SAVE_PATH))

<All keys matched successfully>

In [6]:
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [7]:
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, sample_batched in enumerate(train_loader):
        Xs, Ys = sample_batched['X'], sample_batched['Y']
        # Move tensors to the configured device
        for j in range(train_dataset.Tx):
            X = Xs[:, j, :].to(device)
            Y = Ys[:, j, :].to(device)

            # Forward pass
            outputs = model(X)
            loss = criterion(outputs, Y)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                epoch + 1, num_epochs, i + 1, total_step, loss.item()))


Epoch [1/2], Step [100/2588], Loss: 0.0205
Epoch [1/2], Step [200/2588], Loss: 0.0249
Epoch [1/2], Step [300/2588], Loss: 0.0170
Epoch [1/2], Step [400/2588], Loss: 0.0800
Epoch [1/2], Step [500/2588], Loss: 0.0218
Epoch [1/2], Step [600/2588], Loss: 0.0329
Epoch [1/2], Step [700/2588], Loss: 0.0079
Epoch [1/2], Step [800/2588], Loss: 0.0152
Epoch [1/2], Step [900/2588], Loss: 0.0167
Epoch [1/2], Step [1000/2588], Loss: 0.0660
Epoch [1/2], Step [1100/2588], Loss: 0.0098
Epoch [1/2], Step [1200/2588], Loss: 0.0115
Epoch [1/2], Step [1300/2588], Loss: 0.0070
Epoch [1/2], Step [1400/2588], Loss: 0.0106
Epoch [1/2], Step [1500/2588], Loss: 0.0183
Epoch [1/2], Step [1600/2588], Loss: 0.0759
Epoch [1/2], Step [1700/2588], Loss: 0.0106
Epoch [1/2], Step [1800/2588], Loss: 0.0447
Epoch [1/2], Step [1900/2588], Loss: 0.0297
Epoch [1/2], Step [2000/2588], Loss: 0.0153
Epoch [1/2], Step [2100/2588], Loss: 0.0201
Epoch [1/2], Step [2200/2588], Loss: 0.0158
Epoch [1/2], Step [2300/2588], Loss: 0.01

In [0]:
# Test the model
with torch.no_grad():
    total_loss = 0 
    total = 0
    for i, sample_batched in enumerate(test_loader):
        Xs, Ys = sample_batched['X'], sample_batched['Y']
        # Move tensors to the configured device
        for j in range(train_dataset.Tx):
            X = Xs[:, j, :].to(device)
            Y = Ys[:, j, :].to(device)

            # Forward pass
            outputs = model(X)
            loss = criterion(outputs, Y)
            total_loss += loss.item()
            total += 1

            if i < 1:
                print('loss[i][j]: %f' %loss.item())
                print(outputs)
                print(Y)

    print('Test loss: {} %'.format(
        loss / total))

In [0]:
# Save the model checkpoint
torch.save(model.state_dict(), MODEL_SAVE_PATH)