## HMT Model

**packages**

In [1]:
import numpy as np
import os
import h5py
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.modules import *

import random

#### MultiHead Attention

In [28]:

class MultiheadAttention(nn.Module):
    def __init__(self, input_size=1024, output_size=1024, freq=10000, heads=1, pos_enc=None):
        super(MultiheadAttention, self).__init__()

#         self.permitted_encodings = ["absolute", "relative"]
#         if pos_enc is not None:
#             pos_enc = pos_enc.lower()
#             assert pos_enc in self.permitted_encodings, f"Supported encodings: {*self.permitted_encodings,}"

        self.heads = heads
        self.pos_enc = pos_enc
        self.freq = freq
        self.softmax = nn.Softmax(dim=-1)
        self.relu = nn.ReLU()
        self.input_size = input_size
        self.output_size = output_size 
        self.Wq, self.Wk, self.Wv = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
        
        for _ in range(self.heads):
            self.Wk.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False))
            self.Wq.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False))
            self.Wv.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False))
            
        self.W0 = nn.Linear(in_features=input_size, out_features=input_size, bias=False)
        self.drop = nn.Dropout(p=0.5)
        self.W2 = nn.Linear(in_features=input_size, out_features=input_size, bias=False)
        self.LN = nn.LayerNorm(normalized_shape=self.W2.out_features, eps=1e-6)
    
    
########################################## Positional Encoding Code ###############################
#     def getAbsolutePosition(self, T):
#         """Calculate the sinusoidal positional encoding based on the absolute position of each considered frame.
#         Based on 'Attention is all you need' paper (https://arxiv.org/abs/1706.03762)
#         :param int T: Number of frames contained in Q, K and V
#         :return: Tensor with shape [T, T]
#         """
#         freq = self.freq
#         d = self.input_size

#         pos = torch.tensor([k for k in range(T)], device=self.out.weight.device)
#         i = torch.tensor([k for k in range(T//2)], device=self.out.weight.device)

#         # Reshape tensors each pos_k for each i indices
#         pos = pos.reshape(pos.shape[0], 1)
#         pos = pos.repeat_interleave(i.shape[0], dim=1)
#         i = i.repeat(pos.shape[0], 1)

#         AP = torch.zeros(T, T, device=self.out.weight.device)
#         AP[pos, 2*i] = torch.sin(pos / freq ** ((2 * i) / d))
#         AP[pos, 2*i+1] = torch.cos(pos / freq ** ((2 * i) / d))
#         return AP

#     def getRelativePosition(self, T):
#         """Calculate the sinusoidal positional encoding based on the relative position of each considered frame.
#         r_pos calculations as here: https://theaisummer.com/positional-embeddings/
#         :param int T: Number of frames contained in Q, K and V
#         :return: Tensor with shape [T, T]
#         """
#         freq = self.freq
#         d = 2 * T
#         min_rpos = -(T - 1)

#         i = torch.tensor([k for k in range(T)], device=self.out.weight.device)
#         j = torch.tensor([k for k in range(T)], device=self.out.weight.device)

#         # Reshape tensors each i for each j indices
#         i = i.reshape(i.shape[0], 1)
#         i = i.repeat_interleave(i.shape[0], dim=1)
#         j = j.repeat(i.shape[0], 1)

#         # Calculate the relative positions
#         r_pos = j - i - min_rpos

#         RP = torch.zeros(T, T, device=self.out.weight.device)
#         idx = torch.tensor([k for k in range(T//2)], device=self.out.weight.device)
#         RP[:, 2*idx] = torch.sin(r_pos[:, 2*idx] / freq ** ((i[:, 2*idx] + j[:, 2*idx]) / d))
#         RP[:, 2*idx+1] = torch.cos(r_pos[:, 2*idx+1] / freq ** ((i[:, 2*idx+1] + j[:, 2*idx+1]) / d))
#         return RP
###################################################################################################

    def forward(self, x):

        embs = []
        
        for head in range(self.heads):
            K = self.Wk[head](x)
            Q = self.Wq[head](x)
            V = self.Wv[head](x)
            

            # Q *= 0.06                            # scale factor
            
            energies = torch.matmul(Q, K.transpose(1, 2))/np.sqrt(self.output_size//self.heads)
            
            ######### Positional Encoding skipped ##############
            # if self.pos_enc is not None:
            #     if self.pos_enc == "absolute":
            #         AP = self.getAbsolutePosition(T=energies.shape[0])
            #         energies = energies + AP
            #     elif self.pos_enc == "relative":
            #         RP = self.getRelativePosition(T=energies.shape[0])
            #         energies = energies + RP
            ###########################################
            
            
            _att = self.softmax(energies)
            att_weights = self.drop(_att)
            emb = torch.matmul(att_weights, V)
            
            # Save the current head output
            embs.append(emb)
            
        embs = self.W0(torch.cat(embs, dim=-1))
        embs = self.relu(embs)
        embs = self.W2(embs)
        embs = self.LN(embs)
        
        return embs




### Frame and Audio Attention

In [29]:

class FrameAndAudioAttention(nn.Module):
    def __init__(self, vid_input_size=1024, aud_input_size=128, freq=10000, heads=1, pos_enc=None):
        
        super(FrameAndAudioAttention, self).__init__()
        
        self.heads = heads
        self.pos_enc = pos_enc
        self.freq = freq
        self.softmax = nn.Softmax(dim=-1)
        self.relu = nn.ReLU()
    
        ##### Attention for Frames
        self.vid_input_size = vid_input_size
        self.vid_output_size = vid_input_size 
        self.Wq_v, self.Wk_v, self.Wv_v = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
        for _ in range(self.heads):
            self.Wk_v.append(nn.Linear(in_features=vid_input_size, out_features=vid_input_size//heads, bias=False))
            self.Wq_v.append(nn.Linear(in_features=vid_input_size, out_features=vid_input_size//heads, bias=False))
            self.Wv_v.append(nn.Linear(in_features=vid_input_size, out_features=vid_input_size//heads, bias=False))
            
        self.v_W0 = nn.Linear(in_features=vid_input_size, out_features=vid_input_size, bias=False)
        self.v_drop = nn.Dropout(p=0.5)
        self.v_W2 = nn.Linear(in_features=vid_input_size, out_features=vid_input_size, bias=False)
        self.v_LN = nn.LayerNorm(normalized_shape=self.v_W2.out_features, eps=1e-6)
        
        
        ##### Attention for Audios
        self.aud_input_size = aud_input_size
        self.aud_output_size = aud_input_size
        self.Wk_a, self.Wv_a = nn.ModuleList(), nn.ModuleList()
        for _ in range(self.heads):
            self.Wk_a.append(nn.Linear(in_features=aud_input_size, out_features=vid_input_size//heads, bias=False))
            self.Wv_a.append(nn.Linear(in_features=aud_input_size, out_features=aud_input_size//heads, bias=False))
            
        self.a_W0 = nn.Linear(in_features=aud_input_size, out_features=aud_input_size, bias=False)
        self.a_drop = nn.Dropout(p=0.5)
        self.a_W2 = nn.Linear(in_features=aud_input_size, out_features=aud_input_size, bias=False)
        self.a_LN = nn.LayerNorm(normalized_shape=self.a_W2.out_features, eps=1e-6)

        
    def forward(self, vid, aud):
        
        vid_embs = []
        aud_embs = []
        
        for head in range(self.heads):
            K_v = self.Wk_v[head](vid)
            Q_v = self.Wq_v[head](vid)
            V_v = self.Wv_v[head](vid)
            
            K_a = self.Wk_a[head](aud)
            V_a = self.Wv_a[head](aud)
            
            
            energies_vid = torch.matmul(Q_v, K_v.transpose(1, 2))/np.sqrt(self.vid_output_size//self.heads)
            energies_aud = torch.matmul(Q_v, K_a.transpose(1, 2))/np.sqrt(self.aud_output_size//self.heads)

            vid_att = self.softmax(energies_vid)
            vid_att_weights = self.v_drop(vid_att)
            vid_emb = torch.matmul(vid_att_weights, V_v)
            
            aud_att = self.softmax(energies_aud)
            aud_att_weights = self.a_drop(aud_att)
            aud_emb = torch.matmul(aud_att_weights, V_a)
            
            
            # Save the current head output
            vid_embs.append(vid_emb)
            aud_embs.append(aud_emb)
            
        
        vid_embs = self.v_W0(torch.cat(vid_embs, dim=-1))
        
        vid_embs = self.relu(vid_embs)
        vid_embs =self.v_W2(vid_embs)
        vid_embs =self.v_LN(vid_embs)
        
        aud_embs = self.a_W0(torch.cat(aud_embs, dim=-1))
        aud_embs = self.relu(aud_embs)
        aud_embs =self.a_W2(aud_embs)
        aud_embs =self.a_LN(aud_embs)
        
        
        return vid_embs, aud_embs


# if __name__ == '__main__':
#     pass

### Hierarchical Multimodal Transformer

In [27]:
class HMT(nn.Module):
    def __init__(self, vid_input_size=1024, aud_input_size=128, freq=10000, pos_enc=None, num_segments=None, heads=1, fusion=None):

        super(HMT, self).__init__()
        
        self.FrameAndAudioAttention = FrameAndAudioAttention(vid_input_size, aud_input_size, freq=freq, pos_enc=pos_enc, heads=4)
        self.Shot_attention = MultiheadAttention(input_size=1024+128, output_size=1024+128, freq=freq, pos_enc=pos_enc, heads=4)
        self.Scorer = nn.Linear(in_features=1024+128, out_features=1)
        self.probs = nn.Softmax(dim=0)
       

        ################# Global Attention ############
        # self.glob_attention = SelfAttention(input_size=input_size, output_size=output_size,
        #                                freq=freq, pos_enc=pos_enc, heads=heads)

        
    def forward(self, x, a, shot_boundaries):
        
        # weighted_value, attn_weights = self.glob_attention(x)  # global attention
        
        Shot_embeddings_list=[]
        start=0
        for _, end in shot_boundaries:
            end=math.ceil(end/15)
            
            shot=x[0][start:end].unsqueeze(0)
            audio= a[0][start:end].unsqueeze(0)
            
            # print('star and stop', start,end)
            start = end
            
            
            sv, sa = self.FrameAndAudioAttention(shot, audio)
            
            shot_emb = torch.mean(sv,1)
            audio_emb = torch.mean(sa,1)
            
            Shot_embeddings_list.append(torch.hstack((shot_emb,audio_emb)))
            
        Shot_embeddings = torch.cat(Shot_embeddings_list, dim=0)
        
        E2 = self.Shot_attention(Shot_embeddings.unsqueeze(0))
        P = self.Scorer(E2.squeeze(0))
        P = self.probs(P.squeeze(0))
            
        return P

---------------------------

In [17]:
def train():
        losses=[]
        with h5py.File('../../Preprocessing/extracted_features/normal/TVSum.h5') as d:
            key = list(d.keys())[0]
            vid_feats= d[key]['features'][...]
            aud_feats= d[key]['aud_feats'][...]
            boundaries = d[key]['change_points'][...]
            target = d[key]['gt_probs'][...]
            # target = target.astype(float)
                
        # seq = dataset['features'][...]
        vid_feats = torch.from_numpy(vid_feats).unsqueeze(0)
        aud_feats = torch.from_numpy(aud_feats).unsqueeze(0)
        # target = dataset['gtscore'][...]
        # target = torch.from_numpy(target).unsqueeze(0)

        # Min-Max Normalize frame scores
        # target -= target.min()
        # target /= target.max()


#         if self.hps.use_cuda:
#             seq, target = seq.float().cuda(), target.float().cuda()

        # seq_len = seq.shape[1]
        print('video and audio feat:', vid_feats.shape, aud_feats.shape)
        P = model(vid_feats, aud_feats, boundaries)
        # print(P)
        
#         loss_att = 0
        print(P.reshape(-1))
#         loss = self.criterion(P, target.float())
#         loss = loss + loss_att
#         self.optimizer.zero_grad()
#         loss.backward()
#         self.optimizer.step()
#         losses.append(float(loss))
            
        return np.mean(np.array(losses))

In [18]:
# model=HMT()
# train()

video and audio feat: torch.Size([1, 706, 1024]) torch.Size([1, 706, 128])
stack shape: torch.Size([1, 184, 1152])
tensor([0.0052, 0.0050, 0.0051, 0.0062, 0.0052, 0.0051, 0.0055, 0.0060, 0.0051,
        0.0055, 0.0054, 0.0057, 0.0057, 0.0057, 0.0053, 0.0052, 0.0054, 0.0057,
        0.0053, 0.0052, 0.0050, 0.0055, 0.0052, 0.0053, 0.0051, 0.0053, 0.0051,
        0.0061, 0.0061, 0.0052, 0.0057, 0.0049, 0.0052, 0.0056, 0.0048, 0.0059,
        0.0056, 0.0057, 0.0058, 0.0053, 0.0057, 0.0057, 0.0053, 0.0050, 0.0047,
        0.0051, 0.0055, 0.0052, 0.0057, 0.0057, 0.0055, 0.0051, 0.0054, 0.0049,
        0.0055, 0.0055, 0.0056, 0.0052, 0.0058, 0.0054, 0.0058, 0.0054, 0.0052,
        0.0051, 0.0053, 0.0053, 0.0054, 0.0054, 0.0050, 0.0058, 0.0057, 0.0049,
        0.0052, 0.0061, 0.0051, 0.0055, 0.0049, 0.0056, 0.0053, 0.0053, 0.0051,
        0.0053, 0.0059, 0.0060, 0.0051, 0.0056, 0.0052, 0.0053, 0.0056, 0.0059,
        0.0058, 0.0055, 0.0049, 0.0051, 0.0053, 0.0053, 0.0054, 0.0063, 0.0054,
     

nan