In [2]:
!pip install transformers datasets einops timm --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m51.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m102.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m123.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [4]:
from einops.layers.torch import Rearrange
from einops import reduce, rearrange, repeat

In [27]:
import torch
import transformers

class MyWav2Vec2Encoder(transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder):
  def __init__(self, config, channels=256, d_model = 1024, d_k=16, d_v=32, n_heads=8, n_blocks=12, layers_to_drop=[]):
    super().__init__(config)
    self.score_predictor = nn.ModuleList([DropPredictor(d_model) for _ in range(n_blocks)])
    self.layers = nn.ModuleList([TransformerBlock(d_model, d_k, d_v, n_heads) for _ in range(n_blocks)])

  def forward(
      self,
      hidden_states,
      attention_mask=None,
      output_attentions=False,
      output_hidden_states=False,
      return_dict=True, drop_temp = 1,
      ):
    
    # Initialize drop decisions
    B, P, _ = hidden_states.shape
    prev_decision = torch.ones(B, P, 1, dtype=hidden_states.dtype, device=hidden_states.device)
    policy = torch.ones(B, P, 1, dtype=hidden_states.dtype, device=hidden_states.device)
    
    out_pred_prob = []
    pred_distr = [[],[],[],[]]

    for i, l in enumerate(self.layers):
      if i in [4, 7, 11]:  
        
        # Current drop score
        pred_score = self.score_predictor[i](hidden_states, prev_decision)#.reshape(B, -1, 2)
        keepall = torch.cat((torch.zeros_like(pred_score[:,:,0:1]), torch.ones_like(pred_score[:,:,1:2])),2) 
        pred_score = pred_score*drop_temp + keepall*(1-drop_temp)

        if True: #self.training:

          # Convert to log-prob
          pred_score = torch.log(pred_score + 1e-8)
          
          # Sample mask and update previous one
          hard_keep_decision = F.gumbel_softmax(pred_score, hard = True)[:, :, 1:2]*prev_decision
        
        else:
          
          # Treshold mask and update previous one
          hard_keep_decision = (pred_score[:, :, 1:2] > 0.9).float() * prev_decision
        
        #cls_policy = torch.ones(B, 1, 1, dtype=hard_keep_decision.dtype, device=hard_keep_decision.device)
        #policy = torch.cat([cls_policy, hard_keep_decision], dim=1)
        policy = hard_keep_decision
        prev_decision = hard_keep_decision
        
      hidden_states = l(hidden_states, policy = policy)

    if return_dict:
      return {
          "last_hidden_state": hidden_states,
          "hidden_states": None,  
          "attentions": None,  
            }
    else:
      return hidden_states


In [6]:
class DropPredictor(nn.Module):
    """ Computes the log-probabilities of dropping a token, adapted from PredictorLG here:
    https://github.com/raoyongming/DynamicViT/blob/48ac52643a637ed5a4cf7c7d429dcf17243794cd/models/dyvit.py#L287 """
    def __init__(self, embed_dim):
        super().__init__()
        self.in_conv = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim),
            nn.GELU()
        )

        self.out_conv = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.GELU(),
            nn.Linear(embed_dim // 2, embed_dim // 4),
            nn.GELU(),
            nn.Linear(embed_dim // 4, 2),
            nn.Softmax(dim=-1)
        )

    def forward(self, x, policy):
        x = self.in_conv(x)
        B, N, C = x.size()
        local_x = x[:,:, :C//2]
        global_x = (x[:,:, C//2:] * policy).sum(dim=1, keepdim=True) / (torch.sum(policy, dim=1, keepdim=True)+0.000001)
        x = torch.cat([local_x, global_x.expand(B, N, C//2)], dim=-1)
        return self.out_conv(x)

In [7]:
class TransformerBlock(nn.Module):
  """ A more-or-less standard transformer block. """
  def __init__(self, d_model, d_k, d_v, n_heads, dropout=0.1):
    super().__init__()
    self.sa = MultiHeadAttentionNew(n_heads, d_model, d_k, d_v, dropout=dropout)
    self.ln1 = nn.LayerNorm(d_model)
    self.ln2 = nn.LayerNorm(d_model)
    self.ff = nn.Sequential(
        nn.Linear(d_model, d_model*2),
        nn.GELU(),
        nn.Linear(d_model*2, d_model)
    )

  def forward(self, x, policy=None):
    x = self.sa(self.ln1(x), policy=policy)[0].permute(0,2,1) + x
    x = self.ff(self.ln2(x)) + x
    return x

In [8]:
class MultiHeadAttentionNew(nn.Module):
    """ Multihead attention from here: https://einops.rocks/pytorch-examples.html 
    Useful if we want to further modify the model """
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()
        self.n_head = n_head

        self.w_qs = nn.Conv1d(d_model, n_head * d_k, kernel_size=1, bias=False)
        self.w_ks = nn.Conv1d(d_model, n_head * d_k, kernel_size=1, bias=False)
        self.w_vs = nn.Conv1d(d_model, n_head * d_v, kernel_size=1, bias=False)
        
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
        
        self.fc = nn.Conv1d(n_head * d_v, d_model, kernel_size=1, bias=False)
        nn.init.xavier_normal_(self.fc.weight)
        self.dropout = nn.Dropout(p=dropout)

    def softmax_with_policy(self, attn, policy, eps=1e-6):
        # This is the masked softmax (Eq. (11) in the paper), 
        # taken from here: https://github.com/raoyongming/DynamicViT/blob/master/models/dyvit.py
        B, N, _ = policy.size()
        B, H, N, N = attn.size()

        attn_policy = policy.reshape(B, 1, 1, N)  # * policy.reshape(B, 1, N, 1)
        eye = torch.eye(N, dtype=attn_policy.dtype, device=attn_policy.device).view(1, 1, N, N)
        attn_policy = attn_policy + (1.0 - attn_policy) * eye
        max_att = torch.max(attn, dim=-1, keepdim=True)[0]
        attn = attn - max_att
        # attn = attn.exp_() * attn_policy
        # return attn / attn.sum(dim=-1, keepdim=True)

        # for stable training
        attn = attn.to(torch.float32).exp_() * attn_policy.to(torch.float32)
        attn = (attn + eps/N) / (attn.sum(dim=-1, keepdim=True) + eps)
        return attn.type_as(max_att)

    def forward(self, x, policy=None):

        x = x.permute(0,2,1)

        # x (batch, tokens, features) are the tokens.
        # policy (batch, tokens, 1) is a boolean mask denoting which tokens we should remove from the computation.
        q = rearrange(self.w_qs(x), 'b (head k) t -> b head t k', head=self.n_head)
        k = rearrange(self.w_ks(x), 'b (head k) t -> b head t k', head=self.n_head)
        v = rearrange(self.w_vs(x), 'b (head v) t -> b head t v', head=self.n_head)
        attn = torch.einsum('bhlk,bhtk->bhlt', [q, k]) / np.sqrt(q.shape[-1])
        if policy is None:
            attn = attn.softmax(dim=-1)
        else:
            attn = self.softmax_with_policy(attn, policy)
        output = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
        output = rearrange(output, 'b head l v -> b (head v) l')
        output = self.dropout(self.fc(output))
        return output, attn

In [9]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Config

config = Wav2Vec2Config.from_pretrained('facebook/wav2vec2-large-960h')
model = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-large-960h', config=config)
#model.encoder = MyWav2Vec2Encoder(config)  # replace the encoder with the new one

Downloading (…)lve/main/config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [29]:
model.wav2vec2.encoder = MyWav2Vec2Encoder(config)

In [11]:
model.wav2vec2.encoder

MyWav2Vec2Encoder(
  (pos_conv_embed): Wav2Vec2PositionalConvEmbedding(
    (conv): Conv1d(1024, 1024, kernel_size=(128,), stride=(1,), padding=(64,), groups=16)
    (padding): Wav2Vec2SamePadLayer()
    (activation): GELUActivation()
  )
  (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x TransformerBlock(
      (sa): MultiHeadAttentionNew(
        (w_qs): Conv1d(256, 128, kernel_size=(1,), stride=(1,), bias=False)
        (w_ks): Conv1d(256, 128, kernel_size=(1,), stride=(1,), bias=False)
        (w_vs): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
        (fc): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=256,

In [12]:
from datasets import load_dataset

In [13]:
dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
dataset = dataset.sort("id")
sampling_rate = dataset.features["audio"].sampling_rate

Downloading builder script:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

Downloading and preparing dataset librispeech_asr_demo/clean to /root/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_demo/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset librispeech_asr_demo downloaded and prepared to /root/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_demo/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b. Subsequent calls will reuse this data.


In [14]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")

Downloading (…)rocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/163 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.60k [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

In [20]:
# audio file is decoded on the fly
inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")


In [30]:
with torch.no_grad():
    logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)

KeyError: ignored

In [None]:
transcription = processor.batch_decode(predicted_ids)
transcription[0]