In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torch.autograd import Variable
from torch import Tensor, optim, nn
import wandb
from tqdm import tqdm
import pprint

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmoritz-palm[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
sweep_config = {
    'method': 'grid',
    'metric': {
        'name': 'val_accuracy',
        'goal': 'maximize'
    },
    'parameters': {
        'learning_rate': {
            'values': [0.01, 0.02, 0.05, 0.1, 0.2]
        },
        'epochs': {
            'values': [50, 100, 200]
        },
        'batch_size': {
            'values': [32, 64, 128]
        },
        'num_layers': {
            'values': [1, 2, 3, 4]
        },
        'hidden_size': {
            'values': [32, 64, 128]
        },
        'dropout_prob': {
            'values': [0.0, 0.1, 0.2, 0.3]
        },
        'regularization_lambda': {
            'values': [0.0, 0.01, 0.1, 1.0]
        },
        'optimizer': {
            'values': ['adam', 'sgd']
        },
        'loss': {
            'values': ['CrossEntropyLoss']
        },
        'activation': {
            'values': ['ReLU']
        },
        'input_size': {
            'value': 229
        }
    }
}

In [3]:
pprint.pprint(sweep_config)

{'method': 'grid',
 'metric': {'goal': 'maximize', 'name': 'val_accuracy'},
 'parameters': {'activation': {'values': ['ReLU']},
                'batch_size': {'values': [32, 64, 128]},
                'dropout_prob': {'values': [0.0, 0.1, 0.2, 0.3]},
                'epochs': {'values': [50, 100, 200]},
                'hidden_size': {'values': [32, 64, 128]},
                'input_size': {'value': 229},
                'learning_rate': {'values': [0.01, 0.02, 0.05, 0.1, 0.2]},
                'loss': {'values': ['CrossEntropyLoss']},
                'num_layers': {'values': [1, 2, 3, 4]},
                'optimizer': {'values': ['adam', 'sgd']},
                'regularization_lambda': {'values': [0.0, 0.01, 0.1, 1.0]}}}


In [4]:
sweep_id = wandb.sweep(sweep_config, project='leaguify')

Create sweep with ID: 8g5ov2yg
Sweep URL: https://wandb.ai/moritz-palm/leaguify/sweeps/8g5ov2yg


In [5]:
device = (
    "cuda" if torch.cuda.is_available()
    else "cpu"
)
if torch.cuda.is_available():
    print(f'PyTorch version: {torch.__version__}')
    print('*' * 10)
    print(f'_CUDA version: ')
    !nvcc --version
    print('*' * 10)
    print(f'CUDNN version: {torch.backends.cudnn.version()}')
    print(f'Available GPU devices: {torch.cuda.device_count()}')
    print(f'Device Name: {torch.cuda.get_device_name()}')
print(f"Using {device} device")

PyTorch version: 2.1.0+cu121
**********
_CUDA version: 
nvcc: NVIDIA (R) Cuda compiler driver**********
CUDNN version: 8801

Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:09:35_Pacific_Daylight_Time_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0
Available GPU devices: 1
Device Name: NVIDIA GeForce RTX 2080
Using cuda device


In [6]:
class StaticDataset(Dataset):
    def __init__(self, data_dir, transform=None, target_transform=None):
        self.data = torch.tensor(np.load(data_dir)[:, :-1], dtype=torch.float32, device=device)
        self.labels = torch.tensor(np.load(data_dir)[:, -1], dtype=torch.int64, device=device)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx, 1:]
        label = self.labels[idx]
        if self.transform:
            sample = self.transform(sample)
        if self.target_transform:
            label = self.target_transform(label)
        return sample, label

In [7]:
def get_train_data(slice=1):
    full_dataset = StaticDataset('../data/processed/train_static.npy')
    sub_dataset = torch.utils.data.Subset(full_dataset, range(0, len(full_dataset), slice))
    train_data = torch.utils.data.Subset(sub_dataset, range(0, int(len(sub_dataset) * 0.8)))
    val_data = torch.utils.data.Subset(sub_dataset, range(int(len(sub_dataset) * 0.8), len(sub_dataset)))
    return train_data, val_data

In [8]:
def get_test_data():
    return StaticDataset('../data/processed/test_static.npy')

In [9]:
def make_loader(dataset, batch_size=64):
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

In [10]:
class NeuralNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout_prob, num_classes=2):
        super(NeuralNetwork, self).__init__()
        self.dropout = nn.Dropout(dropout_prob)
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential()
        for i in range(num_layers):
            if i == 0:
                self.linear_relu_stack.append(nn.Linear(input_size, hidden_size))
            else:
                self.linear_relu_stack.append(nn.Linear(hidden_size, hidden_size))
            self.linear_relu_stack.append(nn.ReLU())
            self.linear_relu_stack.append(self.dropout)
        self.linear_relu_stack.append(nn.Linear(hidden_size, num_classes))
        self.linear_relu_stack.append(nn.Sigmoid())

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [11]:
def build_optimizer(network, optimizer, learning_rate):
    if optimizer == "sgd":
        optimizer = optim.SGD(network.parameters(),
                              lr=learning_rate, momentum=0.9)
    elif optimizer == "adam":
        optimizer = optim.Adam(network.parameters(),
                               lr=learning_rate)
    else:
        raise ValueError("Optimizer not supported")
    return optimizer

In [12]:
def train_batch(matches, labels, model, optimizer, criterion):
    output = model(matches)
    loss = criterion(output, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss

In [13]:
def train_log(loss, example_count, epoch):
    wandb.log({"epoch": epoch, "loss": loss}, step=example_count)
    print(f"Loss after {str(example_count).zfill(5)} examples: {loss:.3f}")

In [14]:
def train(config=None):
    with wandb.init(config=config):

        config = wandb.config  #
        train_data, val_data = get_train_data(slice=1)
        train_loader = make_loader(train_data, batch_size=config.batch_size)
        val_loader = make_loader(val_data, batch_size=config.batch_size)
        model = NeuralNetwork(config.input_size, config.hidden_size, config.num_layers, config.dropout_prob).to(device)
        optimizer = build_optimizer(model, config.optimizer, config.learning_rate)
        print(f'optimizer: {optimizer}')
        criterion = nn.CrossEntropyLoss()
        wandb.watch(model, criterion, log="all", log_freq=10)

        total_batches = len(train_loader) * config.epochs
        example_count = 0
        batch_count = 0
        for epoch in tqdm(range(config.epochs)):
            for _, (matches, labels) in enumerate(train_loader):
                loss = train_batch(matches, labels, model, optimizer, criterion)
                example_count += len(matches)
                batch_count += 1
                if (batch_count + 1) % 25 == 0:
                    train_log(loss, example_count, epoch)
        test(model, val_loader)

In [15]:
def test(model, test_loader):
    model.eval()

    # Run the model on some test examples
    with torch.no_grad():
        correct, total = 0, 0
        for matches, labels in test_loader:
            matches, labels = matches.to(device), labels.to(device)
            outputs = model(matches)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f"Accuracy of the model on the {total} " +
              f"test matches: {correct / total:%}")

        wandb.log({"val_accuracy": correct / total})

In [None]:
wandb.agent(sweep_id, train, count=50)

[34m[1mwandb[0m: Agent Starting Run: 0t00s7ad with config:
[34m[1mwandb[0m: 	activation: ReLU
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	dropout_prob: 0
[34m[1mwandb[0m: 	epochs: 50
[34m[1mwandb[0m: 	hidden_size: 32
[34m[1mwandb[0m: 	input_size: 229
[34m[1mwandb[0m: 	learning_rate: 0.01
[34m[1mwandb[0m: 	loss: CrossEntropyLoss
[34m[1mwandb[0m: 	num_layers: 1
[34m[1mwandb[0m: 	optimizer: adam
[34m[1mwandb[0m: 	regularization_lambda: 0


optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.01
    maximize: False
    weight_decay: 0
)


  2%|▏         | 1/50 [00:02<01:48,  2.21s/it]

Loss after 00768 examples: 0.550
Loss after 01568 examples: 0.655
Loss after 02368 examples: 0.541
Loss after 03158 examples: 0.474


  4%|▍         | 2/50 [00:02<00:49,  1.03s/it]

Loss after 03958 examples: 0.521
Loss after 04758 examples: 0.525
Loss after 05548 examples: 0.608
Loss after 06348 examples: 0.448


  6%|▌         | 3/50 [00:02<00:30,  1.53it/s]

Loss after 07148 examples: 0.443
Loss after 07948 examples: 0.477
Loss after 08738 examples: 0.416
Loss after 09538 examples: 0.490


  8%|▊         | 4/50 [00:02<00:22,  2.09it/s]

Loss after 10338 examples: 0.476
Loss after 11128 examples: 0.382
Loss after 11928 examples: 0.450
Loss after 12728 examples: 0.512


 10%|█         | 5/50 [00:03<00:17,  2.64it/s]

Loss after 13528 examples: 0.446
Loss after 14318 examples: 0.519
Loss after 15118 examples: 0.575
Loss after 15918 examples: 0.427


 14%|█▍        | 7/50 [00:03<00:12,  3.53it/s]

Loss after 16708 examples: 0.386
Loss after 17508 examples: 0.499
Loss after 18308 examples: 0.405
Loss after 19098 examples: 0.408


 16%|█▌        | 8/50 [00:03<00:10,  3.86it/s]

Loss after 19898 examples: 0.389
Loss after 20698 examples: 0.449
Loss after 21498 examples: 0.365
Loss after 22288 examples: 0.447


 18%|█▊        | 9/50 [00:03<00:09,  4.12it/s]

Loss after 23088 examples: 0.447
Loss after 23888 examples: 0.467
Loss after 24678 examples: 0.408
Loss after 25478 examples: 0.380


 20%|██        | 10/50 [00:04<00:09,  4.18it/s]

Loss after 26278 examples: 0.431
Loss after 27078 examples: 0.345
Loss after 27868 examples: 0.351
Loss after 28668 examples: 0.410


 22%|██▏       | 11/50 [00:04<00:09,  4.32it/s]

Loss after 29468 examples: 0.353
Loss after 30258 examples: 0.460
Loss after 31058 examples: 0.377
Loss after 31858 examples: 0.406


 26%|██▌       | 13/50 [00:04<00:08,  4.59it/s]

Loss after 32648 examples: 0.506
Loss after 33448 examples: 0.379
Loss after 34248 examples: 0.376
Loss after 35048 examples: 0.389


 28%|██▊       | 14/50 [00:04<00:07,  4.68it/s]

Loss after 35838 examples: 0.470
Loss after 36638 examples: 0.492
Loss after 37438 examples: 0.355
Loss after 38228 examples: 0.469


 30%|███       | 15/50 [00:05<00:07,  4.74it/s]

Loss after 39028 examples: 0.414
Loss after 39828 examples: 0.556
Loss after 40628 examples: 0.408
Loss after 41418 examples: 0.380


 32%|███▏      | 16/50 [00:05<00:07,  4.74it/s]

Loss after 42218 examples: 0.443
Loss after 43018 examples: 0.364
Loss after 43808 examples: 0.408
Loss after 44608 examples: 0.356


 34%|███▍      | 17/50 [00:05<00:07,  4.63it/s]

Loss after 45408 examples: 0.349
Loss after 46198 examples: 0.346
Loss after 46998 examples: 0.470
Loss after 47798 examples: 0.411


 36%|███▌      | 18/50 [00:05<00:06,  4.66it/s]

Loss after 48598 examples: 0.345
Loss after 49388 examples: 0.377
Loss after 50188 examples: 0.380
Loss after 50988 examples: 0.407


 40%|████      | 20/50 [00:06<00:06,  4.78it/s]

Loss after 51778 examples: 0.378
Loss after 52578 examples: 0.381
Loss after 53378 examples: 0.400
Loss after 54178 examples: 0.445


 42%|████▏     | 21/50 [00:06<00:06,  4.83it/s]

Loss after 54968 examples: 0.545
Loss after 55768 examples: 0.325
Loss after 56568 examples: 0.352
Loss after 57358 examples: 0.470


 44%|████▍     | 22/50 [00:06<00:05,  4.80it/s]

Loss after 58158 examples: 0.376
Loss after 58958 examples: 0.345
Loss after 59748 examples: 0.346
Loss after 60548 examples: 0.377


 46%|████▌     | 23/50 [00:06<00:05,  4.62it/s]

Loss after 61348 examples: 0.313
Loss after 62148 examples: 0.401
Loss after 62938 examples: 0.313
Loss after 63738 examples: 0.323


 48%|████▊     | 24/50 [00:07<00:05,  4.57it/s]

Loss after 64538 examples: 0.408
Loss after 65328 examples: 0.346
Loss after 66128 examples: 0.345
Loss after 66928 examples: 0.440


 50%|█████     | 25/50 [00:07<00:05,  4.67it/s]

Loss after 67728 examples: 0.348
Loss after 68518 examples: 0.470
Loss after 69318 examples: 0.409
Loss after 70118 examples: 0.384


 52%|█████▏    | 26/50 [00:07<00:05,  4.55it/s]

Loss after 70908 examples: 0.376
Loss after 71708 examples: 0.345
Loss after 72508 examples: 0.345


 54%|█████▍    | 27/50 [00:07<00:05,  4.50it/s]

Loss after 73298 examples: 0.354
Loss after 74098 examples: 0.345
Loss after 74898 examples: 0.407
Loss after 75698 examples: 0.376


 58%|█████▊    | 29/50 [00:08<00:04,  4.56it/s]

Loss after 76488 examples: 0.351
Loss after 77288 examples: 0.429
Loss after 78088 examples: 0.314
Loss after 78878 examples: 0.376


 60%|██████    | 30/50 [00:08<00:04,  4.51it/s]

Loss after 79678 examples: 0.346
Loss after 80478 examples: 0.396
Loss after 81278 examples: 0.346
Loss after 82068 examples: 0.313


 62%|██████▏   | 31/50 [00:08<00:04,  4.51it/s]

Loss after 82868 examples: 0.439
Loss after 83668 examples: 0.432
Loss after 84458 examples: 0.408
Loss after 85258 examples: 0.345


 64%|██████▍   | 32/50 [00:08<00:03,  4.57it/s]

Loss after 86058 examples: 0.357
Loss after 86848 examples: 0.313
Loss after 87648 examples: 0.379
Loss after 88448 examples: 0.470


 66%|██████▌   | 33/50 [00:09<00:03,  4.62it/s]

Loss after 89248 examples: 0.377
Loss after 90038 examples: 0.353
Loss after 90838 examples: 0.345


 68%|██████▊   | 34/50 [00:09<00:03,  4.42it/s]

Loss after 91638 examples: 0.361
Loss after 92428 examples: 0.472
Loss after 93228 examples: 0.376
Loss after 94028 examples: 0.438


 70%|███████   | 35/50 [00:09<00:03,  4.49it/s]

Loss after 94828 examples: 0.376
Loss after 95618 examples: 0.438
Loss after 96418 examples: 0.313
Loss after 97218 examples: 0.376


 72%|███████▏  | 36/50 [00:09<00:03,  4.41it/s]

Loss after 98008 examples: 0.345
Loss after 98808 examples: 0.346
Loss after 99608 examples: 0.407


 74%|███████▍  | 37/50 [00:09<00:02,  4.40it/s]

Loss after 100398 examples: 0.533
Loss after 101198 examples: 0.313
Loss after 101998 examples: 0.414
Loss after 102798 examples: 0.420


 76%|███████▌  | 38/50 [00:10<00:02,  4.46it/s]

Loss after 103588 examples: 0.377
Loss after 104388 examples: 0.376
Loss after 105188 examples: 0.376


 78%|███████▊  | 39/50 [00:10<00:02,  4.40it/s]

Loss after 105978 examples: 0.345
Loss after 106778 examples: 0.442
Loss after 107578 examples: 0.447


 80%|████████  | 40/50 [00:10<00:02,  4.39it/s]

Loss after 108378 examples: 0.376
Loss after 109168 examples: 0.470
Loss after 109968 examples: 0.347
Loss after 110768 examples: 0.376


 84%|████████▍ | 42/50 [00:11<00:01,  4.60it/s]

Loss after 111558 examples: 0.345
Loss after 112358 examples: 0.407
Loss after 113158 examples: 0.409


 86%|████████▌ | 43/50 [00:11<00:01,  4.69it/s]

Loss after 113948 examples: 0.376
Loss after 114748 examples: 0.362
Loss after 115548 examples: 0.450
Loss after 116348 examples: 0.402


 88%|████████▊ | 44/50 [00:11<00:01,  4.69it/s]

Loss after 117138 examples: 0.314
Loss after 117938 examples: 0.407
Loss after 118738 examples: 0.346
Loss after 119528 examples: 0.408


 90%|█████████ | 45/50 [00:11<00:01,  4.55it/s]

Loss after 120328 examples: 0.345
Loss after 121128 examples: 0.360
Loss after 121928 examples: 0.385
Loss after 122718 examples: 0.359


 92%|█████████▏| 46/50 [00:11<00:00,  4.40it/s]

Loss after 123518 examples: 0.345
Loss after 124318 examples: 0.470
Loss after 125108 examples: 0.438
Loss after 125908 examples: 0.407


 94%|█████████▍| 47/50 [00:12<00:00,  4.46it/s]

Loss after 126708 examples: 0.345
Loss after 127498 examples: 0.376
Loss after 128298 examples: 0.323
Loss after 129098 examples: 0.377


 96%|█████████▌| 48/50 [00:12<00:00,  4.37it/s]

Loss after 129898 examples: 0.313
Loss after 130688 examples: 0.380
Loss after 131488 examples: 0.376


 98%|█████████▊| 49/50 [00:12<00:00,  4.24it/s]

Loss after 132288 examples: 0.370
Loss after 133078 examples: 0.407
Loss after 133878 examples: 0.438


100%|██████████| 50/50 [00:12<00:00,  3.88it/s]

Loss after 134678 examples: 0.321
Loss after 135478 examples: 0.407
Accuracy of the model on the 678 test matches: 70.353982%



