In [8]:
from argparse import ArgumentParser
from fastprogress.fastprogress import master_bar, progress_bar
import torch
import pandas as pd
import os

In [None]:
def train_one_epoch(dataloader, model, criterion, optimizer, device, mb):

    # Put the model into training mode
    model.train()

    # Loop over the data using the progress_bar utility
    for _, (X, Y) in progress_bar(DataLoaderProgress(dataloader), parent=mb):
        X, Y = X.to(device), Y.to(device)

        # Compute model output and then loss
        output = model(X)
        loss = criterion(output, Y)

        # - zero-out gradients
        optimizer.zero_grad()
        # - compute new gradients
        loss.backward()
        # - update paramters
        optimizer.step()

In [None]:
def validate(dataloader, model, criterion, device, epoch, num_epochs, mb):

    # Put the model into validation/evaluation mode
    model.eval()

    N = len(dataloader.dataset)
    num_batches = len(dataloader)

    loss, num_correct = 0, 0

    # Tell pytorch to stop updating gradients when executing the following
    with torch.no_grad():

        for X, Y in dataloader:
            X, Y = X.to(device), Y.to(device)

            # Compute the model output
            output = model(X)

            # - compute loss
            loss += criterion(output, Y).item()
            # - compute the number of correctly classified examples
            num_correct += (output.argmax(1) == Y).type(torch.float).sum().item()

        loss /= num_batches
        accuracy = num_correct / N

    message = "Initial" if epoch == 0 else f"Epoch {epoch:>2}/{num_epochs}:"
    message += f" accuracy={100*accuracy:5.2f}%"
    message += f" and loss={loss:.3f}"
    mb.write(message)

In [None]:
def train(model, criterion, optimizer, train_loader, valid_loader, device, num_epochs):

    mb = master_bar(range(num_epochs))

    validate(valid_loader, model, criterion, device, 0, num_epochs, mb)

    for epoch in mb:
        train_one_epoch(train_loader, model, criterion, optimizer, device, mb)
        validate(valid_loader, model, criterion, device, epoch + 1, num_epochs, mb)

In [None]:
def main():

    aparser = ArgumentParser("FIFAI--Train a neural network to predict EPL scorelines.")
    aparser.add_argument("epl_data", type=str, help="Path to store/find the EPL games dataset")
    aparser.add_argument("--num_epochs", type=int, default=10)
    aparser.add_argument("--batch_size", type=int, default=128)
    aparser.add_argument("--learning_rate", type=float, default=0.01)
    aparser.add_argument("--momentum", type=float, default=0.9)
    aparser.add_argument("--gpu", action="store_true")

    args = aparser.parse_args()

    # Use GPU if requested and available
    device = "cuda" if args.gpu and torch.cuda.is_available() else "cpu"
    

    # Get data loaders
    # train_loader, valid_loader = get_epl_data_loaders(args.epl_data, args.batch_size, 0)
    
    
    # model = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(in_features=?, out_features=10),)

    # TODO:
    # - create a CrossEntropyLoss criterion
    # - create an optimizer of your choice
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum)

    train(model, criterion, optimizer, train_loader, valid_loader, device, args.num_epochs)

In [10]:
def load_data():
    file_out = pd.read_csv("./Liverpool19_20_data.csv")
    print(file_out)
    for file in os.listdir("../teams_data_by_season"):
#         print(file)


In [11]:
load_data()

Man City19_20_data.csv
Arsenal18_19_data.csv
Man United20_21_data.csv
West Ham20_21_data.csv
Crystal Palace17_18_data.csv
Tottenham18_19_data.csv
Arsenal17_18_data.csv
West Ham18_19_data.csv
Crystal Palace16_17_data.csv
Arsenal19_20_data.csv
Man City17_18_data.csv
Leicester18_19_data.csv
Everton20_21_data.csv
Southampton16_17_data.csv
West Ham16_17_data.csv
Man City18_19_data.csv
Tottenham17_18_data.csv
Southampton20_21_data.csv
Everton16_17_data.csv
Everton17_18_data.csv
Chelsea20_21_data.csv
Liverpool17_18_data.csv
Man United16_17_data.csv
Liverpool19_20_data.csv
Chelsea17_18_data.csv
Man United17_18_data.csv
Tottenham19_20_data.csv
Leicester17_18_data.csv
Burnley19_20_data.csv
Arsenal20_21_data.csv
Burnley16_17_data.csv
Leicester19_20_data.csv
Liverpool16_17_data.csv
West Ham19_20_data.csv
Southampton17_18_data.csv
Man United19_20_data.csv
Southampton18_19_data.csv
Southampton19_20_data.csv
Chelsea18_19_data.csv
Man City16_17_data.csv
Crystal Palace19_20_data.csv
Burnley20_21_data.c