In [1]:
import sys

import pandas as pd

if ".." not in sys.path:
    sys.path.insert(0, "..")

In [2]:
from datasets import OrganoidDataset
from models import VQVAE
from configs.vqvae import get_config
import torch

In [3]:
config = get_config()
config.n_layers=2
config.hidden_features=32
config.embed_entries=2
config.embed_channels=8
config.embed_dim=1

data = OrganoidDataset(data_dir='/data/PycharmProjects/cytof_benchmark/data/organoids')
model = VQVAE(config=config).to(config.device)

X_train, y_train = data.train
X_val, y_val = data.val
X_test, y_test = data.test

In [4]:
checkpoint_file = '/data/PycharmProjects/cytof_benchmark/logs/VQVAE/exp_8/run_23/model.pth'

In [5]:
model_checkpoint = torch.load(checkpoint_file)
model.load_state_dict(model_checkpoint)

<All keys matched successfully>

In [6]:
model.eval()
encoded_val = model.encoder.forward(torch.Tensor(X_val).to('cuda'))
encoded_test = model.encoder.forward(torch.Tensor(X_test).to('cuda'))
encoded_val

tensor([[ 0.3691,  3.2805,  0.8540,  ...,  2.2586,  5.9430,  8.4871],
        [-1.8386,  3.0424,  7.3720,  ..., -0.2954,  4.7616,  1.8147],
        [ 1.0163,  0.7789,  0.7059,  ...,  0.7262,  3.9918,  0.1125],
        ...,
        [-0.9265,  1.3669, -0.0152,  ..., -0.6471,  4.8156,  2.1567],
        [-1.1816,  0.9818,  3.9361,  ...,  1.1103,  7.5415,  8.1801],
        [ 0.7615,  0.4349,  1.2540,  ..., -0.2954,  7.1272,  3.4665]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [7]:
encoded_val.shape

torch.Size([234495, 32])

In [8]:
encoded_test.shape

torch.Size([117248, 32])

In [9]:
outputs_val = list()
outputs_test = list()
with torch.no_grad():
    for codebook in model.codebooks:
        outputs_val.append(codebook.forward(encoded_val)[2].cpu().numpy())
        outputs_test.append(codebook.forward(encoded_test)[2].cpu().numpy())

In [10]:
import pandas as pd
import numpy as np

np.array(outputs_val).T

array([[0, 1, 0, ..., 0, 1, 0],
       [0, 0, 1, ..., 0, 0, 0],
       [1, 0, 1, ..., 0, 0, 1],
       ...,
       [1, 0, 1, ..., 0, 1, 0],
       [0, 1, 0, ..., 1, 1, 0],
       [0, 0, 0, ..., 0, 1, 0]])

In [11]:
np.array(outputs_test).T

array([[1, 1, 0, ..., 0, 1, 0],
       [1, 1, 0, ..., 1, 1, 0],
       [0, 1, 0, ..., 0, 1, 0],
       ...,
       [0, 0, 1, ..., 1, 1, 0],
       [0, 1, 1, ..., 0, 1, 0],
       [0, 1, 0, ..., 0, 1, 0]])

In [12]:
pd.concat([y_val,pd.DataFrame(np.array(outputs_val).T,columns=["VQ_{}".format(i) for i in range(1, 9)])],axis=1).to_csv(
    '/data/PycharmProjects/cytof_benchmark/results/summary/vqvae/latent_8bit_binary_val.csv'
)

pd.concat([y_test,pd.DataFrame(np.array(outputs_test).T,columns=["VQ_{}".format(i) for i in range(1, 9)])],axis=1).to_csv(
    '/data/PycharmProjects/cytof_benchmark/results/summary/vqvae/latent_8bit_binary_test.csv'
)

In [13]:
config = get_config()
config.n_layers=6
config.hidden_features=64
config.embed_entries=256
config.embed_channels=1
config.embed_dim=2

checkpoint_file = '/data/PycharmProjects/cytof_benchmark/logs/VQVAE/exp_8/run_48/model.pth'

In [14]:
model = VQVAE(config=config).to(config.device)
model_checkpoint = torch.load(checkpoint_file)
model.load_state_dict(model_checkpoint)

<All keys matched successfully>

In [15]:
model.eval()
encoded = model.encoder.forward(torch.Tensor(X_val).to('cuda'))
encoded

tensor([[-3.0288,  2.3241, -5.2383,  ..., -6.3466, -1.7437, -0.7820],
        [ 0.3396, -1.2505, -4.6614,  ..., -0.5607, -0.9361,  3.5800],
        [-0.9290, -0.9715, -1.2207,  ..., -0.3555,  1.9868,  1.2538],
        ...,
        [-4.4608, -1.2449, -1.6243,  ..., -0.6221,  2.8256,  3.1274],
        [ 0.5938,  1.8523, -5.2814,  ..., -1.0317, -0.2515,  0.8412],
        [-8.5829, -4.6227, -3.1082,  ..., -1.9564,  3.2567,  2.0712]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [16]:
encoded.shape

torch.Size([234495, 64])

In [55]:
model.codebooks

ModuleList(
  (0): CodeLayer(
    (linear_in): Linear(in_features=64, out_features=256, bias=True)
  )
)

In [54]:
quantize, diff, embed_ind = model.codebooks[0].forward(encoded)

In [69]:
pd.concat([y_val,pd.DataFrame(quantize.detach().cpu().numpy(),columns=['VQVAE1','VQVAE2'])],axis=1).to_csv(
    '/data/PycharmProjects/cytof_benchmark/results/summary/vqvae/latent_8bit_coords.csv'
)

In [67]:
pd.DataFrame(embed_ind.detach().cpu().numpy(),columns=['code']).drop_duplicates()

Unnamed: 0,code
0,88
1,6
2,193
3,19
4,66
5,180
7,67
8,18
9,119
12,52
