

# Zero-shot text classification with SSTuing

First of all, install the dependencies

In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.29.2-py3-none-any.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m55.0 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m40.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.14.1 tokenizers-0.13.3 transformers-4.29.2


In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch, string, random

In [2]:
!pip show transformers

Name: transformers
Version: 4.10.0
Summary: State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch
Home-page: https://github.com/huggingface/transformers
Author: Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Suraj Patil, Stas Bekman, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors
Author-email: thomas@huggingface.co
License: Apache
Location: /Users/liuchaoqun/miniforge3/Applications/envs/env_conda/lib/python3.8/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, sacremoses, tokenizers, tqdm
Required-by: 


# Load the model and tokenizer

In [3]:
model_name = "DAMO-NLP-SG/zero-shot-classify-SSTuning-base" #@param ["DAMO-NLP-SG/zero-shot-classify-SSTuning-base", "DAMO-NLP-SG/zero-shot-classify-SSTuning-large","DAMO-NLP-SG/zero-shot-classify-SSTuning-ALBERT"]

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForSequenceClassification.from_pretrained(model_name)

## Create some helper functions to process the data

In [4]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
list_ABC = [x for x in string.ascii_uppercase]
def add_prefix(text,list_label, shuffle=False):
    list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
    list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label))
    if shuffle: 
        random.shuffle(list_label_new)
    s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
    return f'{s_option} {tokenizer.sep_token} {text}', list_label_new

def check_text(model, text, list_label, shuffle=False): 
    text, list_label_new = add_prefix(text,list_label, shuffle = shuffle)
    print('*'*50)
    model.to(device)
    model.eval()
    ids = tokenizer.encode(text)
    tokens = tokenizer.convert_ids_to_tokens(ids)
    print('input text:   ',text)
    encoding = tokenizer([text],truncation=True, max_length=512)
    item = {key: torch.tensor(val).to(device) for key, val in encoding.items()}
    logits = model(**item).logits
    probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
    predictions = torch.argmax(logits, dim=-1).item()
    probabilities = [round(x,5) for x in probs[0]]

    print('probabilities:',probabilities)
    print(f'prediction:    {predictions} => ({list_ABC[predictions]}) {list_label_new[predictions]}')
    print(f'probability:   {round(probabilities[predictions]*100,2)}%')

# Inference

## Sentiment Analysis

Provide the input and the list of labels. 
You can use original labels or convert the labels to sentences.

In [5]:
text = "I love this place! The food is always so fresh and delicious. The staff is always friendly, as well."

list_label = ["negative","positve"]
# list_label = ["It's terrible.","It's great."]

Process the input and do inference

In [6]:
check_text(model,text,list_label, shuffle=False)

**************************************************
input text:    (A) negative. (B) positve. (C) <pad> (D) <pad> (E) <pad> (F) <pad> (G) <pad> (H) <pad> (I) <pad> (J) <pad> (K) <pad> (L) <pad> (M) <pad> (N) <pad> (O) <pad> (P) <pad> (Q) <pad> (R) <pad> (S) <pad> (T) <pad> </s> I love this place! The food is always so fresh and delicious. The staff is always friendly, as well.
probabilities: [0.00361, 0.99555, 1e-05, 2e-05, 3e-05, 3e-05, 2e-05, 5e-05, 8e-05, 2e-05, 5e-05, 2e-05, 8e-05, 6e-05, 8e-05, 6e-05, 5e-05, 6e-05, 6e-05, 6e-05]
prediction:    1 => (B) positve.
probability:   99.56%


## Topic Classification

Provide the input and the list of labels.

In [11]:
text = "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again."

# list_label = ["politics","sports","business","technology"]
list_label = ["This text is about politics.", "This text is about sports.", "This text is about business.", "This text is about technology."]

In [12]:
check_text(model,text,list_label, shuffle=False)

**************************************************
input text:    (A) This text is about politics. (B) This text is about sports. (C) This text is about business. (D) This text is about technology. (E) <pad> (F) <pad> (G) <pad> (H) <pad> (I) <pad> (J) <pad> (K) <pad> (L) <pad> (M) <pad> (N) <pad> (O) <pad> (P) <pad> (Q) <pad> (R) <pad> (S) <pad> (T) <pad> </s> Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindlinand of ultra-cynics, are seeing green again.
probabilities: [0.03413, 0.02098, 0.82671, 0.1151, 0.00022, 0.00015, 0.00019, 0.00024, 9e-05, 0.00011, 0.00014, 0.00017, 0.00016, 0.0003, 0.00038, 0.0002, 0.00014, 0.00038, 0.00015, 7e-05]
prediction:    2 => (C) This text is about business.
probability:   82.67%
