# Logits : Sentiment Analysis

Use DistillBert for sentiment analysis.

https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english



#### Google Colab
If you are running the code in Google colab, install the packages by uncommenting/running the cell below

In [1]:
# !pip install transformers torch

### Import appropriate model & config classes

In [1]:
# PyTorch package Needed for interpretation of logits
import torch 

# Import the task specific class
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification


## 1. Create instance of tokenizer & model

In [9]:
model_name = "distilbert-base-uncased-finetuned-sst-2-english"

# Create tokenizer
tokenizer = DistilBertTokenizer.from_pretrained(model_name)

# Create model
model = DistilBertForSequenceClassification.from_pretrained(model_name)


## 2. Create tensors

In [19]:
# Sample text
text = 'I hated the restaurant'

# Convert to PyTorch tensors
inputs = tokenizer(text, return_tensors="pt")

## 3. Run Inference

In [20]:
# Safe way to run inference
with torch.no_grad():
    output = model(**inputs)

# Print the output type & output
print('Type = ', type(output))
print('Output = ', output)

Type =  <class 'transformers.modeling_outputs.SequenceClassifierOutput'>
Output =  SequenceClassifierOutput(loss=None, logits=tensor([[ 4.1924, -3.4785]]), hidden_states=None, attentions=None)


## 4. Interpret Logits (SequenceClassifierOutput)

* Returns a tensor with size = number of labels. 
* Each label has a score = probability for that label
* Find the index for which the score is highest

**PyTorch argmax()**
Returns index of the element with the maximum value.

https://pytorch.org/docs/stable/generated/torch.argmax.html#torch.argmax

* The configuration of the model has the information for the labels. Use it to get the label text.

In [21]:
logits = output.logits

# Get index for the label with maximum score
predicted_class_id = logits.argmax().item()

# Get the text for the label by passing the index
model.config.id2label[predicted_class_id]

'NEGATIVE'