In [None]:
!pip install transformers datasets einops timm --quiet

In [2]:
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

In [3]:
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [4]:
from einops.layers.torch import Rearrange
from einops import reduce, rearrange, repeat

In [27]:
dpr = [x.item() for x in torch.linspace(0, 0., 12)] 

In [28]:
dpr

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

In [78]:
import torch
import transformers

class MyWav2Vec2Encoder(transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder):
  def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=1024, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, 
                 pruning_loc=None, token_ratio=None, distill=False):
    super().__init__(config)
    
    predictor_list = [PredictorLG(embed_dim) for _ in range(len(pruning_loc))]

    self.score_predictor = nn.ModuleList(predictor_list)

    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule

    self.layers = nn.ModuleList([Block(dim = embed_dim, num_heads = num_heads, drop_path = dpr[i]) for i in range(depth)])
    

    self.pruning_loc = pruning_loc

  def forward(
      self,
      hidden_states,
      attention_mask=None,
      output_attentions=False,
      output_hidden_states=False,
      return_dict=True, drop_temp = 1,
      ):
    all_hidden_states = () if output_hidden_states else None
    all_self_attentions = () if output_attentions else None
    
    # Initialize drop decisions
    B = hidden_states.shape[0]   
    p_count = 0
    out_pred_prob = []
    init_n = 14 * 14
    prev_decision = torch.ones(B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device)
    policy = torch.ones(B, init_n + 1, 1, dtype=hidden_states.dtype, device=hidden_states.device)

    for i, blk in enumerate(self.layers):
      if i in self.pruning_loc:
        spatial_x = hidden_state[:, 1:]
        pred_score = self.score_predictor[p_count](spatial_x, prev_decision).reshape(B, -1, 2)
        if self.training:
          hard_keep_decision = F.gumbel_softmax(pred_score, hard=True)[:, :, 0:1] * prev_decision
          out_pred_prob.append(hard_keep_decision.reshape(B, init_n))
          cls_policy = torch.ones(B, 1, 1, dtype=hard_keep_decision.dtype, device=hard_keep_decision.device)
          policy = torch.cat([cls_policy, hard_keep_decision], dim=1)
          hidden_states = blk(hidden_states, policy=policy)
          prev_decision = hard_keep_decision
        else:
          score = pred_score[:,:,0]
          num_keep_node = int(init_n * self.token_ratio[p_count])
          keep_policy = torch.argsort(score, dim=1, descending=True)[:, :num_keep_node]
          cls_policy = torch.zeros(B, 1, dtype=keep_policy.dtype, device=keep_policy.device)
          now_policy = torch.cat([cls_policy, keep_policy + 1], dim=1)
          hidden_states = batch_index_select(hidden_states, now_policy)
          prev_decision = batch_index_select(prev_decision, keep_policy)
          hidden_states = blk(hidden_states)
        p_count += 1
      else:
        if self.training:
          hidden_states = blk(hidden_states, policy)
        else:
          hidden_states = blk(hidden_states)
    hidden_states = self.norm(hidden_states)
    features = hidden_states[:, 1:]
    hidden_state = hidden_states[:, 0]
    hidden_states = self.pre_logits(hidden_states)
    hidden_states = self.head(hidden_states)

        

    
    if return_dict:
      return {
          "last_hidden_state": hidden_states,
          "hidden_states": None,  
          "attentions": None,  
            }
    else:
      return hidden_states


In [24]:
def batch_index_select(x, idx):
    if len(x.size()) == 3:
        B, N, C = x.size()
        N_new = idx.size(1)
        offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N
        idx = idx + offset
        out = x.reshape(B*N, C)[idx.reshape(-1)].reshape(B, N_new, C)
        return out
    elif len(x.size()) == 2:
        B, N = x.size()
        N_new = idx.size(1)
        offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N
        idx = idx + offset
        out = x.reshape(B*N)[idx.reshape(-1)].reshape(B, N_new)
        return out
    else:
        raise NotImplementedError

In [7]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

In [8]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def softmax_with_policy(self, attn, policy, eps=1e-6):
        B, N, _ = policy.size()
        B, H, N, N = attn.size()
        attn_policy = policy.reshape(B, 1, 1, N)  # * policy.reshape(B, 1, N, 1)
        eye = torch.eye(N, dtype=attn_policy.dtype, device=attn_policy.device).view(1, 1, N, N)
        attn_policy = attn_policy + (1.0 - attn_policy) * eye
        max_att = torch.max(attn, dim=-1, keepdim=True)[0]
        attn = attn - max_att
        # attn = attn.exp_() * attn_policy
        # return attn / attn.sum(dim=-1, keepdim=True)

        # for stable training
        attn = attn.to(torch.float32).exp_() * attn_policy.to(torch.float32)
        attn = (attn + eps/N) / (attn.sum(dim=-1, keepdim=True) + eps)
        return attn.type_as(max_att)

    def forward(self, x, policy):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale

        if policy is None:
            attn = attn.softmax(dim=-1)
        else:
            attn = self.softmax_with_policy(attn, policy)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [72]:
import tensorflow as tf

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU):
        super().__init__()
        #tensor = tf.constant([1024,]) 
        #norm_layer = nn.LayerNorm(tensor, eps=1e-05, elementwise_affine=True)
        norm_layer = nn.LayerNorm
        self.norm1 = norm_layer(1024)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, policy=None):
        x = x + self.drop_path(self.attn(self.norm1(x), policy=policy))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

In [10]:
class PredictorLG(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, embed_dim=384):
        super().__init__()
        self.in_conv = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim),
            nn.GELU()
        )

        self.out_conv = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.GELU(),
            nn.Linear(embed_dim // 2, embed_dim // 4),
            nn.GELU(),
            nn.Linear(embed_dim // 4, 2),
            nn.LogSoftmax(dim=-1)
        )

    def forward(self, x, policy):
        x = self.in_conv(x)
        B, N, C = x.size()
        local_x = x[:,:, :C//2]
        global_x = (x[:,:, C//2:] * policy).sum(dim=1, keepdim=True) / torch.sum(policy, dim=1, keepdim=True)
        x = torch.cat([local_x, global_x.expand(B, N, C//2)], dim=-1)
        return self.out_conv(x)

In [11]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Config

config = Wav2Vec2Config.from_pretrained('facebook/wav2vec2-large-960h')
model = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-large-960h', config=config)
#model.encoder = MyWav2Vec2Encoder(config)  # replace the encoder with the new one

Downloading (…)lve/main/config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [65]:
norm = nn.LayerNorm((1024,), eps=config.layer_norm_eps)

In [None]:
norm()

In [79]:
model.wav2vec2.encoder = MyWav2Vec2Encoder(config, pruning_loc = [4,7,11])

In [31]:
model.wav2vec2.encoder

MyWav2Vec2Encoder(
  (pos_conv_embed): Wav2Vec2PositionalConvEmbedding(
    (conv): Conv1d(1024, 1024, kernel_size=(128,), stride=(1,), padding=(64,), groups=16)
    (padding): Wav2Vec2SamePadLayer()
    (activation): GELUActivation()
  )
  (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=False)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_featu

In [32]:
from datasets import load_dataset

In [33]:
dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
dataset = dataset.sort("id")
sampling_rate = dataset.features["audio"].sampling_rate

Downloading builder script:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

Downloading and preparing dataset librispeech_asr_demo/clean to /root/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_demo/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset librispeech_asr_demo downloaded and prepared to /root/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_demo/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b. Subsequent calls will reuse this data.


In [36]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")

Downloading (…)rocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/163 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.60k [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

In [37]:
# audio file is decoded on the fly
inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")


In [80]:
with torch.no_grad():
    logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)

RuntimeError: ignored

In [None]:
transcription = processor.batch_decode(predicted_ids)
transcription[0]

'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'