In [None]:
import tqdm
import pickle, json
import numpy  as np
import pandas as pd
import torch
from itertools import chain
from nltk.corpus import wordnet as wn

from src.utils import tosn, get_imcount, get_wrdcount, get_split
from src.semantics import GloVe, get_glove, base_glove, select_lemmas
from src.tree import get_tree, select_subset, select_split
from src.zsl import *

import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
plt.style.use('dark_background')
%matplotlib inline

In [None]:
def ilsvrc_sem_df(synswords, glove):
    """
        FOR ILSVRC training set
    """
    df_sem = get_glove(synswords, glove)
    df_tr  = base_glove(get_split("train"), glove)
    df_sem = pd.concat([df_sem, df_tr])
    df_sem = df_sem[~df_sem.index.duplicated()]
    return df_sem

def ilsvrc_tree_df(synswords, imcounts, min_photo):
    """
        FOR ILSVRC training set
    """
    df = pd.DataFrame({"photo": imcounts, "lemmas":synswords})
    df["sem"]   = df["lemmas"].apply(lambda x:len(x)>0)
    df["train"] = False
    df["train"][get_split("train")]=True
    df["test"]  = (df["photo"]>min_photo) & df["sem"]
    df["image"]=True
    return df

In [2]:
wrdcounts   = "./data/materials/wiki_counts" #"./data/new_counts"
#glove_cache = "../../data/word_embeddings/glove_vec"
#glove_type  = "6B"

glove = GloVe()#glove_cache, glove_type)
imcounts = get_imcount()
wrdcounts= get_wrdcount(wrdcounts)
synsets  = imcounts.index.tolist()
lem_cond = {"wn":"base"}#, "we":True}
word_cond= {"min":500, "max":1000000, "voc":glove.isin}
synswords=select_lemmas(synsets, lem_cond, word_cond, wrdcounts)

100%|██████████| 21845/21845 [00:08<00:00, 2497.00it/s]

4273 correct lemmas. (36468 lemmas removed). 18058 synsets removed for lack of correct lemmas





In [4]:
te_selecta = lambda x:x["type"]in [2,3]
df_tree = ilsvrc_tree_df(synswords, imcounts, 100)
df_sem  = ilsvrc_sem_df(synswords, glove)
tree    = get_tree(df_tree)
trnodes = get_split("train")
tenodes = select_split(tree, te_selecta, set(trnodes))
dist, dico = tree.pairwise_dist()
len(tenodes)

82115it [00:02, 35353.77it/s]
100%|██████████| 82115/82115 [00:00<00:00, 125326.73it/s]
100%|██████████| 82115/82115 [00:01<00:00, 80624.65it/s]


82115 84427 [60270, 17864, 2982, 999]
5611 5816 [674, 956, 2982, 999]


  1%|          | 11/999 [00:00<00:09, 101.11it/s]

4695 4897 [257, 457, 2982, 999]


100%|██████████| 999/999 [00:09<00:00, 100.69it/s]
100%|██████████| 4695/4695 [00:22<00:00, 212.66it/s]
4695it [00:42, 110.08it/s]


2229

In [None]:
nte = 1000
tenodes_ = np.random.choice(list(tenodes), nte, replace=False)

#tenodes_=get_split("2-hops")

In [102]:

dist[range(dist.shape[0]),range(dist.shape[0])]=100

tenodes = np.asarray(tenodes_)
tenodes_idx = np.asarray([dico[x] for x in tenodes_])
trnodes_idx = np.asarray([dico[x] for x in trnodes])

dst_tr = dist[trnodes_idx][:,tenodes_idx]
dst_te = dist[tenodes_idx][:,tenodes_idx]

mintr  = dst_tr.min(0)
meantr = dst_tr.mean(0)
minte  = dst_te.min(0)
meante = dst_te.mean(0)

In [None]:
v_tr = get_visu(trnodes)
s_tr = get_sem(trnodes, df_sem)
v_te, s_te, l_te = get_data(tenodes_, df_sem)
best = [0]

for g in [0,1,10,100, 1000]:
    for l in [0,1,10,100, 1000]:
        w = get_W(v_tr, s_tr, g, l)
        res = scores(v_te, s_te, w)
        res = topk(res, 5, l_te)
        if res[0]>best[0]:
            best=res
            G,L=g,l

In [None]:
w = get_W(v_tr, s_tr, G, L)
res = scores(v_te, s_te, w)

In [None]:
s = pd.Series(acc_per_class(res, l_te, 1), index=tenodes_)
mintr  = pd.Series(dst_tr.min(0), index=tenodes_)
meantr = pd.Series(dst_tr.mean(0), index=tenodes_)
minte  = pd.Series(dst_te.min(0), index=tenodes_)
meante = pd.Series(dst_te.mean(0), index=tenodes_)

In [None]:
lems = synswords.loc[s.index]#s.index.to_series().apply(lambda x: tuple(tosn(x).lemma_names()))
imc  = imcounts.loc[s.index]
wrdc = lems.apply(lambda x:min(max(wrdcounts[c] for c in x), 10000))
df = pd.DataFrame({"score":s, "lems":lems, "imc":imc, "wrdc":wrdc,
                   "mintr":mintr, "minte":minte, "meantr":meantr, 
                   "meante":meante, "ratio":mintr/minte})

In [None]:
df.plot(y="score", x="minte", kind="scatter", alpha=0.5)