In [1]:
import torch
from transformers import RobertaForSequenceClassification, AutoTokenizer
from transformers import BartForConditionalGeneration, AutoTokenizer
import numpy as np
import pandas as pd
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import evaluate
import numpy as np
from tqdm import tqdm
import json
from nltk.tokenize import RegexpTokenizer
bleu_metric = evaluate.load("sacrebleu")
sari_metric = evaluate.load("sari")
import wandb
# wandb.login(key = "bb076da462c822cc4bc5bec1095c39708332b34d" )
# wandb.init(project="paraphrase_controled_with_classification", entity="paraphrase")


def metrics_func(eval_arg):
    print(len(eval_arg[0]),len(eval_arg[1]),len(eval_arg[2]))
    print(len(eval_arg))
    text_inputs = eval_arg[0]
    text_preds = eval_arg[1]
    text_labels = eval_arg[2]
    texts_bleu =[text.strip() for text in text_preds]
    labels_bleu = [[text.strip()] for text in text_labels[0]]
    result = bleu_metric.compute(predictions=texts_bleu, references=text_inputs)
    return result["score"],sari_metric.compute(
        predictions=text_preds,
        references=text_labels,
        sources=text_inputs,
    )['sari']

In [3]:


classifier = RobertaForSequenceClassification.from_pretrained("liamcripwell/ctrl44-clf",cache_dir='/ssd_scratch/cvit/aparna/classification')
tokenizer1 = AutoTokenizer.from_pretrained("liamcripwell/ctrl44-clf",cache_dir='/ssd_scratch/cvit/aparna/classification_tokenizer')

model = BartForConditionalGeneration.from_pretrained("liamcripwell/ctrl44-simp",cache_dir='/ssd_scratch/cvit/aparna/paraphrase')
tokenizer2 = AutoTokenizer.from_pretrained("liamcripwell/ctrl44-simp",cache_dir='/ssd_scratch/cvit/aparna/paraphrase_tokenizer')


test_data = pd.read_csv("../../data/10/test.csv")
test_data = test_data.dropna()
test_data = test_data.reset_index(drop=True)
#take 1000 samples
test_data = test_data[:10000]
texts = test_data["source"].tolist()
labels = test_data["target"].tolist()

metrics =[]
inpu = []
cands = []
lab = []
max_l = 512
num_b = 10
num_sub_b =1

for  i  in tqdm(range(len(texts))):
    text = texts[i]
    inputs = tokenizer1(text, return_tensors="pt")
    inpu.extend([text]*10)
    with torch.no_grad():
        logits = classifier(**inputs).logits
        predicted_class_id = logits.argmax().item()
        predicted_class_name = classifier.config.id2label[predicted_class_id]
        text = predicted_class_name + " " + text
        inputs1 = tokenizer2(text, return_tensors="pt")
        beam_outputs = model.generate(**inputs1,max_length=max_l,num_beams=num_b,early_stopping=True,
        no_repeat_ngram_size=3,
        num_return_sequences=10,
        top_k=4, top_p=0.95
        # return_dict_in_generate=True,
    )
    c = []
    for x, beam in enumerate(beam_outputs):
    #print("{}".format(i, tokenizer.decode(beam, skip_special_tokens=True)))
        c.append(tokenizer2.decode(beam, skip_special_tokens=True))
    cands.extend(c)

    lab.extend([[labels[i]]]*10)
    eval_arg = [inpu,cands,lab]
    #print(i)
    if i%10==0:
        metric = metrics_func(eval_arg)
        metrics.append(metric)
        print(metric)
        
print(len(inpu),len(cands),len(lab))
metric = metrics_func(eval_arg)
print(metric)



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


10 10 10
3


  0%|          | 1/10000 [01:20<222:15:51, 80.02s/it]

(64.80660519293919, 39.15192624769725)


  0%|          | 10/10000 [11:59<203:03:45, 73.18s/it]

110 110 110
3


  0%|          | 11/10000 [15:25<316:02:54, 113.90s/it]

(70.54178072608359, 46.836926288363195)


  0%|          | 20/10000 [26:42<212:30:46, 76.66s/it] 

210 210 210
3


  0%|          | 21/10000 [27:03<166:11:03, 59.95s/it]

(71.04719859387419, 50.32154571874382)


  0%|          | 30/10000 [27:32<10:04:17,  3.64s/it] 

310 310 310
3


  0%|          | 31/10000 [27:34<8:53:49,  3.21s/it] 

(68.85612876937226, 49.86674347882743)


  0%|          | 40/10000 [29:15<27:55:03, 10.09s/it]

410 410 410
3


  0%|          | 41/10000 [29:25<28:08:53, 10.18s/it]

(69.54174395753573, 48.71997313052782)


  0%|          | 50/10000 [31:33<51:07:23, 18.50s/it]

510 510 510
3


  1%|          | 51/10000 [31:39<40:25:13, 14.63s/it]

(69.35243166774546, 47.80644002940515)


  1%|          | 60/10000 [33:48<35:20:45, 12.80s/it]

610 610 610
3


  1%|          | 61/10000 [34:21<52:29:52, 19.02s/it]

(70.35532419224955, 48.779826833888826)


  1%|          | 70/10000 [36:27<35:47:46, 12.98s/it]

710 710 710
3


  1%|          | 71/10000 [36:33<30:43:53, 11.14s/it]

(70.77187332041827, 48.676242191013515)


  1%|          | 80/10000 [39:27<49:16:04, 17.88s/it]

810 810 810
3


  1%|          | 81/10000 [39:43<47:50:39, 17.36s/it]

(70.99442215947855, 48.11475963706405)


  1%|          | 90/10000 [41:45<27:16:36,  9.91s/it]

910 910 910
3


  1%|          | 91/10000 [42:04<34:52:57, 12.67s/it]

(70.61096101098563, 48.44181275168831)


  1%|          | 100/10000 [42:56<6:30:50,  2.37s/it]

1010 1010 1010
3


  1%|          | 101/10000 [42:59<6:49:41,  2.48s/it]

(70.83445521500816, 48.08033791717513)


  1%|          | 110/10000 [43:10<3:39:32,  1.33s/it]

1110 1110 1110
3


  1%|          | 111/10000 [43:13<4:44:00,  1.72s/it]

(70.43052594742274, 47.54168189333594)


  1%|          | 120/10000 [43:47<4:16:11,  1.56s/it] 

1210 1210 1210
3


  1%|          | 121/10000 [43:50<5:38:18,  2.05s/it]

(70.95409638377379, 47.24959246332547)


  1%|▏         | 130/10000 [44:00<2:59:38,  1.09s/it]

1310 1310 1310
3


  1%|▏         | 131/10000 [44:03<5:03:43,  1.85s/it]

(71.7333348396733, 47.34951905644603)


  1%|▏         | 140/10000 [44:14<3:46:01,  1.38s/it]

1410 1410 1410
3


  1%|▏         | 141/10000 [44:17<5:32:36,  2.02s/it]

(71.22691736999529, 47.36419069522122)


  2%|▏         | 150/10000 [44:27<3:09:31,  1.15s/it]

1510 1510 1510
3


  2%|▏         | 151/10000 [44:30<4:33:16,  1.66s/it]

(71.5022811047697, 47.480414460524166)


  2%|▏         | 160/10000 [44:39<3:21:02,  1.23s/it]

1610 1610 1610
3


  2%|▏         | 161/10000 [44:42<5:00:32,  1.83s/it]

(71.0640741427282, 47.179534372007026)


  2%|▏         | 170/10000 [44:54<2:50:01,  1.04s/it]

1710 1710 1710
3


  2%|▏         | 171/10000 [44:57<4:56:56,  1.81s/it]

(70.83703293667517, 47.27312848767143)


  2%|▏         | 180/10000 [45:06<2:49:49,  1.04s/it]

1810 1810 1810
3


  2%|▏         | 181/10000 [45:10<5:32:37,  2.03s/it]

(70.53400770760163, 46.947503641045465)


  2%|▏         | 190/10000 [45:20<3:21:05,  1.23s/it]

1910 1910 1910
3


  2%|▏         | 191/10000 [45:25<6:00:56,  2.21s/it]

(70.12295089755837, 46.911655187993524)


  2%|▏         | 200/10000 [45:34<3:09:25,  1.16s/it]

2010 2010 2010
3


  2%|▏         | 201/10000 [45:39<6:26:36,  2.37s/it]

(70.12323645568293, 46.694909042389334)


  2%|▏         | 210/10000 [45:48<2:39:32,  1.02it/s]

2110 2110 2110
3


  2%|▏         | 211/10000 [45:53<6:08:47,  2.26s/it]

(70.03902549820648, 46.73354258783017)


  2%|▏         | 220/10000 [46:02<3:41:13,  1.36s/it]

2210 2210 2210
3


  2%|▏         | 221/10000 [46:07<6:31:11,  2.40s/it]

(69.92839709103282, 46.61884240165913)


  2%|▏         | 230/10000 [46:17<3:30:47,  1.29s/it]

2310 2310 2310
3


  2%|▏         | 231/10000 [46:21<5:42:54,  2.11s/it]

(69.94185548370616, 46.767005212883824)


  2%|▏         | 240/10000 [46:30<2:29:23,  1.09it/s]

2410 2410 2410
3


  2%|▏         | 241/10000 [46:35<5:38:19,  2.08s/it]

(69.78922736387585, 46.71830099186845)


  2%|▎         | 250/10000 [46:45<3:14:51,  1.20s/it]

2510 2510 2510
3


  3%|▎         | 251/10000 [46:49<5:35:06,  2.06s/it]

(69.9484582444172, 46.63396716134594)


  3%|▎         | 260/10000 [46:57<2:28:59,  1.09it/s]

2610 2610 2610
3


  3%|▎         | 261/10000 [47:02<5:32:38,  2.05s/it]

(69.86235439632505, 46.49706133888547)


  3%|▎         | 270/10000 [47:10<2:26:02,  1.11it/s]

2710 2710 2710
3


  3%|▎         | 271/10000 [47:15<5:12:06,  1.92s/it]

(70.0731364734662, 46.559810883960644)


  3%|▎         | 280/10000 [47:25<3:06:25,  1.15s/it]

2810 2810 2810
3


  3%|▎         | 281/10000 [47:30<6:33:21,  2.43s/it]

(70.5411071464047, 46.36294524518705)


  3%|▎         | 290/10000 [47:40<3:00:59,  1.12s/it]

2910 2910 2910
3


  3%|▎         | 291/10000 [47:44<5:52:11,  2.18s/it]

(70.41911227062276, 46.387442929220896)


  3%|▎         | 300/10000 [47:54<3:13:02,  1.19s/it]

3010 3010 3010
3


  3%|▎         | 301/10000 [47:58<5:48:05,  2.15s/it]

(70.3784540880233, 46.48910103041297)


  3%|▎         | 310/10000 [48:06<2:36:04,  1.03it/s]

3110 3110 3110
3


  3%|▎         | 311/10000 [48:12<6:23:13,  2.37s/it]

(70.47343955033334, 46.50620064448573)


  3%|▎         | 320/10000 [48:22<2:52:25,  1.07s/it]

3210 3210 3210
3


  3%|▎         | 321/10000 [48:28<6:46:40,  2.52s/it]

(70.61315861121342, 46.413449907380134)


  3%|▎         | 330/10000 [48:39<2:59:18,  1.11s/it]

3310 3310 3310
3


  3%|▎         | 331/10000 [48:44<5:52:49,  2.19s/it]

(70.90373863907463, 46.47027552203239)


  3%|▎         | 340/10000 [48:54<2:44:45,  1.02s/it]

3410 3410 3410
3


  3%|▎         | 341/10000 [48:59<5:53:01,  2.19s/it]

(71.03180068515914, 46.420068007605494)


  4%|▎         | 350/10000 [49:07<2:48:50,  1.05s/it]

3510 3510 3510
3


  4%|▎         | 351/10000 [49:12<5:57:30,  2.22s/it]

(71.13476773149243, 46.426444173726054)


  4%|▎         | 360/10000 [49:23<3:08:22,  1.17s/it]

3610 3610 3610
3


  4%|▎         | 361/10000 [49:29<7:37:53,  2.85s/it]

(71.39790192841862, 46.398790417257224)


  4%|▎         | 370/10000 [49:41<4:20:01,  1.62s/it]

3710 3710 3710
3


  4%|▎         | 371/10000 [49:48<8:09:25,  3.05s/it]

(71.40129670751433, 46.306933902818145)


  4%|▍         | 380/10000 [50:00<3:08:45,  1.18s/it]

3810 3810 3810
3


  4%|▍         | 381/10000 [50:08<8:37:52,  3.23s/it]

(71.35337051764337, 46.274645977912115)


  4%|▍         | 390/10000 [50:18<3:22:50,  1.27s/it]

3910 3910 3910
3


  4%|▍         | 391/10000 [50:25<7:43:37,  2.89s/it]

(71.44321132176286, 46.28899659753253)


  4%|▍         | 400/10000 [50:36<3:20:09,  1.25s/it]

4010 4010 4010
3


  4%|▍         | 401/10000 [50:42<7:47:27,  2.92s/it]

(71.40759570735054, 46.35974612907791)


  4%|▍         | 410/10000 [50:53<3:52:10,  1.45s/it]

4110 4110 4110
3


  4%|▍         | 411/10000 [51:00<8:17:54,  3.12s/it]

(71.44493046174391, 46.335290441989315)


  4%|▍         | 420/10000 [51:09<2:49:32,  1.06s/it]

4210 4210 4210
3


  4%|▍         | 421/10000 [51:16<7:05:21,  2.66s/it]

(71.24476382671102, 46.268053289469904)


  4%|▍         | 430/10000 [51:26<2:38:19,  1.01it/s]

4310 4310 4310
3


  4%|▍         | 431/10000 [51:33<7:13:39,  2.72s/it]

(71.33827407030589, 46.212837981975305)


  4%|▍         | 440/10000 [51:43<3:05:47,  1.17s/it]

4410 4410 4410
3


  4%|▍         | 441/10000 [51:49<6:49:43,  2.57s/it]

(71.32406994861574, 46.19032193848718)


  4%|▍         | 450/10000 [51:58<3:03:50,  1.16s/it]

4510 4510 4510
3


  5%|▍         | 451/10000 [52:05<7:06:16,  2.68s/it]

(71.34226305793062, 46.32402152328336)


  5%|▍         | 460/10000 [52:13<2:58:13,  1.12s/it]

4610 4610 4610
3


  5%|▍         | 461/10000 [52:19<7:04:45,  2.67s/it]

(71.36231512810976, 46.25655650372653)


  5%|▍         | 466/10000 [52:26<17:52:46,  6.75s/it]


KeyboardInterrupt: 

In [None]:
from transformers import RobertaForSequenceClassification, AutoTokenizer

model = RobertaForSequenceClassification.from_pretrained("liamcripwell/ctrl44-clf")
tokenizer = AutoTokenizer.from_pretrained("liamcripwell/ctrl44-clf")

text = "Barack Hussein Obama II is an American politician who served as the 44th president of the United States from 2009 to 2017."
inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
  logits = model(**inputs).logits
predicted_class_id = logits.argmax().item()
predicted_class_name = model.config.id2label[predicted_class_id]

In [5]:
text = "Since,2010, project researchers have uncovered documents in portugal that have revealed who owned the ship."
    # encode the text into tensor of integers using the appropriate tokenizer
inputs = tokenizer1(text, return_tensors="pt")
logits = classifier(**inputs).logits
predicted_class_id = logits.argmax().item()
predicted_class_name = classifier.config.id2label[predicted_class_id]
text = predicted_class_name + " " + text
inputs1 = tokenizer2(text, return_tensors="pt")
    # generate text until the output length (which includes the context length) reaches 50
beam_outputs = model.generate(**inputs,max_length=max_l,num_beams=num_b,early_stopping=True,
    no_repeat_ngram_size=3,
    num_return_sequences=10,
    top_k=4, top_p=0.95
    # return_dict_in_generate=True,
)
for x, beam in enumerate(beam_outputs):
    print("{} {}".format(i, tokenizer2.decode(beam, skip_special_tokens=True)))



466  Since,2010, project researchers have uncovered documents in Portugal that have revealed who owned the ship.
466  Since,2010, project researchers have uncovered documents in Portugal that reveal who owned the ship.
466  Since,2010, project researchers have uncovered documents in Portugal that revealed who owned the ship.
466  Since,2010, project researchers have uncovered documents in portugal that reveal who owned the ship.
466  Since,2010, project researchers have uncovered documents in Portugal which have revealed who owned the ship.
466  Since,2010, project researchers have uncovered documents that have revealed who owned the ship.
466  Since,2010, project researchers have uncovered documents in Portugal which reveal who owned the ship.
466  Since,2010, project researchers have uncovered documents in Portugal that have revealed who owned the vessel.
466  Since,2010, project researchers have uncovered documents in Portugal that have revealed who owns the ship.
466  Since,2010, p

In [6]:
text = "Experts say China's air pollution exacts a tremendous toll on human health."
    # encode the text into tensor of integers using the appropriate tokenizer
inputs = tokenizer1(text, return_tensors="pt")
logits = classifier(**inputs).logits
predicted_class_id = logits.argmax().item()
predicted_class_name = classifier.config.id2label[predicted_class_id]
text = predicted_class_name + " " + text
inputs1 = tokenizer2(text, return_tensors="pt")
    # generate text until the output length (which includes the context length) reaches 50
beam_outputs = model.generate(**inputs,max_length=max_l,num_beams=num_b,early_stopping=True,
    no_repeat_ngram_size=3,
    num_return_sequences=10,
    top_k=4, top_p=0.95
    # return_dict_in_generate=True,
)
for x, beam in enumerate(beam_outputs):
    print("{} {}".format(i, tokenizer2.decode(beam, skip_special_tokens=True)))

466  Experts say China's air pollution has a huge impact on human health.
466  Experts say China's air pollution has a huge toll on human health.
466  Experts say China's air pollution is a huge toll on human health.
466  Experts say China's air pollution has a huge effect on human health.
466  Experts say China's air pollution is a terrible toll on human health.
466  Experts say China's air pollution has a tremendous toll on human health.
466  Experts say China's air pollution puts a huge toll on human health.
466  Experts say China's air pollution has a tremendous impact on human health.
466  Experts say China's air pollution has a huge impact on health.
466  Experts say China's air pollution is a serious toll on human health.
