In [1]:
import torch
from efficientnet_pytorch import EfficientNet
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from pathlib import Path

In [2]:
# you can find a pretrained model at model/b3.pth
MODEL_F = 'model/b3_128.pth'
# directory with the numpy optical flow images you want to use for inference
OF_NPY_DIR = '../opical-flow-estimation-with-RAFT/output'

In [3]:
# check if cuda is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Load Model

In [4]:
V = 0     # what version of efficientnet did you use
IN_C = 2  # number of input channels
NUM_C = 1 # number of classes to predict

In [5]:
model = EfficientNet.from_pretrained(f'efficientnet-b{V}', in_channels=IN_C, num_classes=NUM_C)
state = torch.load(MODEL_F)
model.load_state_dict(state)
model.to(device);

Loaded pretrained weights for efficientnet-b3


In [6]:
def inference(of_f):
    of = np.load(of_f)
    i = torch.from_numpy(of).to(device)
    pred = model(i)
    del i
    torch.cuda.empty_cache()
    return pred

In [7]:
# loop over all files in directory and predict
# for f in Path(OF_NPY_DIR).glob('*.npy'):
#     y_hat = inference(f).item()
#     print(f'{f.name}: {round(y_hat, 2)}')
for i in range(40):
    f = OF_NPY_DIR + '/' + str(i) + '.npy'
    y_hat = inference(f).item()
    print(f'{f}: {round(y_hat, 2)}')
    
    

../opical-flow-estimation-with-RAFT/output/0.npy: 23.95
../opical-flow-estimation-with-RAFT/output/1.npy: 24.04
../opical-flow-estimation-with-RAFT/output/2.npy: 24.21
../opical-flow-estimation-with-RAFT/output/3.npy: 23.81
../opical-flow-estimation-with-RAFT/output/4.npy: 24.61
../opical-flow-estimation-with-RAFT/output/5.npy: 25.04
../opical-flow-estimation-with-RAFT/output/6.npy: 23.55
../opical-flow-estimation-with-RAFT/output/7.npy: 24.28
../opical-flow-estimation-with-RAFT/output/8.npy: 23.7
../opical-flow-estimation-with-RAFT/output/9.npy: 24.73
../opical-flow-estimation-with-RAFT/output/10.npy: 24.72
../opical-flow-estimation-with-RAFT/output/11.npy: 25.27
../opical-flow-estimation-with-RAFT/output/12.npy: 22.99
../opical-flow-estimation-with-RAFT/output/13.npy: 24.89
../opical-flow-estimation-with-RAFT/output/14.npy: 24.5
../opical-flow-estimation-with-RAFT/output/15.npy: 24.37
../opical-flow-estimation-with-RAFT/output/16.npy: 24.17
../opical-flow-estimation-with-RAFT/output/