In [1]:
import sys
import os

import matplotlib.pylab as pl
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as mcol
import matplotlib.cm as cm

import pickle

from transformer import Transformer
from data import create_cat_data_rbf, create_cat_data_grid, create_cat_data_random_grid, create_weights
from config import config
from train import *
from eval import *
from util import *
import plot

from IPython.display import display

Query shape:  (6, 6)
Key shape:  (6, 6)
Value shape:  (6, 6)
Projection shape:  (6, 6)
Embedding matrix shape:  (7, 6)


In [2]:
c_size = 100
i_size = 2
e_size= 5
cats = 20
input_range = 1
rng = jax.random.PRNGKey(0)

rng, new_rng, new_rng2, new_rng3, new_rng4, new_rng5, new_rng6 = jax.random.split(rng, num=7)
W_e = jax.random.normal(new_rng, shape=(e_size, cats))

# draw training data and query
x = jax.random.uniform(new_rng2, shape=[c_size, i_size],
                           minval=-input_range, maxval=input_range)

# Define range and step size
x_test = np.arange(-1, 1.02, 0.02)
y_test = np.arange(-1, 1.02, 0.02)

# Generate grid
xx, yy = np.meshgrid(x_test, y_test)

x_query = np.stack([xx.ravel(), yy.ravel()], axis=-1)

c_idx = jax.random.choice(new_rng4, a=cats, shape=(4,), replace=False)
c = W_e[:, c_idx]

x_0 = x[:, 0]
x_1 = x[:, 1]

temp_0 = jnp.where(x_0 >= 0, 2, 0)
temp_1 = jnp.where(x_1 >= 0, 3, 2)
temp = jnp.concatenate([temp_0[:, None], temp_1[:, None]], axis=1)
quad = jnp.sum(temp, axis=1) - 2

# calculate f
f = c[:, quad]

# calculate the probability logits of each class
probs = jax.nn.softmax(f.T @ W_e, axis=1)

# randomly draw labels for each sample
y_data = jax.random.categorical(new_rng5, f.T @ W_e, axis=1)

v_data_full = jax.nn.one_hot(y_data, num_classes=cats)

In [3]:
c_idx

Array([13, 12, 18, 17], dtype=int32)

In [4]:
c.shape

(5, 4)

In [5]:
W_e.shape

(5, 20)

In [6]:
c.T @ W_e

Array([[ 2.505779  ,  1.6474655 , -4.323917  , -0.55539024,  1.9624969 ,
        -2.9421191 ,  0.8105659 ,  1.4069335 ,  1.3965553 , -1.2213427 ,
        -0.10855858, -0.20429921,  2.2066185 ,  2.8055956 , -1.3939929 ,
        -3.0629728 ,  0.7434795 ,  1.3788484 ,  0.4715544 , -1.3931487 ],
       [ 4.5311537 , -0.4228973 , -0.55117846,  0.2235102 ,  2.5642157 ,
        -1.680838  ,  2.2476897 ,  1.8882071 ,  4.47393   , -0.25071353,
        -3.4734242 , -0.33041313,  4.887509  ,  2.2066185 , -2.3211799 ,
        -1.9510038 ,  2.9949775 ,  0.8091926 ,  3.1151812 , -1.6347799 ],
       [ 3.4582825 , -0.10158603,  0.90677947, -1.1702279 ,  1.9772658 ,
         0.38071963,  3.9239843 , -0.36842188,  3.3568168 ,  1.0498735 ,
        -5.3140836 , -1.6717342 ,  3.1151812 ,  0.4715544 , -0.39420208,
         1.3273163 ,  3.044857  , -2.2156801 ,  4.7600756 ,  4.8035116 ],
       [-0.04188072, -1.0871061 , -0.01933361,  2.2015812 ,  0.42537338,
        -0.59926885, -2.6229773 ,  0.7029242 , -

In [7]:
jnp.argmax(c.T @ W_e, axis=1)

Array([13, 12, 19, 17], dtype=int32)

In [8]:
# get quadrant of x_query
x_0_query = x_query[:, 0]
x_1_query = x_query[:, 1]
temp_0_query = jnp.where(x_0_query >= 0, 2, 0)
temp_1_query = jnp.where(x_1_query >= 0, 3, 2)
temp_query = jnp.concatenate([temp_0_query[:, None], temp_1_query[:, None]], axis=1)
quad_query = jnp.sum(temp_query, axis=1) - 2

with jnp.printoptions(threshold=sys.maxsize):
    print(quad_query)

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 

In [10]:
# calculate f(x_query)
f_target = c[:, quad_query]

# calculate the probability logits of each class for the query
probs_target = jax.nn.softmax(f_target.T @ W_e, axis=1)

# randomly draw label for query
y_target = jax.random.categorical(new_rng6, f_target.T @ W_e, axis=1)

v_target_full = jax.nn.one_hot(y_target, num_classes=cats)

print(v_target_full.shape)

(10201, 20)


In [14]:
x_query.shape

(10201, 2)

In [13]:
f_target.shape

(5, 10201)

In [11]:
W_e_seq = W_e[:, y_data].T
E_w_e_init = jnp.zeros(shape=(c_size, e_size))
f_init = jnp.zeros(shape=(c_size, e_size))

In [12]:
seq = jnp.concatenate([x, v_data_full - 1 / cats, f_init], axis=-1)
seq = jnp.tile(seq, (x_query.shape[0], 1, 1))  # in-context data sequence
print(seq.shape)

(10201, 100, 27)


In [16]:
target = jnp.expand_dims(jnp.concatenate([x_query, v_target_full - 1 / cats, f_target.T], axis=-1), 1)
target.shape

(10201, 1, 27)

In [18]:
zero =  jnp.expand_dims(jnp.concatenate([x_query, jnp.zeros((x_query.shape[0], cats)) - 1 / cats, jnp.zeros((x_query.shape[0], e_size))], axis=-1), 1)
zero.shape

(10201, 1, 27)

In [19]:
seq = jnp.concatenate([seq, zero], axis=1)
seq.shape

(10201, 101, 27)

In [22]:
jnp.squeeze(jnp.tile(probs, (x_query.shape[0],1,1))).shape

(10201, 100, 20)

In [23]:
probs_target.shape

(10201, 20)

In [None]:
jnp.squeeze(seq), jnp.squeeze(target), jnp.squeeze(jnp.tile(probs, (x_query.shape[0], 1, 1))), jnp.squeeze(probs_target), jnp.squeeze(f), \
        jnp.squeeze(f_target), jnp.squeeze(v_data_full), jnp.squeeze(v_target_full), \
        jnp.squeeze(W_e), jnp.squeeze(v_target_full)