<a href="https://colab.research.google.com/github/Nishasathish13/TheSchoolofAI-END3.0/blob/main/Session%2011_12%20-%20BERT%20and%20BART/Assignment/TASK_3_Session_11_12_BART_for_paraphrasing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# setup

In [1]:
!pip install simpletransformers

Collecting simpletransformers
  Downloading simpletransformers-0.63.4-py3-none-any.whl (248 kB)
[?25l[K     |█▎                              | 10 kB 23.1 MB/s eta 0:00:01[K     |██▋                             | 20 kB 30.2 MB/s eta 0:00:01[K     |████                            | 30 kB 36.4 MB/s eta 0:00:01[K     |█████▎                          | 40 kB 35.0 MB/s eta 0:00:01[K     |██████▋                         | 51 kB 24.9 MB/s eta 0:00:01[K     |████████                        | 61 kB 28.2 MB/s eta 0:00:01[K     |█████████▎                      | 71 kB 28.1 MB/s eta 0:00:01[K     |██████████▌                     | 81 kB 29.7 MB/s eta 0:00:01[K     |███████████▉                    | 92 kB 32.0 MB/s eta 0:00:01[K     |█████████████▏                  | 102 kB 34.0 MB/s eta 0:00:01[K     |██████████████▌                 | 112 kB 34.0 MB/s eta 0:00:01[K     |███████████████▉                | 122 kB 34.0 MB/s eta 0:00:01[K     |█████████████████▏              |

#Data Preparation

We will be combining three datasets to serve as training data for our BART Paraphrasing Model.
* Google PAWS-Wiki Labeled (Final)
* Quora Question Pairs Dataset

In [2]:
!mkdir data
!wget https://storage.googleapis.com/paws/english/paws_wiki_labeled_final.tar.gz -P data
!tar -xvf data/paws_wiki_labeled_final.tar.gz -C data
!mv data/final/* data
!rm -r data/final

!wget http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv -P data

--2022-03-12 17:23:10--  https://storage.googleapis.com/paws/english/paws_wiki_labeled_final.tar.gz
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.16.128, 142.251.45.16, 172.217.0.48, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.16.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4687157 (4.5M) [application/gzip]
Saving to: ‘data/paws_wiki_labeled_final.tar.gz’


2022-03-12 17:23:10 (129 MB/s) - ‘data/paws_wiki_labeled_final.tar.gz’ saved [4687157/4687157]

final/test.tsv
final/
final/train.tsv
final/dev.tsv
--2022-03-12 17:23:11--  http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv
Resolving qim.fs.quoracdn.net (qim.fs.quoracdn.net)... 151.101.1.2, 151.101.65.2, 151.101.129.2, ...
Connecting to qim.fs.quoracdn.net (qim.fs.quoracdn.net)|151.101.1.2|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 58176133 (55M) [text/tab-separated-values]
Saving to: ‘data/quora_duplicate_questi

We also have a couple of helper functions, one to load data, and one to clean unnecessary spaces in the training data. Both of these functions are defined in utils.py.

In [3]:
import warnings

import pandas as pd


def load_data(
    file_path, input_text_column, target_text_column, label_column, keep_label=1
):
    df = pd.read_csv(file_path, sep="\t", error_bad_lines=False)
    df = df.loc[df[label_column] == keep_label]
    df = df.rename(
        columns={input_text_column: "input_text", target_text_column: "target_text"}
    )
    df = df[["input_text", "target_text"]]
    df["prefix"] = "paraphrase"

    return df


def clean_unnecessary_spaces(out_string):
    if not isinstance(out_string, str):
        warnings.warn(f">>> {out_string} <<< is not a string.")
        out_string = str(out_string)
    out_string = (
        out_string.replace(" .", ".")
        .replace(" ?", "?")
        .replace(" !", "!")
        .replace(" ,", ",")
        .replace(" ' ", "'")
        .replace(" n't", "n't")
        .replace(" 'm", "'m")
        .replace(" 's", "'s")
        .replace(" 've", "'ve")
        .replace(" 're", "'re")
    )
    return out_string

import all the necessary stuff and set up logging

In [4]:
import os
from datetime import datetime
import logging

import pandas as pd
from sklearn.model_selection import train_test_split
from simpletransformers.seq2seq import Seq2SeqModel, Seq2SeqArgs

#from utils import load_data, clean_unnecessary_spaces

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.ERROR)

Loading the datasets

In [5]:
# Google Data
train_df = pd.read_csv("data/train.tsv", sep="\t").astype(str)
eval_df = pd.read_csv("data/dev.tsv", sep="\t").astype(str)

train_df = train_df.loc[train_df["label"] == "1"]
eval_df = eval_df.loc[eval_df["label"] == "1"]

train_df = train_df.rename(
    columns={"sentence1": "input_text", "sentence2": "target_text"}
)
eval_df = eval_df.rename(
    columns={"sentence1": "input_text", "sentence2": "target_text"}
)

train_df = train_df[["input_text", "target_text"]]
eval_df = eval_df[["input_text", "target_text"]]

train_df["prefix"] = "paraphrase"
eval_df["prefix"] = "paraphrase"

# Quora Data

# The Quora Dataset is not separated into train/test, so we do it manually the first time.
df = load_data(
    "data/quora_duplicate_questions.tsv", "question1", "question2", "is_duplicate"
)
q_train, q_test = train_test_split(df)

q_train.to_csv("data/quora_train.tsv", sep="\t")
q_test.to_csv("data/quora_test.tsv", sep="\t")

# The code block above only needs to be run once.
# After that, the two lines below are sufficient to load the Quora dataset.

# q_train = pd.read_csv("data/quora_train.tsv", sep="\t")
# q_test = pd.read_csv("data/quora_test.tsv", sep="\t")

train_df = pd.concat([train_df, q_train])
eval_df = pd.concat([eval_df, q_test])

train_df = train_df[["prefix", "input_text", "target_text"]]
eval_df = eval_df[["prefix", "input_text", "target_text"]]

train_df = train_df.dropna()
eval_df = eval_df.dropna()

train_df["input_text"] = train_df["input_text"].apply(clean_unnecessary_spaces)
train_df["target_text"] = train_df["target_text"].apply(clean_unnecessary_spaces)

eval_df["input_text"] = eval_df["input_text"].apply(clean_unnecessary_spaces)
eval_df["target_text"] = eval_df["target_text"].apply(clean_unnecessary_spaces)

print(train_df)





            prefix                                         input_text  \
1       paraphrase  The NBA season of 1975 -- 76 was the 30th seas...   
3       paraphrase  When comparable rates of flow can be maintaine...   
4       paraphrase  It is the seat of Zerendi District in Akmola R...   
5       paraphrase  William Henry Henry Harman was born on 17 Febr...   
7       paraphrase  With a discrete amount of probabilities Formul...   
...            ...                                                ...   
68903   paraphrase      What is the physical significance of entropy?   
251332  paraphrase  What is the most inspirational book you have e...   
65045   paraphrase                  When will humans become immortal?   
345815  paraphrase  How do I become partners with tech companies i...   
389543  paraphrase  I used to rock back and forth - sometimes in c...   

                                              target_text  
1       The 1975 -- 76 season of the National Basketba...  
3  

Then, we set up the model and hyperparameter values. Note that we are using the pre-trained facebook/bart-large model, and fine-tuning it on our own dataset.
Finally, we’ll generate paraphrases for each of the sentences in the test data.

In [6]:
train_df = train_df[10000:18000]
eval_df = eval_df[20000:23000]

In [7]:
model_args = Seq2SeqArgs()
model_args.do_sample = True
model_args.eval_batch_size = 8
model_args.evaluate_during_training = True
model_args.evaluate_during_training_steps = 2500
model_args.evaluate_during_training_verbose = True
model_args.fp16 = False
model_args.learning_rate = 5e-5
model_args.max_length = 128
model_args.max_seq_length = 128
model_args.num_beams = None
model_args.num_return_sequences = 3
model_args.num_train_epochs = 2
model_args.overwrite_output_dir = True
model_args.reprocess_input_data = True
model_args.save_eval_checkpoints = False
model_args.save_steps = -1
model_args.top_k = 50
model_args.top_p = 0.95
model_args.train_batch_size = 8
model_args.use_multiprocessing = False
model_args.wandb_project = "Paraphrasing with BART"


model = Seq2SeqModel(
    encoder_decoder_type="bart",
    encoder_decoder_name="facebook/bart-large",
    args=model_args,
)

model.train_model(train_df, eval_data=eval_df)

to_predict = [
    prefix + ": " + str(input_text)
    for prefix, input_text in zip(eval_df["prefix"].tolist(), eval_df["input_text"].tolist())
]
truth = eval_df["target_text"].tolist()

preds = model.predict(to_predict)

# Saving the predictions if needed
os.makedirs("predictions", exist_ok=True)

with open(f"predictions/predictions_{datetime.now()}.txt", "w") as f:
    for i, text in enumerate(eval_df["input_text"].tolist()):
        f.write(str(text) + "\n\n")

        f.write("Truth:\n")
        f.write(truth[i] + "\n\n")

        f.write("Prediction:\n")
        for pred in preds[i]:
            f.write(str(pred) + "\n")
        f.write(
            "________________________________________________________________________________\n"
        )

Downloading:   0%|          | 0.00/1.59k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/971M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/8000 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model: Training started


Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Running Epoch 0 of 2:   0%|          | 0/1000 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model:Saving model into outputs/checkpoint-1000-epoch-1
INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/3000 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model:{'eval_loss': 2.4787684418360394}
INFO:simpletransformers.seq2seq.seq2seq_model:Saving model into outputs/best_model


Running Epoch 1 of 2:   0%|          | 0/1000 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model:Saving model into outputs/checkpoint-2000-epoch-2
INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/3000 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model:{'eval_loss': 2.4818944028218586}
INFO:simpletransformers.seq2seq.seq2seq_model:Saving model into outputs/
INFO:simpletransformers.seq2seq.seq2seq_model: Training of facebook/bart-large model complete. Saved to outputs/.


Generating outputs:   0%|          | 0/375 [00:00<?, ?it/s]

In [8]:
model_args.do_sample = True
model_args.num_beams = None
model_args.num_return_sequences = 3
model_args.max_length = 128
model_args.top_k = 50
model_args.top_p = 0.95

In [12]:
import logging

from simpletransformers.seq2seq import Seq2SeqModel


logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.ERROR)

model = Seq2SeqModel(
    encoder_decoder_type="bart", encoder_decoder_name="outputs"
)


while True:
    original = input("Input sentence")
    to_predict = [original]

    preds = model.predict(to_predict)

    print("---------------------------------------------------------")
    print(original)

    print()
    print("Predictions >>>")
    for pred in preds[0]:
        print(pred)

    print("---------------------------------------------------------")
    print()

Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f1dd7fc3110>> (for pre_run_cell):


Exception: ignored

Input sentenceHis fame is due in mathematical astronomy to the introduction of the astronomical globe and to his early contributions to the understanding of the movement of the planets


Generating outputs:   0%|          | 0/1 [00:00<?, ?it/s]

---------------------------------------------------------
His fame is due in mathematical astronomy to the introduction of the astronomical globe and to his early contributions to the understanding of the movement of the planets

Predictions >>>
His fame in mathematical astronomy is due to the introduction of the astronomical globe and his early contributions to the understanding of the movement of the planets.
His fame in mathematical astronomy is due to the introduction of the astronomical globe and his early contributions to the understanding of the movement of the planets.
His fame in mathematical astronomy is due to the introduction of the astronomical globe and his early contributions to the understanding of the movement of the planets.
---------------------------------------------------------

Input sentenceWhy are people so obsessed with Cara Delevingne?


Generating outputs:   0%|          | 0/1 [00:00<?, ?it/s]

---------------------------------------------------------
Why are people so obsessed with Cara Delevingne?

Predictions >>>
Why are people so obsessed with Cara Delevingne?
Why are people so obsessed with Cara Delevingne?
Why are people so obsessed with Cara Delevingne?
---------------------------------------------------------

Input sentenceWhy are people obsessed with Cara Delevingne?


Generating outputs:   0%|          | 0/1 [00:00<?, ?it/s]

---------------------------------------------------------
Why are people obsessed with Cara Delevingne?

Predictions >>>
Why are people obsessed with Cara Delevingne?
Why are people obsessed with Cara Delevingne?
Why are people obsessed with Cara Delevingne?
---------------------------------------------------------

Input sentenceEarl St Vincent was a British ship that was captured in 1803 and became a French trade man


Generating outputs:   0%|          | 0/1 [00:00<?, ?it/s]

---------------------------------------------------------
Earl St Vincent was a British ship that was captured in 1803 and became a French trade man

Predictions >>>
Earl St Vincent was a British ship captured in 1803 and became a French tradesman.
Earl St Vincent was a British ship that was captured in 1803 and became a French merchantman.
Earl St Vincent was a British ship captured in 1803 and became a French trade man.
---------------------------------------------------------

Input sentenceWorcester is a town and county city of Worcestershire in England.


Generating outputs:   0%|          | 0/1 [00:00<?, ?it/s]

---------------------------------------------------------
Worcester is a town and county city of Worcestershire in England.

Predictions >>>
Worcester is a town and county borough of Worcestershire in England.
Worcester is a town and county town of Worcestershire in England.
Worcester is a town and county borough of Worcestershire in England.
---------------------------------------------------------



KeyboardInterrupt: ignored

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f1dd7fc3110>> (for post_run_cell):


Exception: ignored