In [None]:
from typing import Type

In [None]:
from transformers import AutoTokenizer, Automodel, AutoConfig # To tokenize the dataset for LLM consumption
import torch.nn as nn
from torch import Tensor

In [None]:
CHECKPOINT = "bert-base-uncased"

In [None]:
class CorrectnessModuleLLM(nn.Module):
    def __init__(self: Type["CorrectnessModuleLLM"],
                 checkpoint: str) -> None:
        super(CorrectnessModuleLLM, self).__init()
        self.embedding_body = Automodel.from_pretrained(checkpoint, 
                                                        config=AutoConfig.from_pretrained(checkpoint,
                                                                                          output_attention=True,
                                                                                          output_hidden_states=True))
        self.logit_transform = nn.Linear(in_features = 768, # This should be somehow dynamic. Can be with the help of above model's config variable
                                         out_features = 1,
                                         bias=True)
        self.output_transform = nn.Sigmoid()
        return
    def forward(self: Type["CorrectnessModuleLLM"],
                input_ids: Tensor,
                attention_mask) -> Tensor:
        llm_embeddings = self.embedding_body(input_ids=input_ids,
                                             attention_mask=attention_mask)
        cls_token_output = llm_embeddings.hidden_states[0]
        logits = self.logit_transform(cls_token_output)
        output_prob = self.output_transform(logits)
        return output_prob

In [None]:
## Tokenizer code
checkpoint_model_tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT, return_tensors='pt')

def tokenize_and_align_dataset(sample):
    question_string = sample['Question']
    answer_string = sample["Answer"]
    tokenized_question = checkpoint_model_tokenizer(question_string, is_split_into_words=True)
    tokenized_answer = checkpoint_model_tokenizer(answer_string, is_split_into_words=True)
    