In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from models.transformer import Transformer, create_look_ahead_mask
from dataset.dataset_extracted import ExtractedFeatureDataset
import sentencepiece as spm
import transforms
from evaluate import BeamSearchCaptioner
import os

In [2]:
flag = 'train'

In [3]:
test_corpus_file = '../MVAD/corpus_M-VAD_{}.txt'.format(flag)
tokenizer_file = 'tokenizer.model'
model_weight_file = '../checkpoint/20190815013812/100'
inp_max_seq_length = 50
tar_max_seq_length = 50
tar_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
encoder_num_layers = 6
decoder_num_layers = 6
dff = 2048
dropout = 0.1
max_seq_length = 80  # For positional encoding

In [4]:
test_feature_path = '../MVAD/I3D_rgb_kinetics/{}'.format(flag)
with open('../MVAD/{}_fine'.format(flag)) as f:
    files = f.readlines()
    feature_files = list(map(lambda file: os.path.join(test_feature_path, str.strip(file) + '.npy'), files))

In [5]:
sp = spm.SentencePieceProcessor()
sp.Load(tokenizer_file)

True

In [6]:
feature_transform = transforms.Compose([
    transforms.FeaturePadding(inp_max_seq_length)
])
caption_transform = transforms.Compose([
    transforms.CaptionPadding(tar_max_seq_length, sp.PieceToId('<PAD>'))
])

In [7]:
dataset = ExtractedFeatureDataset(None, test_corpus_file, inp_max_seq_length, tar_max_seq_length, sp, feature_transform=feature_transform, caption_transform=caption_transform, feature_files=feature_files)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [8]:
model = Transformer(tar_vocab_size, d_model, num_heads, encoder_num_layers, decoder_num_layers, dff, dropout, max_seq_length)
state_dict = torch.load(model_weight_file, map_location='cpu')
model.load_state_dict(state_dict['model_state_dict'], strict=False)
model.eval()

Transformer(
  (linear): Linear(in_features=1024, out_features=512, bias=True)
  (pe): PositionalEncoder()
  (encoder): TransformerEncoder(
    (encoder_layers): ModuleList(
      (encoder_layer1): TransformerEncoderLayer(
        (mha): MultiheadAttention(
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (ff): FeedForward(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1)
        )
        (an1): AddNorm(
          (dropout): Dropout(p=0.1)
          (layernorm): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
        )
        (an2): AddNorm(
          (dropout): Dropout(p=0.1)
          (layernorm): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
        )
      )
      (encoder_layer2): TransformerEncoderLayer(
        (mha): MultiheadAttention(
          (out_proj): Line

In [9]:
dataiter = iter(dataloader)

In [10]:
caption = dataiter.next()[1]

In [11]:
from torchnlp.metrics.bleu import get_moses_multi_bleu

In [12]:
caption = caption.squeeze(0).tolist()

In [13]:
caption

[1,
 91,
 1842,
 15,
 358,
 293,
 6,
 17,
 5,
 3064,
 4,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]

In [14]:
get_moses_multi_bleu(sp.decode_ids(caption), sp.decode_ids(caption), lowercase=True)

0.0

In [15]:
captioner = BeamSearchCaptioner(model, sp, tar_max_seq_length, 3, 10, 5)
captioner.caption_video_from_dataloader(dataloader, count=10)

video: ('KARATE_KID_DVS1138.avi',)
caption origin: Beaming, he returns the gesture.
step: 2, end_nodes: 0
step: 3, end_nodes: 0
step: 4, end_nodes: 0
step: 5, end_nodes: 0
step: 6, end_nodes: 0
step: 7, end_nodes: 0
step: 8, end_nodes: 0
step: 9, end_nodes: 0
step: 10, end_nodes: 1
step: 11, end_nodes: 2
step: 12, end_nodes: 3
step: 13, end_nodes: 4
step: 14, end_nodes: 5
step: 15, end_nodes: 6
step: 16, end_nodes: 6
step: 17, end_nodes: 6
step: 18, end_nodes: 8
step: 19, end_nodes: 9
step: 20, end_nodes: 9
caption predict: Beaming, he returns the gesture.

video: ('SALT_DVS60.avi',)
caption origin: The chopper arrives at a fortress.
step: 2, end_nodes: 0
step: 3, end_nodes: 0
step: 4, end_nodes: 0
step: 5, end_nodes: 0
step: 6, end_nodes: 0
step: 7, end_nodes: 0
step: 8, end_nodes: 0
step: 9, end_nodes: 0
step: 10, end_nodes: 0
step: 11, end_nodes: 1
step: 12, end_nodes: 1
step: 13, end_nodes: 2
step: 14, end_nodes: 3
step: 15, end_nodes: 3
step: 16, end_nodes: 5
step: 17, end_nodes: 