In [None]:
import torch as th
from torch import nn
from torch.nn import functional as F

In [None]:
x = th.randn(10, 128)
y = th.randint(0, 4, (10,))

In [None]:
class DataAndLabelEncoder(nn.Module):
    def __init__(self, x_max_dim: int, nb_class_max: int, y_emb_dim :int, hidden_dim: int, output_dim: int) -> None:
        super().__init__()
        
        self.__y_emb = nn.Embedding(nb_class_max, y_emb_dim)
        
        self.__encoder = nn.Sequential(
            nn.Linear(x_max_dim + y_emb_dim, hidden_dim),
            nn.Mish(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, output_dim),
            nn.Mish(),
            nn.BatchNorm1d(output_dim),
        )
    
    def forward(self, x: th.Tensor, y: th.Tensor) -> th.Tensor:
        y_emb = self.__y_emb(y)
        
        out = th.cat([x, y_emb], dim=1)
        out = self.__encoder(out)
        
        return out

In [None]:
class DataEncoder(nn.Sequential):
    def __init__(self, x_max_dim: int, hidden_dim: int, output_dim: int) -> None:
        super().__init__(
            nn.Linear(x_max_dim, hidden_dim),
            nn.Mish(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, output_dim),
            nn.Mish(),
            nn.BatchNorm1d(output_dim)
        )

In [None]:
enc= DataAndLabelEncoder(128, 5, 128, 128, 128)

In [None]:
x_enc = enc(x, y)

In [None]:
x_enc.size()

In [None]:
enc2 = DataEncoder(128, 128, 128)

In [None]:
x_enc_2 = enc2(x)

In [None]:
x_enc_2.size()

In [None]:
trf_enc = 

In [None]:
def get_mask(x_train: th.Tensor, x_test: th.Tensor) -> th.Tensor:
    mask = th.eye(x_train.size(0) + x_test.size(0))
    
    mask[:, :x_train.size(0)] = 1
    
    return mask

In [None]:
src_mask = get_mask(x_enc, x_enc_2)

In [None]:
enc_input = th.cat([x_enc, x_enc_2], dim=0)
out = trf_enc(enc_input, mask=src_mask)

In [None]:
out.size()

In [None]:
class PFN(nn.Module):
    def __init__(self, model_dim: int, hidden_dim: int, nb_class: int) -> None:
        super().__init__()
        
        self.__trf = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(model_dim, 4, hidden_dim, activation=F.gelu, batch_first=True),
            6
        )
        
        self.__to_class = nn.Linear(model_dim, nb_class)
        
    @staticmethod
    def get_mask(x_train: th.Tensor, x_test: th.Tensor) -> th.Tensor:
        mask = th.eye(x_train.size(0) + x_test.size(0))
        
        mask[:, :x_train.size(0)] = 1
        
        return mask
    
    def forward(self, x_train: th.Tensor, x_test: th.Tensor) -> th.Tensor:
        src_mask = self.get_mask(x_train, x_test)
        
        enc_input = th.cat([x_train, x_test], dim=0)
        
        out = self.__trf(enc_input, mask=src_mask)[x_train.size(0):, :]
        out = self.__to_class(out)
        
        return out

In [None]:
pfn = PFN(128, 256, 10)

In [None]:
out = pfn(x_enc, x_enc_2)

In [None]:
out.size()