This Jupyter notebook is a simple walkthrough on how to get started as quickly as possible with a trained FUSION model. 

## 1. Opening model

In [1]:
import os

# Import FUSION model and prime_X_fusion function
from src.models.pose_ir_fusion import *
# Import "device" and "classes" variables
from src.models.utils import *

### 1.1. Global variables

- **model_folder**: location of .pt model
- **use_pose**: Include pose module
- **use_ir**: Include IR module

In [2]:
# Global variables
model_folder = os.getcwd() + '/../' \
                           + 'models/' \
                           + 'fusion_test_tube_seed=0/' \
                           + 'fusion_20/' \
                           + 'cross_subject/' \
                           + 'aug=True/' 
model_file = 'model12.pt'
use_pose = True
use_ir = True


### 1.2. Create FUSION model

In [3]:
model = FUSION(use_pose, use_ir, pretrained = False)

### 1.3. Load trained weights

In [4]:
model.load_state_dict(torch.load(model_folder + model_file))
None

### 1.4. Push to device and set to evaluation mode

In [5]:
model.to(device)
model.eval()
None

## 2. Model inference

- **batch_size**: The model can study different sequences at a time
- **seq_len**: From a full sequence, sample *seq_len* frames from *seq_len* evenly spaced subwindows (see paper)

In [6]:
# Global variables
batch_size = 1
seq_len = 20

### 2.1. Create random tensor

In a real-life scenario, *X_skeleton* and *X_ir* should have pixel values in the [0, 255] range.

In [7]:
X_skeleton = torch.rand(batch_size, 3, 224, 224)
X_ir = torch.rand(batch_size, seq_len, 3, 112, 112)

X = [X_skeleton, X_ir]

### 2.2. Prime input tensor 

Priming the input tensor includes the normalization steps, reshaping and pushing on *device*. 

In [8]:
X_primed = prime_X_fusion(X, use_pose, use_ir)

### 2.3. Forward pass on model (inference)

In [9]:
predictions = model(X_primed)

In [10]:
_, class_predicted = predictions.max(1)
print("Class predicted : " + classes[class_predicted.item()])

Class predicted : walk toward other
