# 1. Custom Encoder and Model

First, import `CE_Module`.

In [None]:
import torch.nn as nn
from Custom_Encoder import CE_Module

Define an encoder method.  
The following is a sample implementation, but you can use any encoding approach you prefer.
> [!NOTE]  
> The function `get_encoder_dim()` must be implemented.
> It should take the input feature dimension and simply return the output embedding dimension.

In [None]:
class standard_code(nn.Module):

  def __init__(self , input_x ) -> None:
    super().__init__()
    # calculate mean
    self.mean = torch.mean(input_x , 0)
    # calculate sigma
    self.sigma = torch.std(input_x , dim=0)

  def forward(self , x: torch.Tensor) -> torch.Tensor:

    for i, sigma in enumerate(self.sigma):
      row = x.t()[i].clone()
      row_mean = row.clone()
      row_sigma = row.clone()
      row_mean[row_mean >self.mean[i]] = 1
      row_mean[row_mean <self.mean[i]] = 0
      row_sigma[row_sigma >self.sigma[i]+self.mean[i]] = 1
      row_sigma[row_sigma <self.sigma[i]-self.mean[i]] = 1
      row_sigma[row_sigma <self.sigma[i]+self.mean[i]] = 0
      row_sigma[row_sigma >self.sigma[i]-self.mean[i]] = 0
      encode = torch.stack((row_mean ,row_sigma) ,1)
      if ( i == 0 ):
        output = encode
      else:
        output = torch.cat( ( output , encode ), 1 )

      return output

  def get_encoder_dim(self , input_dim):
    output_dim = input_dim*2
    return output_dim


Next, define your own model.  
Just inherit from the `CE_Module` class, and call `self.forward_encoder()` whenever you need to forward the encoder.  
`self.forward_encoder()` refers to the encoder in your architecture that you can replace freely. 
The following is a sample implementation, but you can replace it with your own model too.

In [None]:
class SCARF1(CE_Module):
    def __init__(
        self ,
        input_dim ,
        emb_dim ,
        features_low ,
        features_high ,
        num_hidden=4,
        head_depth=2,
        corruption_rate =0.6,
        dropout=0.0,
    ):

        super().__init__()

        self.encoder = MLP(input_dim , emb_dim , num_hidden , dropout)
        self.pretraining_head = MLP(emb_dim , emb_dim , head_depth)

        # uniform distribution over marginal distributions of dataset 's features
        self.marginals = Uniform(torch.Tensor(features_low), torch.Tensor(features_high))
        self.corruption_len = int(corruption_rate *input_dim)

    def forward(self , x):
        batch_size , m = x.size()

        corruption_mask = torch.zeros_like(x, dtype=torch.bool , device=x.device)
        for i in range(batch_size):
            corruption_idx = torch.randperm(m)[: self.corruption_len]
            corruption_mask[i, corruption_idx] = True

        x_random = self.marginals.sample(torch.Size((batch_size ,))).to(x.device)
        x_corrupted = torch.where(corruption_mask , x_random, x)

        # Custom Encoder
        x = self.forward_encoder(x)
        x_corrupted = self.forward_encoder(x_corrupted)

        embeddings = self.encoder(x)
        embeddings = self.pretraining_head(embeddings)

        embeddings_corrupted = self.encoder(x_corrupted)
        embeddings_corrupted = self.pretraining_head(embeddings_corrupted)

        return embeddings , embeddings_corrupted

    def adjust_structure(self):
        self.encoder[0] = self.adjust_layer( self.encoder[0] )

    def get_embedding(self):
        return self.encoder

    @torch.inference_mode()
    def get_embeddings(self , x):
        return self.encoder(x)

> [!NOTE]  
> Here are some important points to note:  
> The function `adjust_structure()` must be implemented.  
> Since various encoder methods need to be connected without modifying the overall model, you have to implement this function.  
>  The `CE_Module` will automatically adjust the dimensions of the connecting layers for you.  
>  
> You need to provide the first `nn.Linear()` layer of the connecting layer.  
> In this example, the structure connected after `self.forward_encoder()` is `self.encoder()`,  
> so the first layer is `self.encoder[0]`.  
>  
> Calling `self.adjust_layer()` replaces the original first layer accordingly.


Finally, when you run experiments, create the model just like you normally do.  
However, after initialization, make sure to call `set_encoder()`.  
Pass in the encoder instance, the encoder dimension, and whether reshaping is needed.  
In my example, `standard_code(features)` requires `features` for initialization; otherwise, you can call `standard_code()` directly without any issue.  
This completes the model setup.

In [None]:
model = scarf_model.SCARF1(
 input_dim=args.feature_dim ,
 emb_dim=args.emb_dim ,
 features_low=dataset.get_feature_marginal_low(),
 features_high=dataset.get_feature_marginal_high(),
 )

model.set_encoder(callback=standard_code(features), encoder_dim =2*args.feature_dim , reshape=True )
model.to(device)
