In [1]:
%load_ext autoreload
%autoreload 2
import networkx as nx

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import torch
import torch.nn as nn
from torch_geometric.data import DataLoader
import argparse
import numpy as np
import random
import ogb
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from graph_transformer import GT
from utils import pre_process, pre_process_with_summary, get_n_params, get_optimizer
import datetime
from tqdm import tqdm
from tensorboardX import SummaryWriter
import pytz


Data(adamic_edge_attr=[308, 1], alloc_edge_attr=[308, 1], cn_edge_attr=[308, 1], comm_edge_attr=[308, 1], edge_attr=[308, 3], edge_index=[2, 308], hier_label=[32, 4], hsd_edge_attr=[308, 4], jaccard_edge_attr=[308, 1], lap_x=[32, 10], orig_edge_attr=[70, 3], orig_edge_index=[2, 70], sd_edge_attr=[308, 1], x=[32, 9], y=[1, 1])
torch.Size([70, 3])
torch.Size([308, 3])
tensor([[0.0000],
        [0.0000],
        [0.5000],
        [0.0000],
        [0.2500],
        [0.0000],
        [0.2500],
        [0.2500],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.5000],
        [0.0000],
        [0.0000],
        [0.5000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.3333],
        [0.0000],
        [0.3333],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.5000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.5000],
 

In [2]:
parser = argparse.ArgumentParser(description='PyTorch implementation of relative positional encodings and relation-aware self-attention for graph Transformers')
args = parser.parse_args("")

args.dataset = 'ogbg-molhiv'
args.n_classes = 1
args.lr = 3e-4
args.n_hid = 512
args.n_heads = 8
args.n_layer = 4
args.dropout = 0.2
args.num_epochs = 60
args.k_hop_neighbors = 3
args.weight_decay = 1e-2
args.bsz      = 512
args.strategies = ['ea', 'sd', 'ja']
args.summary_node = True
args.hier_levels = 3
args.lap_k = 10
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args.metric = 'rocauc'
print("device:", args.device)

device: cuda


In [3]:
print("Loading data...")
print("dataset: {} ".format(args.dataset))
tz = pytz.timezone('US/Pacific')
time_now = datetime.datetime.now(tz).strftime('%m-%d_%H:%M:%S')

if args.summary_node:
    pre_transform = lambda d : pre_process_with_summary(d, args)
    root_path= f'dataset/{args.dataset}/with_summary_{args.k_hop_neighbors}'
    args.writer = SummaryWriter(log_dir=f'runs_new/{args.dataset}/with_summary_k={args.k_hop_neighbors}/strats={"-".join(args.strategies)}/{time_now}')

else:
    pre_transform = lambda d : pre_process(d, args)
    root_path= f'dataset/{args.dataset}/{args.k_hop_neighbors}'
    args.writer = SummaryWriter(log_dir=f'runs_new/{args.dataset}/k={args.k_hop_neighbors}/strats={"-".join(args.strategies)}/{time_now}')
    
    
dataset = PygGraphPropPredDataset(name=args.dataset, pre_transform=pre_transform, root = root_path)
evaluator = Evaluator(name=args.dataset)
split_idx = dataset.get_idx_split()
edge_dim_dict = {# 'ea': dataset.data.edge_attr.max(dim=0)[0].int().view(-1) + 1, \
                 'disc': {
#                      'sd': (dataset.data.sd_edge_attr.max(dim=0)[0].int().view(-1) + 1).tolist(), \
#                      'cn': (dataset.data.cn_edge_attr.max(dim=0)[0].int().view(-1) + 1).tolist(), \
#                      'hsd': (dataset.data.hsd_edge_attr.max(dim=0)[0].int().view(-1) + 1).tolist(), \
                    },
                 'cont': {
                    'ja': args.n_hid, \
#                     'ad': dataset.data.adamic_edge_attr.max(dim=0)[0].int().view(-1) + 1, \
                 }
                }
model = GT(args.n_hid, args.n_classes, args.n_heads, args.n_layer, edge_dim_dict, args.dropout, args.summary_node, args.lap_k).to(args.device)

Loading data...
dataset: ogbg-molhiv 


In [4]:
# Regular loader
# train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.bsz, shuffle=True)
# valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.bsz, shuffle=False)
# test_loader  = DataLoader(dataset[split_idx["test"]],  batch_size=args.bsz, shuffle=False)

In [5]:
# Loader with weighted sampler, for unbalanced data

from torch.utils.data import WeightedRandomSampler
weight = [1.0, np.sqrt((dataset.data.y == 0).sum().item() / (dataset.data.y == 1).sum().item())]
samples_weight = np.array([weight[yi] for yi in dataset.data.y.view(-1)[split_idx["train"]]])

samples_weight = torch.from_numpy(samples_weight)
samples_weigth = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.bsz, sampler = sampler)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.bsz, shuffle = False)
test_loader  = DataLoader(dataset[split_idx["test"]],  batch_size=args.bsz, shuffle = False)

In [6]:
print('Model #Params: %d' % get_n_params(model))

criterion = torch.nn.BCEWithLogitsLoss(reduction = "mean")

# optimizer = get_optimizer(model, weight_decay = args.weight_decay)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, pct_start = 0.05,\
#         steps_per_epoch=len(train_loader), epochs = args.num_epochs, anneal_strategy = 'linear')

optimizer = get_optimizer(model, weight_decay = args.weight_decay, learning_rate = args.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6)

Model #Params: 6666753


In [7]:
stats = []
for epoch in range(args.num_epochs):
    model.train()
    train_loss = []
    y_true = []
    y_scores = []
    for num_iters, data in enumerate(tqdm(train_loader)):
        data.to(args.device)
        # strats = {'ea': data.edge_attr, 'sd': data.sd_edge_attr} # 'lap_x': data.lap_x
        strats = {'ja': data.jaccard_edge_attr}
        out = model(data.x, data.batch, data.edge_index, strats)
        loss = criterion(out, data.y.float())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
        
        train_loss += [loss.item()]
        
        y_true += [data.y]
        y_scores += [out]
        
    args.writer.add_scalar("LR/epoch", optimizer.param_groups[0]['lr'], epoch + 1)
    args.writer.add_scalar("Loss/train", np.average(train_loss), epoch + 1)
    input_dict = {"y_true": torch.cat(y_true), "y_pred": torch.cat(y_scores)}
    train_metric = evaluator.eval(input_dict)[args.metric]
    args.writer.add_scalar(args.metric + "/train", train_metric, epoch + 1)

    model.eval()
    with torch.no_grad():
        valid_loss = []
        y_true = []
        y_scores = []
        for num_iters, data in enumerate(tqdm(valid_loader)):
            data.to(args.device)
            # strats = {'ea': data.edge_attr, 'sd': data.sd_edge_attr} # 'lap_x': data.lap_x
            strats = {'ja': data.jaccard_edge_attr}
            out = model(data.x, data.batch, data.edge_index, strats)
        
            loss = criterion(out, data.y.float())
            valid_loss += [loss.item()]

            y_true += [data.y]
            y_scores += [out]

        input_dict = {"y_true": torch.cat(y_true), "y_pred": torch.cat(y_scores)}
        valid_metric = evaluator.eval(input_dict)[args.metric]
        args.writer.add_scalar("Loss/valid", np.average(valid_loss), epoch + 1)
        args.writer.add_scalar(args.metric + "/valid", valid_metric, epoch + 1)
        
        test_loss = []
        y_true = []
        y_scores = []
        for data in test_loader:
            data.to(args.device)
            # strats = {'ea': data.edge_attr, 'sd': data.sd_edge_attr} # 'lap_x': data.lap_x
            strats = {'ja': data.jaccard_edge_attr}
            out = model(data.x, data.batch, data.edge_index, strats)
        
            loss = criterion(out, data.y.float())
            test_loss += [loss.item()]

            y_true += [data.y]
            y_scores += [out]

        input_dict = {"y_true": torch.cat(y_true), "y_pred": torch.cat(y_scores)}
        test_metric = evaluator.eval(input_dict)[args.metric]
        args.writer.add_scalar("Loss/test", np.average(test_loss), epoch + 1)
        args.writer.add_scalar(args.metric + "/test", test_metric, epoch + 1)
    
    print('Epoch %d: LR: %.5f, Train loss: %.3f Train %s: %.3f Valid loss: %.3f  Valid %s: %.3f Test loss: %.3f  Test %s: %.3f' \
          % (epoch + 1, optimizer.param_groups[0]['lr'], np.average(train_loss), args.metric, train_metric, \
             np.average(valid_loss), args.metric, valid_metric, \
             np.average(test_loss), args.metric, test_metric))
    stats += [[epoch, np.average(train_loss), train_metric, np.average(valid_loss), valid_metric, np.average(test_loss), test_metric]]

args.writer.close()

100%|██████████| 65/65 [01:00<00:00,  1.08it/s]
100%|██████████| 9/9 [00:04<00:00,  2.05it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 1: LR: 0.00029, Train loss: 0.448 Train rocauc: 0.605 Valid loss: 0.167  Valid rocauc: 0.751 Test loss: 0.198  Test rocauc: 0.686


100%|██████████| 65/65 [00:58<00:00,  1.10it/s]
100%|██████████| 9/9 [00:04<00:00,  2.15it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 2: LR: 0.00025, Train loss: 0.377 Train rocauc: 0.751 Valid loss: 0.120  Valid rocauc: 0.766 Test loss: 0.145  Test rocauc: 0.726


100%|██████████| 65/65 [00:59<00:00,  1.10it/s]
100%|██████████| 9/9 [00:04<00:00,  2.12it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 3: LR: 0.00020, Train loss: 0.347 Train rocauc: 0.795 Valid loss: 0.126  Valid rocauc: 0.786 Test loss: 0.153  Test rocauc: 0.753


100%|██████████| 65/65 [01:02<00:00,  1.04it/s]
100%|██████████| 9/9 [00:04<00:00,  1.91it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 4: LR: 0.00014, Train loss: 0.327 Train rocauc: 0.825 Valid loss: 0.140  Valid rocauc: 0.787 Test loss: 0.169  Test rocauc: 0.751


100%|██████████| 65/65 [01:01<00:00,  1.05it/s]
100%|██████████| 9/9 [00:04<00:00,  1.90it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 5: LR: 0.00008, Train loss: 0.305 Train rocauc: 0.852 Valid loss: 0.146  Valid rocauc: 0.792 Test loss: 0.179  Test rocauc: 0.739


100%|██████████| 65/65 [01:03<00:00,  1.02it/s]
100%|██████████| 9/9 [00:04<00:00,  1.93it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 6: LR: 0.00004, Train loss: 0.297 Train rocauc: 0.862 Valid loss: 0.125  Valid rocauc: 0.791 Test loss: 0.160  Test rocauc: 0.724


100%|██████████| 65/65 [01:02<00:00,  1.05it/s]
100%|██████████| 9/9 [00:04<00:00,  1.98it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 7: LR: 0.00001, Train loss: 0.282 Train rocauc: 0.873 Valid loss: 0.123  Valid rocauc: 0.802 Test loss: 0.163  Test rocauc: 0.730


100%|██████████| 65/65 [01:02<00:00,  1.04it/s]
100%|██████████| 9/9 [00:04<00:00,  2.03it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 8: LR: 0.00000, Train loss: 0.283 Train rocauc: 0.875 Valid loss: 0.131  Valid rocauc: 0.803 Test loss: 0.170  Test rocauc: 0.731


100%|██████████| 65/65 [01:02<00:00,  1.04it/s]
100%|██████████| 9/9 [00:04<00:00,  1.97it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 9: LR: 0.00002, Train loss: 0.286 Train rocauc: 0.876 Valid loss: 0.121  Valid rocauc: 0.802 Test loss: 0.160  Test rocauc: 0.723


100%|██████████| 65/65 [01:03<00:00,  1.03it/s]
100%|██████████| 9/9 [00:04<00:00,  2.00it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 10: LR: 0.00006, Train loss: 0.282 Train rocauc: 0.878 Valid loss: 0.133  Valid rocauc: 0.792 Test loss: 0.177  Test rocauc: 0.726


100%|██████████| 65/65 [01:02<00:00,  1.04it/s]
100%|██████████| 9/9 [00:04<00:00,  1.98it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 11: LR: 0.00012, Train loss: 0.279 Train rocauc: 0.880 Valid loss: 0.118  Valid rocauc: 0.798 Test loss: 0.154  Test rocauc: 0.725


100%|██████████| 65/65 [01:02<00:00,  1.04it/s]
100%|██████████| 9/9 [00:04<00:00,  1.96it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 12: LR: 0.00018, Train loss: 0.278 Train rocauc: 0.885 Valid loss: 0.097  Valid rocauc: 0.829 Test loss: 0.127  Test rocauc: 0.735


100%|██████████| 65/65 [01:03<00:00,  1.03it/s]
100%|██████████| 9/9 [00:04<00:00,  2.01it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 13: LR: 0.00023, Train loss: 0.273 Train rocauc: 0.893 Valid loss: 0.087  Valid rocauc: 0.818 Test loss: 0.125  Test rocauc: 0.726


100%|██████████| 65/65 [01:01<00:00,  1.06it/s]
100%|██████████| 9/9 [00:04<00:00,  1.89it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 14: LR: 0.00028, Train loss: 0.261 Train rocauc: 0.898 Valid loss: 0.097  Valid rocauc: 0.837 Test loss: 0.140  Test rocauc: 0.727


100%|██████████| 65/65 [01:01<00:00,  1.05it/s]
100%|██████████| 9/9 [00:04<00:00,  1.87it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 15: LR: 0.00030, Train loss: 0.257 Train rocauc: 0.909 Valid loss: 0.108  Valid rocauc: 0.812 Test loss: 0.144  Test rocauc: 0.737


100%|██████████| 65/65 [01:02<00:00,  1.05it/s]
100%|██████████| 9/9 [00:04<00:00,  2.02it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 16: LR: 0.00030, Train loss: 0.251 Train rocauc: 0.910 Valid loss: 0.127  Valid rocauc: 0.822 Test loss: 0.160  Test rocauc: 0.739


100%|██████████| 65/65 [00:58<00:00,  1.12it/s]
100%|██████████| 9/9 [00:04<00:00,  2.16it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 17: LR: 0.00027, Train loss: 0.229 Train rocauc: 0.930 Valid loss: 0.148  Valid rocauc: 0.820 Test loss: 0.192  Test rocauc: 0.730


100%|██████████| 65/65 [01:00<00:00,  1.07it/s]
100%|██████████| 9/9 [00:04<00:00,  1.88it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 18: LR: 0.00022, Train loss: 0.214 Train rocauc: 0.938 Valid loss: 0.099  Valid rocauc: 0.804 Test loss: 0.145  Test rocauc: 0.732


100%|██████████| 65/65 [01:04<00:00,  1.01it/s]
100%|██████████| 9/9 [00:04<00:00,  1.98it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 19: LR: 0.00016, Train loss: 0.201 Train rocauc: 0.949 Valid loss: 0.110  Valid rocauc: 0.824 Test loss: 0.164  Test rocauc: 0.739


100%|██████████| 65/65 [01:02<00:00,  1.04it/s]
100%|██████████| 9/9 [00:04<00:00,  1.85it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 20: LR: 0.00010, Train loss: 0.182 Train rocauc: 0.960 Valid loss: 0.130  Valid rocauc: 0.803 Test loss: 0.178  Test rocauc: 0.741


100%|██████████| 65/65 [01:04<00:00,  1.00it/s]
100%|██████████| 9/9 [00:04<00:00,  1.99it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 21: LR: 0.00005, Train loss: 0.165 Train rocauc: 0.968 Valid loss: 0.108  Valid rocauc: 0.795 Test loss: 0.154  Test rocauc: 0.753


100%|██████████| 65/65 [01:01<00:00,  1.05it/s]
100%|██████████| 9/9 [00:04<00:00,  1.96it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 22: LR: 0.00002, Train loss: 0.153 Train rocauc: 0.974 Valid loss: 0.117  Valid rocauc: 0.782 Test loss: 0.163  Test rocauc: 0.742


100%|██████████| 65/65 [01:04<00:00,  1.01it/s]
100%|██████████| 9/9 [00:04<00:00,  1.92it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 23: LR: 0.00000, Train loss: 0.145 Train rocauc: 0.976 Valid loss: 0.120  Valid rocauc: 0.781 Test loss: 0.168  Test rocauc: 0.737


100%|██████████| 65/65 [01:01<00:00,  1.05it/s]
100%|██████████| 9/9 [00:04<00:00,  2.02it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 24: LR: 0.00001, Train loss: 0.145 Train rocauc: 0.976 Valid loss: 0.122  Valid rocauc: 0.785 Test loss: 0.171  Test rocauc: 0.739


100%|██████████| 65/65 [01:02<00:00,  1.04it/s]
100%|██████████| 9/9 [00:04<00:00,  1.98it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 25: LR: 0.00004, Train loss: 0.142 Train rocauc: 0.977 Valid loss: 0.127  Valid rocauc: 0.790 Test loss: 0.177  Test rocauc: 0.739


100%|██████████| 65/65 [01:02<00:00,  1.05it/s]
100%|██████████| 9/9 [00:04<00:00,  2.14it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 26: LR: 0.00010, Train loss: 0.147 Train rocauc: 0.975 Valid loss: 0.105  Valid rocauc: 0.801 Test loss: 0.157  Test rocauc: 0.742


100%|██████████| 65/65 [01:00<00:00,  1.07it/s]
100%|██████████| 9/9 [00:04<00:00,  1.87it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 27: LR: 0.00016, Train loss: 0.147 Train rocauc: 0.976 Valid loss: 0.126  Valid rocauc: 0.771 Test loss: 0.182  Test rocauc: 0.734


100%|██████████| 65/65 [01:01<00:00,  1.06it/s]
100%|██████████| 9/9 [00:04<00:00,  1.93it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 28: LR: 0.00021, Train loss: 0.155 Train rocauc: 0.973 Valid loss: 0.130  Valid rocauc: 0.749 Test loss: 0.185  Test rocauc: 0.730


100%|██████████| 65/65 [01:05<00:00,  1.01s/it]
100%|██████████| 9/9 [00:05<00:00,  1.67it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 29: LR: 0.00026, Train loss: 0.160 Train rocauc: 0.971 Valid loss: 0.104  Valid rocauc: 0.800 Test loss: 0.155  Test rocauc: 0.755


100%|██████████| 65/65 [01:06<00:00,  1.02s/it]
100%|██████████| 9/9 [00:05<00:00,  1.67it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 30: LR: 0.00029, Train loss: 0.162 Train rocauc: 0.970 Valid loss: 0.127  Valid rocauc: 0.800 Test loss: 0.195  Test rocauc: 0.763


100%|██████████| 65/65 [01:01<00:00,  1.05it/s]
100%|██████████| 9/9 [00:04<00:00,  2.03it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 31: LR: 0.00030, Train loss: 0.159 Train rocauc: 0.971 Valid loss: 0.099  Valid rocauc: 0.789 Test loss: 0.184  Test rocauc: 0.750


100%|██████████| 65/65 [00:58<00:00,  1.11it/s]
100%|██████████| 9/9 [00:04<00:00,  2.05it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 32: LR: 0.00028, Train loss: 0.153 Train rocauc: 0.974 Valid loss: 0.097  Valid rocauc: 0.811 Test loss: 0.163  Test rocauc: 0.717


100%|██████████| 65/65 [01:02<00:00,  1.04it/s]
100%|██████████| 9/9 [00:04<00:00,  1.90it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 33: LR: 0.00024, Train loss: 0.141 Train rocauc: 0.977 Valid loss: 0.129  Valid rocauc: 0.813 Test loss: 0.194  Test rocauc: 0.747


100%|██████████| 65/65 [01:02<00:00,  1.04it/s]
100%|██████████| 9/9 [00:04<00:00,  2.00it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 34: LR: 0.00019, Train loss: 0.134 Train rocauc: 0.980 Valid loss: 0.109  Valid rocauc: 0.814 Test loss: 0.184  Test rocauc: 0.740


100%|██████████| 65/65 [01:01<00:00,  1.06it/s]
100%|██████████| 9/9 [00:04<00:00,  1.97it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 35: LR: 0.00013, Train loss: 0.120 Train rocauc: 0.985 Valid loss: 0.124  Valid rocauc: 0.817 Test loss: 0.209  Test rocauc: 0.730


100%|██████████| 65/65 [01:03<00:00,  1.02it/s]
100%|██████████| 9/9 [00:04<00:00,  1.93it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

Epoch 36: LR: 0.00007, Train loss: 0.100 Train rocauc: 0.989 Valid loss: 0.116  Valid rocauc: 0.795 Test loss: 0.207  Test rocauc: 0.729


 68%|██████▊   | 44/65 [00:42<00:20,  1.04it/s]


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
labels = ['epoch', 'train_loss', 'train_metric', 'valid_loss', 'valid_metric', 'test_loss', 'test_metric']
fig = plt.figure(figsize=(15, 10))
stats_np = np.array(stats)
best_valid = stats_np[stats_np[:, 4].argmax()]
print(best_valid)
for i in range(1, stats_np.shape[-1]):
    ax = fig.add_subplot(2, 3, i)
    ax.plot(stats_np[:, i], label=labels[i])
    ax.scatter(x=best_valid[0], y=best_valid[i], color='red')
    ax.annotate(best_valid[i].round(3), xy=(best_valid[0]+5, best_valid[i]), color='red')
    ax.legend()
