<a href="https://colab.research.google.com/github/DLesmes/bert_embeddings_generator/blob/main/bert_embedding_generator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Requirements

In [62]:
!pip install datasets
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
# data
from datasets import load_dataset
import pandas as pd
import numpy as np



# Embedding Model

In [None]:
# Choose a suitable pre-trained BERT model
model_name = 'bert-base-uncased'

# Load the model and tokenizer
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

# Data

In [13]:
dataset = load_dataset("tasksource/mmlu", "international_law")
dataset

Downloading data:   0%|          | 0.00/29.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.63k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.47k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/121 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

DatasetDict({
    test: Dataset({
        features: ['question', 'choices', 'answer'],
        num_rows: 121
    })
    validation: Dataset({
        features: ['question', 'choices', 'answer'],
        num_rows: 13
    })
    dev: Dataset({
        features: ['question', 'choices', 'answer'],
        num_rows: 5
    })
})

In [32]:
data_train = []
data_test = []
data_val = []

for question in dataset['test']:
  row = { "text" : question['question']
          + " " + question['choices'][0]
          + " " + question['choices'][1]
          + " " + question['choices'][2]
          + " " + question['choices'][3],
          "result": question['answer']
  }
  data_train.append(row)

for question in dataset['validation']:
  row = { "text" : question['question']
          + " " + question['choices'][0]
          + " " + question['choices'][1]
          + " " + question['choices'][2]
          + " " + question['choices'][3],
          "result": question['answer']
  }
  data_test.append(row)



for question in dataset['dev']:
  row = { "text" : question['question']
          + " " + question['choices'][0]
          + " " + question['choices'][1]
          + " " + question['choices'][2]
          + " " + question['choices'][3],
          "result": question['answer']
  }
  data_val.append(row)

In [36]:
df_train = pd.DataFrame(data_train)
df_test = pd.DataFrame(data_test)
df_val = pd.DataFrame(data_val)
dfs = [df_train, df_test, df_val]
for df in dfs:
  print(df.shape)

(121, 2)
(13, 2)
(5, 2)


In [50]:
def embed(text: str):
  # Tokenize and encode the text
  inputs = tokenizer(text, return_tensors='pt')

  # Pass the input through the model (no fine-tuning needed)
  with torch.no_grad():
      outputs = model(**inputs)

  # Extract the embeddings
  return outputs.last_hidden_state[:, 0, :][0]  # [CLS] token embedding


for df in dfs:
  df['X'] = df['text'].apply(lambda x: embed(x))
  print(df.shape)



(121, 3)
(13, 3)
(5, 3)


In [51]:
df_train.head()

Unnamed: 0,text,result,X
0,Which State ordinarily exercises jurisdiction ...,1,"[tensor(-0.2486), tensor(-0.2092), tensor(-0.0..."
1,What is the meaning of justiciability? Justici...,0,"[tensor(-0.5092), tensor(-0.3899), tensor(-0.7..."
2,In what way is Responsibility to Protect (R2P)...,2,"[tensor(-0.4645), tensor(-0.6626), tensor(-0.3..."
3,What is the 'Lotus principle'? The so-called L...,0,"[tensor(-0.2778), tensor(-0.5422), tensor(-0.6..."
4,Which of these statements best describes the U...,2,"[tensor(0.0603), tensor(0.2212), tensor(-0.117..."


In [52]:
df_test.head()

Unnamed: 0,text,result,X
0,What kind of passage does qualify as 'innocent...,1,"[tensor(-0.4748), tensor(-0.3957), tensor(-0.1..."
1,What kind of State practice is required? Wides...,0,"[tensor(-0.1644), tensor(0.1313), tensor(-0.40..."
2,Which treaties are considered as 'source of in...,1,"[tensor(-0.2643), tensor(-0.0352), tensor(-0.5..."
3,What is the 'optional; clause' in the ICJ Stat...,2,"[tensor(-0.2463), tensor(-0.2587), tensor(-0.4..."
4,Can armed violence perpetrated by non-State ac...,1,"[tensor(-0.3789), tensor(-0.3413), tensor(-1.0..."


In [54]:
len(df_test.X[0])

768

In [59]:
X_train = np.array(df_train.X)
X_test = np.array(df_test.X)
y_train = np.array(df_train.result)
y_test = np.array(df_test.result)

In [57]:
X_train.shape

(121,)

In [58]:
X_test.shape

(13,)

In [61]:
y_train.shape

(121,)

In [60]:
y_test.shape

(13,)

# NN

In [67]:
class il_mmlu_data(Dataset):
  def __init__(self, X_train, y_train) -> None:
        super().__init__()
        self.X = X_train
        self.y = y_train
        self.y = self.y.shape[0]
        self.len = self.X.shape[0]

  def __getitem__(self, index):
      return self.X[index], self.y[index]

  def __len__(self):
      return self.len


# %% dataloader
il_mmlu_dataset = il_mmlu_data(X_train=X_train, y_train=y_train)
train_loader = DataLoader(dataset=il_mmlu_dataset, batch_size=32)