In [4]:
# default_exp transforms.dataloaders

%reload_ext autoreload
%autoreload 2

In [5]:
import sys
sys.path.append('../')

# transforms.dataloaders

Dataloader objects to get sequence and DataFrames into Fast.ai datasets.

In [25]:
# export 

from fastai.text.all import *


from justenough.transforms.sequence import *

This dataloader is useful for turning sequences into vectors using HuggingFace BERT encoders as preprocessing.

Data can either be pre-processed ahead of time or processed on the fly.

In [26]:

# export


class HFBertDataLoaders(DataLoaders):
    
    @staticmethod
    def from_df(frame, tokenizer, model, sequence_col = 'sequence', label_col = None, vocab=None,
                max_length = 128, device = 'cuda', bs = 32, precompute = True,
                splitter = None, num_workers = 0):
        
        if splitter is None:
            splitter = RandomSplitter()
            
            
        seq_tfms = [ColReader(sequence_col),
                    SpaceTransform(),
                    HFTokenizerWrapper(tokenizer, 
                                       max_length=max_length, 
                                       tokens_only=False, 
                                       device = device),
                    HFPoolingTransform(model, bs=bs)]
        if label_col is None:
            label_tfms = seq_tfms
        else:
            label_tfms = [ColReader(label_col), Categorize(vocab=vocab)]
            
        
        if precompute:
            
            seq_pipe = Pipeline(seq_tfms)
            seq_tls = seq_pipe(frame)
            
            if label_col is None:
                label_tls = seq_tls
            else:
                label_tls = TfmdLists(frame, label_tfms)
                
            tls = TfmdLists(zip(seq_tls, label_tls), [])
            train, test = splitter(tls)
            
            return DataLoaders.from_dsets(tls[train], tls[test], num_workers=0, bs=bs).to(device)
            
            
        else:
            
            train, test = splitter(frame)
            feat_tls = Datasets(frame, [seq_tfms, label_tfms],
                               splits = (train, test))
            
            dls = feat_tls.dataloaders(num_workers=0, bs=bs).to(device)
            
            return dls

In [27]:
df = pd.read_csv('../tutorials/HIV_tat_example.csv').dropna(subset = ['sample_tissue']).head(100)
df.head()

Unnamed: 0,accession,sample_tissue,coreceptor,sequence
0,M17449,PBMC,CXCR4,MEPVDPRLEPWKHPGSQPKTACTTCYCKKCCFHCQVCFTKKALGISYGRKKRRQRRRAPEDSQTHQVSLPKQPAPQFRGDPTGPKESKKKVERETETHPVD*
1,M26727,PBMC,CCR5,MEPVDPRLEPWKHPGSQPKTASNNCYCKRCCLHCQVCFTKKGLGISYGRKKRRQRRRAPQDSKTHQVSLSKQPASQPRGDPTGPKESKKKVERETETDPED*
2,M17451,PBMC,CCR5|CXCR4,MEPVDPRLEPWKHPGSQPKTACNNCYCKKCCYHCQVCFLTKGLGISYGRKKRRQRRGPPQGSQTHQVSLSKQPTSQPRGDPTGPKESKEKVERETETDPAVQ
3,K02007,PBMC,CCR5|CXCR4,MEPVDPNLEPWKHPGSQPRTACNNCYCKKCCFHCYACFTRKGLGISYGRKKRRQRRRAPQDSQTHQASLSKQPASQSRGDPTGPTESKKKVERETETDPFD*
4,M62320,blood,,MEPVDPNLEPWKHPGSQPTTACSNCYCKVCCWHCQLCFLKKGLGISYGKKKRKPRRGPPQGSKDHQTLIPKQPLPQSQRVSAGQEESKKKVESKAKTDRFA*


In [9]:
from transformers import AutoTokenizer, AutoModel
model_name = 'Rostlab/prot_bert'
device = 'cpu'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)

In [28]:
dls = HFBertDataLoaders.from_df(df, tokenizer, model, 
                                sequence_col = 'sequence', label_col = 'sample_tissue', 
                                vocab=None,
                                max_length = 128, device = 'cpu',
                                bs = 32, precompute = True,
                                splitter = None, num_workers = 0)

In [29]:
x, y = dls.one_batch()
test_eq(x.shape, (32, 3072))
test_eq(y.shape, (32, ))

In [30]:
dls = HFBertDataLoaders.from_df(df, tokenizer, model, 
                                sequence_col = 'sequence', label_col = 'sample_tissue', 
                                vocab=None,
                                max_length = 128, device = 'cpu',
                                bs = 32, precompute = False,
                                splitter = None, num_workers = 0)

In [31]:
x, y = dls.one_batch()
test_eq(x.shape, (32, 3072))
test_eq(y.shape, (32, ))