In [30]:
import os
import sys
import json

curr_path = os.path.dirname(os.path.abspath(__name__))
if curr_path not in sys.path:
    sys.path.insert(0, curr_path)

from configs import IMAGE_SIZE, DEVICE, BEST_MODEL_DIR
from data.fewshotdataloader import generate_loader
from models.protonet import PrototypicalNetwork

from torchmetrics.classification import MulticlassRecall, MulticlassPrecision
import torch

import pandas as pd
import mlflow

from dotenv import load_dotenv
load_dotenv()

True

In [8]:
# Combined validation and test class

with open('../../src/pytorch/data/val.json') as f:
    json_val = json.load(f)
with open('../../src/pytorch/data/test.json') as f:
    json_test = json.load(f)

json_test['class_names'] += json_val['class_names']
json_test['class_roots'] += json_val['class_roots']

with open('../../src/pytorch/data/test2.json', 'w') as f:
    json.dump(json_test, f, indent=4)

In [28]:
def evaluate_per_task(
    model,
    support_images, support_labels,
    query_images, query_labels
):
    classification_scores = model(support_images, support_labels, query_images)
    correct = (torch.max(classification_scores.detach().data, 1)[1] == query_labels).sum().item()
    total = query_labels.shape[0]

    return classification_scores, correct, total

In [31]:
def evaluate(model, data_loader, n_way):
    total_pred = 0
    correct_pred = 0
    total_recall = 0
    total_precision = 0

    recall = MulticlassRecall(num_classes=n_way, average='macro')
    precision = MulticlassPrecision(num_classes=n_way, average='macro')

    with torch.no_grad():
        for support_images, support_labels, query_images, query_labels, _ in data_loader:
            classification_scores, correct, total = evaluate_per_task(
                model,
                support_images.to(DEVICE), support_labels.to(DEVICE),
                query_images.to(DEVICE), query_labels.to(DEVICE)
            )
            correct_pred += correct
            total_pred += total

            top_scores, pred_labels = torch.max(classification_scores.data, 1)

            total_recall += recall(pred_labels, query_labels).item()
            total_precision += precision(pred_labels, query_labels).item()
    
    avg_accuracy = correct_pred/total_pred
    avg_recall = total_recall/len(data_loader)
    avg_precision = total_precision/len(data_loader)

    return avg_accuracy, avg_recall, avg_precision

In [33]:
def benchmark(model, n_ways, n_shots, metrics):
    results = []
    for n_way in n_ways:
        result = []
        for n_shot in n_shots:
            test_loader = generate_loader(
                'test2',
                image_size=IMAGE_SIZE,
                n_way=n_way,
                n_shot=n_shot,
                n_query=10,
                n_task=200,
                n_workers=2
            )
            avg_accuracy, avg_recall, avg_precision = evaluate(model, test_loader, n_way)
            result.extend([avg_accuracy, avg_recall, avg_precision])
        results.append(result)
    
    header = [
        [f'{shot}-shot' for shot in n_shots for i in range(len(metrics))],
        [*[metric for metric in metrics]*len(n_shots)]
    ]

    df = pd.DataFrame(
        results,
        index=[f'{way}-way' for way in n_ways],
        columns = header
    )

    return df

In [34]:
# Download artifact model and load model
def get_embedding(run_id, backend):
    artifact_path = os.path.join(BEST_MODEL_DIR, run_id)
    if not os.path.exists(artifact_path):
        mlflow.artifacts.download_artifacts(
            run_id=run_id,
            artifact_path='model',
            dst_path=artifact_path
        )
    embedding = PrototypicalNetwork(
        'convnext_tiny',
        mode='eval'
    ).to(DEVICE)
    embedding = torch.compile(embedding, backend=backend)
    embedding.load_state_dict(torch.load(os.path.join(artifact_path, 'model/model.pt'), map_location=DEVICE))
    embedding = embedding.eval()
    return embedding

In [37]:
n_ways = [3, 6]
n_shots = [3, 10, 20]
metrics = ['accuracy', 'precision', 'recall']

In [36]:
run_id = '9f3784aa59224e52bbf2fef98656a9e6'
backend = ['eager', 'inductor']
model = get_embedding(run_id, backend[0])

In [38]:
results = benchmark(model, n_ways, n_shots, metrics)

In [39]:
results

Unnamed: 0_level_0,3-shot,3-shot,3-shot,10-shot,10-shot,10-shot,20-shot,20-shot,20-shot
Unnamed: 0_level_1,accuracy,precision,recall,accuracy,precision,recall,accuracy,precision,recall
3-way,0.761833,0.761833,0.783158,0.7965,0.7965,0.809349,0.8315,0.8315,0.840766
6-way,0.590583,0.590583,0.613054,0.6795,0.6795,0.694714,0.696917,0.696917,0.708406
