In [1]:
import numpy as np
from numpy import dot
from numpy.linalg import norm
import random
import pickle
from tqdm import tqdm
import gc
from itertools import combinations

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
from keras import models, layers, losses, optimizers, regularizers, Model

In [2]:
gc.collect()

0

In [3]:
brain_index = pickle.load(open("../data/support.pkl", "rb"))
nouns, verbs = pickle.load(open("../data/vecs.pkl", "rb"))

pickles = [pickle.load(open(f"../data/pickles/{i}.pkl", "rb")) for i in range(1)]
pickles = [item for sublist in pickles for item in sublist]
pickles = sorted(pickles, key=lambda x: x[1])
pickles = [[item for item in pickles if item[1] == noun] for noun in nouns]
pickles = [(np.add.reduce([item[0] for item in sublist]) / len(sublist), sublist[0][1]) for sublist in pickles]
pickles = [(item[0][brain_index], item[1]) for item in pickles]

len(nouns), len(verbs)

(60, 25)

In [4]:
def l2(a, b):
    return norm(np.subtract(a, b))

In [5]:
class BasisSum(Model):
    def __init__(self):
        super().__init__()
        self.basis = tf.Variable(tf.convert_to_tensor([verbs[verb] for verb in verbs]), trainable=False, name="verb_basis")
        self.d1 = layers.Dense(64, activation="relu")
        self.d2 = layers.Dense(32, activation="relu")
        self.dn = layers.Dense(self.basis.shape[0], activation="sigmoid")
        
    @tf.function(reduce_retracing=True)
    def call(self, x):
        x = self.d1(x)
        x = self.d2(x)
        x = self.dn(x)
        x = x / tf.reduce_sum(x, axis=-1, keepdims=True)
        
        x = tf.einsum("bi,ij->bj", x, self.basis)
        
        return x

In [6]:
batch_size = 64
pbar = tqdm(combinations(range(60), 58))
correct_count = 0

x = np.array([item[0] for item in pickles])
y = [item[1] for item in pickles]
y = np.array([nouns[item] for item in y])

x, y = tf.cast(x, tf.dtypes.float32), tf.cast(y, tf.dtypes.float32)

for i, comb in enumerate(pbar):
    comp = list(set.difference(set(range(60)), set(comb)))

    model = BasisSum()
    loss = losses.MeanSquaredError()
    opt = optimizers.Adam(0.001)
    
    train_x, test_x = tf.gather(x, comb), tf.gather(x, comp)
    train_y, test_y = tf.gather(y, comb), tf.gather(y, comp)

    batchlosses = []
    for j in range(5000):
        idx1 = tf.random.uniform(shape=[batch_size], minval=0, maxval=tf.shape(train_x)[0], dtype=tf.int32)
        idx2 = tf.random.uniform(shape=[batch_size], minval=0, maxval=tf.shape(train_x)[0], dtype=tf.int32)

        batch_x1, batch_y1 = tf.gather(train_x, idx1), tf.gather(train_y, idx1)
        batch_x2, batch_y2 = tf.gather(train_x, idx2), tf.gather(train_y, idx2)

        ratios = tf.random.uniform((len(batch_x1), 1), 0, 1)
        batch_x = batch_x1 * ratios + batch_x2 * (1 - ratios)
        batch_y = batch_y1 * ratios + batch_y2 * (1 - ratios)

        with tf.GradientTape() as tape:
            pred_y = model(batch_x)
            batchloss = loss(batch_y, pred_y)
            grads = tape.gradient(batchloss, model.trainable_variables)
            opt.apply_gradients(zip(grads, model.trainable_variables))

            batchlosses.append(float(batchloss))

        if j % 100 == 0:
            # print(sum(batchlosses[-100:]) / 100)
            ...

    pred = model(test_x)
    t1, t2 = test_y.numpy()
    t1, t2 = t1.flat, t2.flat
    p1, p2 = pred.numpy()
    p1, p2 = p1.flat, p2.flat
    
    correct = l2(t1, p1) + l2(t2, p2)
    incorrect = l2(t1, p2) + l2(t2, p1)

    correct_count += int(correct < incorrect)

    pbar.set_description(f"accuracy: {correct_count / (i + 1):.3f}")

0it [00:00, ?it/s]

0.003373884856700897
0.26873515754938126
0.25013391152024267
0.24338313892483712
0.2459697562456131
0.2433212825655937
0.24274668991565704
0.23745998457074166
0.23517379745841027
0.23326953530311584
0.23542273700237273
0.2250434625148773
0.21668407633900644
0.21342742666602135
0.21205770686268807
0.2107598027586937
0.20970246940851212
0.20806237146258355
0.2072278454899788
0.20947147101163865
0.20867252543568612
0.20649966582655907
0.20681284323334695
0.2087359519302845
0.2075600107014179
0.20598743930459024
0.20635671481490137
0.20699978485703469
0.20754908084869383
0.2078784492611885
0.2082293750345707
0.20688654124736786
0.2075345528125763
0.20652731388807297
0.20615026965737343
0.20657311424612998
0.20503491029143334
0.2047408950328827
0.20672660782933236
0.20743891596794128
0.20780995905399322
0.20698378771543502
0.20737035363912582
0.20749947026371957
0.20702317222952843
0.20729136556386948
0.20632491052150725
0.2079425023496151
0.20572543159127235
0.2084750524163246


accuracy: 1.000: : 1it [01:09, 69.98s/it]

0.0032445475459098815
0.2780319565534592
0.2712273408472538
0.24383723929524423
0.21496526032686233
0.21384094297885894
0.2134184142947197
0.21003156304359435
0.2119832782447338
