# Packages and variables

In [1]:
# Load packages
from sentence_transformers import models, losses, datasets, SentencesDataset
from sentence_transformers import SentenceTransformer, util, InputExample
import pandas as pd
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import BertTokenizer
import torch

ModuleNotFoundError: No module named 'sentence_transformers'

In [None]:
# Specify variables
model_name = "bert-base-uncased"
train_batch_size = 20
max_seq_length = 250
num_epochs = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained(model_name)
model = SentenceTransformer(model_name).to(device)

# Load and prepare dataset

In [None]:
# Load dataset from huggingface
dataset = load_dataset("multi_nli")

# and make dataset as dataframe for easier usage
df = pd.DataFrame()
df["premise"] = dataset["train"]["premise"]
df["hypothesis"] = dataset["train"]["hypothesis"]
df["genre"] = dataset["train"]["genre"]
df["label"] = dataset["train"]["label"]

In [None]:
# Create training dataloader
train_examples = []

# Each different hierarchy needs a different label
for i in df.iterrows():
    if i[1]["genre"] == "telephone":
        v = 0
    elif i[1]["genre"] == "government":
        v = 3
    elif i[1]["genre"] == "travel":
        v = 6
    elif i[1]["genre"] == "fiction":
        v = 9
    elif i[1]["genre"] == "slate":
        v = 12
    lab = int(i[1]["label"]) + v
    
    train_examples.append(InputExample(texts=[i[1]["premise"], i[1]["hypothesis"]], label=lab))

train_dataset = SentencesDataset(train_examples, model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)

# Train and save triplet model

In [None]:
# Specify loss function
train_loss = losses.BatchAllTripletLoss(model=model)

# Train
model.fit(train_objectives=[(train_dataloader, train_loss)],
          epochs=num_epochs,
          show_progress_bar=True,
          optimizer_params={'lr': 1e-05}
          )

# Save model
model.save("model_triplet_mnli")