In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from otc.models.activation import ReGLU
from otc.models.fttransformer import (
    CategoricalFeatureTokenizer,
    CLSToken,
    FeatureTokenizer,
    FTTransformer,
    MultiheadAttention,
    NumericalFeatureTokenizer,
    Transformer,
)

from otc.data.dataset import TabDataset
from otc.data.dataloader import TabDataLoader
from otc.features.build_features import features_classical, features_classical_size

In [None]:
class CLSHead(nn.Module):
    """
    2 Layer MLP projection head
    """
    # d_in is last dim of transformer output torch.Size([4, 17, 96]) -> 96 (ok)
    # d_out -> last dim of output here [4,17,1] -> [4, 16] (ok)
    def __init__(self, *, d_in: int, d_hidden: int):
        super().__init__()
        self.first = nn.Linear(d_in, d_hidden)
        self.out = nn.Linear(d_hidden, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x[:, 1:]

        x = self.out(F.relu(self.first(x))).squeeze(2)
        return x

In [None]:
head = CLSHead(d_in=768, d_hidden=768)

print(head)

In [None]:
class ShufflePermutations(object):
    """
    Generate permutations by shuffeling.
    """
    def __init__(self, X_num, X_cat):
        self.X_num = X_num
        self.X_cat = X_cat

    def permute(self, X):
        """
        generate random index
        """
        if X is None:
            return None

        idx = torch.randint_like(X, X.shape[0], dtype=torch.long)

        print(idx)
        # generate random index array
        return idx

    def gen_permutations(self):
        # permute numerical and categorical by random index
        X_num = self.X_num
        X_cat = self.X_cat if self.X_cat is not None else None
        return self.permute(X_num), self.permute(X_cat)

In [None]:
d_num, d_cat = 8,8
batch_size = 4

X_num = torch.randn(batch_size, d_num)
X_cat = torch.randint(0, 10, (batch_size, d_cat))

perm_class = ShufflePermutations(X_num, X_cat)
x_num_perm, x_cat_perm = perm_class.gen_permutations()

In [None]:
x_cat_perm

In [None]:
corrupt_probability = 0.15

def gen_masks(X, perm):
    # generate binary masks
    masks = torch.empty_like(X).bernoulli(p=corrupt_probability).bool()
    new_masks = masks & (X != X[perm, torch.arange(X.shape[1], device=X.device)])
    return new_masks

# FIXME: probably generate for train and val set
x_num_mask = gen_masks(X_num, x_num_perm)
x_cat_mask = gen_masks(X_cat, x_cat_perm)

In [None]:
x_num_mask

In [None]:
X_num

In [None]:
# get values at permuted places
x_num_permuted = torch.gather(X_num, 0, x_num_perm)

# replace at mask = True
X_num[x_num_mask] = x_num_permuted[x_num_mask]

if X_cat is not None:

    # along the 0 axis get elements based on perm_cat
    x_cat_permuted = torch.gather(X_cat, 0, x_cat_perm)
    
    # replace at mask
    X_cat[x_cat_mask] = x_cat_permuted[x_cat_mask]

In [None]:
x_num_mask

In [None]:
X_num

In [None]:
d_token = 96

params_feature_tokenizer = {
            "num_continous": d_num,
            "cat_cardinalities": [11] * d_cat,
            "d_token": d_token,
        }

feature_tokenizer = FeatureTokenizer(**params_feature_tokenizer)

params_transformer = {
            "d_token": d_token,
            "n_blocks": 3,
            "attention_n_heads": 8,
            "attention_initialization": "kaiming",
            "ffn_activation": ReGLU,
            "attention_normalization": nn.LayerNorm,
            "ffn_normalization": nn.LayerNorm,
            "ffn_dropout": 0.1,
            "ffn_d_hidden": 96 * 2,
            "attention_dropout": 0.1,
            "residual_dropout": 0.1,
            "prenormalization": True,
            "first_prenormalization": False,
            "last_layer_query_idx": None,
            "n_tokens": None,
            "kv_compression_ratio": None,
            "kv_compression_sharing": None,
            "head_activation": nn.ReLU,
            "head_normalization": nn.LayerNorm,
            "d_out": 1,
        }

transformer = Transformer(**params_transformer)

In [None]:
d_hidden = 32

class PretrainModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.cls_token = CLSToken(d_token, "uniform")

        self.feature_tokenizer = feature_tokenizer
        self.transformer = transformer

        # disable BERT-like classification head and replace with idenity mapping
        self.transformer.head = nn.Identity()

        # enable RTD-like head with one class per feature
        self.head = CLSHead(
                d_in=d_token,
                d_hidden=d_hidden,
            )

    def forward(self, x_num, x_cat):

        # tokenize
        x = self.feature_tokenizer(x_num, x_cat)
        
        # add cls token to input
        x = self.cls_token(x)

        # add backbone
        h = self.transformer(x)

        # add classification head
        return self.head(h)

In [None]:
transformer

In [None]:
model = PretrainModel()
predictions = model(X_num, X_cat)
print(predictions.shape)

In [None]:
# move to device
device = "cuda"

model = PretrainModel().to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()


In [None]:
# cat masks if cat variables are present
if X_cat != None:
    masks = torch.cat([x_num_mask, x_cat_mask], dim=1)
else:
    masks = x_num_mask

# logits to binary mask
hard_predictions = torch.zeros_like(predictions, dtype=torch.long)
hard_predictions[predictions > 0] = 1

# calculate column-wise accuracy
features_accuracy = (hard_predictions.bool() == masks).sum(0) / hard_predictions.shape[0]

print(masks)
print(hard_predictions)
print(features_accuracy)