In [1]:
import torch, pickle, argparse, os, warnings, copy, time, mlflow
import numpy as np, pytorch_lightning as pl, matplotlib.pyplot as plt, eagerpy as ep
from models import ConvNet
from data_loader import load_test_data
from foolbox import PyTorchModel
from foolbox.attacks import LinfProjectedGradientDescentAttack
from foolbox.attacks.base import Repeated
from tqdm.notebook import tqdm
from attack_helper import run_batched_attack_cpu, batched_accuracy
from mlflow.tracking.artifact_utils import get_artifact_uri

In [2]:
def as_list(s):
    return list(map(int, s[1:-1].split(',')))

def get_hparam(df, run_id, param):
    return df[df['run_id']==run_id]['params.'+param].values[0]

In [3]:
tracking_uri = 'sqlite:///mlruns/database.db'
mlflow.set_tracking_uri(tracking_uri)
df=mlflow.search_runs(experiment_names=['model_training'])
print(df.keys())

Index(['run_id', 'experiment_id', 'status', 'artifact_uri', 'start_time',
       'end_time', 'metrics.train_loss', 'metrics.val_acc', 'metrics.val_loss',
       'metrics.test_loss', 'metrics.train_acc', 'metrics.epoch',
       'metrics.test_acc', 'metrics.loss', 'params.bandlimit',
       'params.channels', 'params.es_patience', 'params.train_batch_size',
       'params.es_mode', 'params.total_params', 'params.es_monitor',
       'params.es_stopped_epoch', 'params.max_epochs', 'params.train_rot',
       'params.batch_norm', 'params.kernel_max_beta', 'params.name',
       'params.num_workers', 'params.activation_fn', 'params.test_rot',
       'params.nodes', 'params.trainable_params', 'params.train_samples',
       'params.weight_decay', 'params.lr', 'params.test_batch_size',
       'params.es_min_delta', 'params.strides', 'params.kernels',
       'tags.mlflow.runName', 'tags.mlflow.user', 'tags.model'],
      dtype='object')


In [4]:
df[['run_id', 'params.total_params']].head()

Unnamed: 0,run_id,params.total_params
0,18b4936ce2bd483ca10f5d69383cd4e1,35533
1,352c7576044345ecad8abb8fd65f1bd2,35533
2,d86fcea9871043d0a0a54dd0cd80d36e,35533
3,322122e11a9648c198b29d4056cb0b02,35533
4,1e276abad90a4634a76c82bf928420a4,35533


In [5]:
df[df['params.total_params'] == str(113943)][['tags.mlflow.runName', 'metrics.test_acc', 'params.train_samples']]

Unnamed: 0,tags.mlflow.runName,metrics.test_acc,params.train_samples
33,1668152805,0.992,50000
34,1668139362,0.9923,50000
35,1668127973,0.9915,50000
36,1668120367,0.9902,40000
37,1668110081,0.9911,40000
38,1668101504,0.9902,40000
39,1668094166,0.9888,30000
40,1668085318,0.9904,30000
41,1668077822,0.9895,30000
42,1668070267,0.9873,20000


In [4]:
df[df['tags.mlflow.runName']==str(1668152805)]['metrics.test_acc'].values[0]

0.9919999837875366

In [5]:
df[df['tags.mlflow.runName']==str(1668139362)]['metrics.test_acc'].values[0]

0.9922999739646912

In [6]:
df[df['tags.mlflow.runName']==str(1668127973)]['metrics.test_acc'].values[0]

0.9915000200271606

In [6]:
run_id=df[df['tags.mlflow.runName']==str(1668139362)]['run_id'].values[0]

In [7]:
artifact_path = get_artifact_uri(run_id=run_id, tracking_uri=tracking_uri)

In [8]:
artifact_path

'./mlruns/2/d66eec77ff0044938ec96d1653a34355/artifacts'

In [None]:
dirs=os.listdir(artifact_path)

In [None]:
for s in dirs:
    if s.find('.ckpt') >= 0:
        checkpoint = s
        break

In [None]:
checkpoint_path = os.path.join(artifact_path, checkpoint)

In [None]:
best_model = torch.load(checkpoint_path)
hparams = argparse.Namespace(**best_model['hyper_parameters'])
model = ConvNet(hparams, None, None).eval()
model.load_state_dict(best_model['state_dict'])