In [1]:
import os
import torch

from fgvclib.apis import *
from fgvclib.configs import FGVCConfig
from fgvclib.utils.visualization import VOXEL

### Load Config

In [2]:
config_file = "./configs/baseline_restnet50.yml"
config = FGVCConfig()
config.load(config_file)
cfg = config.cfg
print(cfg)

DATASETS:
  ROOT: /data/wangxinran/dataset/Birds2
  TEST:
    BATCH_SIZE: 16
    NUM_WORKERS: 4
    PIN_MEMORY: False
    POSITIVE: None
    SHUFFLE: False
  TRAIN:
    BATCH_SIZE: 16
    NUM_WORKERS: 4
    PIN_MEMORY: True
    POSITIVE: None
    SHUFFLE: True
EPOCH_NUM: 1
EXP_NAME: Baseline_ResNet50
FIFTYONE:
  NAME: BirdsTest
  STORE: True
INTERPRETER:
  METHOD: gradcam
  NAME: cam
  TARGET_LAYERS: ['layer4']
ITERATION_NUM: None
LOGGER:
  FILE_PATH: ./logs/
  NAME: txt_logger
  PRINT_FRE: 50
METRICS: [{'name': 'accuracy(topk=1)', 'metric': 'accuracy', 'top_k': 1, 'threshold': None}, {'name': 'recall(threshold=0.5)', 'metric': 'recall', 'top_k': None, 'threshold': 0.5}, {'name': 'precision(threshold=0.5)', 'metric': 'precision', 'top_k': None, 'threshold': 0.5}]
MODEL:
  BACKBONE:
    ARGS: [{'pretrained': True}, {'del_keys': []}]
    NAME: resnet50
  CLASS_NUM: 200
  CRITERIONS: [{'name': 'cross_entropy_loss', 'args': [], 'w': 1.0}]
  ENCODING:
    ARGS: None
    NAME: GlobalAvgPooli

### build Model

In [3]:
model = build_model(cfg.MODEL)
weight_path = os.path.join(cfg.WEIGHT.SAVE_DIR, cfg.WEIGHT.NAME)
assert os.path.exists(weight_path), f"The resume weight {cfg.RESUME_WEIGHT} dosn't exists."
state_dict = torch.load(weight_path, map_location="cpu")
model.load_state_dict(state_dict=state_dict)

if cfg.USE_CUDA:
    assert torch.cuda.is_available(), f"Cuda is not available."
    model = torch.nn.DataParallel(model)

transforms = build_transforms(cfg.TRANSFORMS.TEST)
loader = build_dataset(root=os.path.join(cfg.DATASETS.ROOT, 'test'), cfg=cfg.DATASETS.TEST, transforms=transforms)

interpreter = build_interpreter(model, cfg)
voxel = VOXEL(dataset=loader.dataset, name=cfg.FIFTYONE.NAME, interpreter=interpreter)
voxel.predict(model, transforms, 10, cfg.MODEL.NAME)
voxel.launch()

100%|██████████| 5794/5794 [00:01<00:00, 3292.42it/s]


 100% |███████████████| 5794/5794 [5.2s elapsed, 0s remaining, 1.2K samples/s]       
  10% |█------------------|  1/10 [1.0m elapsed, 9.0m remaining, 0.0 samples/s]  


KeyboardInterrupt: 