In [None]:
# default_exp nlp.core

%reload_ext autoreload
%autoreload 2

# nlp.core

This module contains the core structures that are similar between NLP and GLP and helps to abstract similarities.
Switching between DNA, protein, and English is just a matter of switching the internal models and tokenizers.

In [1]:
#hide
#export

from itertools import islice
from Bio import Entrez
import pandas as pd
import numpy as np

from transformers import AutoModel, AutoTokenizer
from umap import UMAP
from fastai.text.all import *
import hdbscan

In [None]:
#export


def mean_pooling_attention(token_embeddings, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask



def masked_concat_pool(output, mask, bptt):
    "Pool `MultiBatchEncoder` outputs into one vector [last_hidden, max_pool, avg_pool]"
    lens = output.shape[1] - mask.long().sum(dim=1)
    last_lens = mask[:,-bptt:].long().sum(dim=1)
    avg_pool = output.masked_fill(mask[:, :, None], 0).sum(dim=1)
    avg_pool.div_(lens.type(avg_pool.dtype)[:,None])
    max_pool = output.masked_fill(mask[:,:,None], -float('inf')).max(dim=1)[0]
    
    last_hidden = output[torch.arange(0, output.size(0)),-last_lens-1]
    x = torch.cat([last_hidden, 
                   max_pool, avg_pool], 1) #Concat pooling.
    x = torch.where(torch.isnan(x) | torch.isinf(x), torch.zeros_like(x), x)

    return x

class PoolingLinearClassifier(Module):
    "Create a linear classifier with pooling"
    def __init__(self, dims, ps, bptt, y_range=None):
        if len(ps) != len(dims)-1: raise ValueError("Number of layers and dropout values do not match.")
        acts = [nn.ReLU(inplace=True)] * (len(dims) - 2) + [None]
        layers = [LinBnDrop(i, o, p=p, act=a) for i,o,p,a in zip(dims[:-1], dims[1:], ps, acts)]
        if y_range is not None: layers.append(SigmoidRange(*y_range))
        self.layers = nn.Sequential(*layers)
        self.bptt = bptt

    def forward(self, input):
        out,mask = input
        x = masked_concat_pool(out, mask, out.shape[1]-1)
        x = self.layers(x)
        return x, out, out


In [None]:
#export


class TopicModelingInterface(object):
    
    def __init__(self, tokenizer = None, model = None, model_name = None, bs=8,
                 cluster_dim = 10, viz_dim = 2, device = 'cuda',
                 min_cluster_size = 5, max_length = 512):
        
        if tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        else:
            self.tokenizer = tokenizer
        
        if model is None:
            self.model = AutoModel.from_pretrained(model_name).to(device)
        else:
            self.model = model
            
        self.bs = bs
        self.device = device
        
        self.max_length = max_length
        self.viz_dim = viz_dim
        self.cluster_dim = cluster_dim
        
        self.umap_cluster = UMAP(n_components=cluster_dim)
        self.cluster = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, prediction_data=True)
        self.umap_viz = UMAP(n_components=2)
        
            
            
    def text2embed(self, text, bs=None):
        if type(text) is str: return self.text2embed([text])[0]
        
        bs = self.bs if bs is None else bs
        
        
        it = iter(text)
        
        out_data = []
        
        with torch.no_grad():
            batch = list(islice(it, bs))
            while batch:
                tokens = self.tokenizer(batch, return_tensors='pt', padding='max_length',
                                        truncation = True,
                                        max_length = self.max_length)
                tokens.to(self.device)
                res = self.model(**tokens)
                #print(tokens['attention_mask'])
                out_data.append(masked_concat_pool(res[0], tokens['attention_mask'].type(torch.bool), self.max_length))
                
                batch = list(islice(it, bs))
                
        return torch.vstack(out_data)
    
    
    def embed2cluster(self, embed, fit = True):
        
        if fit:
            clst_data = self.umap_cluster.fit_transform(embed)
            self.cluster.fit(clst_data)
            labels = self.cluster.labels_
        else:
            clst_data = self.umap_cluster.transform(embed)
            labels, _ = hdbscan.approximate_predict(self.cluster, clst_data)
            
        return labels, clst_data
    
    
    def embed2xy(self, embed, fit = True):
        
        if fit:
            xy = self.umap_viz.fit_transform(embed)
        else:
            xy = self.umap_cluster.transform(embed)
        return xy
    
        
    
    def process_df(self, df, col = 'text', fit = True):
        
        
        emb = self.text2embed(df[col].fillna('').tolist())
        
        clusters, cluster_data = self.embed2cluster(emb.cpu().numpy(), fit = fit)
        xy = self.embed2xy(emb.cpu().numpy(), fit = fit)
        
        ndf = pd.DataFrame({'cluster': clusters,
                            'X': xy[:, 0],
                            'Y': xy[:, 1],
                            'label': [str(c) for c in clusters]}, index = df.index)
        for n in range(self.cluster_dim):
            ndf[f'd{n}'] = cluster_data[:, n]
        
        return ndf, emb