In [1]:
import nltk
nltk.download('stopwords')
nltk.download('punkt')
nltk.download('wordnet')
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk import ngrams

import pandas as pd
import numpy as np
from joblib import Parallel, delayed
import lightgbm as lgb

from pathlib import Path
from pprint import pprint
from typing import Optional, Union, List, Tuple, Dict, Set

[nltk_data] Downloading package stopwords to /home/chris/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /home/chris/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /home/chris/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
data_directory = '../data'
prompt_inputs = 'prompts_test.csv'
summaries_input = 'summaries_test.csv'
submission_csv = 'submission.csv'

content_model_file = 'models/content_lgb_model.txt'
wording_model_file = 'models/wording_lgb_model.txt'

In [3]:
data_path = Path(data_directory).absolute()

prompt_input_path = data_path / prompt_inputs
summaries_input_path = data_path / summaries_input
submission_csv_path = data_path / submission_csv
content_model_path = data_path / content_model_file
wording_model_path = data_path / wording_model_file

In [4]:
def make_split(summaries_path: Path, prompts_path: Path, dtype_backend: Optional[str] = 'pyarrow') -> pd.DataFrame:
    summaries_df = pd.read_csv(summaries_path, dtype_backend=dtype_backend)
    prompts_df = pd.read_csv(prompts_path, dtype_backend=dtype_backend)
    df = pd.merge(summaries_df, prompts_df, how='inner', on='prompt_id')

    if len(df) != len(summaries_df):
        raise AssertionError('Could not match all prompt ids to a prompt')
    
    return df

stop_words = set(stopwords.words('english'))
lemmatiser = WordNetLemmatizer()

def clear_stopwords(column: pd.Series, idx: int) -> Union[List[str], List[str], List[str]]:
    tokens = [tok.lower() for tok in word_tokenize(column.iloc[idx]) if tok.isalnum()]
    cleared_stopwords = [tok for tok in tokens if tok not in stop_words]
    lemmas = [lemmatiser.lemmatize(tok) for tok in cleared_stopwords]
    bigram = set(ngrams(lemmas, 2))

    return tokens, cleared_stopwords, lemmas, bigram

def nlp_splits(df: pd.DataFrame, column: str) -> None:
    output = Parallel(n_jobs=4)(delayed(clear_stopwords)(df[column], idx) for idx in range(len(df)))

    df[f'{column}_tokens'] = [part[0] for part in output]
    df[f'{column}_no_stopwords'] = [part[1] for part in output]
    df[f'{column}_lemmas'] = [part[2] for part in output]
    df[f'{column}_bigram'] = [part[3] for part in output]

In [5]:
def get_bigrams(text: str) -> Set[str]:
    tokens = [tok.lower() for tok in word_tokenize(text) if tok.isalnum()]
    cleared_stopwords = [tok for tok in tokens if tok not in stop_words]
    lemmas = [lemmatiser.lemmatize(tok) for tok in cleared_stopwords]
    bigram = set(ngrams(lemmas, 2))
    return bigram

In [6]:
df = make_split(summaries_input_path, prompt_input_path)

# Make n-grams for all text columns
text_columns = ['prompt_title', 'prompt_question', 'prompt_text', 'text']
for column in text_columns:
    df[f'{column}_bigram'] = df[column].apply(get_bigrams) # TODO: parallelise
    df[f'{column}_unique_bigrams'] = df[f'{column}_bigram'].str.len()

df['text_bigram_overlap'] = df[['prompt_text_bigram', 'text_bigram']].apply(lambda row: len(row[0] & row[1]), axis=1) / df.text_unique_bigrams

In [7]:
content_lgb_model = lgb.Booster(model_file=content_model_path)
wording_lgb_model = lgb.Booster(model_file=wording_model_path)

In [8]:
def predict(model: lgb.Booster, df: pd.DataFrame, features: List[str]) -> pd.Series:
    return model.predict(df[features])

In [9]:
features = ['text_bigram_overlap', 'text_unique_bigrams']
df['content'] = predict(content_lgb_model, df, features)
df['wording'] = predict(wording_lgb_model, df, features)

submission_df = df[['student_id', 'content', 'wording']]
submission_df.to_csv(submission_csv_path, index=False)