In [1]:
import glob
import joblib
import logging
import munch
import os
import sys
import toml
import torch
import torch.nn as nn
from torch_geometric.loader import DataLoader
import warnings

warnings.filterwarnings('ignore')

sys.path.append('../resources/library')
import tropical_cyclone as tc
from tropical_cyclone.dataset import TCGraphDatasetInference
from tropical_cyclone.macros import TEST_YEARS as test_years
from tropical_cyclone.tester import GraphTester

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-n1nppb3j because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


## Select experiment

In [2]:
run_dir = '../experiments/graphunet'

## Configuration file parsing

In [3]:
# set the dataset folder
dataset_dir = '../data/dataset'

# get configuration filename
config_file = os.path.join(run_dir, 'configuration.toml')

# parse config parameters
config = munch.munchify(toml.load(config_file))

# setup scaler
scaler = joblib.load(config.dir.scaler)

# data
drivers = config.data.drivers
targets = config.data.targets

# train parameters
batch_size = config.train.batch_size

## Select model checkpoint

In [4]:
# list all available checkpoints
models = sorted(glob.glob(os.path.join(run_dir, 'checkpoints', '*.ckpt')))
for idx,model_name in enumerate(models):
    print(idx, model_name.split('/')[-1])

0 epoch-0000-val_loss-0.52.ckpt
1 epoch-0001-val_loss-0.34.ckpt
2 last.ckpt


In [5]:
# pick your model
model_file = models[-1]
model_file

'../experiments/graphunet/checkpoints/last.ckpt'

## Model setup

In [6]:
# set model details
model_cls = eval(config.model.cls)
model_args = dict(config.model.args)

# define device
device = torch.device('cuda')

# define model
model:nn.Module = model_cls(**model_args)
model = model.to(device)

# load state dictionary
state_dict = torch.load(f=model_file, map_location=device)

# load weights into the model
model.load_state_dict(state_dict['state_dict'])

<All keys matched successfully>

## Directory setup

In [7]:
# define main inference folder
inference_model_dir = os.path.join('../data/inference', os.path.basename(run_dir))

# define logs directory
log_dir = os.path.join(inference_model_dir, 'logs')

os.makedirs(inference_model_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)

## Inference on the Dataset

In [8]:
# initialize logger
logging_level = logging.INFO
logging.basicConfig(format="[%(asctime)s] %(levelname)s : %(message)s", filename=f"{log_dir}/proc-{0}.log", 
                    filemode="w", level=logging_level, datefmt='%Y-%m-%d %H:%M:%S')
logging.info(f'Starting inference')

for year in test_years:
    logging.info(f'Year {year}')
    
    # creating graph dataset and dataloader for the current year
    logging.info(f'  Dataset preparation...')
    dataset = tc.dataset.TCGraphDatasetInference(src=dataset_dir, year=year, drivers=drivers, targets=targets, scaler=scaler)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=False)
    
    # getting a graph tester to obtain the y predictions with shape [B, 2]
    logging.info(f'  ...predicting...')
    tester = GraphTester(device=device, loader=loader, model=model, nodes_per_graph=model_args['nodes_per_graph'])
    tot_pred = tester.get_inference_y(threshold=0.4)
    
    # post-process operations
    logging.info(f'  ...post-processing...')
    dataset.post_process(tot_pred)
    
    # save .csv with coordinates and times to disk
    detection_dst = os.path.join(inference_model_dir, f'{year}.csv')
    dataset.store_detections(dst=detection_dst)
    logging.info(f'  ...predictions stored!')

Processing...
Done!


	Inference dataset for year 1997 created with 150612 elements!
	shape of elements:
		x: torch.Size([1600, 6])
		edge_index: torch.Size([2, 6240])


Inference on the test set: 100%|██████████| 295/295 [01:09<00:00,  4.23batch/s]
