In [1]:
import hydra
import torch
import pytorch_lightning as pl
from hydra.core.global_hydra import GlobalHydra
import itertools

In [2]:
data_dir = 'D:/open/' # data_dir

config_yaml_path = "config.yaml" # hydra config file
ckpt_path = 'epoch=23.ckpt' # trained model weights

submission_file_name = 'submission' # output file name


In [3]:
def load_model_with_config(config_path, checkpoint_path):
    config_path = config_path.replace('\\','/')
    config_dir = '/'.join(config_path.split('/')[:-1])
    config_name = config_path.split('/')[-1]
    GlobalHydra.instance().clear()
    hydra.initialize(config_path = config_dir)
    cfg = hydra.compose(config_name=config_name)
    model = hydra.utils.instantiate(cfg['framework'])
    model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['state_dict'])
    return model

def load_datasets_with_config(config_path, data_dir, num_workers=0):
    config_path = config_path.replace('\\','/')
    config_dir = '/'.join(config_path.split('/')[:-1])
    config_name = config_path.split('/')[-1]
    GlobalHydra.instance().clear()
    hydra.initialize(config_path = config_dir)
    cfg = hydra.compose(config_name=config_name)
    datasets = hydra.utils.instantiate(cfg['dataloader']['datasets'],
                                       data_root=data_dir,
                                       is_test=True,
                                       num_workers=num_workers)
    return datasets

def get_pred(model, test_data_loader, gpus=[0]):
    trainer = pl.Trainer(gpus=gpus)
    preds = trainer.predict(model, test_data_loader)
    img_idx = list(itertools.chain(*[list(pred[0]) for pred in preds]))
    preds = torch.cat([pred[1] for pred in preds], dim=0)
    return img_idx, preds

In [4]:
# load_model/dataloader
model = load_model_with_config(config_yaml_path, ckpt_path)
datasets = load_datasets_with_config(config_yaml_path, data_dir)
test_loader =datasets.get_test_dataloaders()

In [5]:
img_idx, pred_scores = get_pred(model, test_loader)
pred_scores = torch.argmax(pred_scores, dim=-1).detach().cpu().numpy()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

In [6]:
idx2label = datasets.get_idx2label()
with open(f'{submission_file_name}.csv', 'w') as f:
    f.write('index,label\n')
    for idx, pred_class in zip(img_idx, pred_scores):
        f.write(f'{idx},{idx2label[pred_class]}\n')