# Torch Model

In [1]:
# From transformers.models.bert.modeling_bert.BertIntermediate
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py or https://huggingface.co/transformers/v2.5.0/_modules/transformers/modeling_bert.html (find BertIntermidiate here, we have modified it a bit here)

import torch

class BertIntermediate(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = torch.nn.Linear(config.hidden_size, config.intermediate_size)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.nn.functional.gelu(hidden_states)
        return hidden_states


# Following TDD, the first step is to write a test for the model:
# What is TDD? Need to add acrynoms docs

In [2]:
import pytest
import torch
import transformers

import ttnn
import torch_functional_bert # implemented here: https://github.com/tenstorrent-metal/tt-metal/blob/main/models/experimental/functional_bert/reference/torch_functional_bert.py

from models.utility_functions import torch_random
from tests.ttnn.utils_for_testing import assert_with_pcc

@pytest.mark.parametrize("model_name", ["phiyodr/bert-large-finetuned-squad2"])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("sequence_size", [384])
def test_bert_intermediate(model_name, batch_size, sequence_size):
    torch.manual_seed(0)

    config = transformers.BertConfig.from_pretrained(model_name)
    model = transformers.models.bert.modeling_bert.BertIntermediate(config).eval()

    torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.float32)
    torch_output = model(torch_hidden_states) # Golden output


    # where is this function defined? 
    # must be the following:
    from ttnn.model_preprocessing import preprocess_model_parameters
    parameters = preprocess_model_parameters(
        initialize_model=lambda: model, # Function to initialize the model
        convert_to_ttnn=lambda *_: False, # Keep the weights as torch tensors
    )

    output = torch_functional_bert.bert_intermediate(
        torch_hidden_states,
        parameters=parameters,
    )

    assert_with_pcc(torch_output, output, 0.9999)

ModuleNotFoundError: No module named 'torch_functional_bert'