In [1]:
import torch, pickle, argparse, os, warnings, copy, time, mlflow
import numpy as np, pytorch_lightning as pl, matplotlib.pyplot as plt, eagerpy as ep
from models import ConvNet
from data_loader import load_test_data
from foolbox import PyTorchModel
from foolbox.attacks import LinfProjectedGradientDescentAttack
from foolbox.attacks.base import Repeated
from tqdm.notebook import tqdm
from attack_helper import batched_logits, batched_predictions
from mlflow.tracking.artifact_utils import get_artifact_uri

In [2]:
tracking_uri = 'sqlite:///mlruns/database.db'
mlflow.set_tracking_uri(tracking_uri)
df=mlflow.search_runs(experiment_names=['model_training'])
run_id=df[df['tags.mlflow.runName']==str(1668278459)]['run_id'].values[0]
artifact_path = get_artifact_uri(run_id=run_id, tracking_uri=tracking_uri)
dirs=os.listdir(artifact_path)

for s in dirs:
    if s.find('.ckpt') >= 0:
        checkpoint = s
        break
        
checkpoint_path = os.path.join(artifact_path, checkpoint)

In [3]:
best_model = torch.load(checkpoint_path)
hparams = argparse.Namespace(**best_model['hyper_parameters'])
model = ConvNet(hparams, None, None).eval()
model.load_state_dict(best_model['state_dict'])

<All keys matched successfully>

In [4]:
TEST_PATH = "s2_mnist_cs1.gz"
test_data = load_test_data(TEST_PATH)

images = test_data[:][0]
labels = test_data[:][1]

In [5]:
print(images.size())
print(model(images[:2]).size())

torch.Size([10000, 1, 60, 60])
torch.Size([2, 10])


In [6]:
logits = batched_logits(model.cuda(), images, 100)

  0%|          | 0/100 [00:00<?, ?it/s]

In [7]:
logits.size()

torch.Size([10000, 10])

In [8]:
pred = batched_predictions(model.cuda(), images, 100)

  0%|          | 0/100 [00:00<?, ?it/s]

In [9]:
pred_2 = torch.max(logits.reshape(-1,10),1)[1]

In [10]:
torch.all(pred==pred_2)

tensor(True)

In [11]:
pred[:10]

tensor([7, 2, 1, 0, 4, 1, 5, 9, 5, 9])

In [12]:
logits[:10]

tensor([[ 5.4869, -0.6696,  1.4066,  0.6219, -4.4709, -0.3668, -5.2891,  8.9015,
         -8.2221, -0.8814],
        [-1.3240,  1.8889,  6.9673, -4.2807,  0.0400,  0.8172, -0.6086,  3.2864,
         -3.3349, -3.7987],
        [ 0.5425, 10.9971, -2.9173, -5.4327, -1.4699,  0.0836,  1.3159,  4.4924,
         -2.3520, -2.6516],
        [ 7.2077, -4.4147, -0.3960,  1.4358, -2.4331,  0.6982,  0.6536, -2.0499,
          0.4678, -0.1660],
        [-4.2786, -2.5147, -1.9955, -1.2011,  7.9827,  3.1697, -0.1185, -1.1926,
          1.2032,  2.1169],
        [-0.6105, 10.2034, -0.9548, -3.5241, -1.1989, -1.2720,  2.1219,  1.4404,
         -2.0152, -5.0312],
        [-0.7860, -5.6367,  3.3510,  1.8416,  2.6116,  4.7410,  2.8391, -4.5642,
          1.3760,  0.1803],
        [-3.5238, -1.9698, -0.4505, -1.8892,  3.8877, -3.0857, -0.4553, -1.2669,
          3.6117,  6.6461],
        [-2.0459, -6.5618,  4.4525, -0.2075,  1.2279,  7.0087,  6.2800, -6.9050,
          1.2973, -0.4375],
        [ 2.1582, -