In [1]:
import os
os.chdir("../")

In [2]:
os.environ['HF_HOME'] = '.cache/'

In [3]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


## Model Definition

In [4]:
MODEL_PATH = "/media/ishrak/volume_1/Projects/mining-misconceptions-in-math/.cache/Mistral-7B-v0.1"

In [5]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)

In [6]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
    ),
    device_map="auto",
    trust_remote_code=True,
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.23s/it]


In [7]:
model = prepare_model_for_kbit_training(model)
model = get_peft_model(
    model,
    LoraConfig(
        r=8,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    ),
)

In [8]:
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    model.resize_token_embeddings(len(tokenizer))

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [9]:
from torch import Tensor
from src.model_development.latent_attention_module import LatentAttentionLayer

def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
    
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[
            torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths
        ]


def sentence_embedding(sentence_pooling_method, hidden_state, mask):
    if sentence_pooling_method == "mean":
        s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
        d = mask.sum(axis=1, keepdim=True).float()
        return s / d
    elif sentence_pooling_method == "cls":
        return hidden_state[:, 0]
    elif sentence_pooling_method == "last":
        return last_token_pool(hidden_state, mask)
    elif sentence_pooling_method == "attention":
        latent_attention_layer = LatentAttentionLayer(
            input_dim=hidden_state.size(-1),
            hidden_dim=512,
            num_latents=10,
            num_heads=8,
            mlp_dim=1024,
        )
        return latent_attention_layer(hidden_state)


def forward(input_ids, attention_mask, sentence_pooling_method="attention"):
    features = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
    )[0]
    features = sentence_embedding(sentence_pooling_method, features, attention_mask)
    return features

## Dataset

In [10]:
df = pd.read_csv("data/contrastive-datasethu66w3xp.csv")
df.head()

Unnamed: 0,QuestionId,SubjectId,ConstructId,QuestionDetails,MisconceptionList,Label,MisconceptionId
0,0,33,856,"Given subject, construct, question and incorre...",When using the formula to solve a quadratic eq...,4,1672
1,1,1077,1612,"Given subject, construct, question and incorre...",Believes each number on a clock face represent...,8,2142
2,1,1077,1612,"Given subject, construct, question and incorre...",Does not understand how to find percentages of...,3,143
3,1,1077,1612,"Given subject, construct, question and incorre...",Believes if shapes have rotational symmetry th...,9,2142
4,2,339,2774,"Given subject, construct, question and incorre...",Does not apply the inverse function to every t...,5,1287


### Test Input

In [11]:
from src.data_preparation.datasets.base_dataset import BaseDataset

In [12]:
dataset = BaseDataset(df, tokenizer)

In [13]:
test_dict = dataset[0]

question_ids = test_dict['question_ids'].view(1, -1)
question_mask = test_dict['question_mask'].view(1, -1)
misconception_ids = test_dict['misconception_ids'].view(-1, 10, 64)
misconception_mask = test_dict['misconception_mask'].view(-1, 10, 64)

print("Keys:", test_dict.keys())
print("Question IDs shape:", question_ids.shape)
print("Question Mask shape:", question_mask.shape)
print("Misconception IDs shape:", misconception_ids.shape)
print("Misconception Mask shape:", misconception_mask.shape)


Keys: dict_keys(['question_ids', 'question_mask', 'misconception_ids', 'misconception_mask', 'label'])
Question IDs shape: torch.Size([1, 512])
Question Mask shape: torch.Size([1, 512])
Misconception IDs shape: torch.Size([1, 10, 64])
Misconception Mask shape: torch.Size([1, 10, 64])


### Model Output

In [14]:
features = model(
    input_ids=question_ids,
    attention_mask=question_mask,
    return_dict=True,
)

### Disecting Model Output

In [15]:
features.keys()


odict_keys(['logits', 'past_key_values'])

In [16]:
features[0].shape

torch.Size([1, 512, 32001])

In [17]:
features['logits'].shape

torch.Size([1, 512, 32001])

`feature[0]` is the same as `feature['logits']`

In [18]:
len(features['past_key_values'][0][0])


1

### Forward Method

In [19]:
q_features = forward(question_ids, question_mask, sentence_pooling_method="attention")
q_features.shape

torch.Size([1, 1024])

In [20]:
m_features = forward(misconception_ids.view(-1, 64), misconception_mask.view(-1, 64), sentence_pooling_method="attention")
m_features.shape

torch.Size([10, 1024])

### Similarity Calculation

In [21]:
def compute_similarity(q_reps, p_reps):
    if len(p_reps.size()) == 2:
        return torch.matmul(q_reps, p_reps.transpose(0, 1))
    return torch.matmul(q_reps, p_reps.transpose(-2, -1))

In [22]:
similarity = compute_similarity(q_features, m_features)
similarity.shape

torch.Size([1, 10])

### Latent Attention from Scratch

Check dimensions and correctness of implementation

Test input

In [7]:
BATCH_SIZE = 10
SEQ_LENGTH = 64
INPUT_DIM = 2048
last_hidden_states = torch.randn(BATCH_SIZE, SEQ_LENGTH, INPUT_DIM)
last_hidden_states.shape

torch.Size([10, 64, 2048])

In [8]:
batch_size = last_hidden_states.size(1)
batch_size

64

Latent parameters

In [9]:
NUM_LATENTS = 10
HIDDEN_DIM = 512
latent_array = torch.randn(NUM_LATENTS, HIDDEN_DIM)
latent_array.shape

torch.Size([10, 512])

Expand latent array

In [10]:
latent_array = latent_array.unsqueeze(1).expand(-1, batch_size, -1)
latent_array.shape


torch.Size([10, 64, 512])

Project input to hidden dimension

In [11]:
linear = torch.nn.Linear(INPUT_DIM, HIDDEN_DIM)
Q = linear(last_hidden_states)
Q.shape


torch.Size([10, 64, 512])

In [12]:
attention = torch.nn.MultiheadAttention(
    embed_dim=HIDDEN_DIM,
    num_heads=8,
)
O, _ = attention(query=Q, key=latent_array, value=latent_array)
O.shape

torch.Size([10, 64, 512])

In [17]:
MLP_DIM = 1024

mlp = torch.nn.Sequential(
    torch.nn.Linear(HIDDEN_DIM, MLP_DIM),
    torch.nn.SiLU(),
)
O = mlp(O)
O.shape

torch.Size([10, 64, 1024])

In [18]:
O_pooled = O.mean(dim=1)
O_pooled.shape

torch.Size([10, 1024])