In [None]:
from transformers import BertTokenizer, BertModel
import torch.nn as nn
import torch
import hephaestus as hp



class HybridBertModel(nn.Module):
    def __init__(
        self,
        n_token: int,
        d_model: int,
        device: torch.device,
        bert_model_name="bert-base-uncased",
    ):
        super().__init__()

        # BERT Tokenizer and Model
        self.tokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.bert = BertModel.from_pretrained(bert_model_name)

        self.embedding_dim = self.bert.config.hidden_size

        # Custom Numeric Embedding
        self.numeric_embedding = nn.Sequential(
            nn.Linear(1, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, self.embedding_dim),
        ).to(device)

        # Numeric Neural Net for numbers prediction after BERT
        self.numeric_predictor = nn.Sequential(
            nn.Linear(self.embedding_dim, 128), nn.ReLU(), nn.Linear(128, 1)
        )

    def forward(self, input: [hp.StringNumeric]):
        input_ids = []
        attention_mask = []

        # Prepare BERT inputs
        for val in input:
            if val.is_numeric:
                tokenized_input = self.tokenizer.encode_plus(
                    "[MASK]",
                    add_special_tokens=True,
                    return_tensors="pt",
                    padding="max_length",
                    max_length=3,
                )
            else:
                tokenized_input = self.tokenizer.encode_plus(
                    val.value,
                    add_special_tokens=True,
                    return_tensors="pt",
                    padding="max_length",
                    max_length=3,
                )

            input_ids.append(tokenized_input["input_ids"])
            attention_mask.append(tokenized_input["attention_mask"])

        input_ids = torch.cat(input_ids, dim=0).to(device)
        attention_mask = torch.cat(attention_mask, dim=0).to(device)

        # Get BERT Embeddings
        bert_output = self.bert(input_ids, attention_mask=attention_mask)[0]

        # Iterate through BERT outputs and if the input was a number, replace the embedding with custom numeric embedding
        for idx, val in enumerate(input):
            if val.is_numeric:
                val_tensor = torch.Tensor([val.value]).float().to(device)
                bert_output[idx, 1, :] = self.numeric_embedding(val_tensor)

        # Masked Language Modeling (taking [MASK] token output, which is at position 1 due to the padding)
        mlm_output = bert_output[:, 1, :]

        # Numeric prediction
        numeric_prediction = self.numeric_predictor(mlm_output)

        return mlm_output, numeric_prediction


# Sample usage:

# Assuming we have our input prepared as:
input_data = [
    hp.StringNumeric("Hello"),
    hp.StringNumeric(42.0),
    hp.StringNumeric("world"),
]

model = HybridBertModel(30522, 768, "cuda").to("cuda")
mlm_output, numeric_predictions = model(input_data)
print(mlm_output.shape, numeric_predictions.shape)