In [None]:
import torch
import torch.nn as nn
import numpy as np
from config import GPTConfig
# from dataloader import DL
from pathlib import Path
import tiktoken
from ctmModules import SynapseUNet, CTM, NLM
from modelgpt import GPT


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class DL:  # DataLoader
    def __init__(self, B, T):
        self.B = B
        self.T = T
        datapath = Path.cwd().parent.parent.parent / "data" / "input.txt"
        enc = tiktoken.get_encoding("gpt2")
        with open(datapath, "r") as f:
            text = f.read()
        tokens = enc.encode(text)
        self.tokens = torch.tensor(tokens)
        self.curr_pos = 0

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.curr_pos : self.curr_pos + B * T + 1]
        x = buf[:-1].view(B, T)
        y = buf[1:].view(B, T)

        self.curr_pos += B * T
        if self.curr_pos + B * T + 1 > len(self.tokens):
            self.curr_pos = 0
        return x, y

In [None]:
total_batch_size = 2**11
B = 2 # Num Batches
T = 1024 # Tokens per Batch
grad_accum_steps = total_batch_size // (B * T)
train_loader = DL(B=B, T=T)
x, y = train_loader.next_batch()
# Input of the UnapseNet are logits shape B,T,C
iterations = 1
config = GPTConfig()
d_model = config.n_embd
memory_length = 25
hidden_dims = 4
dropout=0.1
heads = 4

gpt = GPT(config) # GPT Model - lm_head layer
q_proj = nn.LazyLinear(d_model) # q projecton
kv_proj = nn.Sequential(nn.LazyLinear(d_model), nn.LayerNorm(d_model)) 


wte = nn.Embedding(config.vocab_size, d_model)
wpe = nn.Embedding(config.block_size, d_model)

ln_1 = nn.LayerNorm(config.n_embd)
attention = nn.MultiheadAttention(d_model, heads, dropout, batch_first=True)

synnet = SynapseUNet(out_dims=d_model, depth=2, minimum_width=32, dropout=dropout)
for param in synnet.parameters():
    print(param.name, param.data.shape)
nlm = nn.Sequential(
    NLM(memory_length, hidden_dims, d_model), # ()
    nn.GLU(),
)
ouput_proj = nn.Sequential(nn.LazyLinear(d_model))

None torch.Size([0])
None torch.Size([0])
None torch.Size([768])
None torch.Size([768])


In [74]:
n_synch_out = int(d_model//2)
n_synch_action= int(d_model-n_synch_out)
# n_synch_out_rep = (n_synch_out ) //2
# n_synch_act_rep = (n_synch_action) //2

# Random Pairing left right divide
def init_lr_neurons(d_model, n_synch):
    neuron_indices_left = torch.from_numpy(np.random.choice(np.arange(int(d_model)), size=n_synch))
    neuron_indices_right = torch.from_numpy(np.random.choice(np.arange(int(d_model)), size=n_synch))
    return neuron_indices_left, neuron_indices_right

action_neuron_indices_left, action_neuron_indices_right = init_lr_neurons(d_model, n_synch_action)
out_neuron_indices_left, out_neuron_indices_right = init_lr_neurons(d_model, n_synch_out)
decay_params_action = torch.zeros(n_synch_action)
decay_params_out = torch.zeros(n_synch_out)

print(action_neuron_indices_left[:10],action_neuron_indices_left[-10:])

decay_alpha_action, decay_beta_action = None, None
decay_params_action= torch.clamp(decay_params_action, 0, 15)  # clamp between 0 and 15, param.data
decay_params_out = torch.clamp(decay_params_out, 0, 15)
r_action = torch.exp(-decay_params_action).unsqueeze(0).repeat(B, 1)
r_out = torch.exp(-decay_params_out).unsqueeze(0).repeat(B, 1)
print(f"decay action: {decay_params_action.shape}, decay out: {decay_params_out.shape}")
print(f"r action: {r_action.shape}, r out: {r_out.shape}")

tensor([216, 219, 434, 240, 136, 408, 598, 232, 436, 473]) tensor([ 97, 451, 588,  82, 570, 760, 688, 187, 585, 148])
decay action: torch.Size([384]), decay out: torch.Size([384])
r action: torch.Size([2, 384]), r out: torch.Size([2, 384])


In [75]:
 # Recurrent State
start_activated_state = torch.zeros((d_model)).uniform_(-d_model**-0.5, d_model**-0.5) # T
start_state_trace = torch.zeros((d_model, memory_length)).uniform_(-  (d_model+memory_length)**-0.5, (d_model+memory_length)**-0.5) # 
activated_state = start_activated_state.unsqueeze(0).expand(B,-1) # B,H
state_trace = start_state_trace.unsqueeze(0).expand(B,-1,-1) # B,H,T
print(activated_state.shape, state_trace.shape)

torch.Size([2, 768]) torch.Size([2, 768, 25])


In [None]:
# Compute features
out = gpt(x) # B, T, D

pos = torch.arange(0, T, dtype=torch.long, device=x.device)  # shape (T)
pos_enc = wpe(pos) # T, D
comb_feat = (out+pos_enc).flatten(2) # B, T, D
print(comb_feat.shape)
kv = kv_proj(comb_feat)
print(kv)
kv.shape

torch.Size([2, 786432])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x786432 and 768x768)

In [77]:
def compute_sync(activated_state, decay_alpha, decay_beta, r, synch_type):
    '''
    Params: 
        activated_state: torch tensor represents current state embeddings
        decay_alpha: exponential decay rate
        decay_beta: 
        r: decay rate
        synch_type
    '''
    if synch_type == 'action': # Get action parameters
        neurons_left = action_neuron_indices_left
        neurons_right = action_neuron_indices_right
    elif synch_type == 'out': # Get input parameters
        neurons_left = out_neuron_indices_left
        neurons_right = out_neuron_indices_right
    # compute pairwise
    left = activated_state[:, neurons_left]
    right = activated_state[:, neurons_right]
    pairwise_product = left*right
    print("pairwise", pairwise_product.shape)
    # Compute synchronisation recurrently
    if decay_alpha is None or decay_beta is None:
        decay_alpha = pairwise_product
        decay_beta = torch.ones_like(pairwise_product)
    else:
        print("decay_alpha",decay_alpha.shape)
        print("r", r.shape)

        decay_alpha = r * decay_alpha + pairwise_product
        decay_beta = r * decay_beta + 1
    sync = decay_alpha * (decay_beta **-0.5)
    return sync, decay_alpha, decay_beta
        

In [None]:
pre_activations = []
post_activations = []
r_action, r_out = torch.exp(-decay_params_action).unsqueeze(0).repeat(B, 1), torch.exp(-decay_params_out).unsqueeze(0).repeat(B, 1)
activated_state = start_activated_state.unsqueeze(0).expand(B,-1)
print("r_action", r_action.shape, r_out.shape)

_, decay_alpha_out, decay_beta_out = compute_sync(activated_state, None, None, r_out, 'out')
for _ in range(iterations):
    sync, decay_alpha_action, decay_beta_action = compute_sync(activated_state, decay_alpha_action, decay_beta_action, r_action,'action')
    print(sync.shape, decay_alpha_action.shape, decay_beta_action.shape)
    
    q = q_proj(sync)
    print("q: ",q.shape)
    out = ln_1(out)
    attn_out = attention(q, kv, kv, average_attn_weights=False, ne) 
    
    print(attn_out.squeeze(1).shape) #(B,T,C)
    logits = torch.cat((attn_out.squeeze(1),activated_state), dim = -1)
    print("LOGITS SHAPE, (B,T,C)", logits.shape)
    state = synnet(logits) # (B, T, C)

    # state_trace is the history of incoming pre-activations
    print("State Trace:", state_trace[:,:,1:].shape, "State:", state.unsqueeze(-1).shape) # state_trace: (B, C, M-1) state:(B,H,T)
    state_trace = torch.cat((state_trace[:,:,1:], state.unsqueeze(-1)), dim=-1)
    print(state_trace.shape)# Should be batch, d_model, Mem len
    nlm_logits = nlm(state_trace)


    print(start_activated_state.shape, activated_state.shape,  logits.shape, state_trace.shape, nlm_logits.shape) #(B, T, C)\


r_action torch.Size([2, 384]) torch.Size([2, 384])
pairwise torch.Size([2, 384])
pairwise torch.Size([2, 384])
torch.Size([2, 384]) torch.Size([2, 384]) torch.Size([2, 384])
q:  torch.Size([2, 768])


AssertionError: For unbatched (2-D) `query`, expected `key` and `value` to be 2-D but found 3-D and 3-D tensors respectively