The idea now:
Since their implementation of the model requires packages only available for a linux engine, I will have to retrofit their model to work on my computer.

For the SingleStream network I need:
- Visual Encoder:
    - S3D Backbone 1-4 blocks, outputs S3D features (T/4x843) --> these features are used for the head network                    CHECK
    - Head network, outputs Gloss representations in high dim space (T/4x512) --> These are the features sent forward to the S2T  CHECK
        - Linear/BN/ReLU                                                                                                          CHECK
        - Temporal Cov Block                                                                                                      CHECK
    - Linear classifier, outputs Gloss logits (T/4xK)                                                                             CHECK
    - Softmax, outputs gloss probabilities (T/4xK)                                                                                CHECK
    - and then the CTC (connectionist temporal classification) loss and the CTC Decoder (which outputs Gloss Predictions)

- Pretraining of the Visual Encoder (not necessary since we will be using their pretrained weights).

- V-L Mapper

### UTILS

In [119]:
import copy
import glob
import os
import os.path
import errno
import shutil
import random
import logging
from sys import platform
from logging import Logger
from typing import Callable, Optional
import numpy as np
import cv2
import pandas as pd
from torchinfo import summary

import torch
from torch import nn, Tensor
import yaml
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F

def neq_load_customized(model, pretrained_dict, verbose=False):
    ''' load pre-trained model in a not-equal way,
    when new model has been partially modified '''
    model_dict = model.state_dict()
    tmp = {}
    if verbose:
        print(list(model_dict.keys()))
        print('\n=======Check Weights Loading======')
        print('Weights not used from pretrained file:')
    for k, v in pretrained_dict.items():
        if k in model_dict and model_dict[k].shape==v.shape:
            tmp[k] = v
        else:
            if verbose:
                print(k)
    if verbose:
        print('---------------------------')
        print('Weights not loaded into new model:')
        for k, v in model_dict.items():
            if k not in pretrained_dict:
                print(k)
            elif model_dict[k].shape != pretrained_dict[k].shape:
                print(k, 'shape mis-matched, not loaded')
        print('===================================\n')

    del pretrained_dict
    model_dict.update(tmp)
    del tmp
    model.load_state_dict(model_dict)
    return model


def get_logger():
    return Logger

## S3D

In [113]:
### S3D Model architecture

class S3Dsup(nn.Module):
    def __init__(self, in_channels, num_class, use_block, stride):
        super(S3Dsup, self).__init__()
        base_seq = []
        if use_block>=1:
            base_seq += [
                SepConv3d(in_channels, 64, kernel_size=7, stride=2, padding=3),
            ]
        if use_block>=2:
            base_seq += [
                nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(
                    1, 2, 2), padding=(0, 1, 1)),                        # 1
                BasicConv3d(64, 64, kernel_size=1, stride=1),            # 2
                SepConv3d(64, 192, kernel_size=3, stride=1, padding=1),  # 3
            ]
        if use_block>=3:
            base_seq += [
                nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(
                    1, 2, 2), padding=(0, 1, 1)),                        # 4
                Mixed_3b(),                                              # 5
                Mixed_3c(),                                              # 6
            ]
        if use_block>=4:
            base_seq += [
                nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(
                    2, 2, 2), padding=(1, 1, 1)),                        # 7
                Mixed_4b(),                                              # 8
                Mixed_4c(),                                              # 9
                Mixed_4d(),                                              # 10
                Mixed_4e(),                                              # 11
                Mixed_4f(),                                              # 12
            ]
        if use_block>=5:
            base_seq += [
                nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(
                    stride, 2, 2), padding=(0 if stride==2 else 1, 0, 0)),
                Mixed_5b(),
                Mixed_5c(), #15
            ]
        self.base_num_layers = len(base_seq)
        self.base = nn.Sequential(*base_seq)
        #self.fc = nn.Sequential(nn.Conv3d(BLOCK2SIZE[use_block], num_class, kernel_size=1, stride=1, bias=True)) 
        # Took the standard fc from S3D class pytorch, allows the model to load the weights, so we assume it's the right one
        # 1024 for kinetics, 832 for gloss since different blocksize

    def forward(self, x):
        y = self.base(x)
        y = F.avg_pool3d(y, (2, y.size(3), y.size(4)), stride=1)
        #y = self.fc(y)
        y = y.view(y.size(0), y.size(1), y.size(2))
        logits = torch.mean(y, 2)

        return logits


class BasicConv3d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
        super(BasicConv3d, self).__init__()
        self.conv = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class SepConv3d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
        super(SepConv3d, self).__init__()
        self.conv_s = nn.Conv3d(in_planes, out_planes, kernel_size=(1,kernel_size,kernel_size), stride=(1,stride,stride), padding=(0,padding,padding), bias=False)
        self.bn_s = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True)
        self.relu_s = nn.ReLU()

        self.conv_t = nn.Conv3d(out_planes, out_planes, kernel_size=(kernel_size,1,1), stride=(stride,1,1), padding=(padding,0,0), bias=False)
        self.bn_t = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True)
        self.relu_t = nn.ReLU()

    def forward(self, x):
        x = self.conv_s(x)
        x = self.bn_s(x)
        x = self.relu_s(x)

        x = self.conv_t(x)
        x = self.bn_t(x)
        x = self.relu_t(x)
        return x

class Mixed_3b(nn.Module):
    def __init__(self):
        super(Mixed_3b, self).__init__()

        self.branch0 = nn.Sequential(
            BasicConv3d(192, 64, kernel_size=1, stride=1),
        )
        self.branch1 = nn.Sequential(
            BasicConv3d(192, 96, kernel_size=1, stride=1),
            SepConv3d(96, 128, kernel_size=3, stride=1, padding=1),
        )
        self.branch2 = nn.Sequential(
            BasicConv3d(192, 16, kernel_size=1, stride=1),
            SepConv3d(16, 32, kernel_size=3, stride=1, padding=1),
        )
        self.branch3 = nn.Sequential(
            nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
            BasicConv3d(192, 32, kernel_size=1, stride=1),
        )

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)

        return out


class Mixed_3c(nn.Module):
    def __init__(self):
        super(Mixed_3c, self).__init__()
        self.branch0 = nn.Sequential(
            BasicConv3d(256, 128, kernel_size=1, stride=1),
        )
        self.branch1 = nn.Sequential(
            BasicConv3d(256, 128, kernel_size=1, stride=1),
            SepConv3d(128, 192, kernel_size=3, stride=1, padding=1),
        )
        self.branch2 = nn.Sequential(
            BasicConv3d(256, 32, kernel_size=1, stride=1),
            SepConv3d(32, 96, kernel_size=3, stride=1, padding=1),
        )
        self.branch3 = nn.Sequential(
            nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
            BasicConv3d(256, 64, kernel_size=1, stride=1),
        )

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)
        return out


class Mixed_4b(nn.Module):
    def __init__(self):
        super(Mixed_4b, self).__init__()

        self.branch0 = nn.Sequential(
            BasicConv3d(480, 192, kernel_size=1, stride=1),
        )
        self.branch1 = nn.Sequential(
            BasicConv3d(480, 96, kernel_size=1, stride=1),
            SepConv3d(96, 208, kernel_size=3, stride=1, padding=1),
        )
        self.branch2 = nn.Sequential(
            BasicConv3d(480, 16, kernel_size=1, stride=1),
            SepConv3d(16, 48, kernel_size=3, stride=1, padding=1),
        )
        self.branch3 = nn.Sequential(
            nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
            BasicConv3d(480, 64, kernel_size=1, stride=1),
        )

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)
        return out


class Mixed_4c(nn.Module):
    def __init__(self):
        super(Mixed_4c, self).__init__()

        self.branch0 = nn.Sequential(
            BasicConv3d(512, 160, kernel_size=1, stride=1),
        )
        self.branch1 = nn.Sequential(
            BasicConv3d(512, 112, kernel_size=1, stride=1),
            SepConv3d(112, 224, kernel_size=3, stride=1, padding=1),
        )
        self.branch2 = nn.Sequential(
            BasicConv3d(512, 24, kernel_size=1, stride=1),
            SepConv3d(24, 64, kernel_size=3, stride=1, padding=1),
        )
        self.branch3 = nn.Sequential(
            nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
            BasicConv3d(512, 64, kernel_size=1, stride=1),
        )

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)
        return out


class Mixed_4d(nn.Module):
    def __init__(self):
        super(Mixed_4d, self).__init__()

        self.branch0 = nn.Sequential(
            BasicConv3d(512, 128, kernel_size=1, stride=1),
        )
        self.branch1 = nn.Sequential(
            BasicConv3d(512, 128, kernel_size=1, stride=1),
            SepConv3d(128, 256, kernel_size=3, stride=1, padding=1),
        )
        self.branch2 = nn.Sequential(
            BasicConv3d(512, 24, kernel_size=1, stride=1),
            SepConv3d(24, 64, kernel_size=3, stride=1, padding=1),
        )
        self.branch3 = nn.Sequential(
            nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
            BasicConv3d(512, 64, kernel_size=1, stride=1),
        )

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)
        return out


class Mixed_4e(nn.Module):
    def __init__(self):
        super(Mixed_4e, self).__init__()

        self.branch0 = nn.Sequential(
            BasicConv3d(512, 112, kernel_size=1, stride=1),
        )
        self.branch1 = nn.Sequential(
            BasicConv3d(512, 144, kernel_size=1, stride=1),
            SepConv3d(144, 288, kernel_size=3, stride=1, padding=1),
        )
        self.branch2 = nn.Sequential(
            BasicConv3d(512, 32, kernel_size=1, stride=1),
            SepConv3d(32, 64, kernel_size=3, stride=1, padding=1),
        )
        self.branch3 = nn.Sequential(
            nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
            BasicConv3d(512, 64, kernel_size=1, stride=1),
        )

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)
        return out


class Mixed_4f(nn.Module):
    def __init__(self):
        super(Mixed_4f, self).__init__()

        self.branch0 = nn.Sequential(
            BasicConv3d(528, 256, kernel_size=1, stride=1),
        )
        self.branch1 = nn.Sequential(
            BasicConv3d(528, 160, kernel_size=1, stride=1),
            SepConv3d(160, 320, kernel_size=3, stride=1, padding=1),
        )
        self.branch2 = nn.Sequential(
            BasicConv3d(528, 32, kernel_size=1, stride=1),
            SepConv3d(32, 128, kernel_size=3, stride=1, padding=1),
        )
        self.branch3 = nn.Sequential(
            nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
            BasicConv3d(528, 128, kernel_size=1, stride=1),
        )

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)
        return out


class Mixed_5b(nn.Module):
    def __init__(self):
        super(Mixed_5b, self).__init__()

        self.branch0 = nn.Sequential(
            BasicConv3d(832, 256, kernel_size=1, stride=1),
        )
        self.branch1 = nn.Sequential(
            BasicConv3d(832, 160, kernel_size=1, stride=1),
            SepConv3d(160, 320, kernel_size=3, stride=1, padding=1),
        )
        self.branch2 = nn.Sequential(
            BasicConv3d(832, 32, kernel_size=1, stride=1),
            SepConv3d(32, 128, kernel_size=3, stride=1, padding=1),
        )
        self.branch3 = nn.Sequential(
            nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
            BasicConv3d(832, 128, kernel_size=1, stride=1),
        )

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)
        return out


class Mixed_5c(nn.Module):
    def __init__(self):
        super(Mixed_5c, self).__init__()

        self.branch0 = nn.Sequential(
            BasicConv3d(832, 384, kernel_size=1, stride=1),
        )
        self.branch1 = nn.Sequential(
            BasicConv3d(832, 192, kernel_size=1, stride=1),
            SepConv3d(192, 384, kernel_size=3, stride=1, padding=1),
        )
        self.branch2 = nn.Sequential(
            BasicConv3d(832, 48, kernel_size=1, stride=1),
            SepConv3d(48, 128, kernel_size=3, stride=1, padding=1),
        )
        self.branch3 = nn.Sequential(
            nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
            BasicConv3d(832, 128, kernel_size=1, stride=1),
        )

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)
        return out


In [114]:
### Initialize S3D

BLOCK2SIZE = {1:64, 2:192, 3:480, 4:832, 5:1024}

class S3Ds(S3Dsup):
    def __init__(self, num_class=400, in_channel=3, use_block=5, freeze_block=0, stride=2):  # 5 and 0 for kinetics, 4 and 1 for gloss
        self.use_block = use_block
        super(S3Ds, self).__init__(in_channels=in_channel, num_class=num_class, use_block=use_block, stride=stride)
        self.freeze_block = freeze_block
        self.END_POINT2BLOCK = {
            0: 'block1',
            3: 'block2',
            6: 'block3',
            12: 'block4',
            15: 'block5',
        }
        self.BLOCK2END_POINT = {blk:ep for ep, blk in self.END_POINT2BLOCK.items()}

        self.frozen_modules = []
        self.use_block = use_block

        if freeze_block>0:
            for i in range(0, self.base_num_layers): #base  0,1,2,...,self.BLOCK2END_POINT[blk]
                module_name = 'base.{}'.format(i)
                submodule = self.base[i]
                assert submodule != None, module_name
                if i <= self.BLOCK2END_POINT['block{}'.format(freeze_block)]:
                    self.frozen_modules.append(submodule)

## Head network

In [145]:
import pickle

with gzip.open("SLRTNGT/TwoStreamNetwork/experiments/outputs/SingleStream/head_rgb_input/test.pkl", 'rb') as f:
                    split_data = pickle.load(f)

print(split_data[3]['sign'].shape)

print(split_data[3])

## So we know the extracted features are only from the S3D, since the size is (number of frames)/4 by 832, 
## the input for the head are T (num frames/4) by 843
## so the head input size varies.

torch.Size([32, 832])
{'name': 'test/10March_2011_Thursday_heute-58', 'gloss': 'WOCHENENDE SONNE SAMSTAG SCHOEN TEMPERATUR BIS SIEBZEHN GRAD REGION', 'text': 'sonnig geht es auch ins wochenende samstag ein herrlicher tag mit temperaturen bis siebzehn grad hier im westen .', 'num_frames': 130, 'sign': tensor([[2.7592e-36, 0.0000e+00, 2.3513e-36,  ..., 1.3635e-02, 7.4006e-02,
         1.2359e-01],
        [2.7592e-36, 0.0000e+00, 2.3513e-36,  ..., 9.0310e-04, 8.9023e-02,
         2.2051e-01],
        [2.7592e-36, 0.0000e+00, 2.3513e-36,  ..., 1.1407e-04, 8.9353e-02,
         2.4418e-01],
        ...,
        [2.7592e-36, 0.0000e+00, 2.3513e-36,  ..., 4.1541e-02, 7.7600e-02,
         8.1139e-02],
        [2.7592e-36, 0.0000e+00, 2.3513e-36,  ..., 3.9211e-02, 7.4247e-02,
         5.1760e-02],
        [2.7592e-36, 0.0000e+00, 2.3513e-36,  ..., 3.2588e-02, 6.8302e-02,
         4.3865e-02]])}


In [146]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from SLRTNGT.TwoStreamNetwork.utils.misc import get_logger
from SLRTNGT.TwoStreamNetwork.modelling.utils import PositionalEncoding, MaskedNorm, PositionwiseFeedForward, MLPHead

#testing:
#  cfg:
#    recognition:
#      beam_size: 5

#model:
#  RecognitionNetwork:
#    GlossTokenizer:
#      gloss2id_file: data/csl-daily/gloss2ids.pkl
#    s3d:
#      pretrained_ckpt: pretrained_models/s3ds_glosscls_ckpt
#      use_block: 4
#      freeze_block: 1
#    visual_head:
#      input_size: 832
#      hidden_size: 512
#      ff_size: 2048 
#      pe: True
#      ff_kernelsize:
#        - 3
#        - 3

In [169]:
class VisualHead(torch.nn.Module):
    def __init__(self, 
        cls_num, input_size=832, hidden_size=512, ff_size=2048, pe=True,
        ff_kernelsize=3, pretrained_ckpt=None, is_empty=False, frozen=False, 
        plus_conv_cfg={},
        ssl_projection_cfg={}):
        super().__init__()
        self.is_empty = is_empty
        self.plus_conv_cfg = plus_conv_cfg
        self.ssl_projection_cfg = ssl_projection_cfg
        if is_empty==False:
            self.frozen = frozen
            self.hidden_size = hidden_size

            if input_size is None:
                self.fc1 = nn.Identity()
            else:
                self.fc1 = torch.nn.Linear(input_size, self.hidden_size)
            self.bn1 = MaskedNorm(num_features=self.hidden_size, norm_type='sync_batch')
            self.relu1 = torch.nn.ReLU()
            self.dropout1 = torch.nn.Dropout(p=0.1)

            if pe:
                self.pe = PositionalEncoding(self.hidden_size)
            else:
                self.pe = torch.nn.Identity()

            self.feedforward = PositionwiseFeedForward(input_size=self.hidden_size,
                ff_size=ff_size,
                dropout=0.1, kernel_size=ff_kernelsize, skip_connection=True)
            
            self.layer_norm = torch.nn.LayerNorm(self.hidden_size, eps=1e-6)

            if plus_conv_cfg!={}:
                plus_convs = []
                for i in range(plus_conv_cfg['num_layer']):
                    plus_convs.append(nn.Conv1d(self.hidden_size, self.hidden_size, 
                        kernel_size=plus_conv_cfg['kernel_size'], stride=plus_conv_cfg['stride'], padding_mode='replicate'))
                self.plus_conv = nn.Sequential(*plus_convs)
            else:
                self.plus_conv = nn.Identity()

            if ssl_projection_cfg!={}:
                self.ssl_projection = MLPHead(embedding_size=self.hidden_size, 
                    projection_hidden_size=ssl_projection_cfg['hidden_size'])

            self.gloss_output_layer = torch.nn.Linear(self.hidden_size, cls_num)

            if self.frozen:
                self.frozen_layers = [self.fc1, self.bn1, self.relu1,  self.pe, self.dropout1, self.feedforward, self.layer_norm]
                for layer in self.frozen_layers:
                    for name, param in layer.named_parameters():
                        param.requires_grad = False
                    layer.eval()
        else:
            self.gloss_output_layer = torch.nn.Linear(input_size, cls_num)
        if pretrained_ckpt:
            self.load_from_pretrained_ckpt(pretrained_ckpt)

    def load_from_pretrained_ckpt(self, pretrained_ckpt):
        logger = Logger     # get_logger()
        checkpoint = torch.load(pretrained_ckpt, map_location='cpu')#['model_state']
        load_dict = {}
        for k,v in checkpoint.items():
            if 'recognition_network.visual_head.' in k:
                load_dict[k.replace('recognition_network.visual_head.','')] = v
        self.load_state_dict(load_dict)
        logger.info('Load Visual Head from pretrained ckpt {}'.format(pretrained_ckpt))

    def forward(self, x, mask, valid_len_in=None):
        B, Tin, D = x.shape 
        if self.is_empty==False:
            if not self.frozen:
                #projection 1
                x = self.fc1(x)
                x = self.bn1(x, mask)
                x = self.relu1(x)
                #pe
                x = self.pe(x)
                x = self.dropout1(x)

                #feedforward
                x = self.feedforward(x)
                x = self.layer_norm(x)

                x = x.transpose(1,2)
                x = self.plus_conv(x)
                x = x.transpose(1,2)
            else:
                with torch.no_grad():
                    for ii, layer in enumerate(self.frozen_layers):
                        layer.eval()
                        if ii==1:
                            x = layer(x, mask)
                        else:
                            x = layer(x)
                x = x.transpose(1,2)
                x = self.plus_conv(x)
                x = x.transpose(1,2)

        #classification
        logits = self.gloss_output_layer(x) #B,T,V

        #softmax
        gloss_probabilities_log = logits.log_softmax(2) 
        gloss_probabilities = logits.softmax(2)

        if self.plus_conv_cfg!={}:
            B, Tout, D = x.shape
            valid_len_out = torch.floor(valid_len_in*Tout/Tin).long() #B,
        else:
            valid_len_out = valid_len_in
        if self.ssl_projection_cfg!={}:
            x_ssl = self.ssl_projection(x)
            if self.ssl_projection_cfg['normalize']==True:
                x_ssl = F.normalize(x_ssl, dim=-1)
        else:
            x_ssl = None

        ## These are all the different features we can use
        return {'gloss_feature_ssl':x_ssl, 
                'gloss_feature': x,
                'gloss_feature_norm': F.normalize(x, dim=-1),
                'gloss_logits':logits, 
                'gloss_probabilities_log':gloss_probabilities_log,
                'gloss_probabilities': gloss_probabilities,
                'valid_len_out':valid_len_out}

#### Main

In [196]:
## We test the pretrained S3D model on a video from the CorpusNGT dataset, and it indeed works.

def transform(snippet):
    ''' stack & noralization '''
    snippet = np.concatenate(snippet, axis=-1)
    snippet = torch.from_numpy(snippet).permute(2, 0, 1).contiguous().float()
    snippet = snippet.mul_(2.).sub_(255).div(255)
    snippet = snippet.view(1,-1,3,snippet.size(1),snippet.size(2)).permute(0,2,1,3,4) 
    print(snippet.shape)
    return snippet
    # returns tensor in size [batch, channels, frames, height, width]
    # all values normalized

def main():
    ''' Output the top 5 Kinetics classes predicted by the model or the gloss features'''
    
    #path_sample = './sample'
    path_sample = 'Data/CorpusNGT/gloss_split/WINNEN/9'
    
    #file_weight = 'SLRT-NGT/TwoStreamNetwork/pretrained_models/s3ds_glosscls_ckpt/epoch299.pth.tar'
    file_weight = 'SLRT-NGT/TwoStreamNetwork/pretrained_models/csl-daily_s2g/ckpts/best.ckpt'
    #file_weight = 'SLRT-NGT/TwoStreamNetwork/pretrained_models/s3ds_actioncls_ckpt/S3D_kinetics400.pt'
    
    class_names = pd.read_pickle("SLRT-NGT/TwoStreamNetwork/data/csl-daily/gloss2ids.pkl")
    class_names = {K-4:V for (V,K) in [x for x in class_names.items()][4:]}
    #class_names = pd.read_csv("Data/Kinetics_labels/kinetics_400_labels.csv")
    
    num_class = len(class_names)
    print("Number of classes: " + str(num_class))
    #num_class = 400

    state = "features"
    #state = "kinetics"


    ### Perform S3D feature extraction
    model = S3Ds(num_class, use_block=4, freeze_block=1)  ## 4 and 1 for gloss

    # load the weight file and copy the parameters
    if os.path.isfile(file_weight):
        print ('loading weight file')
        weight_dict = torch.load(file_weight)
        model_dict = model.state_dict()
        for name, param in weight_dict.items(): # name is the name of the module, param is the weights
            #print("NAME " + name)
            #print(param.shape)
            #print(param)
            if 'module' in name:
                name = '.'.join(name.split('.')[1:])
            if name in model_dict:
                if param.size() == model_dict[name].size():
                    model_dict[name].copy_(param)
                else:
                    print (' size? ' + name, param.size(), model_dict[name].size())
            else:
                print (' name? ' + name)

        print (' loaded')
    else:
        print ('weight file?')

    model = model.cuda()
    torch.backends.cudnn.benchmark = False
    model.eval()

    list_frames = [f for f in os.listdir(path_sample) if os.path.isfile(os.path.join(path_sample, f)) and f.split(".")[1] == "jpg"]
    list_frames.sort()

    # read all the frames of sample clip
    snippet = []
    for frame in list_frames:
        img = cv2.imread(os.path.join(path_sample, frame))
        img = cv2.resize(img, [270, 270])
        img = img[...,::-1]
        snippet.append(img)
        #snippet.append(img) ## added because not enough frames in test

    clip = transform(snippet)

    with torch.no_grad():
        logits = model(clip.cuda()).cpu().data#[0]

    if state == "features":
        #print(logits)
        print ('\nThe features outputted by pretrained S3D')
        print(logits)

    if state == "kinetics":
        preds = torch.softmax(logits, 0).numpy()
        sorted_indices = np.argsort(preds)[::-1][:5]
        print(sorted_indices)
        print(logits.shape)
        print ('\nTop 5 kinetics classes ... with probability')
        for idx in sorted_indices:
            #print(class_names['name'][idx], '...', preds[idx])
            print(class_names[idx], '...', preds[idx])


    ### Perform Visual Head feature extraction
    #file_weight_vh = 'SLRT-NGT/TwoStreamNetwork/pretrained_models/csl-daily_s2g/ckpts/best.ckpt'
    model = VisualHead(cls_num=num_class, pe=True)  ## 4 and 1 for gloss    #pretrained_ckpt=file_weight_vh
    
    with torch.no_grad():
        #print(split_data[3]['sign'])
        #features = model(x=split_data[3]['sign'].unsqueeze(0), mask=torch.zeros(1))  
        features = model(x=logits.unsqueeze(0), mask=torch.zeros(1))  
        print("\n")
        print(features['gloss_feature'].shape)
        print(features['gloss_probabilities'].shape)
        print(features)


if __name__ == '__main__':
    main()

Number of classes: 2000
loading weight file
 name? model_state
 loaded
torch.Size([1, 3, 9, 270, 270])

The features outputted by pretrained S3D
tensor([[1.3838e-12, 0.0000e+00, 5.1434e-08, 3.1768e-06, 0.0000e+00, 0.0000e+00,
         3.9563e-07, 5.3867e-07, 1.7031e-09, 1.2489e-06, 1.6452e-06, 0.0000e+00,
         1.5541e-07, 7.2978e-07, 2.1194e-06, 1.2691e-06, 2.9470e-06, 0.0000e+00,
         3.8539e-09, 1.8049e-06, 1.8538e-06, 2.1831e-06, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 3.0256e-07, 2.5964e-06, 6.4925e-06, 1.0622e-06, 0.0000e+00,
         2.6224e-06, 0.0000e+00, 0.0000e+00, 1.1861e-06, 5.7402e-06, 6.9225e-08,
         0.0000e+00, 5.8346e-07, 1.1322e-07, 1.5413e-06, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0068e-07,
         0.0000e+00, 2.7371e-06, 0.0000e+00, 5.4820e-07, 1.7476e-06, 0.0000e+00,
         3.9122e-06, 2.6886e-06, 1.6636e-06, 2.9926e-07, 5.0782e-06, 1.9773e-10,
         3.2010e-06, 2.4459e-06, 0.0000e+00, 

## Classification S2G Test

So the pretrained S3D and visual head work, so we can use these to extract the features for all the CorpusNGT data
We can choose to then use either the visual features exported by the S3D, the gloss representations outputted by the VisualHead, or the gloss logits or gloss probabilities which have been classified and then softmaxed. Either way, this is the S3D is the 'cold' part of the model; we don't further train these weights. The VisualHead plus Classifier do not use pretrained weights, so if we are training for S2G this is the 'hot' part that would be trained. 

Next steps:
- Test to see if a the VisualHead and Classifier trained on CorpusNGT acchieve good performance
- See if this performance gets even better when using Mediapipe data alongside the S3D features.