In [None]:
import numpy as np
from numpy import dot
from numpy.linalg import norm
import random
import pickle
from tqdm import tqdm
import gc

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
from keras import models, layers, losses, optimizers, regularizers, Model

In [None]:
gc.collect()

In [None]:
brain_index = pickle.load(open("support.pkl", "rb"))
nouns, verbs = pickle.load(open("vecs.pkl", "rb"))

pickles = [pickle.load(open(f"pickles/{i}.pkl", "rb")) for i in range(1)]
pickles = [item for sublist in pickles for item in sublist]
pickles = sorted(pickles, key=lambda x: x[1])
pickles = [[item for item in pickles if item[1] == noun] for noun in nouns]
pickles = [(np.add.reduce([item[0] for item in sublist]) / len(sublist), sublist[0][1]) for sublist in pickles]
pickles = [(item[0][brain_index], item[1]) for item in pickles]

len(nouns), len(verbs)

In [None]:
class BasisSum(Model):
    def __init__(self):
        super().__init__()
        self.basis = tf.Variable(tf.convert_to_tensor([verbs[verb] for verb in verbs]), trainable=True, name="verb_basis")
         
        self.dense = layers.Dense(self.basis.shape[0], kernel_regularizer=regularizers.L2(0.03), use_bias=False)

    @tf.function(reduce_retracing=True)
    def call(self, x):
        x = self.dense(x)
        x = tf.einsum("bi, ij -> bj", x, self.basis)
        
        return x

In [None]:
def l2(a, b):
    return norm(np.subtract(a, b))

In [None]:
total = 500
pbar = tqdm(range(total))
correct_count = 0

for i in pbar:
    model = BasisSum()
    loss = losses.MeanSquaredError()
    opt = optimizers.Adam(0.01)

    random.shuffle(pickles)

    x = np.array([item[0] for item in pickles])
    y = [item[1] for item in pickles]
    y = np.array([nouns[item] for item in y])

    x, y = tf.cast(x, tf.dtypes.float32), tf.cast(y, tf.dtypes.float32)
    
    train_x, test_x = x[:-2], x[-2:]
    train_y, test_y = y[:-2], y[-2:]

    for j in range(200):
        with tf.GradientTape() as tape:
            pred_y = model(train_x)
            batchloss = loss(train_y, pred_y)
            grads = tape.gradient(batchloss, model.trainable_variables)
            opt.apply_gradients(zip(grads, model.trainable_variables))

    pred = model(test_x)
    t1, t2 = test_y.numpy()
    t1, t2 = t1.flat, t2.flat
    p1, p2 = pred.numpy()
    p1, p2 = p1.flat, p2.flat
    
    correct = l2(t1, p1) + l2(t2, p2)
    incorrect = l2(t1, p2) + l2(t2, p1)

    correct_count += int(correct < incorrect)

    pbar.set_description(f"accuracy: {correct_count / (i + 1):.3f}")