#### TabNet - Feature Selection
1. Masking처리 방법을 사용해서 매번 다른 input마다 어떤 feature을 선택할지 결정하는 것을 학습한다. 기존에는 ML algorithm을 사용하면 데이터 전처리, feature selection, modeling등의 과정을 모두 따로 진행 했지만 이 딥러닝 모델 하나만으로 end-to-end로 모든 것이 한번에 가능하게 된다.
2. 이건 데이터 처리 단계에서 할텐데, feature 1을 선택한다면 feature 1에 대해서만 1로 masking 처리를 하고 나머지 feature은 0으로 둔다.

#### TabNet - Feature Transformer
1. 앞선 feature selection에서 masking처리 되어 전달되는 feature에 embedding을 해준다.
2. Feature Transformer = Fully Connected Layer + Ghost Batch Norm + GLU(Gated Linear Unit)


In [None]:
## Ghost Batch Normalization
"""
Batch를 다시 더 작은 batch로 분할해서 잡음을 momentum을 사용해 추가해주는 방법으로 generalization 성능을 향상시켜 준다.
이때 내부적으로 nano batch를 사용하기 때문에 학습 시에는 large batch를 사용할 수 있어서 학습 속도가 빨라진다.
"""
import torch.nn as nn

class GBN(nn.Module):
    def __init__(self, input_dim, batch_size, momentum):
        super(GBN, self).__init__()
        self.input_dim = input_dim
        self.batch_size = batch_size
        self.momentum = momentum
        self.bn = nn.BatchNorm1d(self.input_dim, momentum)
    
    def forward(self, x):
        chunks = x.chunk(int(np.ceil(x.shape[0] / self.batch_size)), 0)
        results = [self.bn(c) for c in chuncks]
        return torch.cat(results, dim = 0)
    

In [None]:
## TabNet Encoder
from dataclasses import dataclass

@dataclass
class Encoder(nn.Module):
    input_dim: int
    output_dim: list[int]
    n_d: int # prediction layer의 차원 수
    n_a: int # attention layer의 차원 수
    n_steps: int
    gamma: float # 1.0과 2.0사이의 값으로, attention update를 위한 scaling factor이다.
    n_glu_layer: int
    epsilon: float
    batch_size: int
    momentum: float
    
    if (self.n_glu_layer > 0):
        shared_feat_transform = nn.ModuleList()
        for i in range(self.n_glu_layer):
            if i == 0:
                shared_feat_transform.append(nn.Linear(self.input_dim, 2 * (self.n_d + self.n_a), bias = False))
            else:
                shared_feat_transform.append(nn.Linear(self.n_d + self.n_a , 2 * (self.n_d + self.n_a), bias = False))
    else:
        shared_feat_transform = None
    
    self.feature_transformers = nn.ModuleList()
    self.attention_transformers = nn.ModuleList()
    
        
    

#### TabNet - Attentive Transformer
1. feature들을 선택하는 기능을 하는 것으로, Prior Scale + SparseMax로 구성이 된다. 

#### TabNet - Decoder
1. Encoder-Decoder 형태의 Auto Encoder처럼 unsupervised learning이 가능하다.
2. 특정한 영역이 masking된 encoded data를 원본 처럼 복원할 수 있는 과정을 학습한다. 이때 결국에 자연어 처리에서 BERT같은 모델을 학습시킬때 사용하는 Masked Language Modeling과 유사하다고 할 수 있다.
3. 이렇게 사전학습을 하여 예측의 정확도및 속도를 증가 시킬 수 있다.

In [None]:
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn


@dataclass
class SiDConfig:
    num_categories_list: list[int]
    num_numerical_features: int
    num_text_features: int
    text_input_size: int
    embedding_size: int = 64
    hidden_size: int = 256
    intermediate_size: int = 1024
    num_hidden_layers: int = 18
    num_transform_blocks: int = 1
    num_attention_blocks: int = 1
    hidden_dropout_prob: float = 0.5
    attention_dropout_prob: float = 0.5
    drop_path_prob: float = 0.5
    embed_init_std: float = 0.02
    num_labels: Optional[int] = None

    @property
    def num_total_categories(self) -> int:
        return sum(self.num_categories_list)

    @property
    def num_categorical_features(self) -> int:
        return len(self.num_categories_list)

    @property
    def num_total_features(self) -> int:
        return (
            self.num_categorical_features
            + self.num_numerical_features
            + self.num_text_features
        )

    @property
    def total_embedding_size(self) -> int:
        return self.num_total_features * self.embedding_size


class StochasticDepth(nn.Module):
    def __init__(self, drop_path_prob: float = 0.1):
        super().__init__()
        self.drop_path_prob = drop_path_prob

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if self.drop_path_prob == 0 or not self.training:
            return hidden_states

        mask = torch.rand((hidden_states.size(0), 1), device=hidden_states.device)
        mask = (mask > self.drop_path_prob).type_as(hidden_states) / self.drop_path_prob
        return mask * hidden_states


class SiDEmbeddings(nn.Module):
    def __init__(self, config: SiDConfig):
        super().__init__()
        self.categorical_embeddings = nn.Embedding(
            config.num_total_categories, config.embedding_size
        )
        self.numerical_direction = nn.Parameter(
            torch.rand(config.num_numerical_features, config.embedding_size)
        )
        self.numerical_anchor = nn.Parameter(
            torch.rand(config.num_numerical_features, config.embedding_size)
        )

        # Although we define the multiple dense layers to project each text embedding to
        # the input embedding space, we will use batched (stacked) matmul by gathering
        # the weight matrices.
        self.text_projections = nn.ModuleList(
            nn.Linear(config.text_input_size, config.embedding_size, bias=False)
            for _ in range(config.num_text_features)
        )

        # Create embedding offsets which indicate the start embedding index of each
        # categorical feature. Because this class uses only one embedding layer which
        # contains the embedding vectors for all categorical features, it is necessary
        # to separate each embedding group from other features.
        self.register_buffer(
            "categorical_embedding_offsets",
            torch.tensor([[0] + config.num_categories_list[:-1]]).cumsum(1),
        )

    def forward(
        self,
        categorical_inputs: torch.Tensor,
        numerical_inputs: torch.Tensor,
        text_inputs: torch.Tensor,
    ) -> torch.Tensor:
        # Add embedding offsets to the categorical features to map to the corresponding
        # embedding groups.
        categorical_inputs = categorical_inputs + self.categorical_embedding_offsets
        categorical_embeddings = self.categorical_embeddings(categorical_inputs)

        numerical_embeddings = numerical_inputs[:, :, None] * self.numerical_direction
        numerical_embeddings = numerical_embeddings + self.numerical_anchor

        stacked_weight = torch.stack(
            [layer.weight.transpose(0, 1) for layer in self.text_projections],
        )
        text_embeddings = torch.einsum("btm,tmn->btn", text_inputs, stacked_weight)

        # After creating embedding vectors for categorical, numerical and text features,
        # they will be concatenated to a single tensor.
        return torch.cat(
            (categorical_embeddings, numerical_embeddings, text_embeddings), dim=1
        )


class SiDResidualBlock(nn.Module):
    def __init__(self, config: SiDConfig):
        super().__init__()
        self.feedforward = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size),
            nn.GELU(),
            nn.Linear(config.intermediate_size, config.hidden_size * 2),
        )
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        output = self.feedforward(hidden_states)

        output, gating = output.chunk(2, dim=1)
        output = output * gating.sigmoid()

        return hidden_states + self.dropout(output)


class SiDLayer(nn.Module):
    def __init__(self, config: SiDConfig, use_attention: bool = True):
        super().__init__()
        if use_attention:
            self.attention = nn.Sequential(
                *[SiDResidualBlock(config) for _ in range(config.num_attention_blocks)],
                nn.Linear(config.hidden_size, config.num_total_features),
                nn.Sigmoid(),
            )
            self.dropout = nn.Dropout(config.attention_dropout_prob)

        self.projection = nn.Linear(config.total_embedding_size, config.hidden_size)
        self.transform = nn.Sequential(
            *[SiDResidualBlock(config) for _ in range(config.num_transform_blocks)]
        )
        self.droppath = StochasticDepth(config.drop_path_prob)

    def forward(
        self,
        input_embeddings: torch.Tensor,
        hidden_states: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Calculate the attention probabilities and multiply to the embeddings.
        if hasattr(self, "attention") and hidden_states is not None:
            attention_probs = self.attention(hidden_states)
            attention_probs = self.dropout(attention_probs)
            input_embeddings = input_embeddings * attention_probs[:, :, None]

        output = self.projection(input_embeddings.flatten(1))
        output = self.transform(output)

        # If `hidden_states` is not None then use residual connection.
        if hidden_states is not None:
            return hidden_states + self.droppath(output)
        return output


class SiDModel(nn.Module):
    def __init__(self, config: SiDConfig):
        super().__init__()
        self.config = config
        self.embeddings = SiDEmbeddings(config)
        self.layers = nn.ModuleList(
            SiDLayer(config, use_attention=i > 0)
            for i in range(config.num_hidden_layers)
        )
        self.normalization = nn.LayerNorm(config.hidden_size)
        self.init_weights()

    def init_weights(self, module: Optional[nn.Module] = None):
        if module is None:
            self.apply(self.init_weights)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, std=self.config.embed_init_std)
        elif isinstance(module, SiDEmbeddings):
            nn.init.normal_(module.numerical_direction, std=self.config.embed_init_std)
            nn.init.normal_(module.numerical_anchor, std=self.config.embed_init_std)
        elif isinstance(module, nn.Linear):
            nn.init.kaiming_uniform_(module.weight, 5 ** 0.5)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(
        self,
        categorical_inputs: torch.Tensor,
        numerical_inputs: torch.Tensor,
        text_inputs: torch.Tensor,
    ) -> torch.Tensor:
        input_embeddings = self.embeddings(
            categorical_inputs,
            numerical_inputs,
            text_inputs,
        )

        hidden_states = None
        for layer in self.layers:
            hidden_states = layer(input_embeddings, hidden_states)

        hidden_states = self.normalization(hidden_states)
        return hidden_states


class SiDClassifier(nn.Module):
    def __init__(self, config: SiDConfig):
        super().__init__()
        self.model = SiDModel(config)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(
        self,
        categorical_inputs: torch.Tensor,
        numerical_inputs: torch.Tensor,
        text_inputs: torch.Tensor,
    ) -> torch.Tensor:
        hidden_states = self.model(categorical_inputs, numerical_inputs, text_inputs)
        logits = self.classifier(hidden_states)
        return logits