In [1]:
# default_exp models.glp

%reload_ext autoreload
%autoreload 2

In [10]:
import sys
sys.path.append('../')

# models.glp

Models related to Genomic Language Processing

In [2]:
#hide
#export

import pandas as pd
import numpy as np

from fastai.text.all import *

## layers

The output of the ProtBert models need a special type of pooling called `masked_concat_pooling` this concatenates:

  - The last unmasked token (the final state of the model).
  - The max value of each feature along the unmasked tokens.
  - The mean value of each feature along the unmasked tokens.
  
So, for each (batch, seq_len, feature) items we should get a (batch, 3\*feature) tensor. This was taken from the Fast.ai library code and refactored to account for weirdness I often saw.

`bqtt` refers the max sequence length (which may or may not be the same as the `seq_len` dimension)

In [3]:
# export

def masked_concat_pool(output, mask, bptt):
    "Pool `MultiBatchEncoder` outputs into one vector [last_hidden, max_pool, avg_pool]"
    # True in mask implies MASKED and will be hidden!
    
    lens = output.shape[1] - mask.long().sum(dim=1)
    last_lens = mask[:,-bptt:].long().sum(dim=1)
    avg_pool = output.masked_fill(mask[:, :, None], 0).sum(dim=1)
    avg_pool.div_(lens.type(avg_pool.dtype)[:,None])
    max_pool = output.masked_fill(mask[:,:,None], -float('inf')).max(dim=1)[0]
    
    last_hidden = output[torch.arange(0, output.size(0)),-last_lens-1]
    x = torch.cat([last_hidden, 
                   max_pool, avg_pool], 1) #Concat pooling.
    x = torch.where(torch.isnan(x) | torch.isinf(x), torch.zeros_like(x), x)

    return x

In [4]:
feat_seq_1 = [[1, 5],
              [2, 6],
              [3, 7],
              [4, 8]]

feat_seq_2 = [[6, 3], 
              [4, 7],
              [5, 8],
              [0, 0]]

masked_1 = [False, False, False, False]
masked_2 = [False, False, True, True]

feats = tensor([feat_seq_1, feat_seq_1]).type(torch.float)
attn = tensor([masked_1, masked_2]).type(torch.bool)
bptt = 3
feats, attn, bptt

(tensor([[[1., 5.],
          [2., 6.],
          [3., 7.],
          [4., 8.]],
 
         [[1., 5.],
          [2., 6.],
          [3., 7.],
          [4., 8.]]]),
 tensor([[False, False, False, False],
         [False, False,  True,  True]]),
 3)

In [5]:
#        last  max    mean
cor_1 = [4, 8, 4, 8, 2.5, 6.5]
cor_2 = [2, 6, 2, 6, 1.5, 5.5]

cor = tensor([cor_1, cor_2])
test_eq(masked_concat_pool(feats, attn, bptt), cor)

This is also encapsulated into a Layer for easier use.

In [6]:
# export

class MaskedConcatPooling(Module):
    
    def __init__(self, seq_len = None, mask_is_attention = False):
        
        if seq_len is None:
            self.bptt = None
        else:
            self.bptt = seq_len - 1 
        self.mask_is_attention = mask_is_attention
    
    
    def forward(self, x):
        
        x, mask = x
        
        if self.bptt is None:
            bptt = mask.shape[1]
        else:
            bptt = self.bptt
       
        if self.mask_is_attention:
            return masked_concat_pool(x, mask==False, bptt)
        else:
            return masked_concat_pool(x, mask, bptt)
                

In [7]:
pooler = MaskedConcatPooling(seq_len = 4)
test_eq(pooler((feats, attn)), cor)

In [8]:
pooler = MaskedConcatPooling(seq_len = None)
test_eq(pooler((feats, attn)), cor)

Oftentimes it is useful to pass tokens and masks through the model and collect the pooled outputs without any gradients.

In [16]:
# export

def model_mask_pooting(input_ids, attention_mask, model, bs = 32):
    
    with torch.no_grad():
            
        if bs is not None:
            out = []
            for start in range(0, input_ids.shape[0], bs):            
                res = model(input_ids = input_ids[start:start+bs],
                            attention_mask = attention_mask[start:start+bs])
                out.append(masked_concat_pool(res[0], attention_mask[start:start+bs], input_ids.shape[1]-1))
            return torch.vstack(out)
        else:
            res = self.model(input_ids = input_ids,
                             attention_mask = attention)
            return masked_concat_pool(res[0], attention, 
                                      input_ids.shape[1]-1)

In [11]:
from transformers import AutoModel


model_name = 'Rostlab/prot_bert'
model = AutoModel.from_pretrained(model_name)

In [14]:
inp = [(fastuple(tensor([ 2, 21, 11,  8,  5, 3,  0]), tensor([ 1, 1, 1,  1,  1, 1,  0])))]*50

input_ids, atten_mask = zip(*inp)
input_ids = torch.vstack(input_ids)
atten_mask = torch.vstack(atten_mask)


In [17]:
res = model_mask_pooting(input_ids, atten_mask, model, bs = 32)

In [20]:
test_eq(res.shape, (50, 3072))