In [1]:
# default_exp causalinference

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

# AutoCoder

> Automatically codes text fields such as open-ended survey questions.

In [3]:
#hide
from nbdev.showdoc import *

In [4]:
#export

import math
import warnings
import numpy as np

def list2chunks(a, n):
    k, m = divmod(len(a), n)
    return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))

In [5]:
#export

class ZeroShotClassifier():
    """
    interface to Zero Shot Topic Classifier
    """

    def __init__(self, model_name='facebook/bart-large-mnli', device=None):
        """
        ZeroShotClassifier constructor

        Args:
          model_name(str): name of a BART NLI model
          device(str): device to use (e.g., 'cuda', 'cpu')
        """
        if 'mnli' not in model_name and 'xnli' not in model_name:
            raise ValueError('ZeroShotClasifier requires an MNLI or XNLI model')
        try:
            import torch
        except ImportError:
            raise Exception('ZeroShotClassifier requires PyTorch to be installed.')
        self.torch_device = device
        if self.torch_device is None: self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
        from transformers import AutoModelForSequenceClassification, AutoTokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(self.torch_device)


    def predict(self, docs, labels=[], include_labels=False, multilabel=True,
               max_length=512, batch_size=8, nli_template='This text is about {}.',  topic_strings=[]):
        """
        This method performs zero-shot text classification using Natural Language Inference (NLI).
        Args:
          docs(list|str): text of document or list of texts
          labels(list): a list of strings representing topics of your choice
                        Example:
                          labels=['political science', 'sports', 'science']
          include_labels(bool): If True, will return topic labels along with topic probabilities
          multilabel(bool): If True, labels are considered independent and multiple labels can predicted true for document and be close to 1.
                            If False, scores are normalized such that probabilities sum to 1.
          max_length(int): truncate long documents to this many tokens
          batch_size(int): batch_size to use. default:8
                           Increase this value to speed up predictions - especially
                           if len(topic_strings) is large.
          nli_template(str): labels are inserted into this template for use as hypotheses in natural language inference
          topic_strings(list): alias for labels parameter for backwards compatibility
        Returns:
          inferred probabilities or list of inferred probabilities if doc is list
        """

        # error checks
        is_str_input = False
        if not isinstance(docs, (list, np.ndarray)): 
            docs = [docs]
            is_str_input = True
        if not isinstance(docs[0], str): raise ValueError('docs must be string or a list of strings representing document(s)')
        if len(labels) > 0 and len(topic_strings) > 0: raise ValueError('labels and topic_strings are mutually exclusive')
        if not labels and not topic_strings: raise ValueError('labels must be a list of strings')
        if topic_strings: 
            labels = topic_strings


        # convert to sequences
        sequence_pairs = []
        for premise in docs:
            sequence_pairs.extend([[premise, nli_template.format(label)] for label in labels])
        if batch_size  > len(sequence_pairs): batch_size = len(sequence_pairs)
        if len(sequence_pairs) >= 100 and batch_size==8:
            warnings.warn('TIP: Try increasing batch_size to speedup ZeroShotClassifier predictions')
        num_chunks = math.ceil(len(sequence_pairs)/batch_size)
        sequence_chunks = list2chunks(sequence_pairs, n=num_chunks)

        # inference
        import torch
        with torch.no_grad():
            outputs = []
            for sequences in sequence_chunks:
                batch = self.tokenizer.batch_encode_plus(sequences, return_tensors='pt', max_length=max_length, truncation='only_first', padding=True).to(self.torch_device)
                logits = self.model(batch['input_ids'], attention_mask=batch['attention_mask'], return_dict=False)[0]
                outputs.extend(logits.cpu().detach().numpy())
        outputs = np.array(outputs)
        outputs = outputs.reshape((len(docs), len(labels), -1))

        # process outputs
        if multilabel:
            # softmax over the entailment vs. contradiction dim for each label independently
            entail_contr_logits = outputs[..., [0, -1]]
            scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
            scores = scores[..., 1]
        else:
            # softmax the "entailment" logits over all candidate labels
            entail_logits = outputs[..., -1]
            scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)
        scores = scores.tolist()
        if include_labels:
            scores = [list(zip(labels, s)) for s in scores]
        if is_str_input: scores = scores[0]
        return scores



In [6]:
zsl = ZeroShotClassifier()
labels=['politics', 'elections', 'sports', 'films', 'television']
doc = 'I am extremely dissatisfied with the President and will definitely vote in 2020.'
preds = zsl.predict(doc, labels=labels, include_labels=True)

In [7]:
preds

[('politics', 0.979189932346344),
 ('elections', 0.9874580502510071),
 ('sports', 0.0005765454261563718),
 ('films', 0.002292441902682185),
 ('television', 0.001054605352692306)]

In [8]:
d = dict(preds)
assert d['politics'] > 0.9
assert d['elections'] > 0.9
assert d['sports'] < 0.1
assert d['films'] < 0.1
assert d['television'] < 0.1

In [9]:
#export
class AutoCoder:
    """
    Autocodes text fields
    """
    def __init__(self, verbose=1):
        """
        constructor
        """
        self.v = verbose
        self.zsl = ZeroShotClassifier()


           

    def sentiment(self, texts, batch_size=8):
        """
        Autocodes text for positive or negative sentiment
        """
        if isinstance(texts, str): texts = [texts]
        
        if not isinstance(texts, list): raise ValueError('texts must be a string or a list of strings')
        
        return zsl.predict(texts, labels=['negative', 'positive'], include_labels=True, multilabel=False,
                           batch_size=batch_size,
                           nli_template="The sentiment of this movie review is {}.")
    
    def custom_topic(self, texts, labels, batch_size=8):
        """
        Autocodes text for user-specified topics.
        The `label` field is the name of the topic as a string (or a list of them.)
        """
        if isinstance(texts, str): texts = [texts]   
        if not isinstance(texts, list): raise ValueError('texts must be a string or a list of strings')
            
        return zsl.predict(texts, labels=labels, include_labels=True, batch_size=8)


In [10]:
ac = AutoCoder()
reviews = ["I loved this doctor!", "This doctor was absolutely terrible."]
result = ac.sentiment(reviews)
result

[[('negative', 0.005033864174038172), ('positive', 0.9949660897254944)],
 [('negative', 0.9817894101142883), ('positive', 0.018210623413324356)]]

In [11]:
d = dict(result[0])
assert d['negative'] < 0.1
assert d['positive'] > 0.9

In [12]:
comment = "What is your favorite sitcom of all time?"
result = ac.custom_topic(comment, labels=['television', 'film', 'politics'])
result

[[('television', 0.9813268780708313),
  ('film', 0.012259923852980137),
  ('politics', 0.0001566773426020518)]]

In [13]:
d = dict(result[0])
assert d['television'] > 0.9
assert d['film'] < 0.1
assert d['politics'] < 0.1

In [14]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_causalinference.ipynb.
Converted 01_autocoder.ipynb.
Converted index.ipynb.
