In [None]:
import json
import os.path as osp
import random

import matplotlib.pyplot as plt
import torch
from models.ofa.networks import CompositeSubNet
from segmentation_models_pytorch.utils.metrics import IoU, Precision


from ofa.utils.common_tools import build_config_from_file
from ofa.training.strategies.segmentation import Context
from dltools.data_providers.segmentation import SegmentationProvider
from ofa.training.strategies.utils.segmentation.visualise import visualise_segmentation
from dltools.data_providers import DataProvidersRegistry
from ofa.training.strategies import get_strategy_class
from ofa.training.strategies.segmentation import SegmentationStrategy, Context

In [None]:
ROOT = "/workspace/proj/workspace/ofa/output/segmentation/mbnet/camvid/DYNAMIC-FPN/08.12.23_11.31.31.848253"
config = build_config_from_file(osp.join(ROOT, "config.yaml"))

ProviderCLS = DataProvidersRegistry.get_provider_by_name(config.common.dataset.type)
provider: SegmentationProvider = ProviderCLS(config.common.dataset)
dataset = provider.test_dataset
print("dataset inited!")

CLS = get_strategy_class(config.common.strategy)
strategy: SegmentationStrategy = CLS(config.common)
strategy.device = torch.device("cpu")
print("strategy inited!")

model_config_path = osp.join(ROOT, "result_model_config.json")
with open(model_config_path) as fin:
    model_config = json.load(fin)
model = CompositeSubNet.build_from_config(model_config)
state_path = osp.join(ROOT, "result_model.pt")
state = torch.load(state_path, map_location="cpu")
model.load_state_dict(state)
model.eval()
print("model inited")

In [None]:
epoch_metric_dict = strategy.build_metrics_dict()

num_samples = 5
for _ in range(num_samples):

    context: Context = {}
    context["model"] = model
    i = random.randint(0, len(dataset) - 1)
    data = dataset[i]
    data["image"].unsqueeze_(0)
    data["target"].unsqueeze_(0)
    data["image_path"] = [data["image_path"]]
    context.update(data)

    runtime_metric_dict = strategy.build_metrics_dict()
    with torch.no_grad():
        strategy.prepare_batch(context)
        strategy.compute_output(context)

        strategy.update_metric(epoch_metric_dict, context)
        strategy.update_metric(runtime_metric_dict, context)
        sample_metrics = strategy.get_metric_vals(runtime_metric_dict)
        sample_metrics["image"] = data['image_path'][0].split('/')[-1]
        f = visualise_segmentation(
        context,
        strategy.n_classes,
        background=strategy.add_background,
        indices=[0],
        close_figure=False,
        )
        line = "\n".join([f"{k}: {v}" for k, v in sample_metrics.items()])
        f.suptitle(line)
        plt.show()

metrics = strategy.get_metric_vals(epoch_metric_dict)    
print("result metric")
print(metrics)