In [109]:
import torch
import copy
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import Module
from torch.nn import MultiheadAttention
from torch.nn import ModuleList
from torch.nn.init import xavier_uniform_
from torch.nn import Dropout
from torch.nn import Linear
from torch.nn import LayerNorm



class Encoder(Module):
    r"""Encoder is a stack of N encoder layers

    Args:
        encoder_layer: an instance of the EncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
        norm: the layer normalization component (optional).
        
    """

    def __init__(self, encoder_layer, num_layers, norm=None):
        super(Encoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src):
        r"""Pass the input through the endocder layers in turn.

        """
        output = src

        for i in range(self.num_layers):
            output = self.layers[i](output)

        if self.norm:
            output = self.norm(output)

        return output


class Decoder(Module):
    r"""Decoder is a stack of N decoder layers

    Args:
        decoder_layer: an instance of the DecoderLayer() class (required).
        num_layers: the number of sub-decoder-layers in the decoder (required).
        norm: the layer normalization component (optional).
    """

    def __init__(self, decoder_layer, num_layers, norm=None):
        super(Decoder, self).__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, tgt, memory):
        r"""Pass the inputs (and mask) through the decoder layer in turn.
        """
        output = tgt

        for i in range(self.num_layers):
            output = self.layers[i](output, memory)

        if self.norm:
            output = self.norm(output)

        return output


class EncoderLayer(Module):
    r"""EncoderLayer is mainly made up of self-attention.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    """

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def forward(self, src):
        r"""Pass the input through the endocder layer.
        """
        src2 = self.self_attn(src, src, src)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        if hasattr(self, "activation"):
            src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        else:  # for backward compatibility
            src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src


class DecoderLayer(Module):
    r"""DecoderLayer is mainly made up of the proposed cross-modal relation attention (CMRA).

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    """

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def forward(self, tgt, memory):
        r"""Pass the inputs (and mask) through the decoder layer.
        """
        memory = torch.cat([memory, tgt], dim=0)
        tgt2 = self.multihead_attn(tgt, memory, memory)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        if hasattr(self, "activation"):
            tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        else:  # for backward compatibility
            tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        return tgt


def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu
    else:
        raise RuntimeError("activation should be relu/gelu, not %s." % activation)


class New_Audio_Guided_Attention(nn.Module):
    def __init__(self):
        super(New_Audio_Guided_Attention, self).__init__()
        self.hidden_size = 1024
        self.relu = nn.ReLU()
        # channel attention
        self.affine_video_1 = nn.Linear(1024, 1024)
        self.affine_audio_1 = nn.Linear(128, 1024)
        self.affine_bottleneck = nn.Linear(1024, 256)
        self.affine_v_c_att = nn.Linear(256, 1024)
        # spatial attention
        self.affine_video_2 = nn.Linear(1024, 256)
        self.affine_audio_2 = nn.Linear(128, 256)
        self.affine_v_s_att = nn.Linear(256, 1)

        # video-guided audio attention
        self.affine_video_guided_1 = nn.Linear(1024, 64)
        self.affine_video_guided_2 = nn.Linear(64, 128)

        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=-1)


    def forward(self, video, audio):
        '''
        :param visual_feature: [batch, 10, 7, 7, 512]
        :param audio_feature:  [batch, 10, 128]
        :return: [batch, 10, 512]
        '''
        audio = audio.transpose(1, 0)
        batch, t_size, f, v_dim = video.size()
        a_dim = audio.size(-1)
        audio_feature = audio.reshape(batch * t_size, a_dim)
        visual_feature = video.reshape(batch, t_size, f)
        raw_visual_feature = video
        print(visual_feature.shape)
        print(audio_feature.shape)


        audio_query_1 = self.relu(self.affine_audio_1(audio_feature))
        # ============================== Channel Attention ====================================

        video_query_1 = self.relu(self.affine_video_1(visual_feature))

        video_query_1 = video_query_1.reshape(batch*t_size, f)
        print(audio_query_1.shape)
        print(video_query_1.shape)

        audio_video_query_raw = (audio_query_1 * video_query_1)
        audio_video_query = self.relu(self.affine_bottleneck(audio_video_query_raw))
        channel_att_maps = self.affine_v_c_att(audio_video_query).sigmoid().reshape(batch, t_size, -1,v_dim)
        print(channel_att_maps.shape)
        c_att_visual_feat = (raw_visual_feature * (channel_att_maps))

        # ============================== Spatial Attention =====================================
        # channel attended visual feature: [batch * 10, 49, v_dim]
        c_att_visual_feat = c_att_visual_feat.reshape(batch*t_size, v_dim, -1)
        c_att_visual_query = self.relu(self.affine_video_2(c_att_visual_feat))
        audio_query_2 = self.relu(self.affine_audio_2(audio_feature)).unsqueeze(-2)
        audio_video_query_2 = c_att_visual_query * audio_query_2
        spatial_att_maps = self.softmax(self.tanh(self.affine_v_s_att(audio_video_query_2)).transpose(2, 1))
        c_s_att_visual_feat = torch.bmm(spatial_att_maps, c_att_visual_feat).squeeze().reshape(batch, t_size, f)

        return c_s_att_visual_feat

In [37]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import MultiheadAttention


class InternalTemporalRelationModule(nn.Module):
    def __init__(self, input_dim, d_model):
        super(InternalTemporalRelationModule, self).__init__()
        self.encoder_layer = EncoderLayer(d_model=d_model, nhead=4)
        self.encoder = Encoder(self.encoder_layer, num_layers=2)

        self.affine_matrix = nn.Linear(input_dim, d_model)
        self.relu = nn.ReLU(inplace=True)
        # add relu here?

    def forward(self, feature):
        # feature: [seq_len, batch, dim]
        feature = self.affine_matrix(feature)
        feature = self.encoder(feature)

        return feature


class CrossModalRelationAttModule(nn.Module):
    def __init__(self, input_dim, d_model):
        super(CrossModalRelationAttModule, self).__init__()

        self.decoder_layer = DecoderLayer(d_model=d_model, nhead=4)
        self.decoder = Decoder(self.decoder_layer, num_layers=1)

        self.affine_matrix = nn.Linear(input_dim, d_model)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, query_feature, memory_feature):
        query_feature = self.affine_matrix(query_feature)
        output = self.decoder(query_feature, memory_feature)

        return output


class WeaklyLocalizationModule(nn.Module):
    def __init__(self, input_dim):
        super(WeaklyLocalizationModule, self).__init__()

        self.hidden_dim = input_dim # need to equal d_model
        self.classifier = nn.Linear(self.hidden_dim, 1) # start and end
        self.event_classifier = nn.Linear(self.hidden_dim, 29)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, fused_content):
        fused_content = fused_content.transpose(0, 1)
        max_fused_content, _ = fused_content.max(1)
        # confident scores
        is_event_scores = self.classifier(fused_content)
        # classification scores
        raw_logits = self.event_classifier(max_fused_content)[:, None, :]
        # fused
        fused_logits = is_event_scores.sigmoid() * raw_logits
        # Training: max pooling for adapting labels
        logits, _ = torch.max(fused_logits, dim=1)
        event_scores = self.softmax(logits)

        return is_event_scores.squeeze(), raw_logits.squeeze(), event_scores


class SupvLocalizeModule(nn.Module):
    def __init__(self, d_model):
        super(SupvLocalizeModule, self).__init__()
        # self.affine_concat = nn.Linear(2*256, 256)
        self.relu = nn.ReLU(inplace=True)
        self.classifier = nn.Linear(d_model, 1) # start and end
        self.event_classifier = nn.Linear(d_model, 28)
       # self.softmax = nn.Softmax(dim=-1)

    def forward(self, fused_content):

        max_fused_content, _ = fused_content.transpose(1, 0).max(1)
        logits = self.classifier(fused_content)
        # scores = self.softmax(logits)
        class_logits = self.event_classifier(max_fused_content)
        class_scores = class_logits

        return logits, class_scores


# class AudioVideoInter(nn.Module):
#     def __init__(self, d_model, n_head, head_dropout=0.1):
#         super(AudioVideoInter, self).__init__()
#         self.dropout = nn.Dropout(0.1)
#         self.video_multihead = MultiheadAttention(d_model, num_heads=n_head, dropout=head_dropout)
#         self.norm1 = nn.LayerNorm(d_model)


#     def forward(self, video_feat, audio_feat):
#         # video_feat, audio_feat: [10, batch, 256]
#         global_feat = video_feat * audio_feat
#         memory = torch.cat([audio_feat, video_feat], dim=0)
#         mid_out = self.video_multihead(global_feat, memory, memory)[0]
#         output = self.norm1(global_feat + self.dropout(mid_out))

#         return  output


class weak_main_model(nn.Module):
    def __init__(self):
        super(weak_main_model, self).__init__()
        self.spatial_channel_att = New_Audio_Guided_Attention().cuda()
        self.video_input_dim = 512 
        self.video_fc_dim = 512
        self.d_model = 256
        self.v_fc = nn.Linear(self.video_input_dim, self.video_fc_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

        self.video_encoder = InternalTemporalRelationModule(input_dim=self.video_fc_dim, d_model=self.d_model)
        self.video_decoder = CrossModalRelationAttModule(input_dim=self.video_fc_dim, d_model=self.d_model)
        self.audio_encoder = InternalTemporalRelationModule(input_dim=128, d_model=self.d_model)
        self.audio_decoder = CrossModalRelationAttModule(input_dim=128, d_model=self.d_model)

        self.AVInter = AudioVideoInter(self.d_model, n_head=2, head_dropout=0.1)
        self.localize_module = WeaklyLocalizationModule(self.d_model)


    def forward(self, visual_feature, audio_feature):
        # [batch, 10, 512]
        # this fc is optinal, that is used for adaption of different visual features (e.g., vgg, resnet).
        audio_feature = audio_feature.transpose(1, 0).contiguous()
        visual_feature = self.v_fc(visual_feature)
        visual_feature = self.dropout(self.relu(visual_feature))

        # spatial-channel attention
        visual_feature = self.spatial_channel_att(visual_feature, audio_feature)
        visual_feature = visual_feature.transpose(1, 0).contiguous()

        # audio query
        video_key_value_feature = self.video_encoder(visual_feature)
        audio_query_output = self.audio_decoder(audio_feature, video_key_value_feature)

        # video query
        audio_key_value_feature = self.audio_encoder(audio_feature)
        video_query_output = self.video_decoder(visual_feature, audio_key_value_feature)
        
        video_query_output= self.AVInter(video_query_output, audio_query_output)
        scores = self.localize_module(video_query_output)

        return scores


class supv_main_model(nn.Module):
    def __init__(self):
        super(supv_main_model, self).__init__()

        self.spatial_channel_att = New_Audio_Guided_Attention().cuda()
        self.video_input_dim = 512 
        self.video_fc_dim = 512
        self.d_model = 256
        self.v_fc = nn.Linear(self.video_input_dim, self.video_fc_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

        self.video_encoder = InternalTemporalRelationModule(input_dim=512, d_model=256)
        self.video_decoder = CrossModalRelationAttModule(input_dim=512, d_model=256)
        self.audio_encoder = InternalTemporalRelationModule(input_dim=128, d_model=256)
        self.audio_decoder = CrossModalRelationAttModule(input_dim=128, d_model=256)

        self.AVInter = AudioVideoInter(self.d_model, n_head=4, head_dropout=0.1)
        self.localize_module = SupvLocalizeModule(self.d_model)


    def forward(self, visual_feature, audio_feature):
        # [batch, 10, 512]

        # optional, we add a FC here to make the model adaptive to different visual features (e.g., VGG ,ResNet)
        audio_feature = audio_feature.transpose(1, 0).contiguous()
        visual_feature = self.v_fc(visual_feature)
        visual_feature = self.dropout(self.relu(visual_feature))
        
        # spatial-channel attention 
        visual_feature = self.spatial_channel_att(visual_feature, audio_feature)

        # audio-guided needed
        visual_feature = visual_feature.transpose(1, 0).contiguous()

        # audio query
        video_key_value_feature = self.video_encoder(visual_feature)
        audio_query_output = self.audio_decoder(audio_feature, video_key_value_feature)

        # video query
        audio_key_value_feature = self.audio_encoder(audio_feature)
        video_query_output = self.video_decoder(visual_feature, audio_key_value_feature)

        video_query_output= self.AVInter(video_query_output, audio_query_output)
        scores = self.localize_module(video_query_output)

        return scores

In [None]:
####video to specto

import os
import moviepy.editor as mp
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

def video_to_save_log_mel_spectrogram(video_path, output_folder, n_fft=2048, hop_length=512, n_mels=128):
    video_name = os.path.splitext(os.path.basename(video_path))[0]
    audio_temp_path = r"C:\Users\karth\Downloads\train\A.Beautiful.Mind.2001__#00-01-45_00-02-50_label_A.wav"
    # Extract audio from the video and save it as a temporary WAV file
    video_clip = mp.VideoFileClip(video_path)
    audio_clip = video_clip.audio
    audio_clip.write_audiofile(audio_temp_path, codec='pcm_s16le', fps=audio_clip.fps)

    # Load the audio file
    y, sr = librosa.load(audio_temp_path, sr=None)
    os.remove(audio_temp_path)

    # Compute the Mel spectrogram
    mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)

    # Convert to log scale
    log_mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)

    # Save the log Mel spectrogram as an image file
    save_spectrogram_as_image(log_mel_spectrogram, sr=sr, hop_length=hop_length, output_path=os.path.join(output_folder,video_name))

    del y, sr, mel_spectrogram, log_mel_spectrogram, audio_clip, video_clip



def save_spectrogram_as_image(spectrogram, sr, hop_length, output_path, title="Log Mel Spectrogram"):
    plt.figure(figsize=(10, 4))
    librosa.display.specshow(spectrogram, x_axis='time', y_axis='mel', sr=sr, hop_length=hop_length, cmap='viridis')
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    output_path=output_path+'.png'
    plt.savefig(output_path)  # Save the spectrogram as an image file
    plt.close()  # Close the plot to prevent displaying it

def process_videos_in_directory(input_folder, output_folder):

    videos_to_process = [os.path.join(root, file) for root, _, files in os.walk(input_folder) for file in files if file.endswith(".mp4")]

    for video_path in tqdm(videos_to_process, desc="Processing Videos"):
        video_to_save_log_mel_spectrogram(video_path, output_folder)

    # for root, dirs, files in os.walk(input_folder):
    #     for file in files:
    #         if file.endswith(".mp4"):
    #             video_path = os.path.join(root, file)
    #             video_to_save_log_mel_spectrogram(video_path, output_folder)

# Example usage
input_folder = r"C:\Users\karth\Downloads\train"
output_folder = r"C:\Users\karth\Downloads\trainspecto"

process_videos_in_directory(input_folder, output_folder)

In [118]:
####audio video simple fusion

import numpy as np

class AudioVideoInter(nn.Module):
    def __init__(self, d_model, n_head, head_dropout=0.1):
        super(AudioVideoInter, self).__init__()
        self.dropout = nn.Dropout(0.1)
        self.video_multihead = MultiheadAttention(d_model, num_heads=n_head, dropout=head_dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.affine_video_1 = nn.Linear(1024, 128)



    def forward(self, video_feat, audio_feat):
        # video_feat, audio_feat: [10, batch, 256]
        video_feat=self.affine_video_1(video_feat)
        global_feat = video_feat * audio_feat
        memory = torch.cat([audio_feat, video_feat], dim=0)
        mid_out = self.video_multihead(global_feat, memory, memory)[0]
        output = self.norm1(global_feat + self.dropout(mid_out))

        return  output
    

aud= np.load(r"C:\Users\karth\Downloads\audiofeats\train\Your.Name.2016__#01-22-20_01-24-05_label_A.npy")
vid=np.load(r"C:\Users\karth\Downloads\XDVioDet-master\XDVioDet-master\tratrain\Your.Name.2016__#01-22-20_01-24-05_label_A.npy")

aud = np.repeat(aud[:, np.newaxis, :], 5, axis=1)
aud_tensor = torch.from_numpy(aud)
vid_tensor = torch.from_numpy(vid)

aud_tensor = aud_tensor.permute(1, 0, 2)

vid_tensor = vid_tensor.permute(1, 0, 2)


print(vid_tensor.shape)
print(aud_tensor.shape)

model=AudioVideoInter(d_model=128,n_head=2)
out=model(vid_tensor,aud_tensor)

print(out.shape)

print(out)


torch.Size([5, 157, 1024])
torch.Size([5, 157, 128])
torch.Size([5, 157, 128])
tensor([[[-0.3255, -0.7201, -1.3232,  ...,  0.1799, -0.0492,  0.5752],
         [-0.3146, -1.2614, -1.0466,  ...,  0.7632, -0.5162,  1.0648],
         [-0.8436, -1.3632, -0.0903,  ...,  0.9423, -0.3327,  0.1240],
         ...,
         [-1.1972,  1.6714,  0.0734,  ..., -0.1596,  0.4987,  0.3687],
         [-0.5031,  1.4383, -2.3368,  ..., -1.9084,  0.2783,  0.0790],
         [-0.1501, -0.0322, -2.3169,  ...,  0.0730, -0.4764, -0.4420]],

        [[-0.3757, -0.5329, -1.3553,  ...,  0.4328, -0.2067,  0.6716],
         [-0.1790, -1.2031, -0.9562,  ...,  0.9599, -0.6266,  1.2794],
         [-0.0204, -1.3087, -1.1483,  ...,  0.8486, -0.2002,  0.2717],
         ...,
         [-1.2709,  2.9776, -0.0408,  ..., -0.3034,  0.3227,  0.5243],
         [-0.6743,  1.2299, -2.8904,  ..., -1.8598,  0.3340, -0.3133],
         [-0.0254,  0.0122, -2.2938,  ...,  0.0894, -1.0829, -0.4157]],

        [[-0.0602, -0.6329, -0.2386, 

In [115]:

####full combinatio  with crm

class weak_main_model(nn.Module):
    def __init__(self):
        super(weak_main_model, self).__init__()
        self.spatial_channel_att = New_Audio_Guided_Attention()
        self.video_input_dim = 1024 
        self.video_fc_dim = 1024
        self.d_model = 128
        self.v_fc = nn.Linear(self.video_input_dim, self.video_fc_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

        self.video_encoder = InternalTemporalRelationModule(input_dim=self.video_fc_dim, d_model=self.d_model)
        self.video_decoder = CrossModalRelationAttModule(input_dim=self.video_fc_dim, d_model=self.d_model)
        self.audio_encoder = InternalTemporalRelationModule(input_dim=128, d_model=self.d_model)
        self.audio_decoder = CrossModalRelationAttModule(input_dim=128, d_model=self.d_model)

        self.AVInter = AudioVideoInter(self.d_model, n_head=2, head_dropout=0.1)
        self.localize_module = WeaklyLocalizationModule(self.d_model)


    def forward(self, visual_feature, audio_feature):
        # [batch, 10, 512]
        # this fc is optinal, that is used for adaption of different visual features (e.g., vgg, resnet).
        audio_feature = audio_feature.transpose(1, 0).contiguous()
        visual_feature = self.v_fc(visual_feature)
        visual_feature = self.dropout(self.relu(visual_feature))

        # spatial-channel attention
        visual_feature=visual_feature.unsqueeze(3)
        visual_feature = self.spatial_channel_att(visual_feature, audio_feature)
        visual_feature = visual_feature.transpose(1, 0).contiguous()

        # audio query
        video_key_value_feature = self.video_encoder(visual_feature)
        audio_query_output = self.audio_decoder(audio_feature, video_key_value_feature)

        # video query
        audio_key_value_feature = self.audio_encoder(audio_feature)
        video_query_output = self.video_decoder(visual_feature, audio_key_value_feature)
        
        video_query_output= self.AVInter(video_query_output, audio_query_output)
        #scores = self.localize_module(video_query_output)

        return video_query_output
    



aud= np.load(r"C:\Users\karth\Downloads\audiofeats\train\Your.Name.2016__#01-22-20_01-24-05_label_A.npy")
vid=np.load(r"C:\Users\karth\Downloads\XDVioDet-master\XDVioDet-master\tratrain\Your.Name.2016__#01-22-20_01-24-05_label_A.npy")
aud = np.repeat(aud[:, np.newaxis, :], 5, axis=1)
# aud = np.repeat(aud[:, :, :], 8, axis=2)
aud_tensor = torch.from_numpy(aud)
vid_tensor = torch.from_numpy(vid)

aud_tensor = aud_tensor.permute(1, 0, 2)

vid_tensor = vid_tensor.permute(1, 0, 2)


model=weak_main_model()
out=model(vid_tensor,aud_tensor)
print(out)
print(out.shape)


torch.Size([5, 157, 1024])
torch.Size([785, 128])
torch.Size([785, 1024])
torch.Size([785, 1024])
torch.Size([5, 157, 1024, 1])
tensor([[[ 0.2326,  1.6311,  0.5422,  ..., -1.2594,  0.3386, -0.7474],
         [-0.7677,  1.5403,  1.0024,  ..., -1.2624,  0.8253, -0.6294],
         [-0.8642,  0.7435,  0.1040,  ..., -0.4903,  0.1975, -0.4469],
         [-0.7709,  1.0024,  1.0777,  ..., -1.0989, -0.5271, -0.7881],
         [-0.7686,  1.7643,  0.9162,  ..., -0.9813,  0.8918, -0.5953]],

        [[-0.5472,  2.0676,  0.8976,  ..., -0.9731, -0.0681, -0.3853],
         [-0.6023,  1.6399,  0.8768,  ..., -1.4973, -0.0995, -0.4158],
         [-0.8948,  1.8757,  0.6564,  ..., -1.3329,  0.2155, -1.1678],
         [-1.0295, -0.3664,  0.7256,  ..., -1.4664, -0.2354, -0.5229],
         [-0.8598,  1.8328,  0.2853,  ..., -1.4059, -0.3516, -1.0934]],

        [[-0.1162,  0.7674,  1.9938,  ..., -1.6808,  0.5101, -0.7139],
         [-0.6882,  0.9703,  2.1789,  ..., -1.3561,  0.1778, -0.4513],
         [-0.414