In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import math
import random
import sys
from copy import deepcopy
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from ay2.tools import freeze_modules
from ay2.torch.nn import LambdaFunctionModule
from einops import rearrange

In [3]:
from torchvision.transforms import v2

In [7]:
try:
    from ..Aaasist.Aaasist.load_model import get_model  as load_AASIST
    from ..WaveLM.wavlm import BaseLine as WavLM
except ImportError:
    sys.path.append("../Aaasist")
    sys.path.append("../WaveLM")
    from Aaasist.load_model import get_model as load_AASIST
    from wavlm import BaseLine as WavLM

# 1D models

In [8]:
class WavLM_1D(nn.Module):
    def __init__(
        self,
    ):
        super().__init__()
        self.model1D = WavLM()
        self.n_dim = 768

    def forward(self, x):
        if x.ndim == 3:
            x = x[:, 0,:]
        feature = self.model1D.pretrain_model(x)[self.model1D.pretrain_feat] # (B, T, 768)
        return feature.mean(1)

In [9]:
model = WavLM_1D()
x = torch.randn(2, 1, 48000)
model(x).shape

  return self.fget.__get__(instance, owner)()
Some weights of the model checkpoint at /usr/local/ay_data/0-model_weights/microsoft_wavlm-base were not used when initializing WavLMModel: ['encoder.pos_conv_embed.conv.weight_g', 'encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing WavLMModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing WavLMModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of WavLMModel were not initialized from the model checkpoint at /usr/local/ay_data/0-model_weights/microsoft_wavlm-base and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight

torch.Size([2, 768])

In [10]:
model1D = WavLM()

Some weights of the model checkpoint at /usr/local/ay_data/0-model_weights/microsoft_wavlm-base were not used when initializing WavLMModel: ['encoder.pos_conv_embed.conv.weight_g', 'encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing WavLMModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing WavLMModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of WavLMModel were not initialized from the model checkpoint at /usr/local/ay_data/0-model_weights/microsoft_wavlm-base and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this mo

In [17]:
x = torch.randn(2, 48000)
feat = model1D.pretrain_model.feature_extractor(x)

extract_features = model1D.pretrain_model.feature_extractor(x)
extract_features = extract_features.transpose(1, 2)

# 输出的extract_features其实就是输入的layer norm
hidden_states, extract_features = model1D.pretrain_model.feature_projection(extract_features)

encoder_outputs = model1D.pretrain_model.encoder(
    hidden_states,
    attention_mask=None,
    output_attentions=False,
    output_hidden_states=False,
    return_dict=False,
)
hidden_states = encoder_outputs[0]

In [18]:
hidden_states.shape

torch.Size([2, 149, 768])

In [20]:
from einops import rearrange

In [40]:
class CrossAttention2D(nn.Module): 
    def __init__(self, time_dim, spec_dim, feature_dim):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)

        self.conv1 = nn.Conv2d(in_channels=time_dim, out_channels=feature_dim, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels=spec_dim, out_channels=feature_dim, kernel_size=1)
        self.feature_dim = feature_dim

    def forward(self, waveform, spectrogram):
        query = self.conv1(waveform).permute(0,2,3,1) 
        key = self.conv2(spectrogram).permute(0,2,3,1) 
        value = spectrogram.permute(0,2,3,1)

        attn_weights = self.softmax(torch.matmul(query, key.transpose(-2, -1)) / (self.feature_dim ** 0.5))
        out = torch.matmul(attn_weights, value).permute(0,3,1,2)
        
        return out

In [35]:
class Expand(nn.Module):
    def __init__(self, time_len = 149, time_dim=768, spec_height=56, spec_width=56, spec_dim=512):
        super().__init__()

        self.time_len = time_len
        self.time_dim = time_dim
        self.spec_height = spec_height
        self.spec_width = spec_width
        self.spec_dim = spec_dim
        
        self.conv1 = nn.Conv1d(in_channels=time_len, out_channels=spec_height*spec_width, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels=time_dim, out_channels=spec_dim, kernel_size=3, padding=1)

        self.attn = CrossAttention(time_dim=time_dim, spec_dim=spec_dim, feature_dim=spec_dim)
    
    def forward(self, x, y):
        x = self.conv1(x) # [B, spec_H * spec_W, time_dim]
        x = rearrange(x, 'b (h w) c -> b c h w', h=self.spec_height, w=self.spec_width) ## [B, time_dim, spec_H, spec_W]
        res = self.attn(x, y)
        return res

In [None]:
module = Expand()

B = 3
waveform = torch.rand((B, 149, 768))  # Replace with actual input
spectrogram = torch.rand((B, 512, 56, 56))  # Replace with actual input


module(waveform, spectrogram).shape

In [51]:
class CrossAttention1D(nn.Module): 
    def __init__(self, time_dim, spec_dim, feature_dim):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)

        self.linear1 = nn.Linear(time_dim, feature_dim)
        self.linear2 = nn.Linear(spec_dim, feature_dim)
        self.feature_dim = feature_dim

    def forward(self, waveform, spectrogram):
        """
        Args:
            waveform: (B, time_len, time_dim)
            spectrogram: (B, time_len, spec_dim)
        
        """
        key = self.linear1(waveform) ##  (B, time_len, feature_dim)
        query = self.linear2(spectrogram) ##  (B, time_len, feature_dim)
        value = waveform

        attn_weights = self.softmax(torch.matmul(query, key.transpose(-2, -1)) / (self.feature_dim ** 0.5))
        out = torch.matmul(attn_weights, value) ##  (B, time_len, feature_dim)
        
        return out

In [52]:
class Squeeze(nn.Module):
    def __init__(self, time_len = 149, time_dim=768, spec_height=56, spec_width=56, spec_dim=512):
        super().__init__()

        self.time_len = time_len
        self.time_dim = time_dim
        self.spec_height = spec_height
        self.spec_width = spec_width
        self.spec_dim = spec_dim

        ### used to convert spec into waveform
        self.linear = nn.Linear(spec_height*spec_width, time_len)

        self.attn = CrossAttention1D(time_dim=time_dim, spec_dim=spec_dim, feature_dim=spec_dim)
    
    def forward(self, x, y):

        y = rearrange(y, 'b c h w -> b c (h w)')
        y = self.linear(y) ### # [B, time_len, spec_dim]
        y = rearrange(y, 'b c l -> b l c')
        res = self.attn(x, y)
        return res

In [53]:
module = Squeeze()

B = 3
waveform = torch.rand((B, 149, 768))  # Replace with actual input
spectrogram = torch.rand((B, 512, 56, 56))  # Replace with actual input


module(waveform, spectrogram).shape

torch.Size([3, 149, 768])