In [2]:
import os
import torch
import time
import torch.nn as nn
from cell import utils, analysis, plot_utils
from torch.autograd import Variable
from torch.nn import functional as F
import matplotlib.pyplot as plt

from cell.Word2vec import prepare_vocab, dataloader, wv

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
length = 10000
p = 1
q = 1
N = 1
batch_size = 2000
walk_filename = "walk_node21_32_removed.csv"
roi = "VISp"
project_name = "NPP_GNN_project"
layer_class = "single_layer"
layer = "base_unnormalized_allcombined"
walk_type= "Directed_Weighted_node2vec"
window = 2

In [47]:
datasets = {}

for (layer, walk_filename) in [("base_unnormalized_allcombined", "walk_node21_32_removed.csv"),
                               ("Sst-Sstr1", "walk_0.csv"),
                               ("Sst-Sstr2", "walk_0.csv"),
                               ("Vip-Vipr1", "walk_0.csv"),
                               ("Vip-Vipr2", "walk_0.csv")]:
    
    walk_dir = utils.get_walk_dir(roi,
                                  project_name, 
                                  N, 
                                  length, 
                                  p, 
                                  q, 
                                  layer_class, 
                                  layer, 
                                  walk_type) 
    path = os.path.join(walk_dir, walk_filename)
    corpus = utils.read_list_of_lists_from_csv(path)
    vocabulary = prepare_vocab.get_vocabulary(corpus)
    
    print(f'lenght of vocabulary: {len(vocabulary)}')
    
    word_2_index = prepare_vocab.get_word2idx(vocabulary, padding=True)
    index_2_word = prepare_vocab.get_idx2word(vocabulary, padding=True)
    datasets[layer] = [word_2_index]
    datasets[layer].append(index_2_word)
    
    tuples = prepare_vocab.MCBOW_get_word_context_tuples(corpus, window=window)
    dataset = dataloader.MCBOW_WalkDataset(tuples, word_2_index)
    datasets[layer].append(dataset)
    
    datasets[layer].append(len(vocabulary))

lenght of vocabulary: 91
a node called pad is added for padding and its index is zero
a node called pad is added for padding and its index is zero
MCBOW by default adds a padding node called pad with index zero
There are 910000 pairs of target and context words
lenght of vocabulary: 89
a node called pad is added for padding and its index is zero
a node called pad is added for padding and its index is zero
MCBOW by default adds a padding node called pad with index zero
There are 890000 pairs of target and context words
lenght of vocabulary: 91
a node called pad is added for padding and its index is zero
a node called pad is added for padding and its index is zero
MCBOW by default adds a padding node called pad with index zero
There are 910000 pairs of target and context words
lenght of vocabulary: 91
a node called pad is added for padding and its index is zero
a node called pad is added for padding and its index is zero
MCBOW by default adds a padding node called pad with index zero
The

In [48]:
datasets['base_unnormalized_allcombined']

[{'pad': 0,
  '69': 1,
  '6': 2,
  '10': 3,
  '84': 4,
  '87': 5,
  '31': 6,
  '42': 7,
  '90': 8,
  '4': 9,
  '37': 10,
  '27': 11,
  '38': 12,
  '36': 13,
  '46': 14,
  '68': 15,
  '89': 16,
  '50': 17,
  '58': 18,
  '1': 19,
  '12': 20,
  '2': 21,
  '49': 22,
  '82': 23,
  '74': 24,
  '86': 25,
  '81': 26,
  '53': 27,
  '3': 28,
  '14': 29,
  '17': 30,
  '75': 31,
  '40': 32,
  '80': 33,
  '43': 34,
  '54': 35,
  '52': 36,
  '51': 37,
  '91': 38,
  '28': 39,
  '59': 40,
  '33': 41,
  '57': 42,
  '34': 43,
  '7': 44,
  '24': 45,
  '67': 46,
  '66': 47,
  '29': 48,
  '77': 49,
  '55': 50,
  '9': 51,
  '8': 52,
  '73': 53,
  '23': 54,
  '60': 55,
  '47': 56,
  '25': 57,
  '76': 58,
  '62': 59,
  '72': 60,
  '44': 61,
  '13': 62,
  '78': 63,
  '0': 64,
  '65': 65,
  '79': 66,
  '48': 67,
  '70': 68,
  '56': 69,
  '30': 70,
  '85': 71,
  '39': 72,
  '71': 73,
  '5': 74,
  '35': 75,
  '20': 76,
  '92': 77,
  '64': 78,
  '63': 79,
  '83': 80,
  '19': 81,
  '41': 82,
  '22': 83,
  '61': 84,

In [49]:
def get_node_intersections(datasets, base_layer_name):
    node_intersections = {}
    for k, v in datasets.items():
        l1 = set(datasets[k][0])
        l2 = set(datasets[base_layer_name][0])
        node_intersections[k] = set(l1).intersection(l2)
    
    return node_intersections

In [50]:
base_layer_name = "base_unnormalized_allcombined"
layers = ["Sst-Sstr1", "Sst-Sstr2", "Vip-Vipr1", "Vip-Vipr2"]

node_intersections = get_node_intersections(datasets, base_layer_name)
node_intersections.keys()

dict_keys(['base_unnormalized_allcombined', 'Sst-Sstr1', 'Sst-Sstr2', 'Vip-Vipr1', 'Vip-Vipr2'])

In [51]:
class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)

In [52]:
datasets.keys()

dict_keys(['base_unnormalized_allcombined', 'Sst-Sstr1', 'Sst-Sstr2', 'Vip-Vipr1', 'Vip-Vipr2'])

In [53]:
datasets['base_unnormalized_allcombined'][2]

<cell.Word2vec.dataloader.MCBOW_WalkDataset at 0x7fdd70943e50>

In [54]:
def build_data_loader(datasets, batch_size, shuffle=True, drop_last=True, num_workers=1):
    data_loader = torch.utils.data.DataLoader(
        ConcatDataset(*[datasets[k][2] for k in datasets.keys()]),
        batch_size=batch_size, 
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers)
    return {k:i for i,k in enumerate(datasets.keys())}, data_loader

In [55]:
arm_keys, data_loader = build_data_loader(datasets, batch_size=2000, shuffle=False)

In [56]:
for batch_idx, (data1, data2, data3, data4, data4) in enumerate(data_loader):
    print(data4)
    break

[tensor([64, 14, 37,  ..., 76, 17, 76]), tensor([[ 0,  0, 14, 37],
        [ 0, 64, 37, 35],
        [64, 14, 35, 76],
        ...,
        [76, 76, 17, 76],
        [76, 76, 76, 17],
        [76, 17, 17, 76]])]


### Take care of index in different arms and different number of nodes in different arms

In [44]:
arm_keys

{'base_unnormalized_allcombined': 0,
 'Sst-Sstr1': 1,
 'Sst-Sstr2': 2,
 'Vip-Vipr1': 3,
 'Vip-Vipr2': 4}

In [57]:
# v_0 = pd.DataFrame(torch.stack(emb[0]).detach().numpy(), 
#                    index=datasets['base_unnormalized_allcombined'][1].values())

# v_1 = pd.DataFrame(torch.stack(emb[1]).detach().numpy(), 
#                    index=datasets['Sst-Sstr1'][1].values())

# v_0.index.name = "cluster_id"
# v_1.index.name = "cluster_id"

# merged = v_1.merge(v_0, on='cluster_id')
# v_0 = merged[['0_x', '1_x']]
# v_1 = merged[['0_y', '1_y']]

# v_0 = torch.tensor(np.array(v_0))
# v_1 = torch.tensor(np.array(v_1))
# F.mse_loss(v_0, v_1)

In [58]:
# loss_joint = 0 

# base_arm = arm_keys[base_layer_name]
# for arm, (k, v) in enumerate(arm_keys.items()):
#     print(arm, k, v)
#     idx0 = [datasets[base_layer_name][0][i] for i in node_intersections[k]]
#     idx1 = [datasets[k][0][i] for i in node_intersections[k]]
#     loss_joint += F.mse_loss(torch.index_select(input=torch.stack(emb[v]), 
#                                                 dim=0, 
#                                                 index=torch.tensor(idx1), 
#                                                 out=None),
#                              torch.index_select(input=torch.stack(emb[base_arm]), 
#                                                 dim=0, 
#                                                 index=torch.tensor(idx0), 
#                                                 out=None))
# print(loss_joint)

In [61]:
def loss_CMCBOW(prediction, target, emb, arm_keys, base_layer_name, node_intersections, n_arm=2):
    
    base_arm = arm_keys[base_layer_name]
    loss_indep = [None] * n_arm
    loss_joint = [None] * n_arm
    
    for arm, (k, v) in enumerate(arm_keys.items()):
        
        loss_indep[arm] = F.cross_entropy(prediction[arm], target[arm])
        
        idx0 = [datasets[base_layer_name][0][i] for i in node_intersections[k]]
        idx1 = [datasets[k][0][i] for i in node_intersections[k]]
        loss_joint[arm] = F.mse_loss(torch.index_select(input=torch.stack(emb[v]), 
                                                    dim=0, 
                                                    index=torch.tensor(idx1),
                                                    out=None),
                                 torch.index_select(input=torch.stack(emb[base_arm]), 
                                                    dim=0, 
                                                    index=torch.tensor(idx0), 
                                                    out=None))
    loss = sum(loss_indep) + sum(loss_joint)

    return loss

In [62]:
# from torch.nn import functional as F

# i = 10
# arm = 2
# print(F.cross_entropy(predict[arm][[i]], target_data[arm][[i]]))

# sf = F.softmax(predict[arm][i], dim=0)
# loss = -1 * torch.log(sf)
# print(loss[target_data[arm][i]])

### Coupled MCBOW_Word2Vec

In [189]:
class CMCBOW_Word2Vec(nn.Module):
    """
    """
    def __init__(self, vocab_size=[93], embedding_size=2, n_arm=1, padding_idx=0):
        """
        """
        super(CMCBOW_Word2Vec, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.n_arm = n_arm
        
        self.embeddings = nn.ModuleList([nn.Embedding(vocab_size[i],
                                                      embedding_size, 
                                                      padding_idx=padding_idx) 
                                         for i in range(n_arm)])
        
        self.linear = nn.ModuleList([nn.Linear(embedding_size,
                                               vocab_size[i]) 
                                     for i in range(n_arm)])
        
        self.batch_norm = nn.ModuleList([nn.BatchNorm1d(num_features=embedding_size,
                                                        eps=1e-10, 
                                                        momentum=0.1, 
                                                        affine=False) 
                                         for i in range(n_arm)])
                        

    def encoder(self, context_words, arm):
        h1 = torch.mean(self.embeddings[arm](context_words), dim=1)
        node_embeddings = [self.embeddings[arm](torch.tensor(i)) for i 
                           in range(self.vocab_size[arm])]
        
        return node_embeddings ,h1

    def decoder(self, mean_context, arm):
        h2 = self.linear[arm](self.batch_norm[arm](mean_context))
        return h2

    def forward(self, context_words):
        emb = [None] * self.n_arm
        predictions = [None] * self.n_arm

        for arm in range(self.n_arm):
            node_embeddings , mean_context = self.encoder(context_words[arm], arm)
            emb[arm] = node_embeddings
            predictions[arm] = self.decoder(mean_context, arm)
            
        return emb, predictions


In [190]:
embedding_size = 2
learning_rate = 0.001
n_epochs = 1
n_arm=5

In [191]:
model = CMCBOW_Word2Vec(embedding_size=embedding_size, 
                        vocab_size=[v[3] + 1 for (k, v) in datasets.items()],
                        n_arm=n_arm, 
                        padding_idx=0).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
training_loss = []


for epoch in range(n_epochs):
    losses = []
    t0 = time.time()
    for batch_idx, all_data in enumerate(data_loader):
        target_data = [data[0].to(device) for data in all_data]
        context_data = [data[1].to(device) for data in all_data]
        optimizer.zero_grad()
        emb, predict = model(context_data)
        loss = loss_CMCBOW(prediction=predict, 
                           target=target_data, 
                           arm_keys=arm_keys, 
                           emb=emb, 
                           n_arm=n_arm, 
                           base_layer_name=base_layer_name, 
                           node_intersections=node_intersections) 
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        
    t1 = time.time()
    print('time is %.2f' % (t1 - t0))
        
    training_loss.append(np.mean(losses)) 
    print(f'epoch: {epoch+1}/{n_epochs}, loss:{np.mean(losses):.4f}')

time is 91.76
epoch: 1/1, loss:26.5933


In [193]:
target_data

[tensor([21,  9,  4,  ..., 39, 58, 67]),
 tensor([56, 56, 56,  ..., 56, 56, 56]),
 tensor([61, 61, 61,  ..., 61, 61, 61]),
 tensor([91, 91, 91,  ..., 91, 91, 91]),
 tensor([76, 37, 67,  ..., 28, 37, 76])]

In [199]:
loss

tensor(22.7799, grad_fn=<AddBackward0>)

In [198]:
predict[0].shape

torch.Size([2000, 92])

In [197]:
target_data[0].shape

torch.Size([2000])

In [132]:
cldf = utils.read_visp_npp_cldf()
vectors = model.embeddings[0].weight.detach().numpy()

data = analysis.summarize_walk_embedding_results(gensim_dict={"model": vectors},
                                                 index=index_2_word.values(),
                                                 ndim=2, 
                                                 cl_df=cldf, 
                                                 padding_label="pad")


Reading cldf from: //Users/fahimehb/Documents/NPP_GNN_project/dat/cl_df_VISp_annotation.csv


In [136]:
model_dir = utils.get_model_dir(project_name, 
                                roi, 
                                N, 
                                length, 
                                p, 
                                q, 
                                layer_class, 
                                layer, 
                                walk_type)

model_name = utils.get_model_name(size=embedding_size, 
                                  iter=n_epochs, 
                                  window=2, 
                                  lr=learning_rate, 
                                  batch_size=batch_size,
                                  opt_add="test")

In [137]:
data.to_csv(os.path.join(model_dir, model_name))

FileNotFoundError: [Errno 2] No such file or directory: '//Users/fahimehb/Documents/NPP_GNN_project/models/VISp/single_layer/Directed_Weighted_node2vec/N_1_l_10000_p_1_q_1/Vip-Vipr1/model_size_2_iter_10_window_2_lr_0.001_bs_2000_test.csv'

In [134]:
model_name

'model_size_2_iter_10_window_2_lr_0.001_bs_2000_test.csv'