In [None]:
!pip install --upgrade torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
!python -c "import torch; print(torch.__version__); print(torch.version.cuda);"
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.11.0+cu113.html

## Step1: Dataset Preprocessing

In [1]:
import numpy as np
import gc
import torch
import pyarrow as pa
from tqdm import tqdm
from pyarrow.parquet import ParquetFile
from sklearn.neighbors import kneighbors_graph
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

In [2]:
pf = ParquetFile('../input/tau-test-1/BoostedTop_x1_fixed_0.snappy.parquet') 
rows = next(pf.iter_batches(batch_size = 3200)) 
df = pa.Table.from_batches([rows]).to_pandas() 
X_jets = np.array(df['X_jets'].tolist()).astype(np.float32)
labels = torch.from_numpy(df['y'].to_numpy()).reshape(-1,1).type(torch.LongTensor)
del df, rows
print(X_jets.shape, labels.shape)

(3200, 125000) torch.Size([3200, 1])


In [3]:
X_data = X_jets.reshape((-1,125*125,8))

# uncommet this line below if choose to use full channels
X_data = X_data[:,:,:5]
non_black_pixels_mask = np.any(X_data != 0., axis=-1)

node_list = []
for i, x in enumerate(X_data):
    node_list.append(x[non_black_pixels_mask[i]])
del X_jets

In [4]:
dataset = []
for i,nodes in enumerate(tqdm(node_list)):
    dataset.append(Data(x=torch.from_numpy(nodes), y=labels[i]))

100%|██████████| 3200/3200 [00:00<00:00, 13913.65it/s]


In [5]:
del labels, node_list
gc.collect()

23

In [7]:
rand_seed = 42
X_train, X_test = train_test_split(dataset, test_size=0.1, random_state = rand_seed)
X_train, X_val = train_test_split(X_train, test_size=0.1, random_state = rand_seed)
print(len(X_train), len(X_val), len(X_val))

2592 288 288


In [10]:
train_loader = DataLoader(X_train, batch_size=32, shuffle=True)
val_loader = DataLoader(X_val, batch_size=32, shuffle=True)
test_loader = DataLoader(X_test, batch_size=32, shuffle=False)
batch = next(iter(test_loader))
print("Batch:", batch)
print("Labels:", batch.y[:10])
print("Batch indices:", batch.batch[:40])

Batch: DataBatch(x=[36159, 5], y=[32], batch=[36159], ptr=[33])
Labels: tensor([0, 1, 0, 0, 0, 1, 0, 0, 0, 1])
Batch indices: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])


# Step2: Define Dynamic Edge Convolution GNN Model

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

In [12]:
from torch.nn import Linear
from torch_geometric.nn import DynamicEdgeConv
from torch_geometric.nn import global_max_pool
from torch.nn import Linear as Lin
from torch.nn import ReLU
from torch.nn import Sequential as Seq

num_node_features = 8
num_classes = 2

class GCN(torch.nn.Module):
    def __init__(self, c_in, c_hidden, c_out = num_classes, dp_rate_linear=0.3):
        super().__init__()
        self.dp_rate_linear = dp_rate_linear

        nn = Seq(Lin(2*c_in, c_hidden), ReLU(), Lin(c_hidden, c_hidden), ReLU(), Lin(c_hidden, c_hidden), ReLU())
        self.conv1 = DynamicEdgeConv(nn, k=20, aggr='max')

        nn = Seq(Lin(2*c_hidden, 2*c_hidden), ReLU(), Lin(2*c_hidden, 2*c_hidden), ReLU(), Lin(2*c_hidden, 2*c_hidden),
                 ReLU())
        self.conv2 = DynamicEdgeConv(nn, k=20, aggr='max')

        self.lin1 = Lin(2*c_hidden, c_hidden)
        self.lin2 = Lin(c_hidden, c_hidden//2)
        self.lin3 = Lin(c_hidden//2, c_out)

    def forward(self, x, batch):
        x = self.conv1(x, batch)
        x = self.conv2(x, batch)

        x = global_max_pool(x, batch)

        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = F.dropout(x, p=self.dp_rate_linear, training=self.training)
        x = self.lin3(x)
        return x

## Step3: Use PyTorch Lightning to Define Training Process

In [13]:
from torchmetrics.functional import auroc
learning_rate = 3e-4

class GraphLevelGNN(pl.LightningModule):
    
    def __init__(self, **model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()
        
        self.model = GCN(**model_kwargs)
        self.loss_module = nn.BCEWithLogitsLoss() if self.hparams.c_out == 1 else nn.CrossEntropyLoss()
        self.auroc = auroc
        
    def forward(self, data, mode="train"):
        x, batch_idx = data.x, data.batch
        
        x = self.model(x, batch_idx)
        x = x.squeeze(dim=-1)
        
        if self.hparams.c_out == 1:
            preds = (x > 0).float()
            data.y = data.y.float()
        else:
            preds = x.argmax(dim=-1)
        loss = self.loss_module(x, data.y)
        acc = (preds == data.y).sum().float() / preds.shape[0]
        auc = self.auroc(x, data.y, num_classes=2)
        return loss, acc, auc

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=learning_rate, weight_decay=0) # High lr because of small dataset and small model
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, acc, auc = self.forward(batch, mode="train")
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        self.log('train_auc', auc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc, auc = self.forward(batch, mode="val")
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        self.log('val_auc', auc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        loss, acc, auc = self.forward(batch, mode="test")
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        self.log('test_auc', auc, prog_bar=True)

In [14]:
CHECKPOINT_PATH = "./"
def train_graph_classifier(model_name, **model_kwargs):
    pl.seed_everything(46)
    
    # Create a PyTorch Lightning trainer with the generation callback
    root_dir = os.path.join(CHECKPOINT_PATH, "GraphLevel" + model_name)
    os.makedirs(root_dir, exist_ok=True)
    trainer = pl.Trainer(default_root_dir=root_dir,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc")],
                         gpus=1 if str(device).startswith("cuda") else 0,
                         max_epochs=20,
                         progress_bar_refresh_rate=5)

    # Check whether pretrained model exists. If yes, load it and skip training
    model = GraphLevelGNN(**model_kwargs)
    print(model)
    trainer.fit(model, train_loader, val_loader)
    model = GraphLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    
    # Test best model on validation and test set
    val_result = trainer.test(model, dataloaders=val_loader, verbose=False)
    test_result = trainer.test(model, dataloaders=test_loader, verbose=False)
    print(val_result)
    print(test_result)
    result = {"test": test_result[0]['test_acc'], "valid": val_result[0]['test_acc']} 
    return model, result

In [15]:
import os
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, result = train_graph_classifier(model_name="GCN", c_in=5, c_hidden=64, c_out=2)

  f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"


GraphLevelGNN(
  (model): GCN(
    (conv1): DynamicEdgeConv(nn=Sequential(
      (0): Linear(in_features=10, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): ReLU()
      (4): Linear(in_features=64, out_features=64, bias=True)
      (5): ReLU()
    ), k=20)
    (conv2): DynamicEdgeConv(nn=Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): ReLU()
    ), k=20)
    (lin1): Linear(in_features=128, out_features=64, bias=True)
    (lin2): Linear(in_features=64, out_features=32, bias=True)
    (lin3): Linear(in_features=32, out_features=2, bias=True)
  )
  (loss_module): CrossEntropyLoss()
)


Sanity Checking: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validation: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Testing: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Testing: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


[{'test_loss': 0.3713477551937103, 'test_acc': 0.818640947341919, 'test_auc': 0.8673312664031982}]
[{'test_loss': 0.403104692697525, 'test_acc': 0.8088625073432922, 'test_auc': 0.8626824021339417}]


  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
