In [1]:
import pandas as pd
from tqdm import tqdm
import nltk
import numpy as np

In [2]:
haiku_path = '../input/haikus.csv'
seed_word_path = './seed_words.txt'

# eventually would be cool to load this from a config/hyperparam json file
hyperparams = {
    'preprocessing': {
        'max_line_len_quantile': 0.9,
        'min_line_syllables': 3
    },
    'model': {
        'latent_dim': 2048,
        'epochs': 20,
        'learning_rate': 0.01
    }
    
}


In [3]:
# DataFrame of the haiku dataset
# cols: [0, 1, 2, 0_syllables, 1_syllables, 2_syllables]
df = pd.read_csv(haiku_path)
# remove twaiku and gutenberg data because it's weird sometimes.
# df = df[df.source == 'twaiku' or df.source == 'gutenberg']
df = df.drop(columns='source')


In [4]:
# Duplicate lines with ambiguous syllable counts
# i.e. syllable counts with a comma because there exists multiple pronunciations
for i in range(3):
  col = '%s_syllables' % i
  df[col] = df[col].str.split(',', expand=False)
  df = df.explode(col)
  df[col] = df[col].astype(int)

In [5]:
# Drop samples that have lines longer than 90th percentile
# may want to modify this value and see how it changes
quantile = hyperparams['preprocessing']['max_line_len_quantile']
max_len = int(max([df[str(i)].str.len().quantile(quantile) for i in range(3)]))
df = df[
  (df['0'].str.len() < max_len) & 
  (df['1'].str.len() < max_len) & 
  (df['2'].str.len() < max_len)
]

In [6]:
# Drop samples that have lines longer than 90th percentile
# may want to modify this value and see how it changes
min_syll = hyperparams['preprocessing']['min_line_syllables']
df = df[
  (df['0_syllables'] >= min_syll) & 
  (df['1_syllables'] >= min_syll) & 
  (df['2_syllables'] >= min_syll)
]

In [7]:
seed_words = set()
with open(seed_word_path, 'r') as f:
  for line in f:
    seed_words.add(str(line.strip()))


In [8]:
labels = []

In [9]:
for i in tqdm(range(len(df))):
  seeds = []
  for j in range(3):
    line = str(df[str(j)].iloc[i])
    tokens = nltk.word_tokenize(line)
    for tok, pos in nltk.pos_tag(tokens):
      if pos != 'NN':
        continue
      tok = tok.lower()
      if tok in seed_words:
        seeds.append(tok)
  labels.append(','.join(seeds))

100%|██████████| 143715/143715 [20:05<00:00, 119.21it/s]


In [10]:
# add label column from labels list
df['label'] = labels
# remove unlabeled rows
df['label'].replace('', np.nan, inplace=True)
df.dropna(subset=['label'], inplace=True)
# expand rows with multiple labels
df['label'] = df['label'].str.split(',', expand=False)
df = df.explode('label')
print(df)


                          0                                1  \
1             spring rain -             as the doctor speaks   
1             spring rain -             as the doctor speaks   
3           sunny afternoon               an old man lingers   
5             quitting time                the smell of rain   
5             quitting time                the smell of rain   
...                     ...                              ...   
143120     Someone is upset     their team isn't winning the   
143122  The write in Mickey   Mouse probably could have been   
143122  The write in Mickey   Mouse probably could have been   
143124  People really don't       be having shit else better   
143127   Today is the first       day of the rest of my LIFE   

                            2  0_syllables  1_syllables  2_syllables  \
1           i think of lilacs            3            5            5   
1           i think of lilacs            3            5            5   
3            ne

In [11]:
df.to_csv('../input/labeled_haikus.csv')