In [38]:
import os
import getpass

import mlflow
from configs import DAGSHUB_USER_NAME, DAGSHUB_REPO_NAME, MODEL_DIR,  DEVICE
from utils import get_experiment_id, get_last_run_id

from fewshotdataloader import generate_loader, CARS

import torch
from utils import evaluate

In [7]:
mlflow.set_tracking_uri(f'https://dagshub.com/{DAGSHUB_USER_NAME}/{DAGSHUB_REPO_NAME}.mlflow')

In [12]:
os.environ['MLFLOW_TRACKING_USERNAME'] = 'afhabibieee'
os.environ['MLFLOW_TRACKING_PASSWORD'] = getpass.getpass('enter pass: ')

In [18]:
# Use the run id of the selected artifact.
RUN_ID = 'abbda0848d0e4aacb6bdcd1eeb0694d8'

In [19]:
local_path = mlflow.artifacts.download_artifacts(
        run_id=RUN_ID,
        artifact_path='model',
        dst_path=os.path.join(MODEL_DIR, RUN_ID)
    )
print("Artifacts downloaded in: {}".format(MODEL_DIR))
print("Artifacts: {}".format(MODEL_DIR))

Artifacts downloaded in: ../models/saved model
Artifacts: ../models/saved model


In [20]:
logged_model = f'runs:/{RUN_ID}/model'

# Load model
loaded_model = mlflow.pytorch.load_model(logged_model)

In [35]:
test_set = CARS(split='test', image_size=84)

test_loader = generate_loader(
    'test',
    image_size=84,
    n_way=5,
    n_shot=5,
    n_query=5,
    n_task=100,
    n_workers=2        
)

In [37]:
(
    example_support_images,
    example_support_labels,
    example_query_images,
    example_query_labels,
    example_class_ids,
) = next(iter(test_loader))

loaded_model.eval()
example_scores = loaded_model(
    example_support_images.to(DEVICE),
    example_support_labels.to(DEVICE),
    example_query_images.to(DEVICE),
).detach()

_, example_predicted_labels = torch.max(example_scores.data, 1)

print("Ground Truth / Predicted")
for i in range(len(example_query_labels)):
    print(
        f"{test_set.class_names[example_class_ids[example_query_labels[i]]]} / {test_set.class_names[example_class_ids[example_predicted_labels[i]]]}"
    )

Ground Truth / Predicted
Land Rover Discovery 2010-2014 / Land Rover Discovery 2010-2014
Land Rover Discovery 2010-2014 / Land Rover Discovery 2010-2014
Land Rover Discovery 2010-2014 / Land Rover Discovery 2010-2014
Land Rover Discovery 2010-2014 / Land Rover Discovery 2010-2014
Land Rover Discovery 2010-2014 / Land Rover Discovery 2010-2014
Skoda Hao Rui / Skoda Hao Rui
Skoda Hao Rui / Skoda Hao Rui
Skoda Hao Rui / Skoda Hao Rui
Skoda Hao Rui / Skoda Hao Rui
Skoda Hao Rui / Skoda Hao Rui
Ford Fiesta 2009 / Ford Fiesta 2009
Ford Fiesta 2009 / Ford Fiesta 2009
Ford Fiesta 2009 / Ford Fiesta 2009
Ford Fiesta 2009 / Ford Fiesta 2009
Ford Fiesta 2009 / Ford Fiesta 2009
Hyundai Longdong 2012 / Hyundai Longdong 2012
Hyundai Longdong 2012 / Hyundai Longdong 2012
Hyundai Longdong 2012 / Hyundai Longdong 2012
Hyundai Longdong 2012 / Hyundai Longdong 2012
Hyundai Longdong 2012 / Hyundai Longdong 2012
Nissan Sylphy 2007-2012 New / Nissan Sylphy 2007-2012 New
Nissan Sylphy 2007-2012 New / Nissan 

In [41]:
print('Rata-rata akurasi pada data testing untuk {} task adalah {}'.\
    format(len(test_loader), evaluate(loaded_model, test_loader)))

Rata-rata akurasi pada data testing untuk 100 task adalah 0.9628
