In [1]:
import torch, pickle, argparse, os, warnings, copy, time, mlflow
import numpy as np, pytorch_lightning as pl, matplotlib.pyplot as plt, eagerpy as ep
from models import ConvNet
from data_loader import load_test_data
from foolbox import PyTorchModel
from foolbox.attacks import LinfProjectedGradientDescentAttack
from foolbox.attacks.base import Repeated
from tqdm.notebook import tqdm
from attack_helper import run_batched_attack_cpu, batched_accuracy
from mlflow.tracking.artifact_utils import get_artifact_uri

In [2]:
def as_list(s):
    return list(map(int, s[1:-1].split(',')))

def get_hparam(df, run_id, param):
    return df[df['run_id']==run_id]['params.'+param].values[0]

In [3]:
tracking_uri = 'sqlite:///mlruns/database.db'
mlflow.set_tracking_uri(tracking_uri)
df=mlflow.search_runs(experiment_names=['model_training'])
print(df.keys())

Index(['run_id', 'experiment_id', 'status', 'artifact_uri', 'start_time',
       'end_time', 'metrics.loss', 'metrics.test_loss', 'metrics.val_loss',
       'metrics.test_acc', 'metrics.val_acc', 'metrics.epoch',
       'metrics.train_acc', 'metrics.train_loss', 'params.es_mode',
       'params.es_patience', 'params.max_epochs', 'params.name',
       'params.num_workers', 'params.nodes', 'params.total_params',
       'params.weight_decay', 'params.trainable_params', 'params.lr',
       'params.strides', 'params.train_rot', 'params.train_samples',
       'params.batch_norm', 'params.kernels', 'params.channels',
       'params.test_batch_size', 'params.es_monitor', 'params.activation_fn',
       'params.flat', 'params.es_min_delta', 'params.train_batch_size',
       'params.es_stopped_epoch', 'params.test_rot', 'params.padded_img_size',
       'params.bandlimit', 'params.image_size', 'params.kernel_max_beta',
       'tags.mlflow.user', 'tags.model', 'tags.mlflow.runName'],
      dtype='o

In [4]:
df[['run_id', 'params.total_params']].head()

Unnamed: 0,run_id,params.total_params
0,b78b4173276841d48d777e57e51650f3,113943
1,ee3fdeb373ad4873a08ba4a0d48b2d4a,113943
2,fcbc6994b7764a02ae57170754d9eec8,113943
3,52f771e936494cad8a39fc8292758cf9,113943
4,7b18c11dda3943c9a385ee96645339f9,113943


In [5]:
df[df['params.total_params'] == str(113943)][['tags.mlflow.runName', 'metrics.test_acc', 'params.train_samples']]

Unnamed: 0,tags.mlflow.runName,metrics.test_acc,params.train_samples
104,1668152805,0.992,50000
105,1668139362,0.9923,50000
106,1668127973,0.9915,50000
107,1668120367,0.9902,40000
108,1668110081,0.9911,40000
109,1668101504,0.9902,40000
110,1668094166,0.9888,30000
111,1668085318,0.9904,30000
112,1668077822,0.9895,30000
113,1668070267,0.9873,20000


In [10]:
df[df['tags.mlflow.runName']==str(1695204583)]['tags.model'].values[0]

'CConvNet'

In [4]:
eval(df[df['tags.mlflow.runName']==str(1695238653)]['params.test_rot'].values[0])

False

In [7]:
eval(df[df['tags.mlflow.runName']==str(1680863211)]['params.flat'].values[0])

True

In [11]:
print(df[df['tags.mlflow.runName']==str(1680863211)]['params.flat'].values[0])

True


In [16]:
eval(df[df['tags.mlflow.runName']==str(1680863211)]['params.padded_img_size'].values[0])

[60, 60]

In [34]:
df[df['tags.mlflow.runName']==str(1697106728)]['metrics.test_acc'].values[0]

0.9839000105857849

In [35]:
df[df['tags.mlflow.runName']==str(1697110031)]['metrics.test_acc'].values[0]

0.9842000007629395

In [23]:
df[df['tags.mlflow.runName']==str(1671724685)]['metrics.test_acc'].values[0]

0.9891999959945679

In [6]:
run_id=df[df['tags.mlflow.runName']==str(1668139362)]['run_id'].values[0]

In [7]:
artifact_path = get_artifact_uri(run_id=run_id, tracking_uri=tracking_uri)

In [8]:
artifact_path

'./mlruns/2/d66eec77ff0044938ec96d1653a34355/artifacts'

In [None]:
dirs=os.listdir(artifact_path)

In [None]:
for s in dirs:
    if s.find('.ckpt') >= 0:
        checkpoint = s
        break

In [None]:
checkpoint_path = os.path.join(artifact_path, checkpoint)

In [None]:
best_model = torch.load(checkpoint_path)
hparams = argparse.Namespace(**best_model['hyper_parameters'])
model = ConvNet(hparams, None, None).eval()
model.load_state_dict(best_model['state_dict'])

In [13]:
df[df['run_id']=='ae04618ae94c4650b021c7fdcfdb3221']

Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.epoch,metrics.train_loss,metrics.val_loss,metrics.train_acc,...,params.channels,params.lr,params.es_stopped_epoch,params.padded_img_size,params.kernel_max_beta,params.image_size,params.bandlimit,tags.mlflow.user,tags.mlflow.runName,tags.model
542,ae04618ae94c4650b021c7fdcfdb3221,2,FINISHED,./mlruns/2/ae04618ae94c4650b021c7fdcfdb3221/ar...,2023-10-14 15:34:47.690000+00:00,2023-10-14 15:50:59.783000+00:00,44.0,0.071172,0.774552,0.99315,...,,0.001,44,"[28, 28]",,,,dschuh,ae04618ae94c4650b021c7fdcfdb3221,MLP
