# Fine Tune I3D Model for Bend Classification (LEFT, RIGHT, STRAIGHT)

In [1]:
import logging
import numpy as np
import tensorflow as tf

In [2]:
import i3d

## Load Pre trained Model

In [3]:
_CHECKPOINT_PATHS_IMAGENET = {
    "rgb": "data/checkpoints/rgb_imagenet/model.ckpt",
    "flow": "data/checkpoints/flow_imagenet/model.ckpt",
}

# pre trained model checkpoints based on a "joint" model
rgb_model_checkpoint = 'rgb'
flow_model_checkpoint = 'flow'

In [4]:
def get_pretrained_model(model_checkpoint_name, num_classes):
    model = i3d.InceptionI3d(num_classes=num_classes, spatial_squeeze=True, final_endpoint='Logits')

    logging.info('Loading checkpoint for %s model', model_checkpoint_name)

    tf.train.Checkpoint(model=model).restore(
            _CHECKPOINT_PATHS_IMAGENET[model_checkpoint_name]
    )

    logging.info('Loaded checkpoint for %s model', model_checkpoint_name)
    
    return model

rgb_model = get_pretrained_model(rgb_model_checkpoint, 400)
flow_model = get_pretrained_model(flow_model_checkpoint, 400)

## Test Pre trained Model Loaded Correctly

In [7]:
_LABEL_MAP_PATH = "data/label_map.txt"
_SAMPLE_PATHS = {
    "rgb": "data/v_CricketShot_g04_c01_rgb.npy",
    "flow": "data/v_CricketShot_g04_c01_flow.npy",
}

kinetics_classes = (
        [x.strip() for x in open(_LABEL_MAP_PATH, encoding='utf-8')]
    )

num_classes = 400

In [8]:
rgb_sample = tf.convert_to_tensor(np.load(_SAMPLE_PATHS["rgb"]), dtype=tf.float32)
logging.info("RGB sample loaded")
flow_sample = tf.convert_to_tensor(np.load(_SAMPLE_PATHS["flow"]), dtype=tf.float32)

In [10]:
rgb_logits, _ = rgb_model(rgb_sample)
flow_logits, _ = flow_model(flow_sample)

out_logits = rgb_logits + flow_logits

out_predictions = tf.nn.softmax(out_logits)

out_logits = out_logits[0]
out_predictions = out_predictions[0]
sorted_indices = np.argsort(out_predictions)[::-1]
print(f"Norm of logits: {np.linalg.norm(out_logits)}")
print("\nTop classes and probabilities")
for index in sorted_indices[:20]:
    print(out_predictions[index].numpy(), out_logits[index].numpy(), kinetics_classes[index])

Norm of logits: 139.30215454101562

Top classes and probabilities
1.0 42.059036 playing cricket
1.2958588e-09 21.594944 hurling (sport)
3.377472e-10 20.250313 catching or throwing baseball
1.4016727e-10 19.370852 catching or throwing softball
9.852616e-11 19.018337 hitting baseball
7.817063e-11 18.78691 playing tennis
2.229546e-11 17.532398 playing kickball
1.02992875e-11 16.76009 playing squash or racquetball
5.1403977e-12 16.065145 shooting goal (soccer)
3.8235903e-12 15.769205 hammer throw
1.7346235e-12 14.978806 golf putting
1.3726635e-12 14.744768 throwing discus
1.2943969e-12 14.68606 javelin throw
6.587782e-13 14.010647 pumping fist
4.358122e-13 13.597471 shot put
3.6078383e-13 13.408539 celebrating
2.231552e-13 12.928127 applauding
1.5636494e-13 12.572453 throwing ball
1.4015313e-13 12.462996 dodgeball
9.741567e-14 12.099247 tap dancing
