In [23]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
  
tokenizer = AutoTokenizer.from_pretrained("CogComp/bart-faithful-summary-detector")
model = AutoModelForSequenceClassification.from_pretrained("CogComp/bart-faithful-summary-detector")

article = "Ban Ki-Moon was re-elected for a second term by the UN General Assembly, unopposed and unanimously, on 21 June 2011."

bad_summary = "Ban Ki-moon was elected for a second term in 2007."
good_summary = "Ban Ki-moon was elected for a second term in 2011."

bad_pair = tokenizer(text=bad_summary, text_pair=article, return_tensors='pt')
good_pair = tokenizer(text=good_summary, text_pair=article, return_tensors='pt')

bad_score = model(**bad_pair)
good_score = model(**good_pair)

print(good_score[0][:, 1] > bad_score[0][:, 1]) # True, label mapping: "0" -> "Hallucinated" "1" -> "Faithful"

You passed along `num_labels=3` with an incompatible id to label map: {'0': 'HALLUCINATED', '1': 'FAITHFUL'}. The number of labels wil be overwritten to 2.
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'HALLUCINATED', '1': 'FAITHFUL'}. The number of labels wil be overwritten to 2.
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'HALLUCINATED', '1': 'FAITHFUL'}. The number of labels wil be overwritten to 2.
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'HALLUCINATED', '1': 'FAITHFUL'}. The number of labels wil be overwritten to 2.


tensor([True])


In [25]:
bad_score[0]

tensor([[ 0.6020, -0.4472]], grad_fn=<AddmmBackward0>)

In [26]:
good_score[0]

tensor([[-2.6085, -0.2462]], grad_fn=<AddmmBackward0>)

In [61]:
rewards

[[[[{'label': 'HALLUCINATED', 'score': 0.6020027995109558},
    {'label': 'FAITHFUL', 'score': -0.4471779763698578}]]],
 [[[{'label': 'HALLUCINATED', 'score': -2.6085076332092285},
    {'label': 'FAITHFUL', 'score': -0.2461569607257843}]]]]

In [27]:
rewards

[[[[{'label': 'HALLUCINATED', 'score': 0.9396332502365112},
    {'label': 'FAITHFUL', 'score': -0.46942010521888733}]]],
 [[[{'label': 'HALLUCINATED', 'score': -1.5056735277175903},
    {'label': 'FAITHFUL', 'score': -0.32813945412635803}]]]]

In [29]:
bad_pair["input_ids"]

tensor([[    0, 33809, 11488,    12, 16956,    21,  2736,    13,    10,   200,
          1385,    11,  3010,     4,     2,     2, 33809, 11488,    12, 32452,
            21,   769,    12, 15672,    13,    10,   200,  1385,    30,     5,
          2604,  1292,  3389,     6,   542, 10223,  7878,     8, 12008,     6,
            15,   733,   502,  1466,     4,     2]])

In [3]:
good_score[0][:, 1]

tensor([-0.2462], grad_fn=<SelectBackward0>)

In [2]:
bad_score[0][:, 1]

tensor([-0.4472], grad_fn=<SelectBackward0>)

In [1]:
from transformers import pipeline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext autoreload
%autoreload 2
import sys
import collections
import numpy as np
import re
import os

sys.path.append("/home/rame/trl/examples/llama/")

In [3]:
import llama_utils

In [4]:
tokenizer_name=llama_utils.Tokenizer.load_tokenizer_name("unitary/toxic-bert")

In [5]:
import llama_utils

In [6]:
import torch

In [7]:
device = 0 if torch.cuda.is_available() else "cpu"

In [8]:
pipe = pipeline(
    "text-classification",
    model="unitary/toxic-bert",
    device=device,
    tokenizer=tokenizer_name,
)

In [13]:
article = "Ban Ki-Moon was re-elected for a second term by the UN General Assembly, unopposed and unanimously, on 21 June 2011."

bad_summary = "Ban Ki-moon was elected for a second term in 2007."
good_summary = "Ban Ki-moon was elected for a second term in 2011."

In [12]:
text = "Wow... what would you do in a situation like that!? The story of Breaking Point, though short, packed a lot of emotion. The acting was"

In [13]:
pipe(
    text,
    **{
        "return_all_scores": True,
        "function_to_apply": "none",
        "batch_size": 1
    }
)




[[{'label': 'toxic', 'score': -7.053863525390625},
  {'label': 'severe_toxic', 'score': -9.13469409942627},
  {'label': 'obscene', 'score': -8.60804557800293},
  {'label': 'threat', 'score': -9.114510536193848},
  {'label': 'insult', 'score': -8.57530689239502},
  {'label': 'identity_hate', 'score': -8.898863792419434}]]

In [14]:
sentiment_pipes = [pipe]

In [16]:
queries_responses = [
    ("### Input:" + article, bad_summary),
    ("### Input:" + article, good_summary)    
]

In [22]:
llama_utils.Instructions.get_input("### Input:" + article)

'Ban Ki-Moon was re-elected for a second term by the UN General Assembly, unopposed and unanimously, on 21 June 2011.'

In [46]:
bad_pair["input_ids"]

tensor([[    0, 33809, 11488,    12, 16956,    21,  2736,    13,    10,   200,
          1385,    11,  3010,     4,     2,     2, 33809, 11488,    12, 32452,
            21,   769,    12, 15672,    13,    10,   200,  1385,    30,     5,
          2604,  1292,  3389,     6,   542, 10223,  7878,     8, 12008,     6,
            15,   733,   502,  1466,     4,     2]])

In [49]:
pipe.tokenizer.decode(bad_pair["input_ids"][0])

'<s>Ban Ki-moon was elected for a second term in 2007.</s></s>Ban Ki-Moon was re-elected for a second term by the UN General Assembly, unopposed and unanimously, on 21 June 2011.</s>'

In [51]:
pipe.tokenizer.decode(pipe.tokenizer.encode(pairs[0][0]))

'<s>Ban Ki-moon was elected for a second term in 2007.</s></s>Ban Ki-Moon was re-elected for a second term by the UN General Assembly, unopposed and unanimously, on 21 June 2011.</s>'

In [41]:
[(x, y) for x, y in zip(bad_pair["input_ids"][0].tolist(), pipe.tokenizer.encode(pairs[0][0]))]

[(0, 0),
 (33809, 33809),
 (11488, 11488),
 (12, 12),
 (16956, 16956),
 (21, 21),
 (2736, 2736),
 (13, 13),
 (10, 10),
 (200, 200),
 (1385, 1385),
 (11, 11),
 (3010, 3010),
 (4, 2),
 (2, 2),
 (2, 33809),
 (33809, 11488),
 (11488, 12),
 (12, 32452),
 (32452, 21),
 (21, 769),
 (769, 12),
 (12, 15672),
 (15672, 13),
 (13, 10),
 (10, 200),
 (200, 1385),
 (1385, 30),
 (30, 5),
 (5, 2604),
 (2604, 1292),
 (1292, 3389),
 (3389, 6),
 (6, 542),
 (542, 10223),
 (10223, 7878),
 (7878, 8),
 (8, 12008),
 (12008, 6),
 (6, 15),
 (15, 733),
 (733, 502),
 (502, 1466),
 (1466, 4),
 (4, 2)]

In [50]:
pairs = [
    [
        
            llama_utils.transform_text_summary(
                sentiment_pipe=sentiment_pipe,
                post=llama_utils.Instructions.get_input(query),
                response=response
            )
         for sentiment_pipe in sentiment_pipes
    ] for query, response in queries_responses
]

In [59]:
rewards = [
    [
        
            sentiment_pipe(llama_utils.transform_text_summary(
                sentiment_pipe=sentiment_pipe,
                post=llama_utils.Instructions.get_input(query),
                response=response,
                
            ),**{"return_all_scores": True, "function_to_apply": "none", "batch_size": 1})
         for sentiment_pipe in sentiment_pipes
    ] for query, response in queries_responses
]

[[[[{'label': 'HALLUCINATED', 'score': 0.6020027995109558},
    {'label': 'FAITHFUL', 'score': -0.4471779763698578}]]],
 [[[{'label': 'HALLUCINATED', 'score': -2.6085076332092285},
    {'label': 'FAITHFUL', 'score': -0.2461569607257843}]]]]

In [37]:
pipe.tokenizer.encode(pairs[0][0])

[0,
 33809,
 11488,
 12,
 16956,
 21,
 2736,
 13,
 10,
 200,
 1385,
 11,
 3010,
 2,
 2,
 33809,
 11488,
 12,
 32452,
 21,
 769,
 12,
 15672,
 13,
 10,
 200,
 1385,
 30,
 5,
 2604,
 1292,
 3389,
 6,
 542,
 10223,
 7878,
 8,
 12008,
 6,
 15,
 733,
 502,
 1466,
 4,
 2]

In [20]:
rewards

[[[[{'label': 'HALLUCINATED', 'score': 0.9396332502365112},
    {'label': 'FAITHFUL', 'score': -0.46942010521888733}]]],
 [[[{'label': 'HALLUCINATED', 'score': -1.5056735277175903},
    {'label': 'FAITHFUL', 'score': -0.32813945412635803}]]]]

In [None]:
[
                sentiment_pipe(
                    llama_utils.transform_text_summary(
                        sentiment_pipe=sentiment_pipe,
                        post=llama_utils.Instructions.get_input(query),
                        response=response
                    ), **self.sent_kwargs
                ) for sentiment_pipe in sentiment_pipes
            ]