In [1]:
import numpy as np
from numpy import dot
from numpy.linalg import norm
import random
import pickle
from tqdm import tqdm
import gc
from itertools import combinations

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
from keras import models, layers, losses, optimizers, regularizers, Model

In [2]:
gc.collect()

0

In [3]:
brain_index = pickle.load(open("../data/support.pkl", "rb"))
noun_vecs, verb_vecs = pickle.load(open("../data/vecs.pkl", "rb"))

pickles = [pickle.load(open(f"../data/pickles/{i}.pkl", "rb")) for i in range(1)]
pickles = [item for sublist in pickles for item in sublist]
pickles = sorted(pickles, key=lambda x: x[1])
pickles = [[item for item in pickles if item[1] == noun] for noun in noun_vecs]
pickles = [(np.add.reduce([item[0] for item in sublist]) / len(sublist), sublist[0][1]) for sublist in pickles]
pickles = [(item[0][brain_index], item[1]) for item in pickles]

len(noun_vecs), len(verb_vecs)

(60, 25)

In [4]:
def l2(a, b):
    return norm(np.subtract(a, b))

In [5]:
class BasisSum(Model):
    def __init__(self):
        super().__init__()
        self.basis = tf.Variable(tf.convert_to_tensor([verb_vecs[verb] for verb in verb_vecs]), trainable=False, name="verb_basis")
        self.d1 = layers.Dense(64, activation="relu")
        self.d2 = layers.Dense(32, activation="relu")
        self.dn = layers.Dense(self.basis.shape[0], activation="sigmoid")
        
    @tf.function(reduce_retracing=True)
    def call(self, x):
        x = self.d1(x)
        x = self.d2(x)
        x = self.dn(x)
        x = x / tf.norm(x, axis=0, keepdims=True)
        
        x = tf.einsum("bi,ij->bj", x, self.basis)
        
        return x

In [6]:
total = 500
batch_size = 64
pbar = tqdm(combinations(range(60), 58), total=1770)

x = np.array([item[0] for item in pickles])
y = [item[1] for item in pickles]
y = np.array([noun_vecs[item] for item in y])

x, y = tf.cast(x, tf.dtypes.float32), tf.cast(y, tf.dtypes.float32)

correct_count = 0

for i, comb in enumerate(pbar):
    comp = list(set.difference(set(range(60)), set(comb)))

    model = BasisSum()
    loss = losses.MeanSquaredError()
    opt = optimizers.Adam(0.001)

    train_x, test_x = tf.gather(x, comb), tf.gather(x, comp)
    train_y, test_y = tf.gather(y, comb), tf.gather(y, comp)

    batchlosses = []
    for j in range(2000):
        idx1 = tf.random.uniform(shape=[batch_size], minval=0, maxval=tf.shape(train_x)[0], dtype=tf.int32)
        idx2 = tf.random.uniform(shape=[batch_size], minval=0, maxval=tf.shape(train_x)[0], dtype=tf.int32)

        batch_x1, batch_y1 = tf.gather(train_x, idx1), tf.gather(train_y, idx1)
        batch_x2, batch_y2 = tf.gather(train_x, idx2), tf.gather(train_y, idx2)

        ratios = tf.random.uniform((len(batch_x1), 1), 0, 1)
        batch_x = batch_x1 * ratios + batch_x2 * (1 - ratios)
        batch_y = batch_y1 * ratios + batch_y2 * (1 - ratios)

        with tf.GradientTape() as tape:
            pred_y = model(batch_x)
            batchloss = loss(batch_y, pred_y)
            grads = tape.gradient(batchloss, model.trainable_variables)
            opt.apply_gradients(zip(grads, model.trainable_variables))

            batchlosses.append(float(batchloss))

        if j % 100 == 0:
            ...
            print(j, sum(batchlosses[-100:]) / 100)


    pred = model(test_x)
    t1, t2 = test_y.numpy()
    t1, t2 = t1.flat, t2.flat
    p1, p2 = pred.numpy()
    p1, p2 = p1.flat, p2.flat
    
    correct = l2(t1, p1) + l2(t2, p2)
    incorrect = l2(t1, p2) + l2(t2, p1)

    correct_count += int(correct < incorrect)

    pbar.set_description(f"accuracy: {correct_count / (i + 1):.3f}")

  0%|          | 0/1770 [00:00<?, ?it/s]

0 0.015968070030212403
100 0.5107827630639076
200 0.3682992488145828
300 0.3430792862176895
400 0.3250245329737663
500 0.312782416343689
600 0.3031373795866966
700 0.2997839185595512
800 0.2979336825013161
900 0.29517841786146165
1000 0.295569629073143
1100 0.2939473551511764
1200 0.28912455439567564
1300 0.28242976173758505
1400 0.28389204263687134
1500 0.2835342675447464
1600 0.27935258686542513
1700 0.27561098396778105
1800 0.27259741082787514
1900 0.2752645722031593


accuracy: 1.000:   0%|          | 1/1770 [00:29<14:30:27, 29.52s/it]

0 0.016252039670944213
100 0.6302044874429703
200 0.4730979114770889
300 0.4444975742697716
400 0.42260519474744795
500 0.41276758641004563
600 0.41325831949710845
700 0.40757968455553056
800 0.39038507640361786
900 0.3867751607298851
1000 0.3829701918363571
1100 0.3785504272580147
1200 0.3780233883857727
1300 0.37693129777908324
1400 0.36992094814777376
1500 0.3720728123188019
1600 0.370182129740715
1700 0.3683420717716217
1800 0.3586224690079689
1900 0.3503899320960045


accuracy: 1.000:   0%|          | 2/1770 [00:58<14:28:15, 29.47s/it]

0 0.0170040225982666
100 0.4939456009864807
200 0.3584054005146027
300 0.33097593367099765
400 0.3171682533621788
500 0.3114148885011673
600 0.3048152217268944
700 0.2963544264435768
800 0.2810323646664619
900 0.2759032914042473
1000 0.27315658003091814
1100 0.2721414378285408
1200 0.2704558126628399
1300 0.2685805158317089
1400 0.26818350672721863
1500 0.2663959664106369
1600 0.26357258498668673
1700 0.2631925208866596
1800 0.26288542136549947
1900 0.2654390811920166


accuracy: 0.667:   0%|          | 3/1770 [01:27<14:19:30, 29.19s/it]

0 0.016038106679916384
100 0.544783550798893
200 0.39629640072584155
300 0.3592594540119171
400 0.3345611456036568
500 0.3263668105006218
600 0.31285223841667176
700 0.30673236548900606
800 0.3033199501037598
900 0.2993731987476349
1000 0.2971143555641174
1100 0.2967127767205238
1200 0.2961179393529892
1300 0.29254108160734177
1400 0.29308164477348325
1500 0.29233505636453627
1600 0.28440673112869264
1700 0.2878247222304344
1800 0.28704928964376447
1900 0.2826374337077141


accuracy: 0.750:   0%|          | 4/1770 [01:56<14:10:47, 28.91s/it]

0 0.016937599182128907
100 0.5178254932165146
200 0.3770367643237114
300 0.3242714136838913
400 0.3057513228058815
500 0.2988075363636017
600 0.29318019449710847
700 0.2875901508331299
800 0.2836294573545456
900 0.2798945665359497
