In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torch.autograd import Variable
from torch import Tensor, optim, nn
import wandb
from tqdm import tqdm

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmoritz-palm[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
config = {
    "learning_rate": 0.02,
    "architecture": "NN",
    "dataset": "static_1.1",
    "epochs": 100,
    "classes": 2,
    "batch_size": 64,
    "num_layers": 2,
    "hidden_size": 64,
    "dropout_prob": 0.2,
    "input_size": 229,
    "output_size": 2,
    "optimizer": "Adam",
    "loss": "CrossEntropyLoss",
    "activation": "ReLU",
    "initializer": "Xavier",
    "regularization": "L2",
    "regularization_lambda": 0.01,
}

In [4]:
def model_pipeline(hyperparameters):
    with wandb.init(project="leaguify", config=hyperparameters):
        # access all HPs through wandb.config, so logging matches execution!
        config = wandb.config

        # make the model, data, and optimization problem
        model, train_loader, test_loader, criterion, optimizer = make(config)
        print(model)

        # and use them to train the model
        train(model, train_loader, criterion, optimizer, config)

        # and test its final performance
        test(model, test_loader)

    return model

In [5]:
class StaticDataset(Dataset):
    def __init__(self, data_dir, transform=None, target_transform=None):
        self.data = torch.tensor(np.load(data_dir)[:, :-1], dtype=torch.float32, device=device)
        self.labels = torch.tensor(np.load(data_dir)[:, -1], dtype=torch.int64, device=device)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx, 1:]
        label = self.labels[idx]
        if self.transform:
            sample = self.transform(sample)
        if self.target_transform:
            label = self.target_transform(label)
        return sample, label

In [6]:
def make(config):
    train, test = get_data(train=True), get_data(train=False)
    train_loader = make_loader(train, batch_size=config.batch_size)
    test_loader = make_loader(test, batch_size=config.batch_size)

    model = NeuralNetwork(config.input_size, config.hidden_size, config.num_layers, config.dropout_prob,
                          config.classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

    return model, train_loader, test_loader, criterion, optimizer

In [7]:
def get_data(slice=1, train=True):
    if train:
        full_dataset = StaticDataset('../data/train_static.npy')
    else:
        full_dataset = StaticDataset('../data/test_static.npy')
    sub_dataset = torch.utils.data.Subset(full_dataset, range(0, len(full_dataset), slice))
    return sub_dataset

In [8]:
def make_loader(dataset, batch_size=64):
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

In [9]:
device = (
    "cuda" if torch.cuda.is_available()
    else "cpu"
)
if torch.cuda.is_available():
    print(f'PyTorch version: {torch.__version__}')
    print('*' * 10)
    print(f'_CUDA version: ')
    !nvcc --version
    print('*' * 10)
    print(f'CUDNN version: {torch.backends.cudnn.version()}')
    print(f'Available GPU devices: {torch.cuda.device_count()}')
    print(f'Device Name: {torch.cuda.get_device_name()}')
print(f"Using {device} device")

PyTorch version: 2.1.0
**********
_CUDA version: 
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Jun_13_19:42:34_Pacific_Daylight_Time_2023
Cuda compilation tools, release 12.2, V12.2.91
Build cuda_12.2.r12.2/compiler.32965470_0
**********
CUDNN version: 8801
Available GPU devices: 1
Device Name: NVIDIA GeForce RTX 2080
Using cuda device


In [10]:
class NeuralNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout_prob, classes=2):
        super(NeuralNetwork, self).__init__()
        self.dropout = nn.Dropout(dropout_prob)
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 2),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [11]:
def train(model, loader, criterion, optimizer, config):
    wandb.watch(model, criterion, log="all", log_freq=10)

    total_batches = len(loader) * config.epochs
    example_count = 0
    batch_count = 0
    for epoch in tqdm(range(config.epochs)):
        for _, (matches, labels) in enumerate(loader):
            loss = train_batch(matches, labels, model, optimizer, criterion)
            example_count += len(matches)
            batch_count += 1
            if (batch_count + 1) % 25 == 0:
                train_log(loss, example_count, epoch)


In [12]:
def train_batch(matches, labels, model, optimizer, criterion):
    output = model(matches)
    loss = criterion(output, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss

In [13]:
def train_log(loss, example_count, epoch):
    wandb.log({"epoch": epoch, "loss": loss}, step=example_count)
    print(f"Loss after {str(example_count).zfill(5)} examples: {loss:.3f}")

In [14]:
def test(model, test_loader):
    model.eval()

    # Run the model on some test examples
    with torch.no_grad():
        correct, total = 0, 0
        for matches, labels in test_loader:
            matches, labels = matches.to(device), labels.to(device)
            outputs = model(matches)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f"Accuracy of the model on the {total} " +
              f"test matches: {correct / total:%}")

        wandb.log({"test_accuracy": correct / total})

In [15]:
model = model_pipeline(config)

NeuralNetwork(
  (dropout): Dropout(p=0.2, inplace=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=229, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=2, bias=True)
    (5): ReLU()
  )
)


  1%|          | 1/100 [00:01<02:45,  1.67s/it]

Loss after 01536 examples: 0.735
Loss after 03136 examples: 0.659


  2%|▏         | 2/100 [00:01<01:19,  1.23it/s]

Loss after 04732 examples: 0.689
Loss after 06332 examples: 0.497
Loss after 07928 examples: 0.484


  3%|▎         | 3/100 [00:02<00:52,  1.85it/s]

Loss after 09528 examples: 0.448
Loss after 11124 examples: 0.393
Loss after 12724 examples: 0.404


  5%|▌         | 5/100 [00:02<00:32,  2.96it/s]

Loss after 14320 examples: 0.404
Loss after 15920 examples: 0.469
Loss after 17516 examples: 0.416


  6%|▌         | 6/100 [00:02<00:27,  3.38it/s]

Loss after 19116 examples: 0.396
Loss after 20712 examples: 0.234


  7%|▋         | 7/100 [00:02<00:26,  3.56it/s]

Loss after 22312 examples: 0.448
Loss after 23908 examples: 0.258


  8%|▊         | 8/100 [00:03<00:23,  3.87it/s]

Loss after 25508 examples: 0.298
Loss after 27104 examples: 0.302
Loss after 28704 examples: 0.317


  9%|▉         | 9/100 [00:03<00:21,  4.14it/s]

Loss after 30304 examples: 0.389
Loss after 31900 examples: 0.233
Loss after 33500 examples: 0.206


 11%|█         | 11/100 [00:03<00:20,  4.44it/s]

Loss after 35096 examples: 0.276
Loss after 36696 examples: 0.284


 12%|█▏        | 12/100 [00:04<00:19,  4.44it/s]

Loss after 38292 examples: 0.315
Loss after 39892 examples: 0.253


 13%|█▎        | 13/100 [00:04<00:19,  4.53it/s]

Loss after 41488 examples: 0.245
Loss after 43088 examples: 0.291
Loss after 44684 examples: 0.202


 14%|█▍        | 14/100 [00:04<00:18,  4.55it/s]

Loss after 46284 examples: 0.233
Loss after 47880 examples: 0.249
Loss after 49480 examples: 0.290


 16%|█▌        | 16/100 [00:04<00:17,  4.72it/s]

Loss after 51076 examples: 0.201
Loss after 52676 examples: 0.231
Loss after 54272 examples: 0.159
Loss after 55872 examples: 0.218
Loss after 57472 examples: 0.231


 18%|█▊        | 18/100 [00:05<00:17,  4.77it/s]

Loss after 59068 examples: 0.197
Loss after 60668 examples: 0.261
Loss after 62264 examples: 0.201


 19%|█▉        | 19/100 [00:05<00:17,  4.75it/s]

Loss after 63864 examples: 0.197
Loss after 65460 examples: 0.186
Loss after 67060 examples: 0.133


 21%|██        | 21/100 [00:05<00:16,  4.75it/s]

Loss after 68656 examples: 0.143
Loss after 70256 examples: 0.322
Loss after 71852 examples: 0.148


 22%|██▏       | 22/100 [00:06<00:16,  4.79it/s]

Loss after 73452 examples: 0.128
Loss after 75048 examples: 0.212
Loss after 76648 examples: 0.251


 24%|██▍       | 24/100 [00:06<00:15,  4.79it/s]

Loss after 78244 examples: 0.172
Loss after 79844 examples: 0.125


 25%|██▌       | 25/100 [00:06<00:15,  4.82it/s]

Loss after 81440 examples: 0.135
Loss after 83040 examples: 0.248
Loss after 84640 examples: 0.340


 26%|██▌       | 26/100 [00:06<00:15,  4.82it/s]

Loss after 86236 examples: 0.158
Loss after 87836 examples: 0.179
Loss after 89432 examples: 0.175


 27%|██▋       | 27/100 [00:07<00:15,  4.75it/s]

Loss after 91032 examples: 0.181
Loss after 92628 examples: 0.080
Loss after 94228 examples: 0.149


 29%|██▉       | 29/100 [00:07<00:14,  4.80it/s]

Loss after 95824 examples: 0.173
Loss after 97424 examples: 0.095
Loss after 99020 examples: 0.147


 30%|███       | 30/100 [00:07<00:14,  4.73it/s]

Loss after 100620 examples: 0.114
Loss after 102216 examples: 0.159
Loss after 103816 examples: 0.075


 31%|███       | 31/100 [00:08<00:14,  4.74it/s]

Loss after 105412 examples: 0.062
Loss after 107012 examples: 0.140


 33%|███▎      | 33/100 [00:08<00:13,  4.80it/s]

Loss after 108608 examples: 0.082
Loss after 110208 examples: 0.096
Loss after 111804 examples: 0.159


 34%|███▍      | 34/100 [00:08<00:13,  4.78it/s]

Loss after 113404 examples: 0.213
Loss after 115004 examples: 0.110
Loss after 116600 examples: 0.098


 35%|███▌      | 35/100 [00:08<00:13,  4.80it/s]

Loss after 118200 examples: 0.160
Loss after 119796 examples: 0.118
Loss after 121396 examples: 0.119


 37%|███▋      | 37/100 [00:09<00:13,  4.79it/s]

Loss after 122992 examples: 0.089
Loss after 124592 examples: 0.132
Loss after 126188 examples: 0.080


 38%|███▊      | 38/100 [00:09<00:12,  4.81it/s]

Loss after 127788 examples: 0.149
Loss after 129384 examples: 0.124
Loss after 130984 examples: 0.369


 39%|███▉      | 39/100 [00:09<00:12,  4.80it/s]

Loss after 132580 examples: 0.071
Loss after 134180 examples: 0.110


 40%|████      | 40/100 [00:09<00:12,  4.68it/s]

Loss after 135776 examples: 0.092
Loss after 137376 examples: 0.091


 41%|████      | 41/100 [00:10<00:12,  4.72it/s]

Loss after 138972 examples: 0.108
Loss after 140572 examples: 0.205
Loss after 142172 examples: 0.271


 43%|████▎     | 43/100 [00:10<00:11,  4.75it/s]

Loss after 143768 examples: 0.103
Loss after 145368 examples: 0.117
Loss after 146964 examples: 0.237


 44%|████▍     | 44/100 [00:10<00:11,  4.69it/s]

Loss after 148564 examples: 0.040
Loss after 150160 examples: 0.057
Loss after 151760 examples: 0.181


 46%|████▌     | 46/100 [00:11<00:11,  4.75it/s]

Loss after 153356 examples: 0.157
Loss after 154956 examples: 0.046
Loss after 156552 examples: 0.033


 47%|████▋     | 47/100 [00:11<00:11,  4.67it/s]

Loss after 158152 examples: 0.089
Loss after 159748 examples: 0.072
Loss after 161348 examples: 0.109


 48%|████▊     | 48/100 [00:11<00:11,  4.71it/s]

Loss after 162944 examples: 0.072
Loss after 164544 examples: 0.162


 50%|█████     | 50/100 [00:12<00:10,  4.71it/s]

Loss after 166140 examples: 0.062
Loss after 167740 examples: 0.090
Loss after 169340 examples: 0.046


 51%|█████     | 51/100 [00:12<00:10,  4.73it/s]

Loss after 170936 examples: 0.079
Loss after 172536 examples: 0.109
Loss after 174132 examples: 0.081


 52%|█████▏    | 52/100 [00:12<00:10,  4.78it/s]

Loss after 175732 examples: 0.050
Loss after 177328 examples: 0.224
Loss after 178928 examples: 0.138


 54%|█████▍    | 54/100 [00:12<00:09,  4.78it/s]

Loss after 180524 examples: 0.206
Loss after 182124 examples: 0.169
Loss after 183720 examples: 0.219


 55%|█████▌    | 55/100 [00:13<00:09,  4.79it/s]

Loss after 185320 examples: 0.283
Loss after 186916 examples: 0.145
Loss after 188516 examples: 0.262


 56%|█████▌    | 56/100 [00:13<00:09,  4.82it/s]

Loss after 190112 examples: 0.117
Loss after 191712 examples: 0.123


 57%|█████▋    | 57/100 [00:13<00:09,  4.75it/s]

Loss after 193308 examples: 0.084
Loss after 194908 examples: 0.188


 58%|█████▊    | 58/100 [00:13<00:08,  4.76it/s]

Loss after 196504 examples: 0.105
Loss after 198104 examples: 0.259
Loss after 199704 examples: 0.088


 60%|██████    | 60/100 [00:14<00:08,  4.75it/s]

Loss after 201300 examples: 0.097
Loss after 202900 examples: 0.079
Loss after 204496 examples: 0.117


 61%|██████    | 61/100 [00:14<00:08,  4.79it/s]

Loss after 206096 examples: 0.030
Loss after 207692 examples: 0.061
Loss after 209292 examples: 0.109


 63%|██████▎   | 63/100 [00:14<00:07,  4.85it/s]

Loss after 210888 examples: 0.067
Loss after 212488 examples: 0.135
Loss after 214084 examples: 0.048


 64%|██████▍   | 64/100 [00:14<00:07,  4.78it/s]

Loss after 215684 examples: 0.078
Loss after 217280 examples: 0.098
Loss after 218880 examples: 0.093


 65%|██████▌   | 65/100 [00:15<00:07,  4.84it/s]

Loss after 220476 examples: 0.401
Loss after 222076 examples: 0.078


 66%|██████▌   | 66/100 [00:15<00:07,  4.86it/s]

Loss after 223672 examples: 0.081
Loss after 225272 examples: 0.043


 67%|██████▋   | 67/100 [00:15<00:06,  4.76it/s]

Loss after 226872 examples: 0.023
Loss after 228468 examples: 0.102
Loss after 230068 examples: 0.102


 69%|██████▉   | 69/100 [00:15<00:06,  4.82it/s]

Loss after 231664 examples: 0.067
Loss after 233264 examples: 0.041
Loss after 234860 examples: 0.096


 70%|███████   | 70/100 [00:16<00:06,  4.75it/s]

Loss after 236460 examples: 0.084
Loss after 238056 examples: 0.081


 71%|███████   | 71/100 [00:16<00:06,  4.71it/s]

Loss after 239656 examples: 0.035
Loss after 241252 examples: 0.070
Loss after 242852 examples: 0.061


 73%|███████▎  | 73/100 [00:16<00:05,  4.82it/s]

Loss after 244448 examples: 0.055
Loss after 246048 examples: 0.426
Loss after 247644 examples: 0.161


 74%|███████▍  | 74/100 [00:17<00:05,  4.72it/s]

Loss after 249244 examples: 0.181
Loss after 250840 examples: 0.114


 75%|███████▌  | 75/100 [00:17<00:05,  4.69it/s]

Loss after 252440 examples: 0.136
Loss after 254040 examples: 0.209
Loss after 255636 examples: 0.209


 76%|███████▌  | 76/100 [00:17<00:05,  4.69it/s]

Loss after 257236 examples: 0.135
Loss after 258832 examples: 0.069
Loss after 260432 examples: 0.164


 78%|███████▊  | 78/100 [00:17<00:04,  4.76it/s]

Loss after 262028 examples: 0.102
Loss after 263628 examples: 0.121
Loss after 265224 examples: 0.078


 79%|███████▉  | 79/100 [00:18<00:04,  4.81it/s]

Loss after 266824 examples: 0.063
Loss after 268420 examples: 0.083
Loss after 270020 examples: 0.056


 81%|████████  | 81/100 [00:18<00:03,  4.76it/s]

Loss after 271616 examples: 0.107
Loss after 273216 examples: 0.091


 82%|████████▏ | 82/100 [00:18<00:03,  4.78it/s]

Loss after 274812 examples: 0.046
Loss after 276412 examples: 0.132
Loss after 278008 examples: 0.060
Loss after 279608 examples: 0.200


 83%|████████▎ | 83/100 [00:18<00:03,  4.82it/s]

Loss after 281204 examples: 0.297
Loss after 282804 examples: 0.205


 84%|████████▍ | 84/100 [00:19<00:03,  4.72it/s]

Loss after 284404 examples: 0.127
Loss after 286000 examples: 0.078
Loss after 287600 examples: 0.092


 86%|████████▌ | 86/100 [00:19<00:02,  4.74it/s]

Loss after 289196 examples: 0.094
Loss after 290796 examples: 0.061
Loss after 292392 examples: 0.042


 87%|████████▋ | 87/100 [00:19<00:02,  4.66it/s]

Loss after 293992 examples: 0.111
Loss after 295588 examples: 0.078
Loss after 297188 examples: 0.106


 89%|████████▉ | 89/100 [00:20<00:02,  4.62it/s]

Loss after 298784 examples: 0.087
Loss after 300384 examples: 0.088
Loss after 301980 examples: 0.092


 90%|█████████ | 90/100 [00:20<00:02,  4.56it/s]

Loss after 303580 examples: 0.047
Loss after 305176 examples: 0.061
Loss after 306776 examples: 0.019


 92%|█████████▏| 92/100 [00:20<00:01,  4.63it/s]

Loss after 308372 examples: 0.064
Loss after 309972 examples: 0.032
Loss after 311572 examples: 0.067


 93%|█████████▎| 93/100 [00:21<00:01,  4.49it/s]

Loss after 313168 examples: 0.019
Loss after 314768 examples: 0.073


 94%|█████████▍| 94/100 [00:21<00:01,  4.54it/s]

Loss after 316364 examples: 0.059
Loss after 317964 examples: 0.023


 95%|█████████▌| 95/100 [00:21<00:01,  4.56it/s]

Loss after 319560 examples: 0.098
Loss after 321160 examples: 0.272


 96%|█████████▌| 96/100 [00:21<00:00,  4.58it/s]

Loss after 322756 examples: 0.054
Loss after 324356 examples: 0.087
Loss after 325952 examples: 0.091


 97%|█████████▋| 97/100 [00:21<00:00,  4.58it/s]

Loss after 327552 examples: 0.160
Loss after 329148 examples: 0.112
Loss after 330748 examples: 0.030


 98%|█████████▊| 98/100 [00:22<00:00,  4.67it/s]

Loss after 332344 examples: 0.078
Loss after 333944 examples: 0.099


 99%|█████████▉| 99/100 [00:22<00:00,  4.70it/s]

Loss after 335540 examples: 0.047
Loss after 337140 examples: 0.101


100%|██████████| 100/100 [00:22<00:00,  4.42it/s]

Loss after 338740 examples: 0.039
Accuracy of the model on the 847 test matches: 68.240850%





0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
loss,█▆▅▄▄▃▂▄▄▂▂▃▁▂▂▅▃▃▁▁▂▃▂▄▁▂▁▁▁▂▂▂▁▂▂▁▁▂▂▁
test_accuracy,▁

0,1
epoch,99.0
loss,0.03862
test_accuracy,0.68241
