In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from otc.models.activation import ReGLU
from otc.models.fttransformer import (
    CategoricalFeatureTokenizer,
    CLSToken,
    FeatureTokenizer,
    FTTransformer,
    MultiheadAttention,
    NumericalFeatureTokenizer,
    Transformer,
)

In [None]:
class CLSHead(nn.Module):
    """
    2 Layer MLP projection head
    """
    def __init__(self, *, d_in: int, d_hidden: int):
        super().__init__()
        self.first = nn.Linear(d_in, d_hidden)
        self.out = nn.Linear(d_hidden, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x[:, 1:]

        x = self.out(F.relu(self.first(x))).squeeze(2)
        return x

In [None]:
head = CLSHead(d_in=768, d_hidden=768)

print(head)

In [None]:
class ShufflePermutations(object):
    """
    Generate permutations by shuffeling.
    """
    def __init__(self, X_num, X_cat):
        self.X_num = X_num
        self.X_cat = X_cat

    def permute(self, X):
        """
        generate random index
        """
        if X is None:
            return None

        idx = torch.randint_like(X, X.shape[0], dtype=torch.long)

        print(idx)
        # generate random index array
        return idx

    def gen_permutations(self):
        # permute numerical and categorical by random index
        X_num = self.X_num
        X_cat = self.X_cat if self.X_cat is not None else None
        return self.permute(X_num), self.permute(X_cat)

In [None]:
d_num, d_cat = 8,8
batch_size = 4

X_num = torch.randn(batch_size, d_num)
X_cat = torch.randint(0, 10, (batch_size, d_cat))

perm_class = ShufflePermutations(X_num, X_cat)
x_num_perm, x_cat_perm = perm_class.gen_permutations()

In [None]:
x_cat_perm

In [None]:
corrupt_probability = 0.15

def gen_masks(X, perm):
    # generate masks
    masks = torch.empty_like(X).bernoulli(p=corrupt_probability).bool()
    new_masks = masks & (X != X[perm, torch.arange(X.shape[1], device=X.device)])
    return new_masks

# FIXME: probably generate for train and val set
x_num_mask = gen_masks(X_num, x_num_perm)
x_cat_mask = gen_masks(X_cat, x_cat_perm)

In [None]:
x_num_mask

In [None]:
x_cat_mask

In [None]:
params_feature_tokenizer = {
            "num_continous": 3,
            "cat_cardinalities": None,
            "d_token": 96,
        }

feature_tokenizer = FeatureTokenizer(**params_feature_tokenizer)

params_transformer = {
            "d_token": 96,
            "n_blocks": 3,
            "attention_n_heads": 8,
            "attention_initialization": "kaiming",
            "ffn_activation": ReGLU,
            "attention_normalization": nn.LayerNorm,
            "ffn_normalization": nn.LayerNorm,
            "ffn_dropout": 0.1,
            "ffn_d_hidden": 96 * 2,
            "attention_dropout": 0.1,
            "residual_dropout": 0.1,
            "prenormalization": True,
            "first_prenormalization": False,
            "last_layer_query_idx": None,
            "n_tokens": None,
            "kv_compression_ratio": None,
            "kv_compression_sharing": None,
            "head_activation": nn.ReLU,
            "head_normalization": nn.LayerNorm,
            "d_out": 1,
        }

transformer = Transformer(**params_transformer)

In [None]:
d_token = 16
d_hidden = 32

class PretrainModel(nn.Module):
    def __init__(self):
        super().__init__()


        # # Input modules
        # d_cat_embedding = C.model.d_cat_embedding
        # d_num_embedding = C.model.d_num_embedding

        # self.category_sizes = D.get_category_sizes("train")
        # if self.category_sizes and (
        #     C.model.kind == "transformer"
        #     or C.model.d_cat_embedding == "d_num_embedding"
        # ):
        #     d_cat_embedding = C.model.d_num_embedding

        # if d_num_embedding:
        #     self.num_embeddings = lib.NumEmbeddings(
        #         C.model.num_embedding_arch,
        #         D.n_num_features,
        #         d_num_embedding,
        #         d_feature=bins_store.n_bins if bins_store else None,
        #         periodic_embedding_options=C.model.positional_encoding,
        #     )
        #     d_in_num = D.n_num_features * C.model.d_num_embedding
        # else:
        #     self.num_embeddings = None
        #     d_in_num = bins_store.n_bins if bins_store else D.n_num_features

        # if d_cat_embedding:
        #     self.cat_embeddings = rtdl.CategoricalFeatureTokenizer(
        #         self.category_sizes, d_cat_embedding, True, "uniform"
        #     )
        #     d_in_cat = d_cat_embedding * D.n_cat_features
        # else:
        #     self.cat_embeddings = None
        #     d_in_cat = sum(self.category_sizes)

        # d_in = d_in_num + d_in_cat
        # print(f"Model: Built embeddings flattened input dim: {d_in}")

        # # Backbones
        # self.cls_token = None
        # if C.model.kind == "transformer":
        #     # load configuration
        #     baseline_config = rtdl.FTTransformer.get_baseline_transformer_subconfig()
        #     C.model.config = baseline_config | C.model.config
        #     C.model.config["d_token"] = C.model.d_num_embedding
        #     # set backbone and cls token
        #     self.backbone = lib.Transformer(C.model.config)
        #     self.cls_token = rtdl.CLSToken(self.backbone.d, "uniform")


        self.feature_tokenizer = feature_tokenizer

        d_in = d_num + d_cat
        
        self.cls_token = CLSToken(d_token, "uniform")

        # change later
        # self.backbone = nn.Identity()
        self.backbone = transformer

        self.head = CLSHead(
                d_in=d_token,
                d_hidden=d_hidden,
            )



    def forward(self, x_num, x_cat):

        # # concat embeddings, if available
        # if self.num_embeddings:
        #     x_num = self.num_embeddings(x_num)
        # if self.cat_embeddings is not None:
        #     assert x_cat is not None
        #     x_cat = self.cat_embeddings(x_cat)

        # print(x_num.shape)
        # print(x_cat.shape)


        # x = torch.cat(
        #     [
        #         x_ for x_ in [x_num, x_cat]
        #         if x_ is not None
        #     ],
        #     dim=1,
        # )

        # tokenize
        x = self.feature_tokenizer(x_num, x_cat)
        # add cls token to input
        x = self.cls_token(x)

        # add backbone
        h = self.backbone(x)
        
        # add classification head
        return self.head(h)

In [None]:
model = PretrainModel()
model(X_num, X_cat)

In [None]:
# move to device
device = "cuda"

model = PretrainModel().to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

data

