In [1]:
import torch, pickle, argparse, os, warnings, copy, time, mlflow
import numpy as np, pytorch_lightning as pl, matplotlib.pyplot as plt, eagerpy as ep
from models import ConvNet
from data_loader import load_test_data
from foolbox import PyTorchModel
from foolbox.attacks import LinfProjectedGradientDescentAttack
from foolbox.attacks.base import Repeated
from tqdm.notebook import tqdm
from attack_helper import run_batched_attack_cpu, batched_accuracy
from mlflow.tracking.artifact_utils import get_artifact_uri

In [2]:
def as_list(s):
    return list(map(int, s[1:-1].split(',')))

def get_hparam(df, run_id, param):
    return df[df['run_id']==run_id]['params.'+param].values[0]

In [3]:
tracking_uri = 'sqlite:///mlruns/database.db'
mlflow.set_tracking_uri(tracking_uri)
df=mlflow.search_runs(experiment_names=['model_training'])
print(df.keys())

Index(['run_id', 'experiment_id', 'status', 'artifact_uri', 'start_time',
       'end_time', 'metrics.val_loss', 'metrics.test_acc', 'metrics.train_acc',
       'metrics.loss', 'metrics.test_loss', 'metrics.val_acc',
       'metrics.train_loss', 'metrics.epoch', 'params.es_stopped_epoch',
       'params.test_rot', 'params.max_epochs', 'params.test_batch_size',
       'params.num_workers', 'params.trainable_params', 'params.es_min_delta',
       'params.strides', 'params.es_mode', 'params.kernels',
       'params.es_monitor', 'params.lr', 'params.activation_fn', 'params.name',
       'params.weight_decay', 'params.batch_norm', 'params.channels',
       'params.nodes', 'params.es_patience', 'params.train_batch_size',
       'params.train_samples', 'params.train_rot', 'params.total_params',
       'params.kernel_max_beta', 'params.bandlimit', 'tags.mlflow.runName',
       'tags.mlflow.user', 'tags.model'],
      dtype='object')


In [4]:
df[['run_id', 'params.total_params']].head()

Unnamed: 0,run_id,params.total_params
0,239e4f337aee4384bb392795ad832e37,50106
1,08fc3711881141eb85bba07cdcfeddb4,50106
2,8ec8e86f2cd949559a6b0a614aa07338,50106
3,513ea3f822374dc9b770f5bc4c15ed96,50106
4,e4616d4520fe4ebdbd6dfcb37b4ef219,50106


In [5]:
df[df['params.total_params'] == str(113943)][['tags.mlflow.runName', 'metrics.test_acc', 'params.train_samples']]

Unnamed: 0,tags.mlflow.runName,metrics.test_acc,params.train_samples
104,1668152805,0.992,50000
105,1668139362,0.9923,50000
106,1668127973,0.9915,50000
107,1668120367,0.9902,40000
108,1668110081,0.9911,40000
109,1668101504,0.9902,40000
110,1668094166,0.9888,30000
111,1668085318,0.9904,30000
112,1668077822,0.9895,30000
113,1668070267,0.9873,20000


In [14]:
eval(df[df['tags.mlflow.runName']==str(1671757292)]['params.test_rot'].values[0])

False

In [13]:
df[df['tags.mlflow.runName']==str(1671757292)]['metrics.test_acc'].values[0]

0.9908999800682068

In [14]:
df[df['tags.mlflow.runName']==str(1671760833)]['metrics.test_acc'].values[0]

0.9909999966621399

In [15]:
df[df['tags.mlflow.runName']==str(1671762466)]['metrics.test_acc'].values[0]

0.9908000230789185

In [6]:
run_id=df[df['tags.mlflow.runName']==str(1668139362)]['run_id'].values[0]

In [7]:
artifact_path = get_artifact_uri(run_id=run_id, tracking_uri=tracking_uri)

In [8]:
artifact_path

'./mlruns/2/d66eec77ff0044938ec96d1653a34355/artifacts'

In [None]:
dirs=os.listdir(artifact_path)

In [None]:
for s in dirs:
    if s.find('.ckpt') >= 0:
        checkpoint = s
        break

In [None]:
checkpoint_path = os.path.join(artifact_path, checkpoint)

In [None]:
best_model = torch.load(checkpoint_path)
hparams = argparse.Namespace(**best_model['hyper_parameters'])
model = ConvNet(hparams, None, None).eval()
model.load_state_dict(best_model['state_dict'])