In [1]:
from src.replicate import get_audio_segments
from transformers import Wav2Vec2ForAudioFrameClassification
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForAudioFrameClassification
import jax.numpy as jnp
import numpy as np
import torch
from loguru import logger

In [2]:
audio_file = "/dev/shm/namo_speeches/wdtp42dMQkc.wav"
_, segments, _ = get_audio_segments(audio_file)

In [3]:
segments_torch = [torch.tensor(np.asarray(segment)) for segment in segments]

In [4]:
varied_segments = [segments[0], segments[1], segments[-1]]
varied_segments_torch = [segments_torch[0], segments_torch[1], segments_torch[-1]]

In [5]:
logger.info(f"Sanity check: {[segment.shape for segment in varied_segments]}")
logger.info(f"Sanity check (torch): {[segment.shape for segment in varied_segments_torch]}")

[32m2023-08-02 00:03:55.121[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mSanity check: [(2, 528000), (2, 576000), (2, 443736)][0m
[32m2023-08-02 00:03:55.123[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mSanity check (torch): [torch.Size([2, 528000]), torch.Size([2, 576000]), torch.Size([2, 443736])][0m


In [6]:
model = FlaxWav2Vec2ForAudioFrameClassification.from_pretrained("/home/khandelia1000/speech_alignment/mms_alignment_model", from_pt=True)
torch_model = Wav2Vec2ForAudioFrameClassification.from_pretrained("/home/khandelia1000/speech_alignment/mms_alignment_model").eval()

In [7]:
expected_output = [model(segment).logits for segment in varied_segments]
with torch.inference_mode():
    expected_output_torch = [torch_model(segment).logits for segment in varied_segments_torch]

In [11]:
expected_output[0].shape

(2, 1649, 31)

## test can we only send single channel input

In [27]:
test_output = model(jnp.expand_dims(varied_segments[0][0], axis=0))

In [24]:
test_output

FlaxMaskedLMOutput(logits=Array([[[  9.573059 , -23.436459 , -23.549742 , ...,  -3.077409 ,
          -4.258169 ,  -3.5700355],
        [  9.764209 , -24.141933 , -24.319227 , ...,  -3.166452 ,
          -4.5177855,  -3.5430603],
        [  9.780839 , -24.496546 , -24.746893 , ...,  -3.5633605,
          -4.8595405,  -3.7069871],
        ...,
        [  3.6285856, -19.138418 , -19.040558 , ...,  -2.500242 ,
          -3.600256 ,  -4.0279903],
        [  3.674955 , -19.257603 , -19.213037 , ...,  -2.3330367,
          -3.8337636,  -4.350674 ],
        [  3.7082536, -19.360212 , -19.371828 , ...,  -2.2507215,
          -3.873837 ,  -4.3663855]]], dtype=float32), hidden_states=None, attentions=None)

In [23]:
expected_output[0][0]

Array([[  9.569984 , -23.43094  , -23.54404  , ...,  -3.0728865,
         -4.1720247,  -3.5677764],
       [  9.778904 , -24.183798 , -24.358864 , ...,  -3.1577458,
         -4.4343953,  -3.5652158],
       [  9.80444  , -24.534115 , -24.776663 , ...,  -3.5720603,
         -4.7612867,  -3.6760964],
       ...,
       [  3.6246803, -19.143164 , -19.045126 , ...,  -2.4978912,
         -3.60352  ,  -4.043333 ],
       [  3.6692128, -19.255657 , -19.211208 , ...,  -2.3258145,
         -3.8385103,  -4.3620725],
       [  3.7103   , -19.379566 , -19.391365 , ...,  -2.2478476,
         -3.8891234,  -4.3864927]], dtype=float32)

In [26]:
jnp.max(expected_output[0][0] - test_output.logits)

Array(1.105238, dtype=float32)

In [25]:
jnp.allclose(expected_output[0][0], test_output.logits, atol=1e-1)

Array(False, dtype=bool)

In [32]:
with torch.inference_mode:
    torch_model.wav2vec2.feature_extractor(torch.tensor(np.asarray()))

Wav2Vec2ForAudioFrameClassification(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (1-4): 4 x Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affin

## use mask for batched inference

### first without attention mask

In [8]:
padded_input = torch.zeros(2, 576000)
padded_input[:, :varied_segments_torch[0].shape[-1]] = varied_segments_torch[0]

In [10]:
with torch.inference_mode():
    padded_output = torch_model(padded_input)

In [13]:
padded_output.logits.shape, expected_output_torch[0].shape

(torch.Size([2, 1799, 31]), torch.Size([2, 1649, 31]))

In [14]:
frame_count = expected_output_torch[0].shape[1]
torch.allclose(expected_output_torch[0][:, :750, :], padded_output.logits[:, :750, :], atol=1e-1)

False

In [16]:
torch.max(expected_output_torch[0][:, :1000, :] - padded_output.logits[:, :1000, :])

tensor(3.5335)

### now with attention mask

In [17]:
padded_input = torch.zeros(2, 576000)
padded_input[:, :varied_segments_torch[0].shape[-1]] = varied_segments_torch[0]

In [18]:
attention_mask = torch.zeros(2, 576000)
attention_mask[:, :varied_segments_torch[0].shape[-1]] = 1

In [19]:
with torch.inference_mode():
    padded_output = torch_model(padded_input, attention_mask=attention_mask)

In [23]:
frame_count = expected_output_torch[0].shape[1]
torch.allclose(expected_output_torch[0][:, :frame_count, :], padded_output.logits[:, :frame_count, :], atol=1e-3)

True

### now padding with attention mask using flax

In [26]:
padded_input = np.zeros((2, 576000))
padded_input[:, :varied_segments_torch[0].shape[-1]] = varied_segments_torch[0]
padded_input = jnp.asarray(padded_input)

In [30]:
attention_mask = np.zeros((2, 576000))
attention_mask[:, :varied_segments_torch[0].shape[-1]] = 1
attention_mask = jnp.asarray(attention_mask)

In [31]:
padded_output = model(padded_input, attention_mask=attention_mask)

In [32]:
padded_output.logits

Array([[[  9.569984 , -23.43094  , -23.54404  , ...,  -3.0728865,
          -4.1720247,  -3.5677764],
        [  9.778904 , -24.183798 , -24.358864 , ...,  -3.1577458,
          -4.4343953,  -3.5652158],
        [  9.80444  , -24.534115 , -24.776663 , ...,  -3.5720603,
          -4.7612867,  -3.6760964],
        ...,
        [  3.474959 , -19.955288 , -19.794453 , ...,  -2.830078 ,
          -3.4667163,  -3.1838577],
        [  3.474959 , -19.955288 , -19.794453 , ...,  -2.830078 ,
          -3.4667163,  -3.1838577],
        [  3.474959 , -19.955288 , -19.794453 , ...,  -2.830078 ,
          -3.4667163,  -3.1838577]],

       [[  9.609749 , -23.410473 , -23.534405 , ...,  -2.9833636,
          -4.216002 ,  -3.5279503],
        [  9.823359 , -24.165634 , -24.348644 , ...,  -3.042981 ,
          -4.466457 ,  -3.4635952],
        [  9.805192 , -24.40877  , -24.656698 , ...,  -3.4314072,
          -4.7778716,  -3.59754  ],
        ...,
        [  3.4676347, -19.934145 , -19.768532 , ...,  

In [34]:
frame_count = expected_output[0].shape[1]
jnp.allclose(expected_output[0][:, :frame_count, :], padded_output.logits[:, :frame_count, :], atol=1e-3)

Array(True, dtype=bool)

In [35]:
padded_output.logits[:, frame_count:, :]

Array([[[  3.6172302, -19.417892 , -19.459484 , ...,  -2.2876465,
          -3.8989465,  -4.268706 ],
        [  3.6149106, -19.461922 , -19.512568 , ...,  -2.3551292,
          -3.930071 ,  -4.3022866],
        [  3.6141102, -19.50552  , -19.576159 , ...,  -2.3266792,
          -3.9247055,  -4.282878 ],
        ...,
        [  3.474959 , -19.955288 , -19.794453 , ...,  -2.830078 ,
          -3.4667163,  -3.1838577],
        [  3.474959 , -19.955288 , -19.794453 , ...,  -2.830078 ,
          -3.4667163,  -3.1838577],
        [  3.474959 , -19.955288 , -19.794453 , ...,  -2.830078 ,
          -3.4667163,  -3.1838577]],

       [[  3.5924635, -19.371904 , -19.414928 , ...,  -2.2873223,
          -3.8779235,  -4.2652316],
        [  3.5869386, -19.411625 , -19.463875 , ...,  -2.3564184,
          -3.9051268,  -4.294119 ],
        [  3.587076 , -19.47328  , -19.543835 , ...,  -2.3340583,
          -3.9072373,  -4.2731466],
        ...,
        [  3.4676347, -19.934145 , -19.768532 , ...,  