 **Few shot text generation with T5 Transformer**

## 1. Install libraries

In [2]:
!pip install transformers==2.9.0

Collecting transformers==2.9.0
  Downloading transformers-2.9.0-py3-none-any.whl (635 kB)
[K     |████████████████████████████████| 635 kB 5.2 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[K     |████████████████████████████████| 880 kB 40.6 MB/s 
[?25hCollecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 36.6 MB/s 
Collecting tokenizers==0.7.0
  Downloading tokenizers-0.7.0-cp37-cp37m-manylinux1_x86_64.whl (5.6 MB)
[K     |████████████████████████████████| 5.6 MB 23.9 MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.53-py3-none-any.whl size=895260 sha256=fd7d27da696047f351eee72cb7e75c26aaf0db548257122f3c09c8e67f66fb07
  Stored in directory: /root/.cache/pip/wheels/87/39/dd/a83eeef36d0bf98e7a4d1933a

In [3]:
# Check we have a GPU and check the memory size of the GUP
!nvidia-smi

NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.



## 2. Prepare Model

In [4]:

import random
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)

def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)

set_seed(42)

In [5]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')
t5_model = T5ForConditionalGeneration.from_pretrained('t5-base')


Downloading:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.20k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/892M [00:00<?, ?B/s]

In [6]:
# optimizer
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in t5_model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
    {
        "params": [p for n, p in t5_model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=3e-4, eps=1e-8)



In [8]:
import re, math, time, sys, copy, random, json
df=pd.read_csv("newcorp.csv")
text_list=[]
label_list=[]
def datasetmaker(x):
    x=json.loads(x)
    for i in range(0,len(x["paragraphs"])):
        for j in range(0,len(x["paragraphs"][i])):
            text_list.append(x["paragraphs"][i][j]["text"])
            label_list.append(x["paragraphs"][i][j]["label"])
df["article_segments"].apply(lambda x:datasetmaker(x))
df_main={"text":text_list,"label":label_list}
df_main=pd.DataFrame(df_main)
len_no_unit=df_main[df_main["label"]=="no-unit"]["text"].apply(lambda x:len(x.split(" ")))
df_main=df_main.loc[(df_main["label"]!="no-unit") | ((df_main["label"]=="no-unit") & (df_main["text"].map(len) > 1)) ]
df_main
def cleanData(df):
    df = df.replace(r'\n','', regex=True)
    df = df_main.replace('\"','', regex=True)
    df = df_main.replace('\/','', regex=True)
    return df

def getJsonFromFrame(df):
    df=cleanData(df)
    result = df.to_json(orient="records")
    df_parsed = json.loads(result)
    return df_parsed

df_main=cleanData(df_main)

r=getJsonFromFrame(df_main)
r



[{'label': 'title',
  'text': '2015: Beyond Obama, new Congress, we need a revival of the American spirit.'},
 {'label': 'anecdote',
  'text': 'In the film, "Girl Interrupted," Winona Ryder plays an 18-year-old who enters a mental institution for what is diagnosed as borderline personality disorder'},
 {'label': 'anecdote', 'text': 'The year is 1967'},
 {'label': 'no-unit', 'text': 'and'},
 {'label': 'anecdote',
  'text': 'the country is in turmoil over Vietnam and civil rights'},
 {'label': 'anecdote',
  'text': 'While lying on her bed one night and watching TV, she sees a news report about a demonstration'},
 {'label': 'anecdote',
  'text': 'The narrator says something that might apply to today\'s turmoil:"We live in a time of doubt. The institutions we once trusted no longer seem reliable."'},
 {'label': 'no-unit', 'text': 'As 2014 ends,'},
 {'label': 'statistics', 'text': 'the stock market is at record highs'},
 {'label': 'no-unit', 'text': 'but'},
 {'label': 'assumption',
  'text'

In [9]:
def generate_prompt_and_completion(data,isTrain,prompt_text, prompt_end = "####", completion_end = "<|endoftext|>"):    
    ret_dict = {}
    ret_dict['prompt'] = data['text']  ## 
    #if isTrain:
    completion_list = ''
    completion_list=f'''{data['label']}'''  #{sentence['sentence']}
    #print(completion_list)
    # f'''<{sentence['sentiment']}> {sentence['sentence']}'''
    ret_dict['completion'] = completion_list

    return ret_dict

In [10]:
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
le = preprocessing.LabelEncoder()
#df_main["label"]=le.fit_transform(df_main["label"])
X=df_main["text"]
y=df_main["label"]
X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.20, random_state=42,stratify=y)
Train=pd.concat([X_train,y_train],axis=1)
Test=pd.concat([X_test,y_test],axis=1)
trn=getJsonFromFrame(Train)
tst=getJsonFromFrame(Test)
tst

[{'label': 'title',
  'text': '2015: Beyond Obama, new Congress, we need a revival of the American spirit.'},
 {'label': 'anecdote',
  'text': 'In the film, "Girl Interrupted," Winona Ryder plays an 18-year-old who enters a mental institution for what is diagnosed as borderline personality disorder'},
 {'label': 'anecdote', 'text': 'The year is 1967'},
 {'label': 'no-unit', 'text': 'and'},
 {'label': 'anecdote',
  'text': 'the country is in turmoil over Vietnam and civil rights'},
 {'label': 'anecdote',
  'text': 'While lying on her bed one night and watching TV, she sees a news report about a demonstration'},
 {'label': 'anecdote',
  'text': 'The narrator says something that might apply to today\'s turmoil:"We live in a time of doubt. The institutions we once trusted no longer seem reliable."'},
 {'label': 'no-unit', 'text': 'As 2014 ends,'},
 {'label': 'statistics', 'text': 'the stock market is at record highs'},
 {'label': 'no-unit', 'text': 'but'},
 {'label': 'assumption',
  'text'

In [11]:
training_list = []
to_predict_list = []
fixed_propmt=""#" This is "
trn=getJsonFromFrame(Train)
tst=getJsonFromFrame(Test)
for item in trn:
    training_list.append(generate_prompt_and_completion(item,True,fixed_propmt))
for item in tst:
    to_predict_list.append(generate_prompt_and_completion(item,False,fixed_propmt))

In [12]:
def parse_completion(completion):
    compensation_sentences = re.findall(r'<(.*?)> ', completion) ## Our regular expression here is simple because we designed it conveniently!

    completion_dict = {
        'sentences' : compensation_sentences,
        'num_title' : 0,
        'num_anecdote' : 0,
        'num_assumption' : 0,
        'num_nounit' : 0,
        'num_statistics' : 0,
        'num_testimony' : 0,
        'num_commonground' : 0,
        'num_other' : 0
        #'num_assumption' : len(compensation_sentences)
    }

    for sen in compensation_sentences:
        print(sen)
        if sen == 'title':
            completion_dict['num_title'] += 1
        elif sen == 'anecdote':
            completion_dict['num_anecdote'] += 1
        elif sen == 'assumption':
            completion_dict['num_assumption']  += 1
        elif sen == 'no-unit':
            completion_dict['num_nounit']  += 1
        elif sen == 'statistics':
            completion_dict['num_statistics']  += 1 
        elif sen == 'testimony':
            completion_dict['num_testimony']  += 1 
        elif sen == 'common-ground':
            completion_dict['num_commonground']  += 1 
        elif sen == 'other':
            completion_dict['num_other']  += 1                    

    return completion_dict

In [13]:
trnList=[]
for training_item in training_list[:100]:
    trnList.append((training_item['prompt'],training_item['completion']))
#print(few_shot_prompt)
trnList

[('2015: Beyond Obama, new Congress, we need a revival of the American spirit.',
  'title'),
 ('In the film, "Girl Interrupted," Winona Ryder plays an 18-year-old who enters a mental institution for what is diagnosed as borderline personality disorder',
  'anecdote'),
 ('The year is 1967', 'anecdote'),
 ('and', 'no-unit'),
 ('the country is in turmoil over Vietnam and civil rights', 'anecdote'),
 ('While lying on her bed one night and watching TV, she sees a news report about a demonstration',
  'anecdote'),
 ('The narrator says something that might apply to today\'s turmoil:"We live in a time of doubt. The institutions we once trusted no longer seem reliable."',
  'anecdote'),
 ('As 2014 ends,', 'no-unit'),
 ('the stock market is at record highs', 'statistics'),
 ('but', 'no-unit'),
 ('our traditional institutions and self-confidence are in decline',
  'assumption'),
 ('A Pew Research Center study confirms one trend that has been obvious over several years',
  'testimony'),
 ('The "ty

In [23]:
to_predict_list

[{'completion': 'title',
  'prompt': '2015: Beyond Obama, new Congress, we need a revival of the American spirit.'},
 {'completion': 'anecdote',
  'prompt': 'In the film, "Girl Interrupted," Winona Ryder plays an 18-year-old who enters a mental institution for what is diagnosed as borderline personality disorder'},
 {'completion': 'anecdote', 'prompt': 'The year is 1967'},
 {'completion': 'no-unit', 'prompt': 'and'},
 {'completion': 'anecdote',
  'prompt': 'the country is in turmoil over Vietnam and civil rights'},
 {'completion': 'anecdote',
  'prompt': 'While lying on her bed one night and watching TV, she sees a news report about a demonstration'},
 {'completion': 'anecdote',
  'prompt': 'The narrator says something that might apply to today\'s turmoil:"We live in a time of doubt. The institutions we once trusted no longer seem reliable."'},
 {'completion': 'no-unit', 'prompt': 'As 2014 ends,'},
 {'completion': 'statistics', 'prompt': 'the stock market is at record highs'},
 {'compl

In [14]:
for input,output in trnList:
  print(output)

title
anecdote
anecdote
no-unit
anecdote
anecdote
anecdote
no-unit
statistics
no-unit
assumption
testimony
assumption
statistics
statistics
testimony
assumption
statistics
assumption
assumption
assumption
assumption
assumption
assumption
common-ground
anecdote
testimony
testimony
assumption
testimony
assumption
testimony
testimony
assumption
testimony
no-unit
testimony
assumption
assumption
assumption
no-unit
assumption
title
anecdote
testimony
testimony
testimony
testimony
assumption
common-ground
common-ground
no-unit
assumption
assumption
assumption
assumption
assumption
assumption
no-unit
assumption
common-ground
no-unit
common-ground
no-unit
common-ground
no-unit
common-ground
common-ground
common-ground
common-ground
common-ground
assumption
assumption
no-unit
other
assumption
assumption
no-unit
assumption
assumption
common-ground
other
assumption
assumption
assumption
assumption
other
assumption
assumption
assumption
assumption
assumption
assumption
assumption
title
testimony
ot

In [25]:
tstList=[]
for item in to_predict_list[:100]:
    tstList.append((item['prompt'],item['completion']))
#print(few_shot_prompt)
tstList

[('2015: Beyond Obama, new Congress, we need a revival of the American spirit.',
  'title'),
 ('In the film, "Girl Interrupted," Winona Ryder plays an 18-year-old who enters a mental institution for what is diagnosed as borderline personality disorder',
  'anecdote'),
 ('The year is 1967', 'anecdote'),
 ('and', 'no-unit'),
 ('the country is in turmoil over Vietnam and civil rights', 'anecdote'),
 ('While lying on her bed one night and watching TV, she sees a news report about a demonstration',
  'anecdote'),
 ('The narrator says something that might apply to today\'s turmoil:"We live in a time of doubt. The institutions we once trusted no longer seem reliable."',
  'anecdote'),
 ('As 2014 ends,', 'no-unit'),
 ('the stock market is at record highs', 'statistics'),
 ('but', 'no-unit'),
 ('our traditional institutions and self-confidence are in decline',
  'assumption'),
 ('A Pew Research Center study confirms one trend that has been obvious over several years',
  'testimony'),
 ('The "ty

## 3. Train Loop

In [15]:
t5_model.train()

epochs = 2

for epoch in range(epochs):
  print ("epoch ",epoch)
  for input,output in trnList:
    input_sent = input+ " </s>"
    ouput_sent = output+" </s>"
    print(input_sent)
    tokenized_inp = tokenizer.encode_plus(input_sent,  max_length=96, pad_to_max_length=True,return_tensors="pt")
    tokenized_output = tokenizer.encode_plus(ouput_sent, max_length=96, pad_to_max_length=True,return_tensors="pt")


    input_ids  = tokenized_inp["input_ids"]
    attention_mask = tokenized_inp["attention_mask"]

    lm_labels= tokenized_output["input_ids"]
    decoder_attention_mask=  tokenized_output["attention_mask"]


    # the forward function automatically creates the correct decoder_input_ids
    output = t5_model(input_ids=input_ids, lm_labels=lm_labels,decoder_attention_mask=decoder_attention_mask,attention_mask=attention_mask)
    loss = output[0]

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()




epoch  0
2015: Beyond Obama, new Congress, we need a revival of the American spirit. </s>


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1055.)
  exp_avg.mul_(beta1).add_(1.0 - beta1, grad)


In the film, "Girl Interrupted," Winona Ryder plays an 18-year-old who enters a mental institution for what is diagnosed as borderline personality disorder </s>
The year is 1967 </s>
and </s>
the country is in turmoil over Vietnam and civil rights </s>
While lying on her bed one night and watching TV, she sees a news report about a demonstration </s>
The narrator says something that might apply to today's turmoil:"We live in a time of doubt. The institutions we once trusted no longer seem reliable." </s>
As 2014 ends, </s>
the stock market is at record highs </s>
but </s>
our traditional institutions and self-confidence are in decline </s>
A Pew Research Center study confirms one trend that has been obvious over several years </s>
The "typical" American family is no longer typical </s>
Just 46 percent of American children now live in homes with their married, heterosexual parents </s>
Five percent have no parents at home </s>
They most likely are living with grandparents, says the stud

In [27]:
tstt=tstList[:25]
tstt

[('2015: Beyond Obama, new Congress, we need a revival of the American spirit.',
  'title'),
 ('In the film, "Girl Interrupted," Winona Ryder plays an 18-year-old who enters a mental institution for what is diagnosed as borderline personality disorder',
  'anecdote'),
 ('The year is 1967', 'anecdote'),
 ('and', 'no-unit'),
 ('the country is in turmoil over Vietnam and civil rights', 'anecdote'),
 ('While lying on her bed one night and watching TV, she sees a news report about a demonstration',
  'anecdote'),
 ('The narrator says something that might apply to today\'s turmoil:"We live in a time of doubt. The institutions we once trusted no longer seem reliable."',
  'anecdote'),
 ('As 2014 ends,', 'no-unit'),
 ('the stock market is at record highs', 'statistics'),
 ('but', 'no-unit'),
 ('our traditional institutions and self-confidence are in decline',
  'assumption'),
 ('A Pew Research Center study confirms one trend that has been obvious over several years',
  'testimony'),
 ('The "ty

In [32]:
actual=[completion for prompt,completion in tstList[:25]]
actual

['title',
 'anecdote',
 'anecdote',
 'no-unit',
 'anecdote',
 'anecdote',
 'anecdote',
 'no-unit',
 'statistics',
 'no-unit',
 'assumption',
 'testimony',
 'assumption',
 'statistics',
 'statistics',
 'testimony',
 'assumption',
 'statistics',
 'assumption',
 'assumption',
 'assumption',
 'assumption',
 'assumption',
 'assumption',
 'common-ground']

## 4. Test model

In [33]:
pred=[]
actual=[completion for prompt,completion in tstList[:25]]
for prompt,completion in tstList[:25]:
  test_sent = prompt+' </s>'
  test_tokenized = tokenizer.encode_plus(test_sent, return_tensors="pt")

  test_input_ids  = test_tokenized["input_ids"]
  test_attention_mask = test_tokenized["attention_mask"]

  t5_model.eval()
  beam_outputs = t5_model.generate(
      input_ids=test_input_ids,attention_mask=test_attention_mask,
      max_length=64,
      early_stopping=True,
      num_beams=10,
      num_return_sequences=1,
      no_repeat_ngram_size=2
  )

  for beam_output in beam_outputs:
      sent = tokenizer.decode(beam_output, skip_special_tokens=True,clean_up_tokenization_spaces=True)
      print (sent)
      pred.append(sent)    
   

  beam_id = beam_token_id // vocab_size


we need a revival of the American spirit
<extra_id_0> In the film, "Girl Interrupted," Winona Ryder plays an 18-year-old who enters a mental institution for what is diagnosed as borderline personality disorder.
anecdote
l'équilibre et à votre ordinateur.
the country is in turmoil over Vietnam and civil rights are at stake.
anecdote
the narrator: "we live in a time of doubt.
anecdote
anecdote
<extra_id_0>
our traditional institutions and self-confidence
assumption
assumption
just 46 percent of American children now live in homes with their married, heterosexual parents.
five percent have no parents at home.
the study
assumption
73 percent of American children lived in traditional families.
anecdote
it was that generation that promoted cohabitation, no-fault divorce, hatred of the police(they called them "pigs") and disdain for the military and America, spawned not just by the Vietnam War but a life of relative ease unknown to their parents.
anecdote
the two-plus generations born since t

In [1]:
nactual=[]
for item in actual:
  item.str.strip()
  item.str.replace('<','')
  item.replace(">","")
  nactual.append(item)

nactual  


NameError: ignored

In [20]:
pred

['we need a revival of the American spirit',
 '<extra_id_0> In the film, "Girl Interrupted," Winona Ryder plays an 18-year-old who enters a mental institution for what is diagnosed as borderline personality disorder.',
 'anecdote',
 "l'équilibre et à votre ordinateur.",
 'the country is in turmoil over Vietnam and civil rights are at stake.',
 'anecdote',
 'the narrator: "we live in a time of doubt.',
 'anecdote',
 'anecdote',
 '<extra_id_0>',
 'our traditional institutions and self-confidence',
 'assumption',
 'assumption',
 'just 46 percent of American children now live in homes with their married, heterosexual parents.',
 'five percent have no parents at home.',
 'the study',
 'assumption',
 '73 percent of American children lived in traditional families.',
 'anecdote',
 'it was that generation that promoted cohabitation, no-fault divorce, hatred of the police(they called them "pigs") and disdain for the military and America, spawned not just by the Vietnam War but a life of relative

In [43]:
from sklearn.metrics import classification_report
print(classification_report(actual, pred))

                  precision    recall  f1-score   support

      <anecdote>       0.00      0.00      0.00       5.0
    <assumption>       0.00      0.00      0.00       9.0
 <common-ground>       0.00      0.00      0.00       1.0
       <no-unit>       0.00      0.00      0.00       3.0
    <statistics>       0.00      0.00      0.00       4.0
     <testimony>       0.00      0.00      0.00       2.0
         <title>       0.00      0.00      0.00       1.0
       anecdote>       0.00      0.00      0.00       0.0
     assumption>       0.00      0.00      0.00       0.0
  common-ground>       0.00      0.00      0.00       0.0
        no-unit>       0.00      0.00      0.00       0.0
     statistics>       0.00      0.00      0.00       0.0
      testimony>       0.00      0.00      0.00       0.0
          title>       0.00      0.00      0.00       0.0

        accuracy                           0.00      25.0
       macro avg       0.00      0.00      0.00      25.0
    weighted

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
