In [72]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torch.autograd import Variable
from torch import Tensor, optim, nn
import wandb
from tqdm import tqdm

wandb.login()

True

In [73]:
config = {
    "learning_rate": 0.02,
    "architecture": "GRU",
    "dataset": "timeline_1.0",
    "epochs": 200,
    "classes": 2,
    "batch_size": 32,
    "num_layers": 2,
    "hidden_size": 64,
    "dropout_prob": 0,
    "input_size": 381,
    "output_size": 2,
    "optimizer": "Adam",
    "loss": "CrossEntropyLoss",
    "activation": "ReLU",
    "initializer": "Xavier",
    "regularization": "L2",
    "regularization_lambda": 0.01,
    "gru_layers": 1,
    "sequence_length": 16,

}

In [74]:
def model_pipeline(hyperparameters):
    with wandb.init(project="leaguify", config=hyperparameters):
        # access all HPs through wandb.config, so logging matches execution!
        config = wandb.config

        # make the model, data, and optimization problem
        model, train_loader, val_loader, criterion, optimizer = make(config)
        print(model)

        # and use them to train the model
        train(model, train_loader, criterion, optimizer, config)

        # and test its final performance
        test(model, val_loader)

    return model

In [75]:
class TimelineDataset(Dataset):
    def __init__(self, data_dir, sequence_length, transform=None, target_transform=None):
        self.data = torch.tensor(np.load(data_dir)[:, :-1], dtype=torch.float32, device=device)
        self.labels = torch.tensor(np.load(data_dir)[:, -1], dtype=torch.int64, device=device)
        self.sequence_length = sequence_length
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data) - self.sequence_length

    def __getitem__(self, idx):
        sample = self.data[idx:idx + self.sequence_length, :]
        label = self.labels[idx]
        if self.transform:
            sample = self.transform(sample)
        if self.target_transform:
            label = self.target_transform(label)
        return sample, label

In [76]:
def make(config):
    train, val = get_train_data(sequence_length=config.sequence_length)
    train_loader = make_loader(train, batch_size=config.batch_size)
    val_loader = make_loader(val, batch_size=config.batch_size)

    model = GRU(config.input_size, config.hidden_size, config.classes, config.num_layers, config
                .gru_layers, drop_prob=config.dropout_prob).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

    return model, train_loader, val_loader, criterion, optimizer

In [77]:
def get_train_data(sequence_length=16, val_split=0.2):
    dataset = TimelineDataset('../data/processed/train_timeline.npy', sequence_length)
    train_len = int(len(dataset) * val_split)
    val_len = len(dataset) - train_len
    return torch.utils.data.random_split(dataset, [train_len, val_len])

In [78]:
def get_test_data(sequence_length=16):
    full_dataset = TimelineDataset('../data/processed/test_timeline.npy', sequence_length)
    return full_dataset

In [79]:
def make_loader(dataset, batch_size=64):
    return DataLoader(dataset, batch_size=batch_size, num_workers=0, drop_last=False)

In [80]:
device = (
    "cuda" if torch.cuda.is_available()
    else "cpu"
)
if torch.cuda.is_available():
    print(f'PyTorch version: {torch.__version__}')
    print('*' * 10)
    print(f'_CUDA version: ')
    !nvcc --version
    print('*' * 10)
    print(f'CUDNN version: {torch.backends.cudnn.version()}')
    print(f'Available GPU devices: {torch.cuda.device_count()}')
    print(f'Device Name: {torch.cuda.get_device_name()}')
print(f"Using {device} device")

PyTorch version: 2.1.0
**********
_CUDA version: 
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:09:35_Pacific_Daylight_Time_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0
**********
CUDNN version: 8801
Available GPU devices: 1
Device Name: NVIDIA GeForce RTX 2080
Using cuda device


In [81]:
class GRU(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers, gru_layers, drop_prob=0.2):
        super(GRU, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.gru = nn.GRU(input_dim, hidden_dim, gru_layers, batch_first=True, dropout=drop_prob)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x, h=None):
        out, h = self.gru(x, h)
        out = self.fc(self.relu(out[:, -1]))
        return out, h

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = weight.new(1, batch_size, self.hidden_dim).zero_().to(device)
        return hidden

In [82]:
def train(model, loader, criterion, optimizer, config):
    wandb.watch(model, criterion, log='all', log_freq=10)

    total_batches = len(loader) * config.epochs
    example_count = 0
    batch_count = 0
    for epoch in tqdm(range(config.epochs)):
        h = model.init_hidden(config.batch_size)
        for _, (matches, labels) in enumerate(loader):
            output, h = model(matches)  # hidden state is not passed to re-init at each batch
            loss = criterion(output, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            example_count += len(matches)
            batch_count += 1
            if (batch_count + 1) % 25 == 0:
                train_log(loss, example_count, epoch)

In [83]:
def train_log(loss, example_count, epoch):
    wandb.log({"epoch": epoch, "loss": loss}, step=example_count)
    print(f"Loss after {str(example_count).zfill(5)} examples: {loss:.3f}")

In [84]:
train_data, val_data = get_train_data(sequence_length=16)
print(f'train_data: {len(train_data)}')
for matches, labels in make_loader(val_data, batch_size=1):
    if matches.shape[1] != 16:
        print(f'matches: {matches.shape}, labels: {labels.shape}')

train_data: 2556


In [85]:
def test(model, test_loader):
    # Run the model on some test examples
    with torch.no_grad():
        correct, total = 0, 0
        for matches, labels in test_loader:
            print(f'matches: {matches.shape}, labels: {labels.shape}')
            matches, labels = matches.to(device), labels.to(device)
            model.eval()
            output, h = model(matches)
            _, predicted = torch.max(output, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f"Accuracy of the model on the {total} " +
              f"test matches: {correct / total:%}")

        wandb.log({"test_accuracy": correct / total})

In [None]:
model = model_pipeline(config)

GRU(
  (gru): GRU(381, 64, batch_first=True)
  (fc): Linear(in_features=64, out_features=2, bias=True)
  (relu): ReLU()
)


  0%|          | 0/200 [00:00<?, ?it/s]

Loss after 00768 examples: 0.711
Loss after 01568 examples: 0.652
Loss after 02368 examples: 0.745


  0%|          | 1/200 [00:00<00:59,  3.37it/s]

Loss after 03164 examples: 0.618
Loss after 03964 examples: 0.653
Loss after 04764 examples: 0.663

  1%|          | 2/200 [00:00<01:01,  3.21it/s]


Loss after 05560 examples: 0.489
Loss after 06360 examples: 0.494
Loss after 07160 examples: 0.641


  2%|▏         | 3/200 [00:00<01:06,  2.98it/s]

Loss after 07956 examples: 0.614
Loss after 08756 examples: 0.552
Loss after 09556 examples: 0.584


  2%|▏         | 4/200 [00:01<01:03,  3.10it/s]

Loss after 10352 examples: 0.460
Loss after 11152 examples: 0.792
Loss after 11952 examples: 0.559


  2%|▎         | 5/200 [00:01<01:00,  3.21it/s]

Loss after 12752 examples: 0.543
Loss after 13548 examples: 0.318
Loss after 14348 examples: 0.676


  3%|▎         | 6/200 [00:01<01:02,  3.11it/s]

Loss after 15148 examples: 0.576
Loss after 15944 examples: 0.361


  4%|▎         | 7/200 [00:02<01:06,  2.90it/s]

Loss after 16744 examples: 0.580
Loss after 17544 examples: 0.444
Loss after 18340 examples: 0.446
Loss after 19140 examples: 0.565


  4%|▍         | 8/200 [00:02<01:08,  2.79it/s]

Loss after 19940 examples: 0.370
Loss after 20736 examples: 0.418
Loss after 21536 examples: 0.318
Loss after 22336 examples: 0.372


  4%|▍         | 9/200 [00:03<01:07,  2.83it/s]

Loss after 23132 examples: 0.397
Loss after 23932 examples: 0.558


  5%|▌         | 10/200 [00:03<01:05,  2.89it/s]

Loss after 24732 examples: 0.247
Loss after 25532 examples: 0.466
Loss after 26328 examples: 0.325
Loss after 27128 examples: 0.484
Loss after 27928 examples: 0.485


  6%|▌         | 11/200 [00:03<01:03,  2.97it/s]

Loss after 28724 examples: 0.380
Loss after 29524 examples: 0.457
Loss after 30324 examples: 0.305


  6%|▌         | 12/200 [00:03<01:01,  3.07it/s]

Loss after 31120 examples: 0.328
Loss after 31920 examples: 0.242
Loss after 32720 examples: 0.252


  6%|▋         | 13/200 [00:04<00:59,  3.13it/s]

Loss after 33516 examples: 0.381
Loss after 34316 examples: 0.391


  7%|▋         | 14/200 [00:04<01:00,  3.07it/s]

Loss after 35116 examples: 0.429
Loss after 35912 examples: 0.207
Loss after 36712 examples: 0.224
Loss after 37512 examples: 0.152
Loss after 38312 examples: 0.315


  8%|▊         | 15/200 [00:04<01:00,  3.06it/s]

Loss after 39108 examples: 0.254
Loss after 39908 examples: 0.442
Loss after 40708 examples: 0.366


  8%|▊         | 16/200 [00:05<00:57,  3.20it/s]

Loss after 41504 examples: 0.304
Loss after 42304 examples: 0.265
Loss after 43104 examples: 0.347


  8%|▊         | 17/200 [00:05<00:56,  3.22it/s]

Loss after 43900 examples: 0.281
Loss after 44700 examples: 0.225
Loss after 45500 examples: 0.269


  9%|▉         | 18/200 [00:05<00:56,  3.20it/s]

Loss after 46296 examples: 0.274
Loss after 47096 examples: 0.254
Loss after 47896 examples: 0.376


 10%|▉         | 19/200 [00:06<00:55,  3.23it/s]

Loss after 48692 examples: 0.207
Loss after 49492 examples: 0.178
Loss after 50292 examples: 0.101


 10%|█         | 20/200 [00:06<00:55,  3.22it/s]

Loss after 51092 examples: 0.409
Loss after 51888 examples: 0.317
Loss after 52688 examples: 0.370


 10%|█         | 21/200 [00:06<00:55,  3.25it/s]

Loss after 53488 examples: 0.441
Loss after 54284 examples: 0.245


 11%|█         | 22/200 [00:07<00:55,  3.19it/s]

Loss after 55084 examples: 0.328
Loss after 55884 examples: 0.244
Loss after 56680 examples: 0.345
Loss after 57480 examples: 0.136
Loss after 58280 examples: 0.293


 12%|█▏        | 23/200 [00:07<00:55,  3.16it/s]

Loss after 59076 examples: 0.410
Loss after 59876 examples: 0.278
Loss after 60676 examples: 0.286


 12%|█▏        | 24/200 [00:07<00:54,  3.24it/s]

Loss after 61472 examples: 0.068
Loss after 62272 examples: 0.186
Loss after 63072 examples: 0.118


 12%|█▎        | 25/200 [00:08<00:53,  3.30it/s]

Loss after 63872 examples: 0.233
Loss after 64668 examples: 0.147
Loss after 65468 examples: 0.262


 13%|█▎        | 26/200 [00:08<00:51,  3.36it/s]

Loss after 66268 examples: 0.241
Loss after 67064 examples: 0.170
Loss after 67864 examples: 0.181


 14%|█▎        | 27/200 [00:08<00:50,  3.39it/s]

Loss after 68664 examples: 0.177
Loss after 69460 examples: 0.299
Loss after 70260 examples: 0.144


 14%|█▍        | 28/200 [00:08<00:50,  3.39it/s]

Loss after 71060 examples: 0.267
Loss after 71856 examples: 0.331
Loss after 72656 examples: 0.222


 14%|█▍        | 29/200 [00:09<00:50,  3.40it/s]

Loss after 73456 examples: 0.231
Loss after 74252 examples: 0.094
Loss after 75052 examples: 0.120


 15%|█▌        | 30/200 [00:09<00:49,  3.43it/s]

Loss after 75852 examples: 0.097
Loss after 76652 examples: 0.138
Loss after 77448 examples: 0.096


 16%|█▌        | 31/200 [00:09<00:48,  3.45it/s]

Loss after 78248 examples: 0.285
Loss after 79048 examples: 0.244
Loss after 79844 examples: 0.208


 16%|█▌        | 32/200 [00:10<00:48,  3.43it/s]

Loss after 80644 examples: 0.153
Loss after 81444 examples: 0.226
Loss after 82240 examples: 0.366


 16%|█▋        | 33/200 [00:10<00:48,  3.44it/s]

Loss after 83040 examples: 0.285
Loss after 83840 examples: 0.050
Loss after 84636 examples: 0.087


 17%|█▋        | 34/200 [00:10<00:48,  3.43it/s]

Loss after 85436 examples: 0.079
Loss after 86236 examples: 0.252
Loss after 87032 examples: 0.050
Loss after 87832 examples: 0.219
Loss after 88632 examples: 0.052


 18%|█▊        | 35/200 [00:10<00:50,  3.24it/s]

Loss after 89432 examples: 0.269
Loss after 90228 examples: 0.168
Loss after 91028 examples: 0.172


 18%|█▊        | 36/200 [00:11<00:51,  3.19it/s]

Loss after 91828 examples: 0.157
Loss after 92624 examples: 0.121


 18%|█▊        | 37/200 [00:11<00:52,  3.10it/s]

Loss after 93424 examples: 0.072
Loss after 94224 examples: 0.105
Loss after 95020 examples: 0.310


 19%|█▉        | 38/200 [00:11<00:51,  3.15it/s]

Loss after 95820 examples: 0.150
Loss after 96620 examples: 0.026
Loss after 97416 examples: 0.113


 20%|█▉        | 39/200 [00:12<00:50,  3.17it/s]

Loss after 98216 examples: 0.047
Loss after 99016 examples: 0.285
Loss after 99812 examples: 0.081
Loss after 100612 examples: 0.128
Loss after 101412 examples: 0.044


 20%|██        | 40/200 [00:12<00:54,  2.96it/s]

Loss after 102212 examples: 0.237
Loss after 103008 examples: 0.174


 20%|██        | 41/200 [00:12<00:54,  2.94it/s]

Loss after 103808 examples: 0.223
Loss after 104608 examples: 0.196
Loss after 105404 examples: 0.242


 21%|██        | 42/200 [00:13<00:52,  3.02it/s]

Loss after 106204 examples: 0.303
Loss after 107004 examples: 0.166
Loss after 107800 examples: 0.310


 22%|██▏       | 43/200 [00:13<00:50,  3.10it/s]

Loss after 108600 examples: 0.238
Loss after 109400 examples: 0.036
Loss after 110196 examples: 0.206
Loss after 110996 examples: 0.074
Loss after 111796 examples: 0.299


 22%|██▏       | 44/200 [00:13<00:52,  2.99it/s]

Loss after 112592 examples: 0.171
Loss after 113392 examples: 0.259


 22%|██▎       | 45/200 [00:14<00:53,  2.90it/s]

Loss after 114192 examples: 0.089
Loss after 114992 examples: 0.187
Loss after 115788 examples: 0.152
Loss after 116588 examples: 0.213
Loss after 117388 examples: 0.142


 23%|██▎       | 46/200 [00:14<00:52,  2.94it/s]

Loss after 118184 examples: 0.320
Loss after 118984 examples: 0.209
Loss after 119784 examples: 0.218


 24%|██▎       | 47/200 [00:14<00:49,  3.06it/s]

Loss after 120580 examples: 0.237
Loss after 121380 examples: 0.227
Loss after 122180 examples: 0.024


 24%|██▍       | 48/200 [00:15<00:48,  3.11it/s]

Loss after 122976 examples: 0.109
Loss after 123776 examples: 0.039
Loss after 124576 examples: 0.240


 24%|██▍       | 49/200 [00:15<00:46,  3.25it/s]

Loss after 125372 examples: 0.079
Loss after 126172 examples: 0.188
Loss after 126972 examples: 0.062


 25%|██▌       | 50/200 [00:15<00:44,  3.35it/s]

Loss after 127772 examples: 0.200
Loss after 128568 examples: 0.027
Loss after 129368 examples: 0.165


 26%|██▌       | 51/200 [00:16<00:43,  3.40it/s]

Loss after 130168 examples: 0.130
Loss after 130964 examples: 0.031
Loss after 131764 examples: 0.040


 26%|██▌       | 52/200 [00:16<00:44,  3.35it/s]

Loss after 132564 examples: 0.085
Loss after 133360 examples: 0.270
Loss after 134160 examples: 0.208


 26%|██▋       | 53/200 [00:16<00:44,  3.33it/s]

Loss after 134960 examples: 0.045
Loss after 135756 examples: 0.081
Loss after 136556 examples: 0.090


 27%|██▋       | 54/200 [00:17<00:44,  3.25it/s]

Loss after 137356 examples: 0.343
Loss after 138152 examples: 0.027
Loss after 138952 examples: 0.126
Loss after 139752 examples: 0.082


 28%|██▊       | 55/200 [00:17<00:52,  2.74it/s]

Loss after 140552 examples: 0.176
Loss after 141348 examples: 0.124


 28%|██▊       | 56/200 [00:17<00:53,  2.71it/s]

Loss after 142148 examples: 0.206
Loss after 142948 examples: 0.109
Loss after 143744 examples: 0.068
Loss after 144544 examples: 0.055
Loss after 145344 examples: 0.166


 28%|██▊       | 57/200 [00:18<00:50,  2.81it/s]

Loss after 146140 examples: 0.455
Loss after 146940 examples: 0.247
Loss after 147740 examples: 0.057


 29%|██▉       | 58/200 [00:18<00:47,  2.97it/s]

Loss after 148536 examples: 0.172
Loss after 149336 examples: 0.072


 30%|██▉       | 59/200 [00:18<00:48,  2.94it/s]

Loss after 150136 examples: 0.283
Loss after 150932 examples: 0.028
Loss after 151732 examples: 0.135
Loss after 152532 examples: 0.086


 30%|███       | 60/200 [00:19<00:47,  2.94it/s]

Loss after 153332 examples: 0.196
Loss after 154128 examples: 0.081
Loss after 154928 examples: 0.183


 30%|███       | 61/200 [00:19<00:45,  3.02it/s]

Loss after 155728 examples: 0.143
Loss after 156524 examples: 0.125
Loss after 157324 examples: 0.093


 31%|███       | 62/200 [00:19<00:45,  3.07it/s]

Loss after 158124 examples: 0.092
Loss after 158920 examples: 0.377
Loss after 159720 examples: 0.200


 32%|███▏      | 63/200 [00:20<00:43,  3.12it/s]

Loss after 160520 examples: 0.031
Loss after 161316 examples: 0.059
Loss after 162116 examples: 0.027


 32%|███▏      | 64/200 [00:20<00:43,  3.11it/s]

Loss after 162916 examples: 0.274
Loss after 163712 examples: 0.139
Loss after 164512 examples: 0.199


 32%|███▎      | 65/200 [00:20<00:42,  3.17it/s]

Loss after 165312 examples: 0.056
Loss after 166112 examples: 0.239
Loss after 166908 examples: 0.097


 33%|███▎      | 66/200 [00:21<00:42,  3.18it/s]

Loss after 167708 examples: 0.162
Loss after 168508 examples: 0.209
Loss after 169304 examples: 0.218


 34%|███▎      | 67/200 [00:21<00:41,  3.19it/s]

Loss after 170104 examples: 0.106
Loss after 170904 examples: 0.187
Loss after 171700 examples: 0.265


 34%|███▍      | 68/200 [00:21<00:41,  3.19it/s]

Loss after 172500 examples: 0.145
Loss after 173300 examples: 0.011
Loss after 174096 examples: 0.206


 34%|███▍      | 69/200 [00:22<00:40,  3.22it/s]

Loss after 174896 examples: 0.038
Loss after 175696 examples: 0.321
Loss after 176492 examples: 0.037
Loss after 177292 examples: 0.026
Loss after 178092 examples: 0.043
Loss after 178892 examples: 0.122


 35%|███▌      | 70/200 [00:22<00:40,  3.23it/s]

Loss after 179688 examples: 0.154
Loss after 180488 examples: 0.147
Loss after 181288 examples: 0.193


 36%|███▌      | 71/200 [00:22<00:38,  3.32it/s]

Loss after 182084 examples: 0.019
Loss after 182884 examples: 0.039
Loss after 183684 examples: 0.060


 36%|███▌      | 72/200 [00:22<00:38,  3.35it/s]

Loss after 184480 examples: 0.302
Loss after 185280 examples: 0.119
Loss after 186080 examples: 0.033


 36%|███▋      | 73/200 [00:23<00:37,  3.41it/s]

Loss after 186876 examples: 0.143
Loss after 187676 examples: 0.039
Loss after 188476 examples: 0.355


 37%|███▋      | 74/200 [00:23<00:36,  3.42it/s]

Loss after 189272 examples: 0.025
Loss after 190072 examples: 0.017
Loss after 190872 examples: 0.132


 38%|███▊      | 75/200 [00:23<00:36,  3.44it/s]

Loss after 191672 examples: 0.245
Loss after 192468 examples: 0.037
Loss after 193268 examples: 0.130


 38%|███▊      | 76/200 [00:24<00:36,  3.38it/s]

Loss after 194068 examples: 0.156
Loss after 194864 examples: 0.217
Loss after 195664 examples: 0.059


 38%|███▊      | 77/200 [00:24<00:36,  3.39it/s]

Loss after 196464 examples: 0.097
Loss after 197260 examples: 0.326
Loss after 198060 examples: 0.240


 39%|███▉      | 78/200 [00:24<00:35,  3.44it/s]

Loss after 198860 examples: 0.062
Loss after 199656 examples: 0.117
Loss after 200456 examples: 0.039


 40%|███▉      | 79/200 [00:24<00:36,  3.35it/s]

Loss after 201256 examples: 0.478
Loss after 202052 examples: 0.258
Loss after 202852 examples: 0.301
Loss after 203652 examples: 0.168


 40%|████      | 80/200 [00:25<00:37,  3.20it/s]

Loss after 204452 examples: 0.502
Loss after 205248 examples: 0.180


 40%|████      | 81/200 [00:25<00:36,  3.25it/s]

Loss after 206048 examples: 0.152
Loss after 206848 examples: 0.290
Loss after 207644 examples: 0.055


 41%|████      | 82/200 [00:25<00:36,  3.28it/s]

Loss after 208444 examples: 0.069
Loss after 209244 examples: 0.075
Loss after 210040 examples: 0.042


 42%|████▏     | 83/200 [00:26<00:35,  3.32it/s]

Loss after 210840 examples: 0.139
Loss after 211640 examples: 0.190
Loss after 212436 examples: 0.048


 42%|████▏     | 84/200 [00:26<00:34,  3.36it/s]

Loss after 213236 examples: 0.037
Loss after 214036 examples: 0.360
Loss after 214832 examples: 0.041
Loss after 215632 examples: 0.027
Loss after 216432 examples: 0.110
Loss after 217232 examples: 0.176


 42%|████▎     | 85/200 [00:26<00:33,  3.39it/s]

Loss after 218028 examples: 0.027
Loss after 218828 examples: 0.245
Loss after 219628 examples: 0.112


 43%|████▎     | 86/200 [00:27<00:33,  3.42it/s]

Loss after 220424 examples: 0.246
Loss after 221224 examples: 0.030
Loss after 222024 examples: 0.130


 44%|████▎     | 87/200 [00:27<00:33,  3.41it/s]

Loss after 222820 examples: 0.142
Loss after 223620 examples: 0.209
Loss after 224420 examples: 0.029


 44%|████▍     | 88/200 [00:27<00:32,  3.44it/s]

Loss after 225216 examples: 0.065
Loss after 226016 examples: 0.150
Loss after 226816 examples: 0.324


 44%|████▍     | 89/200 [00:27<00:32,  3.45it/s]

Loss after 227612 examples: 0.032
Loss after 228412 examples: 0.044
Loss after 229212 examples: 0.045


 45%|████▌     | 90/200 [00:28<00:31,  3.44it/s]

Loss after 230012 examples: 0.146
Loss after 230808 examples: 0.093
Loss after 231608 examples: 0.220


 46%|████▌     | 91/200 [00:28<00:31,  3.48it/s]

Loss after 232408 examples: 0.186
Loss after 233204 examples: 0.087
Loss after 234004 examples: 0.054


 46%|████▌     | 92/200 [00:28<00:31,  3.47it/s]

Loss after 234804 examples: 0.082
Loss after 235600 examples: 0.045
Loss after 236400 examples: 0.128


 46%|████▋     | 93/200 [00:29<00:30,  3.47it/s]

Loss after 237200 examples: 0.044
Loss after 237996 examples: 0.051
Loss after 238796 examples: 0.053


 47%|████▋     | 94/200 [00:29<00:30,  3.48it/s]

Loss after 239596 examples: 0.335
Loss after 240392 examples: 0.016
Loss after 241192 examples: 0.016


 48%|████▊     | 95/200 [00:29<00:30,  3.42it/s]

Loss after 241992 examples: 0.038
Loss after 242792 examples: 0.138
Loss after 243588 examples: 0.013


 48%|████▊     | 96/200 [00:29<00:30,  3.41it/s]

Loss after 244388 examples: 0.215
Loss after 245188 examples: 0.147
Loss after 245984 examples: 0.032


 48%|████▊     | 97/200 [00:30<00:30,  3.40it/s]

Loss after 246784 examples: 0.024
Loss after 247584 examples: 0.087
Loss after 248380 examples: 0.037


 49%|████▉     | 98/200 [00:30<00:30,  3.39it/s]

Loss after 249180 examples: 0.135
Loss after 249980 examples: 0.023
Loss after 250776 examples: 0.059


 50%|████▉     | 99/200 [00:30<00:29,  3.40it/s]

Loss after 251576 examples: 0.017
Loss after 252376 examples: 0.276
Loss after 253172 examples: 0.011
Loss after 253972 examples: 0.020
Loss after 254772 examples: 0.069
Loss after 255572 examples: 0.120


 50%|█████     | 100/200 [00:31<00:29,  3.39it/s]

Loss after 256368 examples: 0.010
Loss after 257168 examples: 0.176
Loss after 257968 examples: 0.103


 50%|█████     | 101/200 [00:31<00:29,  3.41it/s]

Loss after 258764 examples: 0.013
Loss after 259564 examples: 0.060
Loss after 260364 examples: 0.136


 51%|█████     | 102/200 [00:31<00:28,  3.48it/s]

Loss after 261160 examples: 0.387
Loss after 261960 examples: 0.146
Loss after 262760 examples: 0.059


 52%|█████▏    | 103/200 [00:31<00:27,  3.50it/s]

Loss after 263556 examples: 0.094
Loss after 264356 examples: 0.194
Loss after 265156 examples: 0.308


 52%|█████▏    | 104/200 [00:32<00:27,  3.49it/s]

Loss after 265952 examples: 0.134
Loss after 266752 examples: 0.102
Loss after 267552 examples: 0.027


 52%|█████▎    | 105/200 [00:32<00:27,  3.48it/s]

Loss after 268352 examples: 0.461
Loss after 269148 examples: 0.166
Loss after 269948 examples: 0.252


 53%|█████▎    | 106/200 [00:32<00:27,  3.46it/s]

Loss after 270748 examples: 0.191
Loss after 271544 examples: 0.157
Loss after 272344 examples: 0.043


 54%|█████▎    | 107/200 [00:33<00:26,  3.46it/s]

Loss after 273144 examples: 0.248
Loss after 273940 examples: 0.153
Loss after 274740 examples: 0.115


 54%|█████▍    | 108/200 [00:33<00:26,  3.44it/s]

Loss after 275540 examples: 0.072
Loss after 276336 examples: 0.100


 55%|█████▍    | 109/200 [00:33<00:26,  3.41it/s]

Loss after 277136 examples: 0.088
Loss after 277936 examples: 0.427
Loss after 278732 examples: 0.013
Loss after 279532 examples: 0.057
Loss after 280332 examples: 0.090
Loss after 281132 examples: 0.192


 55%|█████▌    | 110/200 [00:34<00:26,  3.42it/s]

Loss after 281928 examples: 0.023
Loss after 282728 examples: 0.188
Loss after 283528 examples: 0.266


 56%|█████▌    | 111/200 [00:34<00:25,  3.45it/s]

Loss after 284324 examples: 0.176
Loss after 285124 examples: 0.056
Loss after 285924 examples: 0.083


 56%|█████▌    | 112/200 [00:34<00:25,  3.46it/s]

Loss after 286720 examples: 0.129
Loss after 287520 examples: 0.202
Loss after 288320 examples: 0.059


 56%|█████▋    | 113/200 [00:34<00:25,  3.40it/s]

Loss after 289116 examples: 0.083
Loss after 289916 examples: 0.036
Loss after 290716 examples: 0.480


 57%|█████▋    | 114/200 [00:35<00:25,  3.31it/s]

Loss after 291512 examples: 0.030
Loss after 292312 examples: 0.047


 57%|█████▊    | 115/200 [00:35<00:26,  3.21it/s]

Loss after 293112 examples: 0.039
Loss after 293912 examples: 0.142
Loss after 294708 examples: 0.056


 58%|█████▊    | 116/200 [00:35<00:25,  3.25it/s]

Loss after 295508 examples: 0.218
Loss after 296308 examples: 0.247
Loss after 297104 examples: 0.042


 58%|█████▊    | 117/200 [00:36<00:25,  3.24it/s]

Loss after 297904 examples: 0.046
Loss after 298704 examples: 0.067
Loss after 299500 examples: 0.060
Loss after 300300 examples: 0.119
Loss after 301100 examples: 0.023


 59%|█████▉    | 118/200 [00:36<00:26,  3.12it/s]

Loss after 301896 examples: 0.059
Loss after 302696 examples: 0.018


 60%|█████▉    | 119/200 [00:36<00:26,  3.06it/s]

Loss after 303496 examples: 0.323
Loss after 304292 examples: 0.012
Loss after 305092 examples: 0.010
Loss after 305892 examples: 0.022
Loss after 306692 examples: 0.189


 60%|██████    | 120/200 [00:37<00:26,  3.06it/s]

Loss after 307488 examples: 0.028
Loss after 308288 examples: 0.041


 60%|██████    | 121/200 [00:37<00:25,  3.11it/s]

Loss after 309088 examples: 0.185
Loss after 309884 examples: 0.054


 61%|██████    | 122/200 [00:37<00:25,  3.10it/s]

Loss after 310684 examples: 0.041
Loss after 311484 examples: 0.094
Loss after 312280 examples: 0.042
Loss after 313080 examples: 0.131
Loss after 313880 examples: 0.011


 62%|██████▏   | 123/200 [00:38<00:24,  3.09it/s]