In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import torch
import fasttext
import numpy as np
from model import WADataset, WordAssociationPredictionModel, WordAssociationPredictionModel, embed


model = WordAssociationPredictionModel()
model.load_state_dict(torch.load("./data/trained_model.pt"))
model.eval()

WordAssociationPredictionModel(
  (activation): ReLU()
  (linear1): Linear(in_features=600, out_features=512, bias=True)
  (linear2): Linear(in_features=512, out_features=512, bias=True)
  (linear3): Linear(in_features=512, out_features=512, bias=True)
  (linear4): Linear(in_features=512, out_features=256, bias=True)
  (linear5): Linear(in_features=256, out_features=4, bias=True)
)

In [2]:
embedding_model = fasttext.load_model('./crawl-300d-2M-subword/crawl-300d-2M-subword.bin')



In [3]:
embedding_model

<fasttext.FastText._FastText at 0x134bb5650>

In [4]:
test = embed(embedding_model, [["apple", "orange"], [["grape", "grape"], ["banana", ""], ["", ""]]])
test

tensor([[-0.0137,  0.0132,  0.0522,  ..., -0.0115, -0.0015,  0.0099],
        [-0.0142,  0.0141,  0.0280,  ..., -0.0117, -0.0027,  0.0058]])

In [5]:
model(test) # [0-39, 30-50, 50-70, 70+]

tensor([[ 1.0457,  0.1358, -0.6141, -2.3175],
        [ 0.7250,  0.0983, -0.3082, -1.7954]], grad_fn=<AddmmBackward0>)

In [6]:
snow = pd.read_csv("./preprocessing/SWOW-EN.complete_preprocessed.csv")

In [7]:
snow.head()

Unnamed: 0,participantID,country,age,gender,cue,R1Raw,R2Raw,R3Raw,amount
0,64960,United States,24,Ma,lawful,dutiful,square,illegal,3
1,95843,United States,32,Fe,browse,web,surf,look,3
2,66020,United States,26,Fe,Japan,bomb,rice,fish,3
3,34329,United States,38,Ma,contain,within,has,nuclear,3
4,1250,United States,21,Ma,pull,push,drag,slide,3


In [8]:
snow.groupby("participantID").cue.count().describe()

count    75923.000000
mean        14.806725
std          1.974502
min          1.000000
25%         14.000000
50%         14.000000
75%         16.000000
max         18.000000
Name: cue, dtype: float64

participants submitted different numbers of responses, let's evaluate *per-participant* accuracy, using only participants that had 14 responese (the median number)

In [9]:
n_responses_dict = snow.groupby("participantID").cue.count().to_dict()

In [10]:
snow = snow.assign(
    n_responses =  snow.participantID.map(n_responses_dict)
)

In [11]:
snow_subset = snow.query("n_responses == 14")

In [12]:
def participant_predict(cues, r1, r2, r3):
    test = embed(embedding_model, [cues.values, [r1.values, r2.values, r3.values]])
    results = model(test)
    return results

In [13]:
res = participant_predict(
    snow_subset.cue,
    snow_subset.R1Raw,
    snow_subset.R2Raw,
    snow_subset.R3Raw
)

In [14]:
res = res.detach().numpy()

In [15]:
import scipy
from scipy.special import softmax
res = softmax(res)

In [16]:
res = res / res.sum(axis=1).reshape(-1, 1)

In [17]:
snow_subset = snow_subset.assign(
    predictionU30 = res[:,0],
    prediction3050 = res[:,1],
    prediction5070 = res[:,2],
    prediction70U = res[:,3],
    predictedAgeGroupPerTrial = np.argmax(res, axis=1)
)

In [18]:
def checkAge(row):
    if row.age < 30 and row.predictedAgeGroupPerTrial == 0:
        return True
    elif row.age >= 30 and row.age < 50 and row.predictedAgeGroupPerTrial == 1:
        return True
    elif row.age >= 50 and row.age < 70 and row.predictedAgeGroupPerTrial == 2:
        return True
    elif row.age >= 70 and row.predictedAgeGroupPerTrial == 3:
        return True
    else:
        return False


Overall accuracy

In [19]:
snow_subset.apply(checkAge, axis=1).sum() / len(snow_subset)

0.4704272009145943

per participant accuracy based on the most frequently predicted age group

In [20]:
d_prediction_dict = snow_subset.groupby("participantID").apply(
    lambda x: x.apply(checkAge, axis=1).sum() / len(x),
).to_dict()

In [21]:
def getMaxPrediction(row):
    return np.argmax([np.max(row.predictionU30),
                        np.max(row.prediction3050),
                        np.max(row.prediction5070),
                        np.max(row.prediction70U)])



In [22]:
argmaxPredictions = snow_subset.groupby("participantID").apply(
    lambda r:getMaxPrediction(r)).to_dict()

In [23]:
snow_subset = snow_subset.assign(avgParticpantPredictionAcc = lambda df: df.participantID.map(d_prediction_dict))

In [24]:
snow_subset = snow_subset.assign(maxParticipantPredictionAcc = lambda df: df.participantID.map(argmaxPredictions))

In [25]:
def getParticipantAgeGroup(age):
    if age < 30:
        return 0
    elif age >= 30 and age < 50:
        return 1
    elif age >= 50 and age < 70:
        return 2
    elif age >= 70:
        return 3

In [26]:
snow_subset = snow_subset.assign(
    participantAgeGroup = lambda df: df.age.apply(getParticipantAgeGroup)
)

In [27]:
snow_subset.groupby("participantAgeGroup").avgParticpantPredictionAcc.mean()

participantAgeGroup
0    0.883661
1    0.083949
2    0.282344
3    0.000000
Name: avgParticpantPredictionAcc, dtype: float64

In [36]:
snow_subset.assign(
    maxPredictionAcc = snow_subset.maxParticipantPredictionAcc == snow_subset.participantAgeGroup
).groupby("participantAgeGroup").maxPredictionAcc.mean()

participantAgeGroup
0    0.979502
1    0.000886
2    0.207878
3    0.000000
Name: maxPredictionAcc, dtype: float64

In [29]:
snow_subset

Unnamed: 0,participantID,country,age,gender,cue,R1Raw,R2Raw,R3Raw,amount,n_responses,predictionU30,prediction3050,prediction5070,prediction70U,predictedAgeGroupPerTrial,avgParticpantPredictionAcc,maxParticipantPredictionAcc,participantAgeGroup
0,64960,United States,24,Ma,lawful,dutiful,square,illegal,3,14,0.310465,0.339911,0.303626,0.045998,1,0.785714,0,0
2,66020,United States,26,Fe,Japan,bomb,rice,fish,3,14,0.476169,0.388888,0.125822,0.009122,0,1.000000,0,0
3,34329,United States,38,Ma,contain,within,has,nuclear,3,14,0.539082,0.267073,0.160471,0.033374,0,0.000000,0,1
4,1250,United States,21,Ma,pull,push,drag,slide,3,14,0.181081,0.245244,0.426000,0.147675,2,0.285714,0,0
6,42323,United States,28,Fe,shade,grey,blinds,cove,3,14,0.564290,0.292399,0.126269,0.017043,0,1.000000,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1124261,36727,France,20,Fe,patriotism,Flag,Eagle,Stars and Stripes,3,14,0.599884,0.325061,0.071077,0.003978,0,1.000000,0,0
1124262,35503,France,27,Fe,gallop,horse,race,poll,3,14,0.564305,0.287128,0.129580,0.018987,0,1.000000,0,0
1124263,72390,France,24,Fe,picket,line,strike,No more responses,2,14,0.442822,0.367667,0.172795,0.016717,0,1.000000,0,0
1124264,88127,France,36,Fe,tripod,camera,photo,lab,3,14,0.546320,0.332402,0.111609,0.009668,0,0.071429,0,1
