<a href="https://colab.research.google.com/github/ADA-SITE-JML/sign-lang/blob/main/jamal/LSTM_Attention_Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import squeezenet1_1
from torchvision.models.feature_extraction import create_feature_extractor
import sklearn.utils
import pandas as pd
import os

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

rnn_type = 'GRU'
max_frames = 5
BATCH_SIZE = 1

print('Running on ' + device + ' with ' + rnn_type)

Running on cuda:0 with GRU


In [2]:
import torchvision
from torchvision.models import squeezenet1_1
from torchvision.models.feature_extraction import create_feature_extractor

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, device, biDirectional = False):
        super(EncoderRNN, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.device = device
        # this is for LSTM
        self.D = 2 if biDirectional else 1

        if rnn_type == 'GRU':
          self.rnn = nn.GRU(
                  input_size = self.input_size,
                  hidden_size = self.hidden_size*self.D,
                  batch_first = True).to(device)
        elif rnn_type == 'LSTM':
          self.rnn = nn.LSTM(
                  input_size = self.input_size,
                  hidden_size = self.hidden_size*self.D,
                  num_layers = 1,
                  dropout = 0,
                  bidirectional = biDirectional,
                  batch_first = True).to(device)

    def forward(self, input, hidden):
        output, hidden = self.rnn(input, hidden)
        return output, hidden

    def initHidden(self):
        if rnn_type == 'GRU':
          return torch.zeros(self.D, BATCH_SIZE, self.hidden_size*self.D, device=self.device)
        elif rnn_type == 'LSTM':
          return (torch.zeros(self.D, BATCH_SIZE, self.hidden_size*self.D, device=self.device),
                  torch.zeros(self.D, BATCH_SIZE, self.hidden_size*self.D, device=self.device))

In [3]:
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, device, dropout_p=0.1, max_length=max_frames, biDirectional = False, debug=False): #max_length=config.max_words_in_sentence
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length
        self.debug = debug
        self.device = device

        self.D = 2 if biDirectional else 1

        if self.debug:
          print('Attn.init() hidden_size',hidden_size)
          print('Attn.init() output_size',output_size)
          print('Attn.init() max_length',max_length)

        self.embedding = nn.Embedding(self.output_size, self.hidden_size*2)
        self.attn = nn.Linear(self.hidden_size * 3, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 3, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)

        if rnn_type == 'GRU':
          self.rnn = nn.GRU(
                  input_size = self.hidden_size,
                  hidden_size = self.hidden_size*self.D,
                  batch_first = True).to(device)
        elif rnn_type == 'LSTM':
          self.rnn = nn.LSTM(
                  input_size = self.hidden_size,
                  hidden_size = self.hidden_size*self.D,
                  num_layers = 1,
                  dropout = 0,
                  bidirectional = biDirectional,
                  batch_first = True)
        self.out = nn.Linear(self.hidden_size*self.D, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        if rnn_type == 'GRU':
          hidden = hidden.unsqueeze(0)

        embedded = self.embedding(input).view(input.shape[0],input.shape[1], self.hidden_size*2)
        embedded = self.dropout(embedded)

        if self.debug:
          print('Attn.forward() input',input.shape)
          print('Attn.forward() hidden',type(hidden),len(hidden),hidden[0].shape)
          print('Attn.forward() encoder_outputs',encoder_outputs.shape)
          print('embedded: ',embedded.shape)

        tcat = torch.cat((embedded[0], hidden[0]), 1)
        attn_weights = F.softmax(self.attn(tcat), dim=1).to(device=self.device)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),encoder_outputs).to(device=self.device)

        output = torch.cat((embedded[0], attn_applied[0]), 1).to(device=self.device)
        output = self.attn_combine(output).unsqueeze(0).to(device=self.device)

        output = F.relu(output)
        if rnn_type == 'GRU':
          output, hidden = self.rnn(output, hidden[0].unsqueeze(0))
        elif rnn_type == 'LSTM':
          output, hidden = self.rnn(output, (hidden[0].unsqueeze(0),hidden[0].unsqueeze(0)))

        output = F.log_softmax(self.out(output[0]), dim=1).to(device=self.device)
        return output, hidden, attn_weights

    def initHidden(self):
        if rnn_type == 'GRU':
          return torch.zeros(self.D, 1, self.hidden_size*self.D, device=self.device)
        elif rnn_type == 'LSTM':
          return (torch.zeros(self.D, 1, self.hidden_size*self.D, device=self.device),
                  torch.zeros(self.D, 1, self.hidden_size*self.D, device=self.device))

In [4]:
def evaluate(encoder, decoder, frames, max_length = 5):
    with torch.no_grad():
        encoder_hidden = encoder.initHidden()

        encoder_output, encoder_hidden = encoder(frames, encoder_hidden)

        decoder_input = torch.tensor([[encodings['SOS']]], device=device)  # Start of sentence

        decoder_hidden = encoder_hidden

        decoded_words = ''
        decoder_attentions = torch.zeros(max_length, max_length, device=device)

        for di in range(max_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden[0], encoder_output)

            decoder_attentions[di] = decoder_attention.data

            topv, topi = decoder_output.data.topk(1)

            if topi.item() == encodings['EOS']:
                decoded_words += '.'
                break
            else:
                decoded_words += word_idx[topi.item()] + ' '

            decoder_input = topi.detach()

        return decoded_words, decoder_attentions[:di + 1]

In [5]:
from google.colab import drive
drive.mount('/content/drive')
data_folder = 'drive/MyDrive/SLR/Data/'

Mounted at /content/drive


In [6]:
encodings = torch.load(data_folder+'/jamal/encodings.dict')
word_idx = torch.load(data_folder+'/jamal/word_idx.dict')

# encoder = torch.load(data_folder+'/jamal/encoder_' + device + '.model', map_location=torch.device(device))
# decoder = torch.load(data_folder+'/jamal/decoder_' + device + '.model', map_location=torch.device(device))

input_size = 2048
hidden_size = 64
encoding_len = len(encodings)
DEBUG = False

encoder = EncoderRNN(input_size, hidden_size, device = device, biDirectional = False).to(device)
decoder = AttnDecoderRNN(hidden_size, len(encodings), device = device, dropout_p=0.1, biDirectional = False, debug = DEBUG).to(device)

encoder.load_state_dict(torch.load(data_folder+'/jamal/encoder_' + device + '.model', map_location=torch.device(device)))
decoder.load_state_dict(torch.load(data_folder+'/jamal/decoder_' + device + '.model', map_location=torch.device(device)))

# encoder.eval()
# decoder.eval()

<All keys matched successfully>

In [7]:
# For I3D features
!git clone https://github.com/v-iashin/video_features.git
!pip install omegaconf==2.0.6

%cd video_features

from models.i3d.extract_i3d import ExtractI3D
from models.raft.raft_src.raft import RAFT, InputPadder
from utils.utils import build_cfg_path
from omegaconf import OmegaConf

# Load and patch the config
args = OmegaConf.load(build_cfg_path('i3d'))
# args.show_pred = True
# args.stack_size = 24
# args.step_size = 24
# args.extraction_fps = 30
args.flow_type = 'raft' # 'pwc' is not supported on Google Colab (cupy version mismatch)
# args.streams = 'flow'

# Load the model
extractor = ExtractI3D(args)

Cloning into 'video_features'...
remote: Enumerating objects: 1299, done.[K
remote: Counting objects: 100% (420/420), done.[K
remote: Compressing objects: 100% (189/189), done.[K
remote: Total 1299 (delta 264), reused 322 (delta 215), pack-reused 879[K
Receiving objects: 100% (1299/1299), 288.63 MiB | 18.76 MiB/s, done.
Resolving deltas: 100% (671/671), done.
Updating files: 100% (177/177), done.
Collecting omegaconf==2.0.6
  Downloading omegaconf-2.0.6-py3-none-any.whl (36 kB)
Installing collected packages: omegaconf
Successfully installed omegaconf-2.0.6
/content/video_features


In [9]:
# read cvs file
# %cd $drive_folder
drive_folder = '/content/drive/MyDrive/SLR/Data/'
video_folder = drive_folder+'/Video'
train_csv_path = drive_folder+'sentences_all.csv'
camera_source = 'Cam2' # Cam1 - side-top, Cam2 - front

sentences = pd.read_csv(train_csv_path)
sentences = sklearn.utils.shuffle(sentences)

for index, row in sentences.iterrows():
    id = int(row[0])

    phrase = row[2].lower()

    # there is a grouping of videos in production.
    pre_folder = '/1-250/' if (id < 251) else '/'

    dir = video_folder+'/' + camera_source + pre_folder + str(id)
    # iterate over video folders
    fidx = 1

    if str(device).startswith('cuda'):
      torch.cuda.empty_cache()


    for filename in os.listdir(dir):
        if fidx > 3:
          break
        f = os.path.join(dir, filename)
        # checking if it is a file
        if os.path.isfile(f):
            video_id = filename[:filename.rindex('.')]
            print(id,f)

            feature_dict = extractor.extract(f)

            f_num, f_size = feature_dict['rgb'].shape
            REQ_FEATS = 5 # required number of features

            # Keep only REQ_FEATS features from each and apply zero padding if there are less than REQ_FEATS features
            feats_rgb = torch.from_numpy(feature_dict['rgb'])
            feats_flow = torch.from_numpy(feature_dict['flow'])

            # Trim extra features.
            # Trim shall be applied on each, since we need to have equal number of RGB and FLOW features.
            # Like for RGB and FLOW, 8 features each will make 16 features if we apply catenation first.
            # If we trimming after that to keep 10 features, eight of them will be about RGB, two - FLOW.
            if f_num > REQ_FEATS:
              feats_rgb  = feats_rgb[-(REQ_FEATS-f_num):,:]
              feats_flow = feats_flow[-(REQ_FEATS-f_num):,:]

            # Concatenate the features
            feats = torch.cat((feats_rgb,feats_flow),1)

            # Apply zero padding if needed.
            # Zero padding needs to be done after the catenation - zero features shall come at the end, not after each type (RGB and FLOW)
            if f_num < REQ_FEATS:
              padarr = torch.zeros((REQ_FEATS-f_num,f_size*2))
              feats = torch.cat((feats,padarr),0)

            print('Original:',phrase)
            output_words, attentions = evaluate(encoder, decoder, feats.float().unsqueeze(0).to(device))
            print('Prediction:',output_words)

            fidx += 1

277 /content/drive/MyDrive/SLR/Data//Video/Cam2/277/2022-10-26 14-04-29.mp4
Original: dünən youtube maraqlı video baxmaq olmaq
Prediction: mən ev 2 pişik var 
277 /content/drive/MyDrive/SLR/Data//Video/Cam2/277/2022-10-26 14-56-57.mp4
Original: dünən youtube maraqlı video baxmaq olmaq
Prediction: mən ev 2 pişik var 
277 /content/drive/MyDrive/SLR/Data//Video/Cam2/277/2022-11-04 12-32-06.mp4
Original: dünən youtube maraqlı video baxmaq olmaq
Prediction: mən 2 oğul 1 qız 
197 /content/drive/MyDrive/SLR/Data//Video/Cam2/1-250/197/2022-06-02 13-01-27.mp4
Original: mənim şəkil çəkmək xoşum gəlmək
Prediction: mən ev 2 pişik var 
197 /content/drive/MyDrive/SLR/Data//Video/Cam2/1-250/197/2022-06-02 15-28-48.mp4
Original: mənim şəkil çəkmək xoşum gəlmək
Prediction: mən 2 oğul 1 qız 
197 /content/drive/MyDrive/SLR/Data//Video/Cam2/1-250/197/2022-06-02 16-48-05.mp4
Original: mənim şəkil çəkmək xoşum gəlmək
Prediction: mən ev 2 pişik var 
282 /content/drive/MyDrive/SLR/Data//Video/Cam2/282/2022-10

KeyboardInterrupt: ignored