## Trirarchical Model

**packages**

In [3]:
import numpy as np
import os
import h5py
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.modules import *
from torch.nn.utils.rnn import pad_sequence

import random

In [60]:
def get_frame_probs(shot_probs, cps, n_frames, device):
    if len(shot_probs) != len(cps):
        print('no. of shots does not match:', len(shot_probs),len(cps))
        return
    frame_probs = torch.zeros(n_frames, dtype=torch.float32, device = device)
    n_segs = cps.shape[0]
    for seg_idx in range(n_segs):
        first, last = cps[seg_idx]
        first, last =  int(first.item()), int(last.item())
        frame_probs[first:last + 1] = shot_probs[seg_idx]
        
    return frame_probs

#### MultiHead Attention

In [4]:

class MultiheadAttention(nn.Module):
    def __init__(self, input_size=1024, output_size=1024, freq=10000, heads=1, pos_enc=None):
        super(MultiheadAttention, self).__init__()

#         self.permitted_encodings = ["absolute", "relative"]
#         if pos_enc is not None:
#             pos_enc = pos_enc.lower()
#             assert pos_enc in self.permitted_encodings, f"Supported encodings: {*self.permitted_encodings,}"

        self.heads = heads
        self.pos_enc = pos_enc
        self.freq = freq
        self.softmax = nn.Softmax(dim=-1)
        self.relu = nn.ReLU()
        self.input_size = input_size
        self.output_size = output_size 
        self.Wq, self.Wk, self.Wv = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
        
        for _ in range(self.heads):
            self.Wk.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False))
            self.Wq.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False))
            self.Wv.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False))
            
        self.W0 = nn.Linear(in_features=input_size, out_features=input_size, bias=False)
        self.drop = nn.Dropout(p=0.5)
        self.W2 = nn.Linear(in_features=input_size, out_features=input_size, bias=False)
        self.LN = nn.LayerNorm(normalized_shape=self.W2.out_features, eps=1e-6)
    
    
########################################## Positional Encoding Code ###############################
#     def getAbsolutePosition(self, T):
#         """Calculate the sinusoidal positional encoding based on the absolute position of each considered frame.
#         Based on 'Attention is all you need' paper (https://arxiv.org/abs/1706.03762)
#         :param int T: Number of frames contained in Q, K and V
#         :return: Tensor with shape [T, T]
#         """
#         freq = self.freq
#         d = self.input_size

#         pos = torch.tensor([k for k in range(T)], device=self.out.weight.device)
#         i = torch.tensor([k for k in range(T//2)], device=self.out.weight.device)

#         # Reshape tensors each pos_k for each i indices
#         pos = pos.reshape(pos.shape[0], 1)
#         pos = pos.repeat_interleave(i.shape[0], dim=1)
#         i = i.repeat(pos.shape[0], 1)

#         AP = torch.zeros(T, T, device=self.out.weight.device)
#         AP[pos, 2*i] = torch.sin(pos / freq ** ((2 * i) / d))
#         AP[pos, 2*i+1] = torch.cos(pos / freq ** ((2 * i) / d))
#         return AP

#     def getRelativePosition(self, T):
#         """Calculate the sinusoidal positional encoding based on the relative position of each considered frame.
#         r_pos calculations as here: https://theaisummer.com/positional-embeddings/
#         :param int T: Number of frames contained in Q, K and V
#         :return: Tensor with shape [T, T]
#         """
#         freq = self.freq
#         d = 2 * T
#         min_rpos = -(T - 1)

#         i = torch.tensor([k for k in range(T)], device=self.out.weight.device)
#         j = torch.tensor([k for k in range(T)], device=self.out.weight.device)

#         # Reshape tensors each i for each j indices
#         i = i.reshape(i.shape[0], 1)
#         i = i.repeat_interleave(i.shape[0], dim=1)
#         j = j.repeat(i.shape[0], 1)

#         # Calculate the relative positions
#         r_pos = j - i - min_rpos

#         RP = torch.zeros(T, T, device=self.out.weight.device)
#         idx = torch.tensor([k for k in range(T//2)], device=self.out.weight.device)
#         RP[:, 2*idx] = torch.sin(r_pos[:, 2*idx] / freq ** ((i[:, 2*idx] + j[:, 2*idx]) / d))
#         RP[:, 2*idx+1] = torch.cos(r_pos[:, 2*idx+1] / freq ** ((i[:, 2*idx+1] + j[:, 2*idx+1]) / d))
#         return RP
###################################################################################################

    def forward(self, x, batch_mask=None):

        embs = []
        
        for head in range(self.heads):
            K = self.Wk[head](x)
            Q = self.Wq[head](x)
            V = self.Wv[head](x)
            

            # Q *= 0.06                            # scale factor
            
            energies = torch.matmul(Q, K.transpose(1, 2))/np.sqrt(self.output_size//self.heads)
            
            ######### Positional Encoding skipped ##############
            # if self.pos_enc is not None:
            #     if self.pos_enc == "absolute":
            #         AP = self.getAbsolutePosition(T=energies.shape[0])
            #         energies = energies + AP
            #     elif self.pos_enc == "relative":
            #         RP = self.getRelativePosition(T=energies.shape[0])
            #         energies = energies + RP
            ###########################################
            
            
            if batch_mask is not None:
                _att = self.softmax(energies + batch_mask*-1e10)
            else:
                _att = self.softmax(energies)
            att_weights = self.drop(_att)
            emb = torch.matmul(att_weights, V)
            
            # Save the current head output
            embs.append(emb)
            
        embs = self.W0(torch.cat(embs, dim=-1))
        embs = self.relu(embs)
        embs = self.W2(embs)
        embs = self.LN(embs)
        
        return embs




### Frame and Audio Attention

In [5]:

class FrameAndAudioAttention(nn.Module):
    def __init__(self, vid_input_size=1024, aud_input_size=128, freq=10000, heads=1, pos_enc=None):
        
        super(FrameAndAudioAttention, self).__init__()
        
        self.heads = heads
        self.pos_enc = pos_enc
        self.freq = freq
        self.softmax = nn.Softmax(dim=-1)
        self.relu = nn.ReLU()
    
        ##### Attention for Frames
        self.vid_input_size = vid_input_size
        self.vid_output_size = vid_input_size 
        self.Wq_v, self.Wk_v, self.Wv_v = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
        for _ in range(self.heads):
            self.Wk_v.append(nn.Linear(in_features=vid_input_size, out_features=vid_input_size//heads, bias=False))
            self.Wq_v.append(nn.Linear(in_features=vid_input_size, out_features=vid_input_size//heads, bias=False))
            self.Wv_v.append(nn.Linear(in_features=vid_input_size, out_features=vid_input_size//heads, bias=False))
            
        self.v_W0 = nn.Linear(in_features=vid_input_size, out_features=vid_input_size, bias=False)
        self.v_drop = nn.Dropout(p=0.5)
        self.v_W2 = nn.Linear(in_features=vid_input_size, out_features=vid_input_size, bias=False)
        self.v_LN = nn.LayerNorm(normalized_shape=self.v_W2.out_features, eps=1e-6)
        
        
        ##### Attention for Audios
        self.aud_input_size = aud_input_size
        self.aud_output_size = aud_input_size
        self.Wk_a, self.Wv_a = nn.ModuleList(), nn.ModuleList()
        for _ in range(self.heads):
            self.Wk_a.append(nn.Linear(in_features=aud_input_size, out_features=vid_input_size//heads, bias=False))
            self.Wv_a.append(nn.Linear(in_features=aud_input_size, out_features=aud_input_size//heads, bias=False))
            
        self.a_W0 = nn.Linear(in_features=aud_input_size, out_features=aud_input_size, bias=False)
        self.a_drop = nn.Dropout(p=0.5)
        self.a_W2 = nn.Linear(in_features=aud_input_size, out_features=aud_input_size, bias=False)
        self.a_LN = nn.LayerNorm(normalized_shape=self.a_W2.out_features, eps=1e-6)

    def forward(self, vid, aud, batch_mask):

        vid_embs = []
        aud_embs = []

        for head in range(self.heads):
            K_v = self.Wk_v[head](vid)
            Q_v = self.Wq_v[head](vid)
            V_v = self.Wv_v[head](vid)

            K_a = self.Wk_a[head](aud)
            V_a = self.Wv_a[head](aud)


            energies_vid = torch.matmul(Q_v, K_v.transpose(1, 2))/np.sqrt(self.vid_output_size//self.heads)
            energies_aud = torch.matmul(Q_v, K_a.transpose(1, 2))/np.sqrt(self.aud_output_size//self.heads)

                  
            vid_att = self.softmax(energies_vid+batch_mask*-1e10)
            vid_att_weights = self.v_drop(vid_att)
            vid_emb = torch.matmul(vid_att_weights, V_v)

            aud_att = self.softmax(energies_aud+batch_mask*-1e10)
            aud_att_weights = self.a_drop(aud_att)
            aud_emb = torch.matmul(aud_att_weights, V_a)


            # Save the current head output
            vid_embs.append(vid_emb)
            aud_embs.append(aud_emb)


        vid_embs = self.v_W0(torch.cat(vid_embs, dim=-1))

        vid_embs = self.relu(vid_embs)
        vid_embs =self.v_W2(vid_embs)
        vid_embs =self.v_LN(vid_embs)

        aud_embs = self.a_W0(torch.cat(aud_embs, dim=-1))
        aud_embs = self.relu(aud_embs)
        aud_embs =self.a_W2(aud_embs)
        aud_embs =self.a_LN(aud_embs)


        return vid_embs, aud_embs


In [6]:
class AttentionAwareFusion(nn.Module):
    def __init__(self, input_size, heads=1, pos_enc=None):
        
        super(AttentionAwareFusion, self).__init__()
        self.W1 = nn.Linear(in_features=input_size, out_features=input_size, bias=True)
        self.gelu = nn.GELU()
        self.W2 = nn.Linear(in_features=input_size, out_features=input_size, bias=True)
        self.sm = nn.Softmax(dim=-1)
        
    def forward(self, x):
        Q = self.gelu(self.W1(x))
        A = self.W2(Q)
        att = self.sm(A.permute(0,2,1))
        x = att.permute(0,2,1)*x
        return x.sum(dim=1)
              

### Trirarchical Multimodal Transformer

In [85]:
class Trirar(nn.Module):
    def __init__(self, vid_input_size=1024, aud_input_size=128, freq=10000, pos_enc=None, num_segments=None, heads=1, fusion=None):

        super(Trirar, self).__init__()
        
        self.FrameAndAudioAttention = FrameAndAudioAttention(vid_input_size, aud_input_size, freq=freq, pos_enc=pos_enc, heads=4)
        self.Shot_attention = MultiheadAttention(input_size=1024+128, output_size=1024+128, freq=freq, pos_enc=pos_enc, heads=4)
        self.Scorer = nn.Linear(in_features=1024+128, out_features=1)
        self.probs = nn.Softmax(dim=0)
        
        self.ShotsInteraction = MultiheadAttention(input_size=1024+128, output_size=1024+128, freq=freq, pos_enc=pos_enc, heads=4)  ##Can incoporate text feats
        self.Scene_attention = MultiheadAttention(input_size=1024+128, output_size=1024+128, freq=freq, pos_enc=pos_enc, heads=4)
        self.SceneScorer = nn.Linear(in_features=1024+128, out_features=1)

       

        ################# Global Attention ############
        # self.glob_attention = SelfAttention(input_size=input_size, output_size=output_size,
        #                                freq=freq, pos_enc=pos_enc, heads=heads)

        
    def forward(self, x, a, shot_boundaries, scene_boundaries, batch_size=5):
        shot_batch=[]
        aud_batch=[]
        batch_mask=[]
        bs_counter=0
        # weighted_value, attn_weights = self.glob_attention(x)  # global attention
        
        Shot_embeddings_list=[]
        boundaries_num = len(shot_boundaries)
        start=0
        for shotnum, (start, end) in enumerate(shot_boundaries):
            # end=math.ceil(end/15)
            
            shot= x[0][start:end+1]
            audio= a[0][start:end+1]
            # shot= x[0][start:end]
            # print('shot', shot, start,end)
            # audio= a[0][start:end]
            batch_mask.append(torch.zeros(shot.shape[0], shot.shape[0]))
            # start = end+1
            
            shot_batch.append(shot)
            aud_batch.append(audio)
            
            bs_counter+=1
            if bs_counter != batch_size and shotnum!=boundaries_num-1:
                continue
                
            bs_counter=0
            
            shot_batch = pad_sequence(shot_batch, batch_first=True)
            aud_batch = pad_sequence(aud_batch, batch_first=True)
            
            pad_length = shot_batch[0].shape[0]
            batch_mask = torch.stack([ nn.ConstantPad2d((0, pad_length-mask.shape[0], 0, pad_length-mask.shape[0]),1)(mask) for mask in batch_mask])
            # print('####', 'shot_batch', shot_batch,'aud_batch',  aud_batch,'batch_mask',  batch_mask.to(next(self.parameters()).device), '####')
                  
            sv, sa = self.FrameAndAudioAttention(shot_batch, aud_batch, batch_mask.to(next(self.parameters()).device)) 
            # print('SV, SA', sv, sa)
            shot_emb = torch.mean(sv,1)
            audio_emb = torch.mean(sa,1)
            Shot_embeddings_list.append(torch.hstack((shot_emb,audio_emb)))
            del(shot_batch)
            del(aud_batch)
            shot_batch=[]
            aud_batch=[]
            del(batch_mask)
            batch_mask=[]
            
        # print(len(Shot_embeddings_list))
        Shot_embeddings = torch.cat(Shot_embeddings_list, dim=0).unsqueeze(0)
        
        E2 = self.Shot_attention(Shot_embeddings)
        
        
        P = self.Scorer(E2.squeeze(0))         ##  ???
        
        Scene_embeddings_list = []
        batch_mask=[]
        bs_counter=0
        scene_batch = []
        boundaries_num = len(scene_boundaries)
        for scnnum, (start, end) in enumerate(scene_boundaries):
            
            scene=Shot_embeddings[0][start:end+1]
            # print(scene.shape)
            if end+1-start >5:
                
                if len(scene_batch) !=0:
                    
                    scene_batch = pad_sequence(scene_batch, batch_first=True)
                    pad_length = scene_batch[0].shape[0]
                    batch_mask = torch.stack([ nn.ConstantPad2d((0, pad_length-mask.shape[0], 0, pad_length-mask.shape[0]),1)(mask) for mask in batch_mask])
                    
                    sc = self.ShotsInteraction(scene_batch, batch_mask.to(next(self.parameters()).device)) 
                    scene_emb = torch.mean(sc, 1)
                    Scene_embeddings_list.append(scene_emb)
                    
                    del(batch_mask)
                    del(scene_batch)
                    batch_mask=[]
                    bs_counter=0
                    scene_batch = []
                    
                sc = self.ShotsInteraction(scene.unsqueeze(0))
                scene_emb = torch.mean(sc, 1)
                Scene_embeddings_list.append(scene_emb)
                
            else:
                scene_batch.append(scene)
                batch_mask.append(torch.zeros(end+1-start, end+1-start))
                bs_counter+=1
                if bs_counter != batch_size and scnnum!=boundaries_num-1:
                    continue

                scene_batch = pad_sequence(scene_batch, batch_first=True)
                pad_length = scene_batch[0].shape[0]
                batch_mask = torch.stack([ nn.ConstantPad2d((0, pad_length-mask.shape[0], 0, pad_length-mask.shape[0]),1)(mask) for mask in batch_mask])
                
                sc = self.ShotsInteraction(scene_batch, batch_mask.to(next(self.parameters()).device)) 
                scene_emb = torch.mean(sc, 1)
                Scene_embeddings_list.append(scene_emb)
                
                del(batch_mask)
                del(scene_batch)
                batch_mask=[]
                bs_counter=0
                scene_batch = []
            
        
        Scene_embeddings = torch.cat(Scene_embeddings_list, dim=0).unsqueeze(0)
        
        E3 = self.Scene_attention(Scene_embeddings)
        S = self.SceneScorer(E3.squeeze(0))
        
        for scn, (start, end) in enumerate(scene_boundaries):
            P[start:end+1]+= S[scn]
        
        P = self.probs(P.squeeze(0))
        
       
        return P

### )) Trirar 2

In [82]:
# class Trirar2(nn.Module):
#     def __init__(self, vid_input_size=1024, aud_input_size=128, freq=10000, pos_enc=None, num_segments=None, heads=1, fusion=None):

#         super(Trirar2, self).__init__()
        
#         self.FrameAndAudioAttention = FrameAndAudioAttention(vid_input_size, aud_input_size, freq=freq, pos_enc=pos_enc, heads=4)
#         self.shotFA = AttentionAwareFusion(1024)
#         self.audioFA = AttentionAwareFusion(128)
#         self.Shot_attention = nn.TransformerEncoderLayer(d_model=1024+128, nhead=4, batch_first=True)
#         self.Scorer = nn.Linear(in_features=1024+128, out_features=1)
        
#         self.ShotsInteraction = MultiheadAttention(input_size=1024+128, output_size=1024+128, freq=freq, pos_enc=pos_enc, heads=4)
#         self.sceneFA = AttentionAwareFusion(1024+128)
#         self.Scene_attention = nn.TransformerEncoderLayer(d_model=1024+128, nhead=4, batch_first=True)
#         self.SceneScorer = nn.Linear(in_features=1024+128, out_features=1)
        
#         self.probs = nn.Softmax(dim=0)

       

#         ################# Global Attention ############
#         # self.glob_attention = SelfAttention(input_size=input_size, output_size=output_size,
#         #                                freq=freq, pos_enc=pos_enc, heads=heads)

        
#     def forward(self, x, a, shot_boundaries, scene_boundaries, batch_size=5):
#         shot_batch=[]
#         aud_batch=[]
#         batch_mask=[]
#         bs_counter=0
#         # weighted_value, attn_weights = self.glob_attention(x)  # global attention
        
#         Shot_embeddings_list=[]
#         boundaries_num = len(shot_boundaries)
#         start=0
#         for shotnum, (_, end) in enumerate(shot_boundaries):
#             end=math.ceil(end/15)
            
#             shot=x[0][start:end]
#             audio= a[0][start:end]
#             batch_mask.append(torch.zeros(end-start, end-start))
#             start = end
            
#             shot_batch.append(shot)
#             aud_batch.append(audio)
            
#             bs_counter+=1
#             if bs_counter != batch_size and shotnum!=boundaries_num-1:
#                 continue
                
#             bs_counter=0
            
#             shot_batch = pad_sequence(shot_batch, batch_first=True)
#             aud_batch = pad_sequence(aud_batch, batch_first=True)
            
#             pad_length = shot_batch[0].shape[0]
#             batch_mask = torch.stack([ nn.ConstantPad2d((0, pad_length-mask.shape[0], 0, pad_length-mask.shape[0]),1)(mask) for mask in batch_mask])
#             sv, sa = self.FrameAndAudioAttention(shot_batch, aud_batch, batch_mask.to(next(self.parameters()).device)) 
#             shot_emb = self.shotFA(sv)
#             audio_emb = self.audioFA(sa)
            
#             Shot_embeddings_list.append(torch.hstack((shot_emb,audio_emb)))
            
#             del(shot_batch)
#             del(aud_batch)
#             shot_batch=[]
#             aud_batch=[]
#             del(batch_mask)
#             batch_mask=[]
            
            
#         Shot_embeddings = torch.cat(Shot_embeddings_list, dim=0).unsqueeze(0)
        
#         E2 = self.Shot_attention(Shot_embeddings)
        
#         P = self.Scorer(E2.squeeze(0))         ##  ???
        
#         Scene_embeddings_list = []
#         batch_mask=[]
#         bs_counter=0
#         scene_batch = []
        
#         boundaries_num = len(scene_boundaries)
#         for scnnum, (start, end) in enumerate(scene_boundaries):
            
#             scene=Shot_embeddings[0][start:end+1]
#             if end+1-start >5:
                
#                 if len(scene_batch) !=0:
                    
#                     scene_batch = pad_sequence(scene_batch, batch_first=True)
#                     pad_length = scene_batch[0].shape[0]
#                     batch_mask = torch.stack([ nn.ConstantPad2d((0, pad_length-mask.shape[0], 0, pad_length-mask.shape[0]),1)(mask) for mask in batch_mask])
                    
#                     sc = self.ShotsInteraction(scene_batch, batch_mask.to(next(self.parameters()).device)) 
#                     scene_emb = torch.mean(sc, 1)
#                     Scene_embeddings_list.append(scene_emb)
                    
#                     del(batch_mask)
#                     del(scene_batch)
#                     batch_mask=[]
#                     bs_counter=0
#                     scene_batch = []
                    
#                 sc = self.ShotsInteraction(scene.unsqueeze(0))
#                 scene_emb = self.sceneFA(sc)
#                 Scene_embeddings_list.append(scene_emb)
                
#             else:
#                 scene_batch.append(scene)
#                 batch_mask.append(torch.zeros(end+1-start, end+1-start))
#                 bs_counter+=1
#                 if bs_counter != batch_size and scnnum!=boundaries_num-1:
#                     continue

#                 scene_batch = pad_sequence(scene_batch, batch_first=True)
#                 pad_length = scene_batch[0].shape[0]
#                 batch_mask = torch.stack([ nn.ConstantPad2d((0, pad_length-mask.shape[0], 0, pad_length-mask.shape[0]),1)(mask) for mask in batch_mask])
                
#                 sc = self.ShotsInteraction(scene_batch, batch_mask.to(next(self.parameters()).device)) 
#                 scene_emb = torch.mean(sc, 1)
#                 Scene_embeddings_list.append(scene_emb)
                
#                 del(batch_mask)
#                 del(scene_batch)
#                 batch_mask=[]
#                 bs_counter=0
#                 scene_batch = []
            
        
#         Scene_embeddings = torch.cat(Scene_embeddings_list, dim=0).unsqueeze(0)
        
#         E3 = self.Scene_attention(Scene_embeddings)
#         S = self.SceneScorer(E3.squeeze(0))
        
#         for scn, (start, end) in enumerate(scene_boundaries):
#             P[start:end+1]+= S[scn]
        
#         P = self.probs(P.squeeze(0))
        
       
#         return P

---------------------------

### Train check

In [86]:
# def train():
#         losses=[]
#         criterion = nn.MSELoss()
        
#         with h5py.File('../../Preprocessing/extracted_features/normal/TVSum05s.h5') as d, h5py.File('../../Segmentation/Transnet/transnet_segments/tvsumSegs05s.h5') as scnseg:
#             # key = list(d.keys())[0]
#             for key in d.keys():
#                 print(key)
#                 vid_feats= d[key]['features'][...]
#                 aud_feats= d[key]['aud_feats'][...]
#                 boundaries = d[key]['fchange_points'][...]
#                 target = d[key]['gt_probs'][...]
#                 scn_boundaries = scnseg[key]['scene_points'][...]
#                 n_frames = d[key]['n_frames'][()]
#                 target = target.astype(float)
                
#                 # seq = dataset['features'][...]
#                 vid_feats = torch.from_numpy(vid_feats).unsqueeze(0)
#                 aud_feats = torch.from_numpy(aud_feats).unsqueeze(0)
#                 # target = d['gt_score'][...]
#                 target = torch.from_numpy(target)

#                 # Min-Max Normalize frame scores
#                 target -= target.min()
#                 target /= target.max()


#                 # if self.hps.use_cuda:
#                 #     seq, target = seq.float().cuda(), target.float().cuda()

#                 seq_len = vid_feats.shape[1]
#                 # print('video and audio feat:', vid_feats, aud_feats, boundaries, scn_boundaries )
#                 P = model(vid_feats, aud_feats, boundaries, scn_boundaries)
#                 P=P.reshape(-1)
#                 P_frames = get_frame_probs(P, boundaries, n_frames, 'cpu')
#                 # print(P.shape, len(boundaries), P)

#                 loss_att = 0
#                 print(P_frames.shape, target.shape)
#                 loss = criterion(P_frames[:len(target)], target.float())
#                 loss = loss + loss_att
#                 # optimizer.zero_grad()
#                 # loss.backward()
#                 # optimizer.step()
#                 losses.append(float(loss))
            
#         return np.mean(np.array(losses))

In [87]:
# torch.manual_seed(1)
# model=Trirar()
# parameters = filter(lambda p: p.requires_grad, model.parameters())
# optimizer = torch.optim.Adam(parameters, lr=0.00005, weight_decay=0.00001)
# train()

In [35]:
#     def forward(self, vid, aud):
        
#         vid_embs = []
#         aud_embs = []
        
#         for head in range(self.heads):
#             K_v = self.Wk_v[head](vid)
#             Q_v = self.Wq_v[head](vid)
#             V_v = self.Wv_v[head](vid)
            
#             K_a = self.Wk_a[head](aud)
#             V_a = self.Wv_a[head](aud)
            
            
#             energies_vid = torch.matmul(Q_v, K_v.transpose(1, 2))/np.sqrt(self.vid_output_size//self.heads)
#             energies_aud = torch.matmul(Q_v, K_a.transpose(1, 2))/np.sqrt(self.aud_output_size//self.heads)
            
#             print('mask_check:', energies_aud[0],'\n softmaxed', self.softmax(energies_aud)[0])
#             break
            
#             vid_att = self.softmax(energies_vid)
#             vid_att_weights = self.v_drop(vid_att)
#             vid_emb = torch.matmul(vid_att_weights, V_v)
            
#             aud_att = self.softmax(energies_aud)
#             aud_att_weights = self.a_drop(aud_att)
#             aud_emb = torch.matmul(aud_att_weights, V_a)
            
            
#             # Save the current head output
#             vid_embs.append(vid_emb)
#             aud_embs.append(aud_emb)
            
        
#         vid_embs = self.v_W0(torch.cat(vid_embs, dim=-1))
        
#         vid_embs = self.relu(vid_embs)
#         vid_embs =self.v_W2(vid_embs)
#         vid_embs =self.v_LN(vid_embs)
        
#         aud_embs = self.a_W0(torch.cat(aud_embs, dim=-1))
#         aud_embs = self.relu(aud_embs)
#         aud_embs =self.a_W2(aud_embs)
#         aud_embs =self.a_LN(aud_embs)
        
        
#         return vid_embs, aud_embs



#     def forward(self, x, a, shot_boundaries):
        
#         # weighted_value, attn_weights = self.glob_attention(x)  # global attention
        
#         Shot_embeddings_list=[]
#         start=0
#         for _, end in shot_boundaries:
#             end=math.ceil(end/15)
            
#             shot=x[0][start:end].unsqueeze(0)
#             audio= a[0][start:end].unsqueeze(0)
            
#             # print('star and stop', start,end)
#             start = end
            
            
#             sv, sa = self.FrameAndAudioAttention(shot, audio)
#             print('sa:', sa)
            
#             shot_emb = torch.mean(sv,1)
#             audio_emb = torch.mean(sa,1)
            
#             Shot_embeddings_list.append(torch.hstack((shot_emb,audio_emb)))
#             break
            
#         Shot_embeddings = torch.cat(Shot_embeddings_list, dim=0)
        
#         E2 = self.Shot_attention(Shot_embeddings.unsqueeze(0))
#         P = self.Scorer(E2.squeeze(0))
#         P = self.probs(P.squeeze(0))
            
#         return P