In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as nn_init
import torch.nn.functional as F
from torch import Tensor

import typing as ty
import math

In [4]:

class Tokenizer(nn.Module):

    def __init__(self, d_numerical, categories, d_token, bias):
        super().__init__()
        if categories is None:
            d_bias = d_numerical
            self.category_offsets = None
            self.category_embeddings = None
        else:
            d_bias = d_numerical + len(categories)
            category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0)
            self.register_buffer('category_offsets', category_offsets)
            self.category_embeddings = nn.Embedding(sum(categories), d_token)
            nn_init.kaiming_uniform_(self.category_embeddings.weight, a=math.sqrt(5))
            print(f'{self.category_embeddings.weight.shape=}')

        # take [CLS] token into account
        self.weight = nn.Parameter(Tensor(d_numerical + 1, d_token))
        self.bias = nn.Parameter(Tensor(d_bias, d_token)) if bias else None
        # The initialization is inspired by nn.Linear
        nn_init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            nn_init.kaiming_uniform_(self.bias, a=math.sqrt(5))

    @property
    def n_tokens(self):
        return len(self.weight) + (
            0 if self.category_offsets is None else len(self.category_offsets)
        )

    def forward(self, x_num, x_cat):
        x_some = x_num if x_cat is None else x_cat
        assert x_some is not None
        x_num = torch.cat(
            [torch.ones(len(x_some), 1, device=x_some.device)]  # [CLS]
            + ([] if x_num is None else [x_num]),
            dim=1,
        )
    
        x = self.weight[None] * x_num[:, :, None]

        if x_cat is not None:
            x = torch.cat(
                [x, self.category_embeddings(x_cat + self.category_offsets[None])],
                dim=1,
            )
        if self.bias is not None:
            bias = torch.cat(
                [
                    torch.zeros(1, self.bias.shape[1], device=x.device),
                    self.bias,
                ]
            )
            x = x + bias[None]

        return x



In [8]:
cat_train = np.load('/mnt/nas/swethamagesh/tabsyn-fresh/tabsyn/data/adult_cond/X_cat_train.npy', allow_pickle=True)
num_train = np.load('/mnt/nas/swethamagesh/tabsyn-fresh/tabsyn/data/adult_cond/X_num_train.npy')
target = np.load('/mnt/nas/swethamagesh/tabsyn-fresh/tabsyn/data/adult_cond/y_train.npy', allow_pickle=True)

In [12]:
final_cat_train = np.concatenate([target, cat_train], axis=1)

In [13]:
tokenizer = Tokenizer(5, [3, 7, 16, 5, 7, 5, 3, 2, 2], 4, True)

self.category_embeddings.weight.shape=torch.Size([50, 4])


In [19]:
final_cat_train

array([[0, 4, 10, ..., 0, 1, 1],
       [0, 4, 9, ..., 0, 0, 1],
       [0, 1, 9, ..., 0, 0, 1],
       ...,
       [0, 5, 9, ..., 0, 0, 1],
       [0, 4, 10, ..., 0, 0, 1],
       [0, 4, 10, ..., 0, 0, 1]], dtype=object)

In [20]:
torch.from_numpy(final_cat_train.astype(float))

tensor([[ 0.,  4., 10.,  ...,  0.,  1.,  1.],
        [ 0.,  4.,  9.,  ...,  0.,  0.,  1.],
        [ 0.,  1.,  9.,  ...,  0.,  0.,  1.],
        ...,
        [ 0.,  5.,  9.,  ...,  0.,  0.,  1.],
        [ 0.,  4., 10.,  ...,  0.,  0.,  1.],
        [ 0.,  4., 10.,  ...,  0.,  0.,  1.]], dtype=torch.float64)

In [24]:
tokenized_out = tokenizer(torch.from_numpy(num_train), torch.from_numpy(final_cat_train.astype(int)))
tokenized_out.shape

torch.Size([31062, 15, 4])

In [None]:
# 30k 
#  get tokenized output
# (generate constraint for the entire set at once) - [111100011] - c
# constraint for row - 0  & 1 - 0.4  - masked from tokenizd output cxd 


'''
Model load from f'{ckpt_dir}/model.pt'

'''



In [25]:
model_vae = torch.load('/mnt/nas/swethamagesh/tabsyn-fresh/tabsyn/tabsyn/vae/ckpt/adult_cond/model.pt')

In [30]:
for i in model_vae.keys():
    print(i)

VAE.Tokenizer.weight
VAE.Tokenizer.bias
VAE.Tokenizer.category_offsets
VAE.Tokenizer.category_embeddings.weight
VAE.encoder_mu.layers.0.attention.W_q.weight
VAE.encoder_mu.layers.0.attention.W_q.bias
VAE.encoder_mu.layers.0.attention.W_k.weight
VAE.encoder_mu.layers.0.attention.W_k.bias
VAE.encoder_mu.layers.0.attention.W_v.weight
VAE.encoder_mu.layers.0.attention.W_v.bias
VAE.encoder_mu.layers.0.linear0.weight
VAE.encoder_mu.layers.0.linear0.bias
VAE.encoder_mu.layers.0.linear1.weight
VAE.encoder_mu.layers.0.linear1.bias
VAE.encoder_mu.layers.0.norm1.weight
VAE.encoder_mu.layers.0.norm1.bias
VAE.encoder_mu.layers.1.attention.W_q.weight
VAE.encoder_mu.layers.1.attention.W_q.bias
VAE.encoder_mu.layers.1.attention.W_k.weight
VAE.encoder_mu.layers.1.attention.W_k.bias
VAE.encoder_mu.layers.1.attention.W_v.weight
VAE.encoder_mu.layers.1.attention.W_v.bias
VAE.encoder_mu.layers.1.linear0.weight
VAE.encoder_mu.layers.1.linear0.bias
VAE.encoder_mu.layers.1.linear1.weight
VAE.encoder_mu.layers

In [26]:
tokenizer.load_state_dict(model_vae.'VAE.Tokenizer.state_dict())

AttributeError: 'collections.OrderedDict' object has no attribute 'VAE'