# pipelines.classifier

> Wrappers for different approaches to text classification, including scikit-learn text classification, text classification with Hugging Face Transformers, and few-shot classification (via SetFit).

In [None]:
# | default_exp pipelines.classifier

In [None]:
# | hide
from nbdev.showdoc import *

In [None]:
# | export

from typing import List, Union

import warnings

with warnings.catch_warnings():
    warnings.filterwarnings("ignore",category=DeprecationWarning)
    from setfit import SetFitModel, TrainingArguments, Trainer, sample_dataset

import numpy as np

In [None]:
# | export

from abc import ABC, abstractmethod

DATASET_TEXT = "text"
DATASET_LABEL = "label"

class ClassifierBase(ABC):

  
    @abstractmethod
    def train(self,
              X:List[str],
              y:Union[List[int], List[str]],
              max_steps:int=50,
              num_epochs:int=10,
              batch_size:int=32,
              metric='accuracy',
              callbacks=None,
              **kwargs,
             ):
        """
        Trains the classifier on a list of texts (`X`) and a list of labels (`y`).
        Additional keyword arguments are passed directly to `SetFit.TrainingArguments`.

        **Args:**

        - *X*: List of texts
        - *y*: List of integers representing labels
        - *max_steps*: If set to a positive number, the total number of training steps to perform. Overrides num_epochs. 
        - *num_epochs*: Number of epochs to train
        - *batch_size*: Batch size
        - *metric*: metric to use
        - *callbacks*:  A list of callbacks to customize the training loop.

        **Returns:**

        - None
        """
        pass

    @abstractmethod
    def save(self, save_path:str):
        """
        Save model to specified folder path, `save_path`
        """
        pass

    
    def get_explain_predictor(self):
        """
        Get predictor and tokenizer used for shap predictions.
        """
        return None, None

    
    def get_labels(self):
        return self.model.labels if hasattr(self.model, 'labels') else self.labels

        
    def sample_examples(self, X:list, y:list, num_samples:int=8,
                        text_key:str=DATASET_TEXT, label_key:str=DATASET_LABEL):
        """
        Sample a dataset with `num_samples` per class
        """
        full_dataset = self.arrays2dataset(X, y, text_key=text_key, label_key=label_key)
                
        sample = sample_dataset(full_dataset, label_column=label_key, num_samples=num_samples)
        return sample.to_dict()[text_key], sample.to_dict()[label_key]
        

    
    def arrays2dataset(self, X:List[str], y:Union[List[int], List[str]], 
                       text_key:str=DATASET_TEXT, label_key:str=DATASET_LABEL):
        """
        Convert train or test examples to HF dataset
        """
        from datasets import Dataset
        return Dataset.from_dict({text_key:X, label_key:y})


    def dataset2arrays(self, dataset, text_key:str=DATASET_TEXT, label_key:str=DATASET_LABEL):
        """
        Convert a Hugging Face dataset to X, y arrays
        """
        return dataset.to_dict()['text'], dataset.to_dict()['label']
        
    def get_trainer(self):
        """
        Retrieves last trainer
        """
        if not self.trainer:
            raise ValueError('A trainer has not been created yet. You must first train a model on some labeled examples ' +\
                             'using the FewShotClassifier.train method.')
        return self.trainer

    
    def evaluate(self, X_eval:list, y_eval:list, print_report:bool=True, labels:list=[], **kwargs):
        """
        Evaluates labeled data using the trained model. 
        If `print_report` is True, prints classification report and returns nothing.
        Otherwise, returns and prints a dictionary of the results.
        Extra kwargs fed to `self.predict`.

        
        """
        labels = labels if labels else self.get_labels()
        labels = labels if labels else None
        
        from sklearn.metrics import classification_report
        y_pred= self.predict(X_eval, **kwargs)
        if isinstance(y_pred[0], str) and isinstance(y_eval[0], (np.integer, int)) and labels:
            y_eval = [labels[y] for y in y_eval]
        if isinstance(y_eval[0], str) and isinstance(y_pred[0], (np.integer, int)) and labels:
            y_pred = [labels[y] for y in y_pred]

        result = classification_report(y_eval, y_pred, 
                                       output_dict=not print_report,
                                       target_names = labels)
        if print_report:
            print(result)
            return
        else:
            import yaml
            #print(yaml.dump(result, allow_unicode=True, default_flow_style=False))
            return result


    def explain(self, X:list, labels:list=[]):
        """
        Explain the predictions on given examples in `X`. (Requires `shap` and `matplotlib` to be installed.)
        """
        X = [X] if isinstance(X, str) else X
        output_names = labels if labels else self.get_labels()
        output_names = output_names if output_names else None
        try:
            import shap
        except ImportError:
            raise ImportError('Please install the shap library: pip install shap')

        try:
            import matplotlib
        except ImportError:
            raise ImportError('Please install the matplotlib library: pip install matplotlib')

        f, tokenizer = self._get_explain_predictor()
        if f is None:
            raise NotImplementedError('Exlpanations are not currently supported for this model.')
        explainer = shap.Explainer(f, tokenizer, output_names=output_names)
        shap_values = explainer(X)
        shap.plots.text(shap_values)

In [None]:
show_doc(ClassifierBase.arrays2dataset)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/pipelines/classifier.py#L91){target="_blank" style="float:right; font-size:smaller"}

### ClassifierBase.arrays2dataset

>      ClassifierBase.arrays2dataset (X:List[str], y:Union[List[int],List[str]],
>                                     text_key:str='text',
>                                     label_key:str='label')

*Convert train or test examples to HF dataset*

In [None]:
show_doc(ClassifierBase.dataset2arrays)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/pipelines/classifier.py#L100){target="_blank" style="float:right; font-size:smaller"}

### ClassifierBase.dataset2arrays

>      ClassifierBase.dataset2arrays (dataset, text_key:str='text',
>                                     label_key:str='label')

*Convert a Hugging Face dataset to X, y arrays*

In [None]:
show_doc(ClassifierBase.evaluate)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/pipelines/classifier.py#L116){target="_blank" style="float:right; font-size:smaller"}

### ClassifierBase.evaluate

>      ClassifierBase.evaluate (X_eval:list, y_eval:list,
>                               print_report:bool=True, labels:list=[],
>                               **kwargs)

*Evaluates labeled data using the trained model. 
If `print_report` is True, prints classification report and returns nothing.
Otherwise, returns and prints a dictionary of the results.
Extra kwargs fed to `self.predict`.*

In [None]:
show_doc(ClassifierBase.explain)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/pipelines/classifier.py#L147){target="_blank" style="float:right; font-size:smaller"}

### ClassifierBase.explain

>      ClassifierBase.explain (X:list, labels:list=[])

*Explain the predictions on given examples in `X`. (Requires `shap` and `matplotlib` to be installed.)*

In [None]:
show_doc(ClassifierBase.sample_examples)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/pipelines/classifier.py#L79){target="_blank" style="float:right; font-size:smaller"}

### ClassifierBase.sample_examples

>      ClassifierBase.sample_examples (X:list, y:list, num_samples:int=8,
>                                      text_key:str='text',
>                                      label_key:str='label')

*Sample a dataset with `num_samples` per class*

In [None]:
# | export

class SKClassifier(ClassifierBase):
    def __init__(
        self,
        model_path=None,
        labels = [],
        **kwargs,
    ):
        """
        `SKClassifier` is a wrapper to scikit-learn text classifiation models.
        Extra kwargs are fed directly to `onprem.sk.clf.Classifier.create_model`.
        If no arguments are supplied, then a default Logistic Regression model is used.

        **Args:**

        - *model_path*: path to an already saved model file to be reloaded
        - *labels*: 
        - *labels*: list of strings intended to map label indices to string labels
        """

        from onprem.sk.clf import Classifier
        self.model = Classifier()
        self.labels = labels

        if model_path:
            self.model.load(model_path)
        else:
            # set defaults if necessary        
            if 'ctype' not in kwargs: 
                kwargs['ctype'] = 'sgdclassifier'
                kwargs['clf__loss'] = 'modified_huber'
                kwargs['clf__penalty'] = 'l2'
                kwargs['clf__alpha'] = 1e-3
                kwargs['clf__max_iter'] = 5
                kwargs['clf__tol'] = None
                kwargs['clf__random_state'] = 42
                kwargs['vec__token_pattern'] = '(?u)\\b\\w\\w+\\b'
                kwargs['use_tfidf'] = False

                # replace TF-IDF with this for slight accuracy boost
                # in relevant examples
                kwargs['vec__ngram_range'] = (1,3)
                kwargs['vec__binary'] = True
                kwargs['vec__max_features'] = 100000

            self.model.create_model(**kwargs) 

        
    def predict(self, X, **kwargs):
        """
        predict labels
        """
        labels = self.get_labels()
        preds = self.model.predict(X, **kwargs)
        preds = [preds] if not isinstance(preds, (list, np.ndarray)) else preds
        preds =  [labels[p] for p in preds] if labels else preds
        return preds[0] if len(preds) == 1 else preds

    
    def predict_proba(self, X, **kwargs):
        """
        predict label probabilities
        """
        return self.model.predict_proba(X, **kwargs)
    
    def train(self,
              X:List[str],
              y:Union[List[int], List[str]],
              **kwargs,
             ):
        """
        Trains the classifier on a list of texts (`X`) and a list of labels (`y`).
        Additional keyword arguments are passed directly to `self.model.fit`.

        **Args:**

        - *X*: List of texts
        - *y*: List representing labels

        **Returns:**

        - None
        """

        self.model.fit(X, y)                                                                                                                   
  

    def save(self, filename:str):
        """
        Save model to specified `filename` (e.g., `/tmp/mymodel.gz`).
        Model saved as pickle file.
        To reload the model, supply `model_path` when instantiating`SKClassifier`.
        """
        self.model.save(filename)    

        

In [None]:
show_doc(SKClassifier.train)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/pipelines/classifier.py#L236){target="_blank" style="float:right; font-size:smaller"}

### SKClassifier.train

>      SKClassifier.train (X:List[str], y:Union[List[int],List[str]], **kwargs)

*Trains the classifier on a list of texts (`X`) and a list of labels (`y`).
Additional keyword arguments are passed directly to `self.model.fit`.

**Args:**

- *X*: List of texts
- *y*: List representing labels

**Returns:**

- None*

In [None]:
show_doc(SKClassifier.predict)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/pipelines/classifier.py#L219){target="_blank" style="float:right; font-size:smaller"}

### SKClassifier.predict

>      SKClassifier.predict (X, **kwargs)

*predict labels*

In [None]:
show_doc(SKClassifier.predict_proba)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/pipelines/classifier.py#L230){target="_blank" style="float:right; font-size:smaller"}

### SKClassifier.predict_proba

>      SKClassifier.predict_proba (X, **kwargs)

*predict label probabilities*

In [None]:
show_doc(SKClassifier.save)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/pipelines/classifier.py#L258){target="_blank" style="float:right; font-size:smaller"}

### SKClassifier.save

>      SKClassifier.save (filename:str)

*Save model to specified `filename` (e.g., `/tmp/mymodel.gz`).
Model saved as pickle file.
To reload the model, supply `model_path` when instantiating`SKClassifier`.*

### Example: Training Sckit-Learn Text Classification Models

In [None]:
# | notest

from sklearn.datasets import fetch_20newsgroups
from onprem.pipelines.classifier import SKClassifier


In [None]:
# | notest

categories = [
             "alt.atheism",
             "soc.religion.christian",
             "comp.graphics",
             "sci.med" ]

train_b = fetch_20newsgroups(
            subset="train", categories=categories, shuffle=True, random_state=42
)
test_b = fetch_20newsgroups(
subset="test", categories=categories, shuffle=True, random_state=42
)
x_train = train_b.data
y_train = train_b.target
x_test = test_b.data
y_test = test_b.target
classes = train_b.target_names

# y_test = [classes[y] for y in y_test]
# y_train = [classes[y] for y in y_train]

clf = SKClassifier(labels=classes)
clf.train(x_train, y_train)
test_doc1 = "Jesus Christ was a first century Jewish teacher and religious leader."
test_doc2 = "The graphics on my monitor are terrible."
print(clf.predict(test_doc1))
print(clf.predict([test_doc2]))
print(clf.predict([test_doc1, test_doc2]))
clf.evaluate(x_test, y_test)

soc.religion.christian
comp.graphics
['soc.religion.christian', 'comp.graphics']
                        precision    recall  f1-score   support

           alt.atheism       0.93      0.87      0.90       319
         comp.graphics       0.88      0.96      0.92       389
               sci.med       0.94      0.84      0.89       396
soc.religion.christian       0.91      0.96      0.94       398

              accuracy                           0.91      1502
             macro avg       0.91      0.91      0.91      1502
          weighted avg       0.91      0.91      0.91      1502



In [None]:
# | notest

clf.evaluate(x_test, y_test, print_report=False)['accuracy']

0.9114513981358189

In [None]:
# | notest

clf.save('/tmp/mymodel.gz') # save

In [None]:
# | notest

clf = SKClassifier(model_path='/tmp/mymodel.gz', labels=classes) # reload
clf.evaluate(x_test, y_test)

                        precision    recall  f1-score   support

           alt.atheism       0.93      0.87      0.90       319
         comp.graphics       0.88      0.96      0.92       389
               sci.med       0.94      0.84      0.89       396
soc.religion.christian       0.91      0.96      0.94       398

              accuracy                           0.91      1502
             macro avg       0.91      0.91      0.91      1502
          weighted avg       0.91      0.91      0.91      1502



In [None]:
# | export

from onprem.hf import HFTrainer
from transformers import pipeline
import numpy as np
import os.path


class HFClassifier(ClassifierBase):
    def __init__(
        self,
        model_id_or_path:str='google/bert_uncased_L-2_H-128_A-2',
        device=None,
        labels=[],
        **kwargs,
    ):
        """
        `HFClassifier` can be used to train and run text Hugging Face transformer text classifiers.
                Additional keyword arguments are fed directly to `from_pretrained`.


        **Args:**

        - *model_id_or_path*: The Hugging Face model_id or path to model folder (e.g, path previously trained and saved model).
        - *device*: 'cuda' or 'cpu'
        - *labels*: list of strings intended to map label indices to string labels

        """
        self.model_id_or_path = model_id_or_path
        self.device=device
        self.model = None
        self.tokenizer = None
        self.labels = labels
        if os.path.isdir(self.model_id_or_path):
            from transformers import AutoTokenizer
            from transformers import AutoModelForSequenceClassification 
            self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id_or_path)
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_id_or_path)

    def _tempsave(self):
        """
        Temporarily save model
        """
        import tempfile
        temp_dir = tempfile.TemporaryDirectory()
        if self.model:
            self.model.save_pretrained(temp_dir.name)
            self.tokenizer.save_pretrained(temp_dir.name)
        return temp_dir
    
    def train(self,
              X:List[str],
              y:Union[List[int], List[str]],
              **kwargs,
             ):
        """
        Trains the classifier on a list of texts (`X`) and a list of labels (`y`).
        Extra kwargs are treated as arguments to `transformers.TrainingArguments`.

        **Args:**

        - *X*: List of texts
        - *y*: List of integers representing labels
        - *num_epochs*: Number of epochs to train
        - *batch_size*: Batch size
        - *metric*: metric to use
        - *callbacks*:  A list of callbacks to customize the training loop.

        **Returns:**

        - None
        """

        temp_dir = self._tempsave()
              
        # convert to HFTrainer format
        data = [{'text': x, 'label':y[i]} for i, x in enumerate(X)]

        # create a trainer
        trainer = HFTrainer()

        # train
        training_path = self.model_id_or_path if not self.model else temp_dir.name
        self.model, self.tokenizer = trainer(training_path,  data, **kwargs)

        # cleanup
        temp_dir.cleanup()
  

    def save(self, save_path:str):
        """
        Save model to specified folder path, `save_path`.
        To reload the model, supply path in `model_id_or_path` argument when
        instantiating`FewShotClassifier`.

        """
        self.model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)


    def _get_pipeline(self):
        """
        Create transformers pipeline using current model
        """
        if not self.model:
            raise ValueError('There is not trained model yet. Please invoke the HFClassifier.train(...)')
        save_path = None
        if not os.path.isdir(self.model_id_or_path):
            temp_dir = self._tempsave()
            modelpath = temp_dir.name
        else:
            modelpath = self.model_id_or_path
        clf = pipeline('text-classification', model=modelpath, device=self.device)
        return clf
        
    def _predict(self, X:list, return_proba=True, max_length=512, **kwargs):
        """
        Predicts labels using `transformers.pipeline`
        """
        from operator import itemgetter
        X = [X] if isinstance(X, str) else X
        clf = self._get_pipeline(**kwargs)
        preds = []
        labels = self.get_labels()
        for example in clf(X, top_k=None, truncation=True, max_length=max_length):
            # unlike deprecated return_all_scores, top_k re-sorts
            example = sorted(example, key=itemgetter('label'))
            pred = [d['score'] for d in example]
            pred = np.argmax(pred) if not return_proba else pred
            pred = labels[pred] if labels else pred
            preds.append(pred)
        if len(preds) == 1:
            preds = preds[0]
        return preds

    def predict(self, X:list, max_length=512, **kwargs):
        """
        Predict labels. 
        Extra kwargs fed to Hugging Face transformers text-classification pipeline.
        """
        return self._predict(X, max_length=max_length, return_proba=False)
        
    def predict_proba(self, X:list, max_length=512, **wargs):
        """
        Predict labels.
        Extra kwargs fed to Hugging Face transformers text-classification pipeline.
        """
        return self._predict(X, max_length=max_length, return_proba=True)

    
    def _get_explain_predictor(self, device=None):
        return self._get_pipeline(), None
               


In [None]:
show_doc(HFClassifier.train)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/pipelines/classifier.py#L317){target="_blank" style="float:right; font-size:smaller"}

### HFClassifier.train

>      HFClassifier.train (X:List[str], y:Union[List[int],List[str]], **kwargs)

*Trains the classifier on a list of texts (`X`) and a list of labels (`y`).
Extra kwargs are treated as arguments to `transformers.TrainingArguments`.

**Args:**

- *X*: List of texts
- *y*: List of integers representing labels
- *num_epochs*: Number of epochs to train
- *batch_size*: Batch size
- *metric*: metric to use
- *callbacks*:  A list of callbacks to customize the training loop.

**Returns:**

- None*

In [None]:
show_doc(HFClassifier.predict)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/pipelines/classifier.py#L402){target="_blank" style="float:right; font-size:smaller"}

### HFClassifier.predict

>      HFClassifier.predict (X:list, max_length=512, **kwargs)

*Predict labels. 
Extra kwargs fed to Hugging Face transformers text-classification pipeline.*

In [None]:
show_doc(HFClassifier.predict_proba)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/pipelines/classifier.py#L409){target="_blank" style="float:right; font-size:smaller"}

### HFClassifier.predict_proba

>      HFClassifier.predict_proba (X:list, max_length=512, **wargs)

*Predict labels.
Extra kwargs fed to Hugging Face transformers text-classification pipeline.*

The default model is a tiny BERT model (i.e., ``google/bert_uncased_L-2_H-128_A-2`), but we will use a larger model here to improve accuracy (e.g., [distilbert](https://huggingface.co/distilbert/distilbert-base-uncased)).

### Example: Training Hugging Face Transformer Models

In [None]:
# | export

from sklearn.datasets import fetch_20newsgroups
from onprem.pipelines.classifier import HFClassifier

In [None]:
# | notest

categories = [
             "alt.atheism",
             "soc.religion.christian",
             "comp.graphics",
             "sci.med" ]

train_b = fetch_20newsgroups(
            subset="train", categories=categories, shuffle=True, random_state=42
)
test_b = fetch_20newsgroups(
subset="test", categories=categories, shuffle=True, random_state=42
)
x_train = train_b.data
y_train = train_b.target
x_test = test_b.data
y_test = test_b.target
classes = train_b.target_names

clf = HFClassifier(model_id_or_path='distilbert/distilbert-base-uncased', 
                   device='cuda', labels=classes)
clf.train(x_train, y_train, num_train_epochs=1, per_device_train_batch_size=8)
test_doc1 = "Jesus Christ was a first century Jewish teacher and religious leader."
test_doc2 = "The graphics on my monitor are terrible."
print(clf.predict(test_doc1))
print(clf.predict([test_doc2]))
print(clf.predict([test_doc1, test_doc2]))
clf.evaluate(x_test, y_test)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss


soc.religion.christian
comp.graphics
['soc.religion.christian', 'comp.graphics']
                        precision    recall  f1-score   support

           alt.atheism       0.89      0.88      0.89       319
         comp.graphics       0.97      0.98      0.97       389
               sci.med       0.97      0.95      0.96       396
soc.religion.christian       0.94      0.96      0.95       398

              accuracy                           0.95      1502
             macro avg       0.94      0.94      0.94      1502
          weighted avg       0.95      0.95      0.95      1502



In [None]:
# | notest

clf.explain(test_doc1)

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


In [None]:
# | notest

clf.save('/tmp/my_hf_model')

In [None]:
# | notest


clf = HFClassifier('/tmp/my_hf_model', device='cuda', labels=classes)
clf.evaluate(x_test,  y_test, print_report=False)['accuracy']

0.9454061251664447

In [None]:
# | export

DEFAULT_SETFIT_MODEL = "sentence-transformers/paraphrase-mpnet-base-v2"
SMALL_SETFIT_MODEL = "sentence-transformers/all-MiniLM-L6-v2"

class FewShotClassifier(ClassifierBase):
    def __init__(
        self,
        model_id_or_path:str=DEFAULT_SETFIT_MODEL,
        use_smaller:bool=False,
        **kwargs,
    ):
        """
        `FewShotClassifier` can be used to train and run text classifiers. Currently based on SetFit.
                Additional keyword arguments are fed directly to `from_pretrained`.


        **Args:**

        - *model_id_or_path*: The Hugging Face model_id or path to model folder (e.g, path previously trained and saved model).
        - *use_smaller*:  If True, will use a smaller but performant model.

        """
        self.model_id_or_path = model_id_or_path
        if use_smaller and model_id_or_path != DEFAULT_SETFIT_MODEL:
            warnings.warn(f'Over-writing supplied model ({model_id_or_path}) with {SMALL_MODEL} because use_smaller=True.')
        self.model_id_or_path = SMALL_SETFIT_MODEL if use_smaller else self.model_id_or_path
        self.model = SetFitModel.from_pretrained(self.model_id_or_path, **kwargs)
        self.predict = self.model.predict
        self.predict_proba = self.model.predict_proba
        self.trainer = None # set in `FewShotClassifier.train`
        self.labels = [] # set in `FewShotClassifier.train`
            
    def train(self,
              X:List[str],
              y:Union[List[int], List[str]],
              num_epochs:int=10,
              batch_size:int=32,
              metric='accuracy',
              callbacks=None,
              **kwargs,
             ):
        """
        Trains the classifier on a list of texts (`X`) and a list of labels (`y`).
        Additional keyword arguments are passed directly to `SetFit.TrainingArguments`

        **Args:**

        - *X*: List of texts
        - *y*: List of integers representing labels
        - *num_epochs*: Number of epochs to train
        - *batch_size*: Batch size
        - *metric*: metric to use
        - *callbacks*:  A list of callbacks to customize the training loop.

        **Returns:**

        - None
        """

        # convert to HF dataset
        train_dataset = self.arrays2dataset(X, y, text_key='text', label_key='label')

        args = TrainingArguments(
                batch_size=batch_size,
                num_epochs=num_epochs,
                **kwargs
        )      

        trainer = Trainer(
                    model=self.model,
                    args=args,
                    metric=metric,
                    callbacks=callbacks,
                    train_dataset=train_dataset,
                    column_mapping={"text": "text", "label": "label"}
        )
        trainer.train()
      
        self.trainer = trainer
  

    def predict(self, X, **kwargs):
        """
        predict labels
        """
        return self.predict(X, **kwargs)


    def predict_proba(self, X, **kwargs):
        """
        predict label probabilities
        """
        return self.predict_proba(X, **kwargs)


    def save(self, save_path:str):
        """
        Save model to specified folder path, `save_path`.
        To reload the model, supply path in `model_id_or_path` argument when
        instantiating`FewShotClassifier`.

        """
        self.model.save_pretrained(save_path)        

    def _get_explain_predictor(self, device=None):
        def f(x):
            return self.predict_proba(x)

        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(self.model_id_or_path)
        return f, tokenizer

In [None]:
show_doc(FewShotClassifier.train)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/pipelines/classifier.py#L437){target="_blank" style="float:right; font-size:smaller"}

### FewShotClassifier.train

>      FewShotClassifier.train (X:List[str], y:Union[List[int],List[str]],
>                               num_epochs:int=10, batch_size:int=32,
>                               metric='accuracy', callbacks=None, **kwargs)

*Trains the classifier on a list of texts (`X`) and a list of labels (`y`).
Additional keyword arguments are passed directly to `SetFit.TrainingArguments`

**Args:**

- *X*: List of texts
- *y*: List of integers representing labels
- *num_epochs*: Number of epochs to train
- *batch_size*: Batch size
- *metric*: metric to use
- *callbacks*:  A list of callbacks to customize the training loop.

**Returns:**

- None*

In [None]:
show_doc(FewShotClassifier.predict)

---

### FewShotClassifier.predict

>      FewShotClassifier.predict (X, **kwargs)

*predict labels*

In [None]:
show_doc(FewShotClassifier.predict_proba)

---

### FewShotClassifier.predict_proba

>      FewShotClassifier.predict_proba (X, **kwargs)

*predict label probabilities*

In [None]:
show_doc(FewShotClassifier.save)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/pipelines/classifier.py#L486){target="_blank" style="float:right; font-size:smaller"}

### FewShotClassifier.save

>      FewShotClassifier.save (save_path:str)

*Save model to specified folder path, `save_path`.
To reload the model, supply path in `model_id_or_path` argument when
instantiating`FewShotClassifier`.*

### Example: Training Few-Shot Text Classifiers

In [None]:
# | notest

clf = FewShotClassifier(labels=['negative', 'positive'])

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


In [None]:
# | notest

from datasets import load_dataset

Sample a tiny dataset with only 8 examples per class (or 16 total examples):

In [None]:
# | notest
dataset = load_dataset("SetFit/sst2")
X_train, y_train = clf.dataset2arrays(dataset["train"], text_key="text", label_key="label")
X_test, y_test = clf.dataset2arrays(dataset["test"], text_key="text", label_key="label")
X_sample, y_sample = clf.sample_examples(X_train,  y_train, label_key="label", num_samples=8)

Repo card metadata block was not found. Setting CardData to empty.


In [None]:
# | notest
clf.train(X_sample,  y_sample, max_steps=50)

Applying column mapping to the training dataset


Map:   0%|          | 0/16 [00:00<?, ? examples/s]

***** Running training *****
  Num unique pairs = 144
  Batch size = 32
  Num epochs = 10


Step,Training Loss
1,0.2427
50,0.0473


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [None]:
# | notest
clf.evaluate(X_test, y_test)

              precision    recall  f1-score   support

    negative       0.88      0.94      0.91       912
    positive       0.93      0.87      0.90       909

    accuracy                           0.91      1821
   macro avg       0.91      0.91      0.91      1821
weighted avg       0.91      0.91      0.91      1821



In [None]:
# | notest
new_data = ["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"]

In [None]:
# | notest
preds = clf.predict(new_data)
preds

['positive', 'negative']

In [None]:
# | notest
preds = clf.predict_proba(new_data)
preds

tensor([[0.1657, 0.8343],
        [0.8551, 0.1449]], dtype=torch.float64)

In [None]:
# | notest
clf.save('/tmp/my_fewshot_model')

In [None]:
# | notest
clf = FewShotClassifier('/tmp/my_fewshot_model', labels=['negative', 'positive'])
preds = clf.predict(new_data)
preds

['positive', 'negative']

In [None]:
# | notest
clf.explain(new_data)

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()