In [1]:
import numpy as np
from datasets import load_dataset
import matplotlib.pyplot as plt
import pandas as pd
from transformers import AutoTokenizer, DataCollatorForSeq2Seq

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = "reddit_tifu"
seed_num = 1
model_name = "google-t5/t5-small"

In [8]:
loaded_dataset = load_dataset(dataset, 'long')

Generating train split: 100%|██████████| 42139/42139 [00:06<00:00, 6924.14 examples/s]


In [9]:
loaded_dataset
# make the dataset into a pandas dataframe
# df = pd.DataFrame(loaded_dataset['train'])
# # add the test dataset to the dataframe
# df = pd.concat([df, pd.DataFrame(loaded_dataset['test'])], ignore_index=True)

DatasetDict({
    train: Dataset({
        features: ['ups', 'num_comments', 'upvote_ratio', 'score', 'documents', 'tldr', 'title'],
        num_rows: 42139
    })
})

In [6]:
# Tokenize the summary column
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [10]:
prefix = "summarize: "  # Required so the T5 model knows that we are going to summarize
def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["documents"]]
    model_inputs = tokenizer(inputs)
    labels = tokenizer(text_target=examples["tldr"])
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model_name)
tokenized_dataset = loaded_dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/42139 [00:00<?, ? examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (892 > 512). Running this sequence through the model will result in indexing errors
Map: 100%|██████████| 42139/42139 [00:16<00:00, 2525.35 examples/s]


In [11]:
# Make the dataset into a Dataframe
df = pd.DataFrame(tokenized_dataset['train'])
df.tail()

Unnamed: 0,ups,num_comments,upvote_ratio,score,documents,tldr,title,input_ids,attention_mask,labels
42134,105.0,18.0,0.94,105.0,this happened back in middle school.\n\nmy fam...,forgot my quarter for lunch at school for a we...,forgetting my quarter for lunch,"[21603, 10, 48, 2817, 223, 16, 2214, 496, 5, 8...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[15687, 82, 2893, 21, 3074, 44, 496, 21, 3, 9,..."
42135,96.0,64.0,0.92,96.0,my girlfriend told me she has no hair beneath ...,girlfriend prefers clean shaven groin. i try t...,trying to shave my pubes for the first time,"[21603, 10, 82, 17442, 1219, 140, 255, 65, 150...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[17442, 2396, 7, 1349, 3, 7, 7965, 29, 3, 3844..."
42136,726.0,110.0,0.91,726.0,today at work i accidentally crashed a row of ...,today i broke a window that costs more then i ...,breaking a $900 window with a shopping cart.,"[21603, 10, 469, 44, 161, 3, 23, 21169, 24679,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[469, 3, 23, 8238, 3, 9, 2034, 24, 1358, 72, 2..."
42137,26.0,5.0,0.77,26.0,so as u can tell from the title it didn't happ...,i invited over new girlfriend for dinner to sp...,slicing open my finger on first valentines wit...,"[21603, 10, 78, 38, 3, 76, 54, 817, 45, 8, 223...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 23, 5374, 147, 126, 17442, 21, 2634, 12, 1..."
42138,15.0,11.0,0.81,15.0,this did actually happen today. it started aft...,in a rush i mixed my colours in the wash and e...,not listening to my mother,"[21603, 10, 48, 410, 700, 1837, 469, 5, 34, 70...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[16, 3, 9, 10505, 3, 23, 4838, 82, 6548, 16, 8..."


In [12]:
from pprint import pprint
pprint(df['documents'][0])

('this actually happened a couple of years ago. i grew up in germany where i '
 'went to a german secondary school that went from 5th to 13th grade (we still '
 'had 13 grades then, they have since changed that). my school was named after '
 'anne frank and we had a club that i was very active in from 9th grade on, '
 'which was dedicated to teaching incoming 5th graders about anne franks life, '
 'discrimination, anti-semitism, hitler, the third reich and that whole spiel. '
 "basically a day where the students' classes are cancelled and instead we "
 'give them an interactive history and social studies class with lots of '
 'activities and games. \n'
 '\n'
 'this was my last year at school and i already had a lot of experience doing '
 'these project days with the kids. i was running the thing with a friend, so '
 'it was just the two of us and 30-something 5th graders. we start off with a '
 'brief introduction and brainstorming: what do they know about anne frank and '
 "the third 

In [13]:
# Give me the percentiles of length of input_ids using pandas and plot them
df['input_ids'].apply(len).describe(percentiles=[0.25, 0.5, 0.75, 0.9, 0.95, 0.99])

count    42139.000000
mean       541.585443
std        398.646144
min          5.000000
25%        291.000000
50%        439.000000
75%        670.000000
90%        983.000000
95%       1249.000000
99%       2028.000000
max       8989.000000
Name: input_ids, dtype: float64

In [14]:
# The same for the labels
df['labels'].apply(len).describe(percentiles=[0.25, 0.5, 0.75, 0.9, 0.95, 0.99])

count    42139.000000
mean        30.848098
std         18.103418
min          2.000000
25%         19.000000
50%         27.000000
75%         38.000000
90%         52.000000
95%         62.000000
99%         90.000000
max        621.000000
Name: labels, dtype: float64