In [1]:
# default_exp transforms.sequence

%reload_ext autoreload
%autoreload 2

# transforms.sequence

Transforms useful for processing sequence data.

In [2]:
import sys
sys.path.append('../')

In [3]:
#hide
#export

from itertools import islice
import pandas as pd
import numpy as np

from fastai.text.all import *


## HuggingFace interface

Many of the leading BERT embedding models are distributed as HuggingFace models.

The `Pipeline` and `Transforms` below are used to help bridge the gap between fast.ai and HuggingFace. 

### tokenizers

First, we have to deal with the tokenizer. As input it takes space-delimited AA sequences

In [4]:
#export

def space_adder(seq):
    
    return ' '.join(seq)


class SpaceTransform(Transform):
    """Adds spaces between AAs for HuggingFace"""
    
    def encodes(self, x):
        if type(x) == str:
            return space_adder(x)
        
        return L(space_adder(seq) for seq in x)
    
    def decodes(self, x):
        
        if type(x) == str:
            return 
        
        return [seq.replace(' ', '') for seq in x]
        

In [5]:
test_eq(space_adder('MIVLR'), 'M I V L R')

In [6]:
space_tfm = SpaceTransform()

pipe = Pipeline([space_tfm])

tst = ['MIVLR', 'AAR']
cor = ['M I V L R', 'A A R']

test_eq(pipe(tst), cor)

Now that we have a space-delimited AA string, it needs to pass through the HuggingFace `tokenizer`.

In [7]:
from transformers import AutoTokenizer

model_name = 'Rostlab/prot_bert'

tokenizer = AutoTokenizer.from_pretrained(model_name)

This is a wrapper for an HF tokenizer that can process the sequences into integer tokens and attention masks.

In [8]:
#export

class HFTokenizerWrapper(Transform):
    
    def __init__(self, tokenizer, tokens_only = True, 
                 truncation = True, max_length = 128,
                 padding = 'max_length', 
                 skip_special_tokens = True,
                 device = 'cuda'):
        self.tokenizer = tokenizer
        self.tokens_only = tokens_only
        self.truncation = truncation
        self.max_length = max_length
        self.padding = padding
        self.skip_special_tokens = skip_special_tokens
        self.device = device
        
    def encodes(self, x):
        
        if type(x) == str:
            x = [x]
            
        tokenized = self.tokenizer(list(x), 
                                   return_tensors='pt', 
                                   padding=self.padding,
                                   truncation = self.truncation,
                                   max_length = self.max_length)
        tokenized = tokenized.to(self.device)
        
        
        
        if self.tokens_only:
            return tokenized['input_ids']
        else:
            return [fastuple(tokenized['input_ids'][i], tokenized['attention_mask'][i]) for i in range(len(x))]
        
        
    def decodes(self, x):
        
        return self.tokenizer.batch_decode(x, skip_special_tokens = self.skip_special_tokens)

In [9]:
token_tfm = HFTokenizerWrapper(tokenizer, max_length=7, device = 'cpu')

test_eq(token_tfm('M I V L R'), tensor([[ 2, 21, 11,  8,  5, 13,  3]]))
test_eq(token_tfm(['M I V L R']), tensor([[ 2, 21, 11,  8,  5, 13,  3]]))

In [10]:
test_eq(token_tfm.decode([[ 2, 21, 11,  8,  5, 13,  3]]), ['M I V L R'])

In [11]:
token_tfm = HFTokenizerWrapper(tokenizer, max_length=7, device = 'cpu', tokens_only = False)


test_eq(token_tfm('M I V L R'), 
        [fastuple(tensor([ 2, 21, 11,  8,  5, 13,  3]), tensor([ 1, 1, 1,  1,  1, 1,  1]))])

test_eq(token_tfm(['M I V L']), 
        [fastuple(tensor([ 2, 21, 11,  8,  5, 3,  0]), tensor([ 1, 1, 1,  1,  1, 1,  0]))])


In [12]:
space_tfm = SpaceTransform()
token_tfm = HFTokenizerWrapper(tokenizer, max_length=7, device = 'cpu')
pipe = Pipeline([space_tfm, token_tfm])

tst = ['MIVLR', 'AAR']
cor = [[2, 21, 11,  8, 5, 13, 3], 
       [2,  6,  6, 13, 3,  0, 0]]

test_eq(pipe(tst), tensor(cor))

In [13]:
test_eq(pipe.decode(tensor(cor)), tst)

### Pretraining Transforms

Sometimes instead of using the ProtBert model across all of your sequences everytime, you want to pre-process them using a `Transform`. This can DRASTICALLY speed up analysis if you never intend to train the encoder.

In [14]:
# export

from justenough.models.glp import model_mask_pooling

class HFPoolingTransform(Transform):
    
    def __init__(self, model, bs = 32):
        
        self.model = model
        self.bs = bs
    
    def encodes(self, x):
        
        if type(x[0]) == fastuple:
            input_ids, attention = zip(*x)
            input_ids = torch.vstack(input_ids)
            attention = torch.vstack(attention).type(torch.bool)
        else:
            input_ids = x
            attention = x == 0
            
        out = model_mask_pooling(input_ids, attention, self.model, bs = self.bs)
        #print(out.shape)
        if out.shape[0] == 1:
            out = torch.squeeze(out, 0)
        
        return out
        

In [15]:
from transformers import AutoModel

model_name = 'Rostlab/prot_bert'
model = AutoModel.from_pretrained(model_name)

In [16]:
token_tfm = HFTokenizerWrapper(tokenizer, max_length=6, tokens_only=True, device = 'cpu')
bert_pool_tfm = HFPoolingTransform(model)
pipe = Pipeline([space_tfm, token_tfm, bert_pool_tfm])

encoded = pipe(tst*100)
test_eq(encoded.shape, (200, 3072))

In [17]:
token_tfm = HFTokenizerWrapper(tokenizer, max_length=6, tokens_only=False, device = 'cpu')
bert_pool_tfm = HFPoolingTransform(model)
pipe = Pipeline([space_tfm, token_tfm, bert_pool_tfm])

encoded = pipe(tst*100)
test_eq(encoded.shape, (200, 3072))