In [23]:
import os, time
import getpass

import mlflow
from configs import DAGSHUB_USER_NAME, DAGSHUB_REPO_NAME, MODEL_DIR,  DEVICE
from utils import get_experiment_id, get_last_run_id

from fewshotdataloader import generate_loader, CARS

import torch
from utils import evaluate

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

In [4]:
mlflow.set_tracking_uri(f'https://dagshub.com/{DAGSHUB_USER_NAME}/{DAGSHUB_REPO_NAME}.mlflow')

In [5]:
os.environ['MLFLOW_TRACKING_USERNAME'] = 'afhabibieee'
os.environ['MLFLOW_TRACKING_PASSWORD'] = getpass.getpass('enter pass: ')

In [24]:
def do_inference(
        run_id,
        n_way=5,
        n_shot=5,
        n_query=5,
        n_task=100,
):
        # dataset
        test_set = CARS(split='test', image_size=84)

        test_loader = generate_loader(
                'test',
                image_size=84,
                n_way=n_way,
                n_shot=n_shot,
                n_query=n_query,
                n_task=n_task,
                n_workers=2
        )

        # download model
        if not os.path.exists(os.path.join(MODEL_DIR, run_id)):
                local_path = mlflow.artifacts.download_artifacts(
                        run_id=run_id,
                        artifact_path='model',
                        dst_path=os.path.join(MODEL_DIR, run_id)
                )
        print("Artifacts downloaded in: {}".format(MODEL_DIR))
        print("Artifacts: {}\n".format(MODEL_DIR))
        
        # Load model
        logged_model = f'runs:/{run_id}/model'
        loaded_model = mlflow.pytorch.load_model(logged_model)

        # do inference for one task
        (
                example_support_images,
                example_support_labels,
                example_query_images,
                example_query_labels,
                example_class_ids,
        ) = next(iter(test_loader))

        start = time.time()
        loaded_model.to(DEVICE)
        loaded_model.eval()
        example_scores = loaded_model(
                example_support_images.to(DEVICE),
                example_support_labels.to(DEVICE),
                example_query_images.to(DEVICE),
        ).detach()
        _, example_predicted_labels = torch.max(example_scores.data, 1)
        end = time.time()
        time_exec = (end-start)#*10**3

        print("Ground Truth / Predicted")
        for i in range(len(example_query_labels)):
                print(
                        f"{test_set.class_names[example_class_ids[example_query_labels[i]]]} / {test_set.class_names[example_class_ids[example_predicted_labels[i]]]}"
                )
        
        print('\nThe time of inference for one task is : {} s'.format(time_exec))
        
        start = time.time()
        accuracy = evaluate(loaded_model, test_loader)
        end = time.time()
        time_exec = (end-start)#*10**3

        print('\nThe average accuracy of  the {} tasks in data testing is {}, with an execution time of {} s'.\
                format(len(test_loader), accuracy, time_exec))
        

In [25]:
# first model
run_id = 'c647a2612ffb4440930a5b17da7ab462'
do_inference(run_id)

Artifacts downloaded in: ../models/saved model
Artifacts: ../models/saved model





Ground Truth / Predicted
Land Rover Discovery 2010-2014 / Land Rover Discovery 2010-2014
Land Rover Discovery 2010-2014 / Land Rover Discovery 2010-2014
Land Rover Discovery 2010-2014 / Land Rover Discovery 2010-2014
Land Rover Discovery 2010-2014 / Land Rover Discovery 2010-2014
Land Rover Discovery 2010-2014 / Land Rover Discovery 2010-2014
Iveco Body / Iveco Body
Iveco Body / Iveco Body
Iveco Body / Iveco Body
Iveco Body / Iveco Body
Iveco Body / Iveco Body
Haval H3 2012 / Haval H3 2012
Haval H3 2012 / Haval H3 2012
Haval H3 2012 / Haval H3 2012
Haval H3 2012 / Haval H3 2012
Haval H3 2012 / Haval H3 2012
Citroen SEGA 2012 / Citroen SEGA 2012
Citroen SEGA 2012 / Citroen SEGA 2012
Citroen SEGA 2012 / Citroen SEGA 2012
Citroen SEGA 2012 / Citroen SEGA 2012
Citroen SEGA 2012 / Citroen SEGA 2012
Changhe Freda / Changhe Freda
Changhe Freda / Changhe Freda
Changhe Freda / Changhe Freda
Changhe Freda / Changhe Freda
Changhe Freda / Changhe Freda

The time of inference for one task is : 1.70

In [26]:
# second model
run_id = 'abbda0848d0e4aacb6bdcd1eeb0694d8'
do_inference(run_id)

Artifacts downloaded in: ../models/saved model
Artifacts: ../models/saved model





Ground Truth / Predicted
Land Rover Discovery 2010-2014 / Land Rover Discovery 2010-2014
Land Rover Discovery 2010-2014 / Land Rover Discovery 2010-2014
Land Rover Discovery 2010-2014 / Land Rover Discovery 2010-2014
Land Rover Discovery 2010-2014 / Land Rover Discovery 2010-2014
Land Rover Discovery 2010-2014 / Land Rover Discovery 2010-2014
Honda Fit 2008 / Honda Fit 2008
Honda Fit 2008 / Honda Fit 2008
Honda Fit 2008 / Honda Fit 2008
Honda Fit 2008 / Honda Fit 2008
Honda Fit 2008 / Honda Fit 2008
Mercedes Benz S Class 2011 / Mercedes Benz S Class 2011
Mercedes Benz S Class 2011 / Mercedes Benz S Class 2011
Mercedes Benz S Class 2011 / Mercedes Benz S Class 2011
Mercedes Benz S Class 2011 / Mercedes Benz S Class 2011
Mercedes Benz S Class 2011 / Mercedes Benz S Class 2011
Changhe Freda / Changhe Freda
Changhe Freda / Changhe Freda
Changhe Freda / Changhe Freda
Changhe Freda / Changhe Freda
Changhe Freda / Changhe Freda
Hyundai Sonata 14 / Hyundai Sonata 14
Hyundai Sonata 14 / Hyundai