**Company** : <br>
Design Firm

**Notebook Function** : <br>
    This notebook walks through the steps to finetune mittens by starting with the initial company  embeddings that were trained from the Design firm corpus

**Output File(s)** : <br>
    hash_embeddings_50_all - A folder containing the user embeddings for each time period

**Author(s)** : <br>
Lara Yang, Sarayu Anshuman

Install packages and import libraries

In [None]:
pip install mittens

In [None]:
pip install gensim

In [None]:
import os
import sys
from collections import defaultdict
from datetime import datetime
import pandas as pd
import numpy as np
from mittens import Mittens
import csv
from operator import itemgetter
import json
import ujson
import re
from gensim.matutils import cossim
from gensim.test.utils import datapath, get_tmpfile
from gensim.models import KeyedVectors
from gensim.scripts.glove2word2vec import glove2word2vec
from statistics import mean 
import multiprocessing
#set the column name for the output data frame
year_colname, quarter_colname, yearmonth_colname = 'year', 'quarter', 'yearmonth'
num_cores = 12
import re
import os
import json
import sys
import multiprocessing
import shutil
from collections import defaultdict
from datetime import datetime
import pandas as pd
import numpy as np
from mittens import Mittens
import csv
#from mittens_utils import *
import random
#from hash_count import cheap_hash

Run the following helper functions 

In [None]:
#########################################################################
########### Helper Functions for Generating Mittens Embeddings ##########
#########################################################################

def _window_based_iterator(toks, window_size, weighting_function):
    for i, w in enumerate(toks):
        yield w, w, 1
        left = max([0, i-window_size])
        for x in range(left, i):
            yield w, toks[x],weighting_function(x)
        right = min([i+1+window_size, len(toks)])
        for x in range(i+1, right):
            yield w, toks[x], weighting_function(x)
    return

def glove2dict(glove_filename):
    with open(glove_filename) as f:
        reader = csv.reader(f, delimiter=' ', quoting=csv.QUOTE_NONE)
        data = {line[0]: np.array(list(map(float, line[1: ]))) for line in reader}
    return data

# Inspired by original build_weighted_matrix in utils.py in the Mittens paper source codebase
def build_weighted_matrix(files,
        mincount=300, vocab_size=None, window_size=10,
        weighting_function=lambda x: 1 / (x + 1), load_to_mem=False, parallel=False, verbose=False,
        internal_only=False):
    """
    Builds a count matrix based on a co-occurrence window of
    `window_size` elements before and `window_size` elements after the
    focal word, where the counts are weighted based on proximity to the
    focal word.
    Parameters
    ----------
    files : list of strings
        Filenames where documents in a corpus are located
    corpus : iterable of str
        Texts to tokenize.
    mincount : int
        Only words with at least this many tokens will be included.
    vocab_size : int or None
        If this is an int above 0, then, the top `vocab_size` words
        by frequency are included in the matrix, and `mincount`
        is ignored.
    window_size : int
        Size of the window before and after. (So the total window size
        is 2 times this value, with the focal word at the center.)
    weighting_function : function from ints to floats
        How to weight counts based on distance. The default is 1/d
        where d is the distance in words.
    load_to_mem : bool, optional
        Whether to load entire corpus to memory
    parallel : bool, optional
        Whether to process emails in parallel. Only makes sense if loading to memory.
    internal_only : bool, optional
        Whether to only include internal emails - emails whose recipients only include company employees
    Returns
    -------
    pd.DataFrame
        This is guaranteed to be a symmetric matrix, because of the
        way the counts are collected.
    """
    wc = defaultdict(int)
    # This variable will be a list when load_to_mem is true and a generator when load_to_mem is false
    corpus = None
    if load_to_mem:
        if parallel:
            pool = multiprocessing.Pool(processes = num_cores)
            results = [pool.apply_async(read_corpus_parallel, args=(f, internal_only, )) for f in files]
            corpus = [r.get() for r in results if r.get()]
            pool.close()
            pool.join()
        else:
            # previously used split(' ') in 2yp
            corpus = [text.split() for text in read_corpus(files, internal_only)]
        for toks in corpus:
            for tok in toks:
                wc[tok] += 1
        sys.stderr.write('\n')
    else:
        corpus = read_corpus(files, internal_only)
        for text in corpus:
            for tok in text.split():
                wc[tok] += 1
    if verbose: sys.stderr.write('Finished counting %d unique words in corpus at %s.\n' % (len(wc), datetime.now()))
    if vocab_size:
        srt = sorted(wc.items(), key=itemgetter(1), reverse=True)
        vocab_set = {w for w, c in srt[: vocab_size]}
    else:
        vocab_set = {w for w, c in wc.items() if c >= mincount}
    vocab = sorted(vocab_set)
    n_words = len(vocab)
    if verbose: sys.stderr.write('Finished generating vocab of %d words at %s.\n' % (n_words, datetime.now()))
    # Weighted counts:
    # Generator function needs to be re-initated
    if not load_to_mem: corpus = read_corpus(files, internal_only)
    counts = defaultdict(float)
    for toks in corpus:
        if not load_to_mem: toks = toks.split()
        window_iter = _window_based_iterator(toks, window_size, weighting_function)
        for w, w_c, val in window_iter:
            if w in vocab_set and w_c in vocab_set:
                counts[(w, w_c)] += val
    if verbose: sys.stderr.write('Finished counting co-occurrences across %d word pairs at %s.\n' % (len(counts), datetime.now()))
    X = np.zeros((n_words, n_words))
    for i, w1 in enumerate(vocab):
        for j, w2 in enumerate(vocab):
            X[i, j] = counts[(w1, w2)]
    if verbose: sys.stderr.write('Finished converting co-occurrences to sorted matrix of shape %s at %s.\n' % (str(X.shape), datetime.now()))
    X = pd.DataFrame(X, columns=vocab, index=pd.Index(vocab))
    return X

# this method is currently unused as it was used for proof-of-concept during macro lunch,
# where only employees with the top 200 vocab counts are included
def select_usr_corpus(inpath, num_usrs_to_process):
    infile = pd.read_csv(inpath)
    usr2num_tokens = infile.set_index(["usr"]).dropna().to_dict()["num_tokens"]
    usrs_to_analyze = [usr for usr, num_tokens in sorted(usr2num_tokens.items(), key=lambda item: item[1])]
    usrs_to_analyze = usrs_to_analyze[-num_usrs_to_process:]
    return usrs_to_analyze  

# The function currently ignores sentence delimiters and considers co-occurrences across sentence boundaries to be
# co-occurrences as well. If desired, we can vary yield behavior based on the sentence_delim boolean to yield every
# sentence instead of every message as it currently does. However, if we were to yield every sentence, we need to make
# sure to not simply output the body but to yield every new line in the body individually.
def read_corpus(files, internal_only, sentence_delim=False):
    for file in files:
        with open(file, errors='ignore', encoding='utf-8') as infile:
            try:
                msg = ujson.load(infile)
                if internal_only and not is_internal_msg(msg):
                    continue
                # This line ignores all sentence structures as we are using \n to represent sentence splits
                body = msg['hashed-body'].replace('\n', ' ').strip()
                if len(body) > 0:
                    yield body
                else:
                    continue
            except (ValueError, json.decoder.JSONDecodeError) as error:
                continue
    return

def read_corpus_parallel(f, internal_only, sentence_delim=False):
    with open(f, errors='ignore', encoding='utf-8') as infile:
        try:
            msg = ujson.load(infile)
            if internal_only and not is_internal_msg(msg):
                return None
            # This line ignores all sentence structures as we are using \n to represent sentence splits
            body = msg['hashed-body'].replace('\n', ' ').strip()
            if len(body) > 0:
                return body.split()
            else:
                return None
        except (ValueError, json.decoder.JSONDecodeError) as error:
            return None

def output_embeddings(mittens_df, filename, compress=False):
    """
    Writes embeddings in the format of pd.DataFrame to a text file, replacing any existing file with the
    same name
    """
    if compress:
        mittens_df.to_csv(filename + '.gz', quoting=csv.QUOTE_NONE, header=False, sep=" ", compression='gzip')
    else:
        mittens_df.to_csv(filename, quoting=csv.QUOTE_NONE, header=False, sep=" ")
    return

#########################################################################
############# Helper Functions for Working with JSON Emails #############
#########################################################################
# Same as get_recipient function in Analysis/acculturaltion/lingdistance/jensen_shannon.py
def get_recipients(msg):
    sender = msg['from'][0] if type(msg['from']) == list else msg['from']
    rec = set(msg['to'] + msg['cc'] + msg['bcc']) - set([sender])
    return rec

def is_internal_msg(msg):
    """
    Determines whether msg is an internal message. An internal message is a message that is only sent to internal employees,
    whose anon IDs start with 'U'. An individual who is not an internal employee does not have a domain associated with design and
    their anon IDs start with 'E'.
    """
    for r in get_recipients(msg):
        if re.match('E---', r):
            return False
    return True

def slice_user_corpus(files, train_mode):
    timekey2files = defaultdict(list)
    #the function iterates through each email in the email list
    for file in files:
        with open(file, errors='ignore') as infile:
            try:
                msg = ujson.load(infile)
                if train_mode == 'annual':
                    timekey2files[to_year(msg['date'], format='str')].append(file)
                elif train_mode == 'quarterly':
                    timekey2files[to_quarter(msg['date'], format='str')].append(file)
                elif train_mode == 'monthly':
                    timekey2files[to_yearmonth(msg['date'], format='str')].append(file)
                elif train_mode == 'halfyear':
                    timekey2files[to_halfyear(msg['date'], format='str')].append(file)
                elif train_mode == 'all':
                    timekey2files[to_year(msg['date'], format='str')].append(file)
                    timekey2files[to_quarter(msg['date'], format='str')].append(file)
                    timekey2files[to_halfyear(msg['date'], format='str')].append(file)
                    #timekey2files[to_yearmonth(msg['date'], format='str')].append(file) #removed as it leads to too many output files
            except (ValueError, json.decoder.JSONDecodeError) as error:
                continue
    return timekey2files

def read_message(files):
    errors = 0
    for file in files:
        with open(file, errors='ignore') as infile:
            try:
                msg = ujson.load(infile)
                yield msg
            except (ValueError, json.decoder.JSONDecodeError) as error:
                errors += 1
                continue
    sys.stderr.write("{} files produced errors out of a total of {} files.\n".format(errors, len(files)))
    return

#########################################################################
############# Helper Functions for Working with Date Objects ############
#########################################################################

def to_yearmonth(date, format):
    if format == 'str':
        return date[0:7]
    elif format == 'datetime':
        return date.strftime('%Y-%m')

def to_quarter(date, format):
    year, month = 0, 0
    if format == 'str':
        year = date[0:4]
        month = date[5:7]    
    elif format == 'datetime':
        year = date.year
        month = date.month
    quarter = ((int(month)-1) // 3) + 1
    timekey = str(year) + 'Q' + str(quarter)
    return timekey

def to_halfyear(date, format):
    """
    Return half year of date in string
    """
    year, month = 0, 0
    if format == 'str':
        year = date[0:4]
        month = date[5:7]    
    elif format == 'datetime':
        year = date.year
        month = date.month
    halfyear = ((int(month)-1) // 6) + 1
    timekey = str(year) + 'HY' + str(halfyear)
    return timekey

def to_year(date, format):
    if format == 'str':
        return date[0:4]
    elif format == 'datetime':
        return str(date.year)

def datetime_to_timekey(date, time_key):
    if time_key == 'year':
        return to_year(date, format='datetime')
    elif time_key == 'quarter':
        return to_quarter(date, format='datetime')
    elif time_key == 'yearmonth':
        return to_yearmonth(date, format='datetime')

def is_month_before_equal(datetime1, datetime2):
    if datetime1.year < datetime2.year:
        return 1
    elif (datetime1.year == datetime2.year) and (datetime1.month <= datetime2.month):
        return 1
    else:
        return 0

def num_months_between_dates(datetime1, datetime2):
    return abs((datetime1.year - datetime2.year) * 12 + datetime1.month - datetime2.month)

def num_quarters_between_dates(datetime1, datetime2):
    return abs((datetime1.year - datetime2.year) * 12 + datetime1.month - datetime2.month) // 3

def num_years_between_dates(datetime1, datetime2):
    return abs(datetime1.year - datetime2.year)

def time_between_dates(datetime1, datetime2, time_key):
    if time_key == 'monthly':
        return num_months_between_dates(datetime1, datetime2)
    elif time_key == 'quarterly':
        return num_quarters_between_dates(datetime1, datetime2)
    elif time_key == 'annual':
        return num_years_between_dates(datetime1, datetime2)

#########################################################################
############## Helper Functions for Working with Dataframes #############
#########################################################################
def dict_to_df(index2rows, cols, index_name):
    """
    Parameters
    ----------
    index2rows : dict
        Dictionary mapping index to rows to be coverted
    cols : list
        List of column names of type str
    index : list
        List of index names
    Returns
    -------
    df : pd.DataFrame
        Constructed dataframe
    """
    if len(index_name) == 1:
        df = pd.DataFrame.from_dict(index2rows, orient='index', columns=cols)
        df.index.name = index_name[0]
        df.sort_index(axis=0, inplace=True)
        return df
    else:
        df = pd.DataFrame.from_dict(index2rows, orient='index', columns=cols)
        df = pd.DataFrame(df, pd.MultiIndex.from_tuples(df.index, names=index_name))
        df.sort_index(axis=0, inplace=True)
        return df

def generate_test_df():
    """
    A function to generate a DataFrame for testing purposes.

    This DataFrame uses Gender as the index and has a column for simulated word vectors for 'i'
    and a column for simulated word vectors for 'we'. Each of these word vectors is a size-10 numpy vector.
    """
    df = pd.DataFrame()
    df['gender'] = ['F', 'M', 'M', 'F', 'M', 'F', 'F', 'M', 'M', 'F']
    df['i'] = df.apply(lambda x: np.random.rand(10).round(5), axis=1)
    df['we'] = df.apply(lambda x: np.random.rand(10).round(5), axis=1)
    df.set_index('gender', inplace=True)
    return df

#########################################################################
########### Helper Functions for Working with Embedding Output ##########
#########################################################################

def remove_empty_embeddings(embeddings_dir):
    """
    Removes all empty files in embeddings_dir that were produced when vocab size was 0.
    Parameters
    ----------
    embeddings_dir : str
        Full path to directory where embedding files are located
    """
    for file in os.listdir(embeddings_dir):
        mittens_file = os.path.join(embeddings_dir, file)
        if os.path.getsize(mittens_file) == 0:
            os.remove(mittens_file)
    return

def word_similarity(model, w1, w2):
    """
    This is an auxilary function that allows for comparing one word to another word or multiple words
    If w1 and w2 are both single words, n_similarity returns their cosine similarity which is the same as 
    simply calling similarity(w1, w2)
    If w1 or w2 is a set of words, n_similarity essentially takes the mean of the set of words and then computes
    the cosine similarity between that vector mean and the other vector. This functionality is both reflected
    in its source code and has been verified manually.
    Parameters
    ----------
    model : KeyedVectors
        The model that contains all the words and vectors
    w1 : str or list
        The first word or word list to be compared
    w2 : str or list
        The second word or word list to be compared
    Returns
    -------
    float
        Cosine similarity between w1 and w2
    """
    if not isinstance(w1, list):
        w1 = [w1]
    if not isinstance(w2, list):
        w2 = [w2]
    w1 = [w for w in w1 if w in model.vocab]
    w2 = [w for w in w2 if w in model.vocab]
    if len(w1) == 0 or len(w2) == 0:
        return None
    return model.n_similarity(w1, w2)

def extract_company_embedding(company_embeddings_filename, tmp_dir, words):
    """
    Parameters
    ----------
    company_embeddings_filename : str
        File path of the company embeddings
    tmp_dir : str
        Path to the directory for gensim to output its tmp files in order to load embeddings into word2vec format
    words : list
        A list of strings to retrieve vectors for
    Returns
    -------
    vecs : list
        A list of vectors of type numpy.ndarray that correspond to the list of words given as parameters
    """ 
    tmp_mittens = os.path.join(tmp_dir, "mittens_embeddings_all_word2vec.txt")
    word2vec_mittens_file = get_tmpfile(tmp_mittens)
    glove2word2vec(company_embeddings_filename, word2vec_mittens_file)
    model = KeyedVectors.load_word2vec_format(word2vec_mittens_file)
    vecs = []
    for w in words:
        if w in model.vocab:
            vecs.append(model.wv[w])
        else:
            print('%s not in company embeddings' % w)
    return vecs

def isnull_wrapper(x):
    r = pd.isnull(x)
    if type(r) == bool:
        return r
    return r.any()

def cossim_with_none(vec1, vec2, vec_format='sparse'):
    """
    Auxiliary function that calls cossim function to test if vectors are None to prevent erroring out.
    Parameters
    ----------
    vec1 : list of (int, float), gensim sparse vector format
    vec2 : list of (int, float), gensim sparse vector format
    format : str, optional
        Either sparse or dense. If sparse, vec1 and vec2 are in gensim sparse vector format; use cossim function from gensim.
        Otherwise, vec1 and vec2 are numpy arrays and cosine similarity is hand calculated
    Returns
    -------
    float
        Cosine similarity between vec1 and vec2
    """
    if not (isnull_wrapper(vec1) or isnull_wrapper(vec2)):
        if vec_format == 'sparse':
            return cossim(vec1, vec2)
        elif vec_format == 'dense':
            if len(vec1) == 0 or len(vec2) == 0:
                return None
            return np.dot(vec1, vec2)/(np.linalg.norm(vec1) * np.linalg.norm(vec2))
        else:
            raise ValueError()
    return None

def calculate_pairwise_cossim(col1, col2=None, reference=False, reference_group=None, anon_ids=None):
    """
    Calculates averaged cosine similarity of every row vector in col1 with every other row vector in col2.
    If no col2 is provided, cosine similarity of every row vector with every other row vector in col1 is calculated.
    The two columns should have equal length.
    Parameters
    ----------
    col1 : pd.Series
        A column where each row is a sparse word vector (BoW format that gensim code is written for).
    col2 : pd.Series, optional
        A column where each row is a sparse word vector (BoW format that gensim code is written for).
    reference : bool, optional
        Indicator variable for whether filtering for reference groups is needed
    reference_group : pd.Series, optional
        If filtering for reference groups, a list containing reference group members for every employee in col1
    anon_ids : pd.Series, optional
        If filtering for reference groups, a list containing anon_ids for every employee in col1
    Returns
    -------
    results : list
        A list where the ith element is the averaged cosine similarity between the ith vector in col1 and every vector
        in col2 for which i != j. If no col2 is provided, a list where the ith element is the averaged cosine similarity
        between the ith vector in col1 and every other vector in col1 is returned.
    """
    vectors1 = col1.tolist()
    vectors2 = col2.tolist() if col2 is not None else col1.tolist()
    reference_group = reference_group.tolist() if reference else None
    anon_ids = anon_ids.tolist() if anon_ids is not None else None
    results = list()
    for i in range(len(vectors1)):
        total_sim = []
        if vectors1[i]:
            for j in range(len(vectors2)):
                if i != j and vectors2[j]:
                    # filter out any np.nans as our reference group
                    if not reference or (type(reference_group[i]) == set and anon_ids[j] in reference_group[i]):
                        total_sim.append(cossim(vectors1[i], vectors2[j]))
            results.append(mean(total_sim) if len(total_sim) > 0 else None)
        else:
            results.append(None)
    return results

def vector_mean(col):
    """
    Calculate vector means of row vectors
    Parameters
    ----------
    col : pd.Series
        The column to be averaged
    Returns
    -------
    np.array
        A vector that is the numerical average of all vectors in col
    """
    return np.array(col[col.notna()].tolist()).mean(axis=0)

def fix_department(hr_df):
    """
    Correcting all typos in the department field. Mostly for generating a correct department
    roster used by mittens3d_explore_alternative_embeddings.py
    Parameters
    ----------
    hr_df : pd.DataFrame
        A dataframe of HR data
    Returns
    -------
    hr_df : pd.DataFrame
        The original dataframe after typos are fixed in the department column
    """
    hr_df['department'] = hr_df['department'].str.replace("Enviroments", "Environments", regex=True)
    hr_df['department'] = hr_df['department'].str.replace("Leanrnign", "Learning", regex=True)
    hr_df['department'] = hr_df['department'].str.replace(" +$", "", regex=True)
    return hr_df

def extract_hr_df(hr_filepath, time_key=None):
    hr = pd.read_csv(hr_filepath)
    cols = hr.columns
    cols = cols.map(lambda x: x.replace(' ', '_') if isinstance(x, str) else x)
    hr.columns = cols
    hr['snapshot_date'] = pd.to_datetime(hr['snapshot_date'])
    hr = hr.sort_values(by=['anon_id', 'snapshot_date'])
    # drop all rehires
    hr_rehires = set(hr.loc[hr['rehire_date'].notna(), 'anon_id'])
    hr['drop'] = hr.apply(lambda row : 1 if row.anon_id in hr_rehires else 0, axis=1)
    hr = hr.loc[hr['drop'] == 0]
    hr['hire_date'] = pd.to_datetime(hr['hire_date'], 'coerce')
    hr['termination_date'] = pd.to_datetime(hr['termination_date'])
    hr['termination_reason'] = hr['termination_reason'].str.lower()
    hr.dropna(subset=['hire_date', 'snapshot_date'], axis=0, inplace=True)
    hr['entry_age'] = hr.groupby('anon_id')['age'].transform('min')
    hr['exit'] = hr.apply(lambda row : 1 if pd.notna(row['termination_reason']) else 0, axis=1)
    hr['exit_vol'] = hr.apply(lambda row : 1 if (row['termination_reason'] == 'voluntary') else 0, axis=1)
    hr['exit_invol'] = hr.apply(lambda row : 1 if (row['termination_reason'] == 'involuntary') else 0, axis=1)
    hr = fix_department(hr)
    if not time_key:
        supervisors = set(hr['supervisor_id'])
        hr['is_supervisor'] = hr.apply(lambda row : 1 if row.anon_id in supervisors else 0, axis=1)
        hr['initial_salary'] = hr.apply(lambda row: row.annual_salary if (row.snapshot_date.month == row.hire_date.month) else None, axis=1)
        hr['initial_salary'].fillna(method = 'ffill', inplace=True)
        hr['salary_delta'] = (hr['annual_salary'] - hr['initial_salary']) / hr['initial_salary']
        hr['salary_avg'] = hr.groupby('anon_id')['annual_salary'].transform('mean')
        hr['tenure'] = hr.apply(lambda row:
            (row.termination_date - row.hire_date).days if pd.notna(row.termination_reason) else (row.snapshot_date - row.hire_date).days, axis=1)
        hr.drop_duplicates(subset='anon_id', keep='last', inplace=True)
        hr = (hr[['anon_id', 'gender', 'ethnicity', 'department', 'location', 'job_title',
            'tenure', 'exit', 'exit_vol', 'exit_invol', 'is_supervisor', 'salary_delta', 'salary_avg', 'entry_age']])
        hr.set_index('anon_id', inplace=True)
        hr.index.name = 'usr'
    else:
        hr[time_key] = hr.apply(lambda row : datetime_to_timekey(row['snapshot_date'], time_key), axis=1)
        hr['past_salary'] = hr.groupby('anon_id')['annual_salary'].shift(1)
        hr['past_job_title'] = hr.groupby('anon_id')['job_title'].shift(1)
        hr['promotion'] = hr.apply(lambda row : 1 if (row['annual_salary'] > row['past_salary'] and (row['past_job_title'] != row['job_title'])) else 0, axis=1)
        monthly_supervisors = {k: set(sups['supervisor_id'].tolist()) for k, sups in hr[['snapshot_date', 'supervisor_id']].groupby("snapshot_date")}
        hr['is_supervisor'] = hr.apply(lambda row : 1 if row['anon_id'] in monthly_supervisors[row['snapshot_date']] else 0, axis=1)
        if time_key == quarter_colname or time_key == year_colname:
            hr['promotion'] = hr.groupby(['anon_id', time_key])['promotion'].transform('max')
            hr['exit'] = hr.groupby(['anon_id', time_key])['exit'].transform('max')
            hr['exit_vol'] = hr.groupby(['anon_id', time_key])['exit_vol'].transform('max')
            hr['exit_invol'] = hr.groupby(['anon_id', time_key])['exit_invol'].transform('max')
            hr['is_supervisor'] = hr.groupby(['anon_id', time_key])['is_supervisor'].transform('max')
            hr.drop_duplicates(subset=['anon_id', time_key], keep='last', inplace=True)
        hr = (hr[['anon_id', time_key, 'snapshot_date', 'gender', 'ethnicity', 'department', 'location', 'job_title', 'annual_salary', 'hire_date',
                'is_supervisor', 'age', 'entry_age', 'promotion', 'exit', 'exit_vol', 'exit_invol']])
        hr.set_index(['anon_id', time_key], inplace=True)
    return hr

Set the current directory

In [None]:
import os
current_dir = os.getcwd()
current_dir

Set the hyperparameters and the input and output directories

In [None]:
#set the number of cores for the server
num_cores = 13
num_users_to_test = 12

#setting some parameters specific to the firm
internal_only = True
first_batches = False

#setting the Mittens hyperparameters
mittens_params = 0.1 
window_size = 10
embedding_dim = 50
max_iter = 100 
mincount = 50
vocab_size = 2500
max_iter_all = 3000

home_dir = current_dir
home_dir1 = "/zfs/projects/faculty/amirgo-transfer/design/"
#set the design firm input directory - the desgin firm data was stored in two directories
corpus_dirs = [os.path.join(home_dir1, "hashed_corpus"), os.path.join(home_dir1, "hashed_corpus_batch3")]
utils_dir = os.path.join(home_dir, "utils")
hash_dir = "hash_embeddings_{}_b12".format(mincount) if first_batches else "hash_embeddings_{}_all".format(mincount)
if internal_only: hash_dir += '_internal'
output_dir = os.path.join(home_dir, hash_dir)

#use these pre-trained GloVe embeddings if the company embedding are not trained from scratch
glove_filename = os.path.join(utils_dir, "glove.6B.{}d.txt".format(embedding_dim))
#for the main analyses, we fine-tune the trained company embeddings for each person-quarter
company_embeddings_filename = os.path.join(home_dir, "hash_embeddings_b12.txt" if first_batches else "hash_embeddings_all.txt")
glove_embeddings = {}

print(corpus_dirs)
print(output_dir)
print(company_embeddings_filename)

Now finetune Mittens on the compnay embedding for the design firm

In [None]:
def process_employee_dir(i, num_users, usr, usr_dirs):
    """
    Workhorse function for training individual embedding spaces for each individual.
    Parameters
    ----------
    i : int
        Index of current directory, used to keep track of progress
    num_users : int
        Total number of user directories, used to keep track of progress
    usr : str
        Focal user whose emails are processed
    usr_dirs : list of str
        list of full paths to directories containing emails belonging to usr
    """
    sys.stderr.write("\nProcessing \t%d/%d -'%s', at %s.\n" % (i, num_users, usr, datetime.now()))
    user_embedding_filename = os.path.join(output_dir, "hash_embedding_{}.txt".format(usr))
    files = [os.path.join(usr_dir, file) for usr_dir in usr_dirs for file in os.listdir(usr_dir)]
    #process person level embeddings
    if not os.path.exists(user_embedding_filename):
        X = build_weighted_matrix(files, mincount=mincount, window_size=window_size, load_to_mem=True, parallel=False, verbose=False, internal_only=internal_only)
        mittens = Mittens(n=embedding_dim, max_iter=max_iter, mittens=mittens_params)
        mittens = mittens.fit(
            X.values, 
            vocab=list(X.index), 
            initial_embedding_dict=company_embeddings)
        mittens_df = pd.DataFrame(mittens, index=X.index)
        if not mittens_df.empty:
            output_embeddings(mittens_df, filename=user_embedding_filename)
    #process person-quarter embeddings
    if os.path.exists(user_embedding_filename):
        user_embeddings = glove2dict(user_embedding_filename)
        sliced_usr_corpus = slice_user_corpus(files, 'all') #was quarterly before
        for time_key, files in sliced_usr_corpus.items():
            user_embedding_time_filename = os.path.join(output_dir, "hash_embedding_{}_{}.txt".format(usr, time_key))
            if not os.path.exists(user_embedding_time_filename):
                X = build_weighted_matrix(files, mincount=mincount, window_size=window_size, load_to_mem=True, parallel=False, verbose=False, internal_only=internal_only)
                mittens = Mittens(n=embedding_dim, max_iter=max_iter, mittens=mittens_params)
                if not user_embeddings: sys.stderr.write("\n%s does not have corresponding user embeddings with timekey %s.\n" % (usr, time_key))
                mittens = mittens.fit(
                    X.values,
                    vocab=list(X.index), 
                    initial_embedding_dict=user_embeddings)
                mittens_df = pd.DataFrame(mittens, index=X.index)
                if not mittens_df.empty:
                    output_embeddings(mittens_df, filename=user_embedding_time_filename)
    return

#main fine-tuning function for design employee embeddings
def train_usr_corpus(usr_dirs, test_mode=True):
    """
    Train individual embedding spaces for each individual.
    Paramters
    ---------
    usr_dirs : list of str
        list of full paths to user directories
    test_mode : bool, optional
        If true, select random subset of directories
    """
    if test_mode:
        usr_dirs = [usr_dirs[random.randint(0, len(usr_dirs)-1)] for _ in range(num_users_to_test)] 
    
    sys.stderr.write('Processing %d directories in parallel at %s.\n' % (len(usr_dirs), str(datetime.now())))

    # Added to allow for processing batch3 with batch1 and 2, which is stored in a separate parent directory
    usr2dirs = defaultdict(list)
    for usr_dir in usr_dirs:
        usr2dirs[os.path.basename(usr_dir)].append(usr_dir)
    num_users = len(usr2dirs)
    #these are the total number of users
    print(num_users)
    pool = multiprocessing.Pool(processes = num_cores)
    results = [pool.apply_async(process_employee_dir, args=(i, num_users, usr, dirs, )) for i, (usr, dirs) in enumerate(usr2dirs.items())]
    for r in results:
        r.get()
    pool.close()
    pool.join()

    return

#this would be when you are training all embeddings from scratch using the initial embeddings as the pre-trained file from Glove
def train_all_corpus(usr_dirs, glove_embeddings):
    files = [os.path.join(usr_dir, file) for usr_dir in usr_dirs for file in os.listdir(usr_dir)]
    sys.stderr.write('Building co-occurrence matrix with all corpora at %s.\n' % str(datetime.now()))
    X = build_weighted_matrix(files, vocab_size=vocab_size, window_size=window_size, load_to_mem=True, parallel=True, verbose=True, internal_only=False)
    sys.stderr.write('Fitting mittens for all corpora at %s.\n' % datetime.now())
    mittens = Mittens(n=embedding_dim, max_iter=max_iter_all, mittens=mittens_params)
    mittens = mittens.fit(
        X.values, 
        vocab=list(X.index), 
        initial_embedding_dict=glove_embeddings)
    mittens_df = pd.DataFrame(mittens, index=X.index)
    output_embeddings(mittens_df, filename=company_embeddings_filename)
    sys.stderr.write('\nSuccessfully produced mittens embeddings using all corpora at %s.\n' % datetime.now())
    return

In [None]:
if __name__ == '__main__':
    starttime = datetime.now()
    test_mode = False
    for d in [output_dir]:
        if not os.path.exists(d):
            os.mkdir(d)
    
    sys.stderr.write('Starting to load files at %s.\n' % str(datetime.now()))
    
    if first_batches:
        corpus_dir = corpus_dirs[0]
        usr_dirs = [os.path.join(corpus_dir, usr_dir) for usr_dir in os.listdir(corpus_dir) if os.path.isdir(os.path.join(corpus_dir, usr_dir))]
    else: #all batches
        usr_dirs = [os.path.join(corpus_dir, usr_dir) for corpus_dir in corpus_dirs for usr_dir in os.listdir(corpus_dir) if os.path.isdir(os.path.join(corpus_dir, usr_dir))]
    print('loaded all users')
    company_embeddings = glove2dict(company_embeddings_filename)
    print('read the company embeddings')
    
    sys.stderr.write('Processing user corpora at %s.\n' % str(datetime.now()))
    train_usr_corpus(usr_dirs, test_mode)

    sys.stderr.write("\nFinished processing at %s, with a duration of %s.\n"
        % (str(datetime.now()), str(datetime.now() - starttime)))