In [1]:
import numpy as np
from numpy import dot
from numpy.linalg import norm
import random
import pickle
from tqdm import tqdm
import gc
from itertools import combinations

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
from keras import models, layers, losses, optimizers, regularizers, Model

2023-12-26 22:48:33.107879: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-26 22:48:33.107912: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-26 22:48:33.107943: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
gc.collect()

0

Loading in brain scans and arranging them

In [3]:
brain_index = pickle.load(open("../data/support.pkl", "rb"))
nouns, verbs = pickle.load(open("../data/vecs.pkl", "rb"))

pickles = [pickle.load(open(f"../data/pickles/{i}.pkl", "rb")) for i in range(1)]
pickles = [item for sublist in pickles for item in sublist]
pickles = sorted(pickles, key=lambda x: x[1])
pickles = [[item for item in pickles if item[1] == noun] for noun in nouns]
pickles = [(np.add.reduce([item[0] for item in sublist]) / len(sublist), sublist[0][1]) for sublist in pickles]
pickles = [(item[0][brain_index], item[1]) for item in pickles]

len(nouns), len(verbs)

(60, 25)

In [4]:
def l2(a, b):
    return norm(np.subtract(a, b))

Semi-linear basis-sum model

In [5]:
class BasisSum(Model):
    def __init__(self):
        super().__init__()
        self.basis = tf.Variable(tf.convert_to_tensor([verbs[verb] for verb in verbs]), trainable=False, name="verb_basis")
        self.d1 = layers.Dense(64, activation="relu")
        self.d2 = layers.Dense(32, activation="relu")
        self.dn = layers.Dense(self.basis.shape[0], activation="sigmoid")
        
    @tf.function(reduce_retracing=True)
    def call(self, x):
        x = self.d1(x)
        x = self.d2(x)
        x = self.dn(x)
        x = x / tf.reduce_sum(x, axis=-1, keepdims=True)
        
        x = tf.einsum("bi,ij->bj", x, self.basis)
        
        return x

Training loop

In [6]:
batch_size = 64
pbar = tqdm(combinations(range(60), 58), total=1770)

x = np.array([item[0] for item in pickles])
y = [item[1] for item in pickles]
y = np.array([nouns[item] for item in y])

x, y = tf.cast(x, tf.dtypes.float32), tf.cast(y, tf.dtypes.float32)
correct_count = 0

for i, comb in enumerate(pbar):
    comp = [j for j in range(60) if j not in comb]

    model = BasisSum()
    loss = losses.MeanSquaredError()
    opt = optimizers.Adam(0.001)

    train_x, test_x = tf.gather(x, comb), tf.gather(x, comp)
    train_y, test_y = tf.gather(y, comb), tf.gather(y, comp)

    batchlosses = []
    for j in range(5000):
        idx1 = tf.random.uniform(shape=[batch_size], minval=0, maxval=tf.shape(train_x)[0], dtype=tf.int32)
        idx2 = tf.random.uniform(shape=[batch_size], minval=0, maxval=tf.shape(train_x)[0], dtype=tf.int32)

        batch_x1, batch_y1 = tf.gather(train_x, idx1), tf.gather(train_y, idx1)
        batch_x2, batch_y2 = tf.gather(train_x, idx2), tf.gather(train_y, idx2)

        ratios = tf.random.uniform((len(batch_x1), 1), 0, 1)
        batch_x = batch_x1 * ratios + batch_x2 * (1 - ratios)
        batch_y = batch_y1 * ratios + batch_y2 * (1 - ratios)

        with tf.GradientTape() as tape:
            pred_y = model(batch_x)
            batchloss = loss(batch_y, pred_y)
            grads = tape.gradient(batchloss, model.trainable_variables)
            opt.apply_gradients(zip(grads, model.trainable_variables))

            batchlosses.append(float(batchloss))

        if j % 100 == 0:
            print(j, sum(batchlosses[-100:]) / 100)
            ...
            
    pred = model(test_x)
    t1, t2 = test_y.numpy()
    t1, t2 = t1.flat, t2.flat
    p1, p2 = pred.numpy()
    p1, p2 = p1.flat, p2.flat
    
    correct = l2(t1, p1) + l2(t2, p2)
    incorrect = l2(t1, p2) + l2(t2, p1)

    correct_count += int(correct < incorrect)

    pbar.set_description(f"accuracy: {correct_count / (i + 1):.3f}")

  0%|          | 0/1770 [00:00<?, ?it/s]

0 0.0032242628931999206
100 0.2748565323650837
200 0.23284569680690764
300 0.2140559086203575
400 0.2097366474568844
500 0.21185038074851037
600 0.21143039390444757
700 0.21109800517559052
800 0.20855623215436936
900 0.2101852323114872
1000 0.20782779216766356
1100 0.20925382360816003
1200 0.2093658658862114
1300 0.21002602502703666
1400 0.20877980202436447
1500 0.20747047111392022
1600 0.2105011013150215
1700 0.21010493591427803
1800 0.20805832713842393
1900 0.20935674428939818
2000 0.20683040156960486
2100 0.20868419080972672
2200 0.2089761109650135
2300 0.20921524330973626
2400 0.2068786506354809
2500 0.20866013795137406
2600 0.20840485841035844
2700 0.20907114908099175
2800 0.2076926076412201
2900 0.20738729506731032
3000 0.20762544840574265
3100 0.20777850732207298
3200 0.20706646859645844
3300 0.2064448857307434
3400 0.20801396191120147
3500 0.20728522524237633
3600 0.20816572085022927
3700 0.2063796941936016
3800 0.20517071649432184
3900 0.20615896299481393
4000 0.20834925532341

accuracy: 1.000:   0%|          | 1/1770 [01:12<35:26:52, 72.14s/it]

0 0.003512914180755615
100 0.27605067640542985
200 0.25945144027471545
300 0.24551455706357955
400 0.24256567940115928
500 0.24200015395879745
600 0.24160409614443779
700 0.24073706939816475
800 0.23940415009856225
900 0.23877643585205077
1000 0.23962537541985512
1100 0.23856934502720833
1200 0.23948938965797426
1300 0.2413523504137993
1400 0.23970524609088897
1500 0.23769414514303208
1600 0.23856200009584427
1700 0.2396504619717598
1800 0.23923427656292914
1900 0.23618395507335663
2000 0.23837308809161187
2100 0.2371591244637966
2200 0.23616427898406983
2300 0.2369704918563366
2400 0.23723785683512688
2500 0.23923857048153876
2600 0.23800556033849715
2700 0.23895808562636375
2800 0.23941644743084908
2900 0.23927636429667473
3000 0.2377398493885994
3100 0.2406399980187416
3200 0.2397163274884224
3300 0.23684555113315583
3400 0.24103962868452072
3500 0.22944209724664688
3600 0.22547735765576363
3700 0.22130061492323874
3800 0.22131802409887313
3900 0.2236194059252739
4000 0.221057238429

accuracy: 0.500:   0%|          | 2/1770 [02:25<35:48:51, 72.93s/it]

0 0.003203906416893005
100 0.26491287365555766
200 0.24773427352309227
300 0.24810016185045242
400 0.24568409442901612
500 0.24598790258169173
600 0.24589310616254806
700 0.24702388510107995
800 0.24405692994594574
900 0.24600760832428933
1000 0.24683998689055442
