In [71]:
import pandas as pd
import plotly.express as px
import os
import numpy as np
import torch

data_dir = 'data/reddit2015'
model_dir = 'model/reddit2015'

# Basic model performance

In [82]:
# Load the results created by `eval_lm test_perplexity`
results_file = os.path.join(model_dir, 'test_ppl.pickle')
df_ppl = pd.read_pickle(results_file)
    
comms = torch.load(open(os.path.join(model_dir, 'community.field'), 'rb')).vocab.itos[1:]
models = [name for name in list(df_ppl.columns) if not name in ('community', 'comment')]

In [73]:
# The epoch with the lowest validation loss is used for testing
test_epoch = pd.DataFrame([(model, open(f'{model_dir}/{model}/saved-epoch.txt').read()) for model in models],
            columns = ['model', 'test_epoch']).set_index('model')
test_epoch

Unnamed: 0_level_0,test_epoch
model,Unnamed: 1_level_1
lstm-3,21
lstm-3-0,17
lstm-3-1,34
lstm-3-2,11
lstm-3-3,16
transformer-3,20
transformer-3-0,7
transformer-3-1,12
transformer-3-2,7
transformer-3-3,10


## Test Perplexity

In [74]:
# Summary statistics test perplexity, by model
pd.options.display.float_format = '{:,.2f}'.format
df_ppl.describe()[models]

Unnamed: 0,lstm-3,lstm-3-0,lstm-3-1,lstm-3-2,lstm-3-3,transformer-3,transformer-3-0,transformer-3-1,transformer-3-2,transformer-3-3
count,230000.0,230000.0,230000.0,230000.0,230000.0,230000.0,230000.0,230000.0,230000.0,230000.0
mean,80.7,78.64,76.8,77.32,77.04,94.44,90.22,95.49,122.56,89.38
std,133.7,145.95,258.82,124.22,180.92,148.81,170.58,216.51,246.48,155.69
min,1.01,1.02,1.04,1.03,1.01,1.02,1.02,1.1,1.07,1.02
25%,32.79,32.06,31.35,31.62,31.25,39.02,37.5,39.59,50.44,37.19
50%,56.89,55.9,54.44,54.99,54.25,66.95,64.86,68.76,89.65,64.73
75%,94.57,92.77,90.13,91.14,89.84,110.21,106.66,112.66,146.98,106.18
max,23181.27,33776.25,108869.81,19275.18,42067.9,22505.67,34054.73,57392.79,46074.6,30054.51


In [75]:
# Mean test perplexity by model and community
df_ppl.groupby('community').mean().sort_values('lstm-3')

Unnamed: 0_level_0,lstm-3,lstm-3-0,lstm-3-1,lstm-3-2,lstm-3-3,transformer-3,transformer-3-0,transformer-3-1,transformer-3-2,transformer-3-3
community,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
relationships,53.07,54.83,53.39,53.96,53.09,62.19,63.04,67.35,89.07,64.14
stopdrinking,53.27,54.29,52.61,53.44,58.74,65.73,66.46,66.49,92.08,63.02
Advice,56.22,58.7,57.68,57.78,57.27,66.37,68.65,72.74,97.89,68.12
BabyBumps,61.21,60.68,59.17,60.76,58.96,71.29,69.58,74.93,98.34,69.98
xxfitness,65.4,66.6,65.17,64.97,64.28,77.25,76.43,81.36,111.43,76.14
AskWomen,65.65,68.29,66.98,67.68,66.35,77.18,78.08,83.5,110.55,78.04
TwoXChromosomes,65.69,68.89,66.86,67.92,66.2,77.46,79.56,84.63,113.83,80.18
breakingmom,66.66,67.98,66.12,67.54,65.7,77.73,76.92,82.09,106.09,77.31
techsupport,66.74,64.98,63.32,64.23,64.28,78.65,75.07,79.62,105.63,74.8
femalefashionadvice,68.42,67.35,65.81,66.8,65.28,80.71,78.37,83.98,111.84,77.81


## Information gain (vs. unconditioned version)

In [76]:
conditioned_lstms = [f'lstm-3-{i}' for i in range(4)]
conditioned_transformers = [f'transformer-3-{i}' for i in range(4)]
df_info_gain = pd.concat([
        df_ppl[conditioned_lstms].apply(lambda x:df_ppl['lstm-3'] / x, axis=0),
        df_ppl[conditioned_transformers].apply(lambda x:df_ppl['transformer-3'] / x, axis=0)
    ], axis=1)
df_info_gain.describe()

Unnamed: 0,lstm-3-0,lstm-3-1,lstm-3-2,lstm-3-3,transformer-3-0,transformer-3-1,transformer-3-2,transformer-3-3
count,230000.0,230000.0,230000.0,230000.0,230000.0,230000.0,230000.0,230000.0
mean,1.05,1.07,1.06,1.07,1.08,1.02,0.83,1.09
std,0.25,0.26,0.25,0.26,0.3,0.3,0.33,0.34
min,0.07,0.03,0.03,0.04,0.01,0.0,0.01,0.03
25%,0.91,0.93,0.92,0.93,0.92,0.87,0.64,0.92
50%,1.01,1.04,1.03,1.04,1.03,0.98,0.76,1.03
75%,1.14,1.17,1.16,1.17,1.17,1.11,0.93,1.18
max,8.15,7.98,10.33,28.08,10.65,10.35,14.94,30.05


In [77]:
# Mean information gain by community
df_info_gain['community'] = df_ppl['community']
df_info_gain.groupby('community').mean().sort_values('lstm-3-0', ascending=False)

Unnamed: 0_level_0,lstm-3-0,lstm-3-1,lstm-3-2,lstm-3-3,transformer-3-0,transformer-3-1,transformer-3-2,transformer-3-3
community,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
streetwear,1.18,1.21,1.18,1.17,1.23,1.17,1.02,1.27
MaddenUltimateTeam,1.16,1.18,1.16,1.15,1.21,1.16,0.98,1.24
Kappa,1.12,1.15,1.14,1.14,1.13,1.11,0.94,1.21
jailbreak,1.12,1.14,1.14,1.12,1.14,1.1,0.89,1.16
CFB,1.1,1.14,1.11,1.12,1.12,1.09,0.91,1.14
eu4,1.1,1.11,1.11,1.13,1.14,1.09,0.88,1.16
food,1.09,1.11,1.09,1.09,1.12,1.06,0.89,1.13
reddevils,1.09,1.12,1.11,1.11,1.14,1.09,0.9,1.16
MLS,1.09,1.13,1.1,1.12,1.12,1.08,0.9,1.15
rupaulsdragrace,1.09,1.12,1.11,1.11,1.14,1.07,0.91,1.16


# Community embedding PCAs

In [90]:
import numpy as np
import torch

# Load the community embeddeing layer of the conditional models
conditioned_models = [m for m in models if not m in ('lstm-3', 'transformer-3')]
comm_embed = {
    model: np.load(os.path.join(model_dir, f'{model}/comm_embed.npy'))[1:]
    for model in conditioned_models
}

# Load the author co-occurance embedding
w_auth = np.load('model/reddit2015/comm_author_embed_svd16dim.npy')

In [93]:
# Manually assign communities to different types/subjects
# NOTE this shouldn't really be a partition. Some communities clearly belong to multiple types,
# but we need a more sophisticated viz for that.
comm_types = {
    'games': ['Warframe', 'eu4', 'GlobalOffensive', 'MaddenUltimateTeam', 'heroesofthestorm', 'EDH', 'KerbalSpaceProgram'],
    'female-focus': ['xxfitness', 'femalefashionadvice', 'TwoXChromosomes', 'AskWomen', 'breakingmom', 'BabyBumps'],
    'sports': ['MMA', 'reddevils', 'CFB', 'MLS'],
#     'support': ['stopdrinking', 'exjw'],
    'meme': ['justneckbeardthings', 'cringe'],
#     'gamergate': ['Kappa', 'KotakuInAction'],
    'generic': ['Advice', 'relationships', 'LifeProTips', 'explainlikeimfive', 'todayilearned'],
#     'photos': ['photography', 'EarthPorn'],
    'tech': ['pcmasterrace', 'techsupport', 'jailbreak', 'oculus']
    # fitness: ['xxfitness', 'bodybuilding']
}    
df_types = pd.DataFrame([(c,t) for t in comm_types for c in comm_types[t]], columns=['community', 'type']) # assumes single type/community
df_types = df_types.merge(pd.DataFrame(comms, columns=['community']), on='community', how='outer').set_index('community')
df_types['type'] = df_types['type'].fillna('other')

In [94]:
from sklearn.decomposition import PCA

def plot_pca(w):
    pca = PCA(n_components=2)
    pcs = pca.fit_transform(w)
    df_w = pd.DataFrame(pcs, index=comms, columns=['pca1', 'pca2'])
    df_w['type'] = df_types['type']
    df_w['size'] = 1
    return px.scatter(df_w.reset_index(), x='pca1', y='pca2', hover_name='index', color='type', size='size', size_max=10)

## Best LSTM (`lstm-3-1`)

In [99]:
plot_pca(comm_embed['lstm-3-1'])

## Best Transformer (`transformer-3-3`)

In [101]:
plot_pca(comm_embed['transformer-3-3'])

## Author co-occurance embedding

created with `comm_author_embed.py`

In [102]:
plot_pca(w_auth)

# Correlations in community embedding

In [103]:
from itertools import combinations
from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import r2_score


def cos_sim(v1, v2):
    return (v1 * v2).sum(axis=0) / (np.linalg.norm(v1, axis=0) * np.linalg.norm(v2, axis=0))

def cos_sim_trunc(v1, v2, n):
    return cos_sim(v1[:n], v2[:n])

In [104]:
def plot_sim_scatter(w1, w2, n):
#     comm_pairs = pd.Series(combinations(range(len(comms)), 2))

    sims1, sims2 = [], []
    comms_col = []
    for c1 in range(len(comms)):
        for c2 in range(len(comms)):
            c1_name = comms[c1]
            c2_name = comms[c2]
            sims1.append(cos_sim_trunc(w1[c1], w1[c2], n))
            sims2.append(cos_sim_trunc(w2[c1], w2[c2], n))
            comms_col.append((c1_name, c2_name))
    print(f'R^2 = {r2_score(sims1, sims2)}')


    df_sims = pd.DataFrame(zip(sims1, sims2, comms_col), columns=['w1', 'w2', 'name'])

    return px.scatter(df_sims, x='w1', y='w2')

In [106]:
plot_sim_scatter(comm_embed['lstm-3-1'], comm_embed['transformer-3-3'], 16)

R^2 = 0.4444635971264811


In [107]:
plot_sim_scatter(w_auth, comm_embed['lstm-3-1'], 16)

R^2 = -4.479006244567077


In [108]:
plot_sim_scatter(w_auth, comm_embed['transformer-3-3'], 16)

R^2 = -4.430195727088445


# Entropy of the communty inference weights

In [109]:
# Load the community embeddeing layer of the conditional models

w_lstm_infer = torch.load(os.path.join(os.path.join(model_dir, 'lstm-3-1'), 'model.bin'),
            map_location='cpu')['comm_inference.weight'].softmax(1).numpy()
w_trns_infer = torch.load(os.path.join(os.path.join(model_dir, 'transformer-3-3'), 'model.bin'),
            map_location='cpu')['comm_inference.weight'].softmax(1).numpy()

In [110]:
from scipy.special import entr

def entropy(v):
    return np.exp(entr(v).sum())
    
def entropy_df(w1, w2):
    e1, e2 = [], []
    for c1 in range(len(comms)):
        e1.append(entropy(w1[c1]))
        e2.append(entropy(w2[c1]))

    df = pd.DataFrame(zip(e1, e2), columns=['w1', 'w2'], index=comms)
    return df

df_e = entropy_df(w_lstm_infer, w_trns_infer)

In [111]:
px.scatter(df_e.reset_index(), x='w1', y='w2', hover_name='index')

# Bayesian community inference

$$
P(c|m) = \frac{p(m|c) \cdot p(m)}{p(c)}
$$

In [403]:
def entropy(v):
    return entr(v).sum()

def model_comm_confusion_matrix(model_name):
    """ C[i,j] = average_{Posts(cj)}(P(c=ci|m))"""
    P = pd.read_pickle(os.path.join(os.path.join(model_dir, model_name), 'comm_probs.pickle'))
    C = P.groupby('actual_comm').mean()
    C = C.T # transpose to (prob assigned, actual comm), as in the paper
    C = C.sort_index() # sort the rows alphabetically
    C = C[C.index] # sort the columns alphabetically too
    return C
    

def confusion_matrix_ppls(P):
    return np.exp(confusion_matrix_entropies(P))

def plot_confusion_matrix(P):
    fig = go.Figure(
    data=go.Heatmap(z=P, a=sorted(comms))
    )
    fig.update_layout(height=800, width=800, font=dict(size=8))
    fig.show()

def plot_model_scatter(M1, M2):
    df = pd.DataFrame([M1, M2]).T
    df.index.name = 'community'
    df = df.reset_index()
    fig = px.scatter(df, x=M1.name, y=M2.name, hover_name='community')
    return fig

In [404]:
model1, model2 = 'lstm-3-1', 'transformer-3-3'
E1 = model_comm_confusion_matrix(model1).apply(entropy).rename(model1 + ' Entropy')
E2 = model_comm_confusion_matrix(model2).apply(entropy).rename(model2 + ' Entropy')
plot_model_scatter(E1, E2).show()

In [405]:
model, attribute = 'lstm-3-1', 'stability_active'
E = model_comm_confusion_matrix(model).apply(entropy).rename(model + ' Entropy')
plot_model_scatter(E, df_comm_attrs['stability_active'])

In [409]:
model, attribute = 'lstm-3-1', 'size_active_15'
E = model_comm_confusion_matrix(model).apply(entropy).rename(model + ' Entropy')
plot_model_scatter(E, df_comm_attrs[attribute])

In [416]:
model, attribute = 'lstm-3-1', 'clustering_full_15'
E = model_comm_confusion_matrix(model).apply(entropy).rename(model + ' Entropy')
plot_model_scatter(E, df_comm_attrs[attribute])

### Confusion matrices

In [439]:
# Plot confusion matrix as heatmap
# x-axis: actual community
# y-axis: avg. probability assigned to that community
import plotly.graph_objects as go

def plot_confusion(model):
    C = model_comm_confusion_matrix(model)
    fig = go.Figure(
        data=go.Heatmap(
            z=C.values,
            y = C.columns,
            x = C.index,
                       ))
    fig.update_layout(height=800, width=800, font=dict(size=8), title=model)
    return fig.show()


In [440]:
plot_confusion('lstm-3-1')

In [442]:
plot_confusion('transformer-3-3')

In [None]:
df = df_lstm.groupby('actual_comm').mean()
df = df[df.index].T # sort the colums to match the rows
fig = go.Figure(
    data=go.Heatmap(
        z=df.to_numpy(),
        y = list(df.columns),
        x = list(df.index)
                   ))
fig.show()

In [68]:
pd.options.display.float_format = '{:,.4f}'.format

from scipy.special import entr

def ppl(v):
    return np.exp(entr(v).sum())

df_lstm.groupby('actual_comm').mean().apply(ppl).sort_values()

EDH                   35.3010
jailbreak             39.3753
MaddenUltimateTeam    40.1841
techsupport           40.2788
eu4                   40.4830
stopdrinking          40.6867
Warframe              40.9455
femalefashionadvice   41.1717
photography           41.2582
BabyBumps             41.2801
heroesofthestorm      41.8897
xxfitness             42.2935
reddevils             42.2988
KerbalSpaceProgram    43.7442
GameDeals             43.9283
Fantasy               44.2348
rupaulsdragrace       44.7113
streetwear            44.8271
airsoft               45.2677
relationships         45.5071
Drugs                 45.6594
cars                  45.9839
oculus                45.9900
EarthPorn             46.5890
CFB                   46.7759
breakingmom           46.8197
GlobalOffensive       46.8576
exjw                  46.9145
MLS                   47.1426
Advice                47.1684
bodybuilding          47.1902
MMA                   47.3872
TwoXChromosomes       48.7221
pcmasterra

In [67]:
entr(np.array([0.5,0.5])).sum()

0.6931471805599453

In [75]:
ppl(np.array([1]*46)/46)

45.999999999999964

In [109]:
df = df_lstm[df_lstm['actual_comm'] == 'EDH'][comms]

In [111]:
ppl(normalize(df.sum()))

45.808342

In [113]:
df.sum().sum()

5000.001

EDH                   45.7639
MaddenUltimateTeam    45.8241
eu4                   45.8311
jailbreak             45.8448
Warframe              45.8665
reddevils             45.8711
streetwear            45.8746
techsupport           45.8839
heroesofthestorm      45.8875
GlobalOffensive       45.8915
KerbalSpaceProgram    45.8929
CFB                   45.8954
GameDeals             45.8970
BabyBumps             45.9015
MLS                   45.9042
rupaulsdragrace       45.9047
femalefashionadvice   45.9051
Kappa                 45.9073
xxfitness             45.9136
pcmasterrace          45.9138
MMA                   45.9139
stopdrinking          45.9150
photography           45.9254
oculus                45.9255
airsoft               45.9322
cars                  45.9363
bodybuilding          45.9370
Drugs                 45.9383
EarthPorn             45.9462
Fantasy               45.9470
breakingmom           45.9512
exjw                  45.9533
relationships         45.9542
AskWomen  

In [116]:
df_lstm

Unnamed: 0,streetwear,Jokes,MaddenUltimateTeam,cringe,CFB,food,pcmasterrace,jailbreak,MLS,MMA,...,Fantasy,femalefashionadvice,breakingmom,BabyBumps,explainlikeimfive,xxfitness,TwoXChromosomes,Advice,relationships,actual_comm
0,0.0235,0.0175,0.0167,0.0227,0.0181,0.0216,0.0216,0.0179,0.0207,0.0191,...,0.0205,0.0262,0.0249,0.0252,0.0236,0.0241,0.0242,0.0242,0.0230,BabyBumps
1,0.0196,0.0203,0.0141,0.0205,0.0207,0.0215,0.0149,0.0173,0.0212,0.0180,...,0.0218,0.0311,0.0293,0.0341,0.0238,0.0298,0.0279,0.0243,0.0248,BabyBumps
2,0.0168,0.0289,0.0179,0.0245,0.0196,0.0236,0.0186,0.0183,0.0203,0.0205,...,0.0213,0.0199,0.0230,0.0210,0.0246,0.0230,0.0227,0.0225,0.0271,BabyBumps
3,0.0199,0.0233,0.0192,0.0222,0.0219,0.0232,0.0194,0.0175,0.0200,0.0222,...,0.0239,0.0253,0.0255,0.0218,0.0238,0.0216,0.0227,0.0230,0.0250,BabyBumps
4,0.0204,0.0224,0.0147,0.0225,0.0212,0.0252,0.0191,0.0200,0.0193,0.0200,...,0.0215,0.0223,0.0274,0.0268,0.0226,0.0270,0.0271,0.0236,0.0243,BabyBumps
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
229995,0.0304,0.0346,0.0200,0.0298,0.0220,0.0197,0.0174,0.0173,0.0181,0.0149,...,0.0220,0.0162,0.0223,0.0253,0.0183,0.0143,0.0164,0.0219,0.0264,rupaulsdragrace
229996,0.0205,0.0289,0.0094,0.0350,0.0151,0.0309,0.0170,0.0167,0.0156,0.0285,...,0.0211,0.0209,0.0295,0.0271,0.0186,0.0233,0.0211,0.0208,0.0208,rupaulsdragrace
229997,0.0197,0.0199,0.0211,0.0225,0.0225,0.0228,0.0186,0.0202,0.0191,0.0227,...,0.0205,0.0214,0.0245,0.0211,0.0230,0.0205,0.0261,0.0248,0.0268,rupaulsdragrace
229998,0.0253,0.0231,0.0191,0.0235,0.0184,0.0212,0.0194,0.0214,0.0182,0.0197,...,0.0210,0.0264,0.0221,0.0216,0.0193,0.0227,0.0232,0.0200,0.0191,rupaulsdragrace


In [135]:
entropy(df_lstm.loc[0][comms].to_numpy().astype(float))

45.650789854761236