In [1]:
from transformers import BertForSequenceClassification
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=256, problem_type="multi_label_classification")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
embeddings = torch.zeros((1, 4, 768))

out = model(inputs_embeds=embeddings)

out.logits

tensor([[ 0.4179, -0.6047, -0.0948, -0.0974,  0.0524, -0.0569,  0.2535,  0.2742,
         -0.7018,  0.2145, -0.1129,  0.3524, -0.3026, -0.0278,  0.1961, -0.2796,
          0.1969,  0.1593,  0.0839,  0.3601,  0.0644,  0.1702,  0.3429, -0.2367,
          0.0768, -0.1154, -0.0511,  0.2031,  0.3877, -0.4218, -0.4209, -0.2244,
          0.3164, -0.0452,  0.4087, -0.3593,  0.2016, -0.0212, -0.4086, -0.4624,
         -0.0613,  0.0263,  0.0206,  0.2927, -0.1594,  0.1450, -0.1615,  0.2477,
          0.3275,  0.0871,  0.6448, -0.5960, -0.1620, -0.4100, -0.1182,  0.4707,
         -0.3884, -0.1833,  0.2390,  0.0340,  0.1669, -0.0933, -0.1295, -0.2809,
         -0.0256, -0.2331,  0.3238,  0.0989,  0.1943,  0.0863,  0.2977, -0.5551,
          0.2418,  0.0092, -0.2980, -0.2051, -0.2934,  0.5690, -0.3704, -0.1538,
         -0.5173,  0.0311, -0.0684,  0.2077, -0.0089, -0.2699, -0.5253,  0.1557,
         -0.0649,  0.4046,  0.0069,  0.3059, -0.2447, -0.1999, -0.2108,  0.1310,
          0.0512, -0.0582, -

In [3]:
from torch import nn, Tensor
import torch
import math
from typing import List
from einops import rearrange


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe[:,:x.size(1)]
        return self.dropout(x)

class TransformerDecoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        d_model = 768
        nhead = 2
        num_layers = 2
        dropout = 0.5
        n_classes = 256

        self.pos_encoder = PositionalEncoding(d_model, dropout)

        decoder_layer = torch.nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer_decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        self.linear = nn.Linear(d_model, 1)

        self.query_embed = nn.Embedding(n_classes, d_model)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)


    def forward(self, embeddings: List[Tensor]) -> Tensor:
        """
        Arguments:
            src: Tensor, shape ``[seq_len, batch_size]``

        Returns:
            output Tensor of shape ``[seq_len, batch_size, ntoken]``
        """

        out_batch = []

        for sample_emb in embeddings:
            sample_emb = sample_emb.unsqueeze(0)
            sample_emb = self.pos_encoder(sample_emb)
            output = self.transformer_decoder(self.query_embed.weight.unsqueeze(0), sample_emb)
            output = self.linear(output).view(1, -1)

            out_batch.append( output )

        out_batch = torch.cat(out_batch, dim=0)

        return out_batch
    

device = torch.device(0)
model = TransformerDecoder().to(device)
embeddings = [torch.rand(8, 768, device=device), torch.rand(3, 768, device=device)]

out = model(embeddings)

In [29]:
from typing import Any
from transformers import WhisperModel
import torch
from torch import nn
import math

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe[:,:x.size(1)]
        return self.dropout(x)

class Whisper(nn.Module):
    def __init__(self, freeze) -> None:
        super().__init__()
        self.model = WhisperModel.from_pretrained("openai/whisper-base")

        if freeze:
            for param in self.model.parameters():
                param.requires_grad = False

        self.embed_proj = nn.Linear(768, 512)
        self.pos_encoder = PositionalEncoding(512, 0.1)
        self.decoder_inputs_embeds = nn.Embedding(256, 512)
        self.output_proj = nn.Linear(512, 1)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embed_proj.bias.data.zero_()
        self.embed_proj.weight.data.uniform_(-initrange, initrange)

        self.output_proj.bias.data.zero_()
        self.output_proj.weight.data.uniform_(-initrange, initrange)

    def forward(self, embeddings):
        out_batch = []

        for sample_emb in embeddings:
            sample_emb = sample_emb.unsqueeze(0)
            sample_emb = self.embed_proj(sample_emb)
            sample_emb = self.pos_encoder(sample_emb)
            output = self.model(encoder_outputs=(sample_emb,), decoder_inputs_embeds=self.decoder_inputs_embeds.weight.unsqueeze(0)).last_hidden_state
            output = self.output_proj(output).view(1, -1)

            out_batch.append( output )

        out_batch = torch.cat(out_batch, dim=0)

        return out_batch

device = torch.device('cpu')
model = Whisper(True).to(device)

embeddings = [torch.rand(8, 768, device=device), torch.rand(3, 768, device=device)]
out = model(embeddings)

torch.Size([1, 256, 512])
torch.Size([1, 8, 512])
torch.Size([1, 256, 512])
torch.Size([1, 3, 512])


In [None]:
embeds = torch.rand(1, 8, 512, device=device)
decoder_inputs_embeds = torch.rand(1, 256, 512)

output = model(encoder_outputs=(embeds,), decoder_inputs_embeds=decoder_inputs_embeds)

output.last_hidden_state.shape