In [1]:
import os
import pickle
import numpy as np
import pandas as pd
from nltk.probability import FreqDist

def get_word_dist(input_frame, trans_col):
    """
    get the log word frequency distrubtion on the input frame, 
    :param pandas.DataFrame input_frame: the input frame with transcripts to evalute
    :param str trans_col: the column name represents transcript text
    :return: the log frequency for the transcript tokens
    :rtype: dict
    """
    word_dist = load_word_dist()
    log_lf = {}
    trans = input_frame[trans_col].values.tolist()
    for tran in trans:
        try:
            words = tran.split(" ")
            for word in set(words):
                log_lf[word.lower()] = get_word_lf(word, word_dist)
        except AttributeError:
            continue
    log_lf = sorted(log_lf.items(), key=lambda item: item[1])
    return log_lf


def get_word_lf(token, word_dist):
    """
    get the log lexical frequency for a specifc word
    :param str token: the word for calcualtion
    :param dict word_dict: the dictionary for the word raw frequency
    :return: the log lexical frequency for the word
    :rtype: float
    """
    return np.log(word_dist[token.lower()])


def load_word_dist():
    """
    load Subtlex.US.text file for estimating word frequency distribution
    save it to local file
    return the word frequency distribution
    :rtype: nltk.probability.FreqDist
    """
    if os.path.exists("../../scripts/word_dist.pkl"):
        with open("../../scripts/word_dist.pkl", "rb") as f:
            word_dist = pickle.load(f)
        return word_dist
    else:
        word_dist = FreqDist()
        sys.stdout.write("estimating frequency distribution...\n")
        with open("../../scripts/Subtlex.US.text", "r") as c:
            lines = c.readlines()
            for ln in lines:
                words = ln.split()
                for word in words:
                    word_dist[word.lower()] += 1
        sys.stdout.write("done\n")
        with open("word_dist.pkl", "wb") as f:
            pickle.dump(word_dist, f)
        return word_dist

In [2]:
word_dist = load_word_dist()
share = [25, 50, 100]
train_frame = pd.read_csv("../../scripts/address_train.csv")
test_frame = pd.read_csv("../../scripts/address_test.csv")
train_lf = get_word_dist(train_frame, "text")
test_lf = get_word_dist(test_frame, "text")



In [8]:
# remove -inf log frequency words
train_lf = [item for item in train_lf if item[1] > 0]
test_lf = [item for item in test_lf if item[1] > 0]
mean_train = sum([item[1] for item in train_lf])/len(train_lf)
mean_test = sum([item[1] for item in test_lf])/len(test_lf)
#read generated text file
# onetime
layers = list(range(0, 12))
for s in share:
    print("share {}".format(s))
    print("mean log lexical frequency")
    print("="*10)
    print('{:7s} {:7s}  {:7s}'.format('train', "test", "impaired " + str(s)))
    for layer in layers:
        log_file = "../../scripts/logs/style_onetime_share_{}_layer_{}.log".format(s, layer)
        with open(log_file, "r") as f:
            text_lf = []
            text = f.read()
            tokens = text.split(" ")
            for token in tokens:
                text_lf.append(get_word_lf(token, word_dist))
            mean_text = sum(text_lf)/len(text_lf)
            print('{:.2f} {:7.2f}  {:7.2f}'.format(mean_train, mean_test, mean_text))
    print("-"*10)
    print("="*10)

share 25
mean log lexical frequency
train   test     impaired 25
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
----------
share 50
mean log lexical frequency
train   test     impaired 50
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
----------
share 100
mean log lexical frequency
train   test     impaired 100
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.6



In [9]:
layers = list(range(1, 13))
for s in share:
    print("share {}".format(s))
    print("mean log lexical frequency")
    print("="*10)
    print('{:7s} {:7s}  {:7s}'.format('train', "test", "impaired " + str(s)))
    for layer in layers:
        log_file = "../../scripts/logs/style_accumu_share_{}_layer_{}.log".format(s, layer)
        with open(log_file, "r") as f:
            text_lf = []
            text = f.read()
            tokens = text.split(" ")
            for token in tokens:
                text_lf.append(get_word_lf(token, word_dist))
            mean_text = sum(text_lf)/len(text_lf)
            print('{:.2f} {:7.2f}  {:7.2f}'.format(mean_train, mean_test, mean_text))
    print("-"*10)
    print("="*10)

share 25
mean log lexical frequency
train   test     impaired 25
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
----------
share 50
mean log lexical frequency
train   test     impaired 50
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
----------
share 100
mean log lexical frequency
train   test     impaired 100
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.62     -inf
8.36    8.6

