In [2]:
import pickle as pkl
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import random


In [3]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Data preparation function

In [5]:
from dataclasses import dataclass
from typing import Optional

@dataclass
class Question:
    question: str
    context_index: int
    embedding: Optional[torch.Tensor] = None
    transformed_embedding: Optional[torch.Tensor] = None

@dataclass
class Context:
    context: str
    context_index: int
    embedding: Optional[torch.Tensor] = None
    transformed_embedding: Optional[torch.Tensor] = None

@dataclass
class DataCollection:
    questions: list[Question]
    contexts: list[Context]
    metadata: dict


In [6]:
@dataclass
class EmbeddedQuestion(Question):
    embedding: torch.Tensor

@dataclass
class EmbeddedContext(Context):
    embedding: torch.Tensor

    def __hash__(self) -> int:
        return hash(self.context_index)

    def __eq__(self, value: object) -> bool:
        return self.context_index == value.context_index


In [7]:
from sklearn.model_selection import train_test_split


def _build_embedded_questions_from_context_ids(
    context_ids: list[int], questions: list[Question]
) -> list[EmbeddedQuestion]:
    res: list[EmbeddedQuestion] = []
    questions = list(filter(lambda x: x.context_index in context_ids, questions))
    for q in questions:
        if q.embedding is not None:
            res.append(
                EmbeddedQuestion(
                    q.question, q.context_index, q.embedding, q.transformed_embedding
                )
            )
    return res


def preprocess(
    data_collection: DataCollection,
    seed: int,
    context_test_ratio: float,
    context_val_ratio: float,
) -> tuple[
    list[EmbeddedQuestion],
    list[EmbeddedQuestion],
    list[EmbeddedQuestion],
    dict[int, EmbeddedContext],
]:
    """Splits the questions into train and test split.
    It does that while ensuring that no context snippet has questions in both splits.
    Therefore the split_ratio doesn't refer to the ratio of questions but rather of context snippets.

    Args:
        data_collection (DataCollection): The imported DataCollection
        seed (int): Random seed
        context_test_ratio (float): Ratio of how many context snippets should be in the test set.
        context_val_ratio (float): Ratio of how many context snippets of the train set should be the validation set.

    Returns:
        tuple[list[Question], list[Question], list[Question], dict[int, Context]]: train_list, val_list, test_list, dictionary to look up context by index
    """
    contexts = data_collection.contexts
    questions = data_collection.questions

    train_contexts, test_contexts = train_test_split(
        contexts, random_state=seed, test_size=context_test_ratio
    )
    train_contexts, val_contexts = train_test_split(
        train_contexts, random_state=seed, test_size=context_val_ratio
    )

    train_context_ids = [context.context_index for context in train_contexts]
    val_context_ids = [context.context_index for context in val_contexts]
    test_context_ids = [context.context_index for context in test_contexts]
    train_questions = _build_embedded_questions_from_context_ids(
        train_context_ids, questions
    )
    val_questions = _build_embedded_questions_from_context_ids(
        val_context_ids, questions
    )
    test_questions = _build_embedded_questions_from_context_ids(
        test_context_ids, questions
    )

    index_context_map: dict[int, EmbeddedContext] = {
        context.context_index: EmbeddedContext(context.context, context.context_index, context.embedding, context.transformed_embedding)
        for context in contexts
        if context.embedding is not None
    }

    return train_questions, val_questions, test_questions, index_context_map

In [8]:
class LinearTransform(nn.Module):
    def __init__(self, n_layers: int, n_dim: int, p_dropout: float, **kwargs) -> None:
        super().__init__(**kwargs)
        self.linears = nn.ModuleList([nn.Linear(n_dim, n_dim, bias=False) for _ in range(n_layers)])
        self.dropout = nn.Dropout(p_dropout)
        for linear in self.linears:
            torch.nn.init.eye_(linear.weight)

    def forward(self, x):
        x = self.dropout(x)

        for i, layer in enumerate(self.linears):
            x = layer(x)
        return x 

class ExpandTransform(nn.Module):
    def __init__(self,n_dim: int, p_dropout: float, **kwargs) -> None:
        super().__init__(**kwargs)
        self.linear1 = nn.Linear(n_dim, n_dim*4, bias=True)
        self.linear2 = nn.Linear(n_dim*4, n_dim, bias=True)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p_dropout)

    def forward(self, x):
        init = x
        x = self.dropout(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return init+x 

class DimensionalityReductionTransform(nn.Module):
    def __init__(self, n_dim:int, p_dropout:float, **kwargs):
        super().__init__(**kwargs)
        self.linear1 = nn.Linear(n_dim, n_dim//2, bias=False) 
        self.dropout = nn.Dropout(p_dropout)

    def forward(self, x):
        x = self.dropout(x)
        x = self.linear1(x)
        return x

In [9]:
def dataset_collate_fn(samples: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]):
    questions = []
    contexts = []
    labels = []
    for sample in samples:
        questions.append(torch.unsqueeze(sample[0], dim=0))
        contexts.append(torch.unsqueeze(sample[1], dim=0))
        labels.append(sample[2].unsqueeze(dim=0))
    return {
        "questions": torch.cat(questions, dim=0).to(device),
        "contexts": torch.cat(contexts, dim=0).to(device),
        "labels": torch.cat(labels, dim=0).to(device)
    }

In [25]:
from random import shuffle
def get_top_n_hard_examples(question: EmbeddedQuestion, contexts: list[EmbeddedContext], n:int, max_sim: float) -> list[EmbeddedContext]:
    res = []
    sim_indices = cosine_similarity(question.embedding.numpy().reshape(1, -1), np.array([c.embedding.numpy() for c in contexts])).argsort().flatten().tolist()
    sims = cosine_similarity(question.embedding.numpy().reshape(1, -1), np.array([c.embedding.numpy() for c in contexts])).flatten().tolist()
    wrongs = list(filter(lambda x: contexts[x].context_index != question.context_index and sims[x] <= max_sim, sim_indices))[::-1]
    shuffle(wrongs)
    for i in wrongs[-n:]:
        res.append(contexts[i])
    return res

class TripletDataset(Dataset):
    def __init__(self, train_questions: list[EmbeddedQuestion], index_context_map: dict[int, EmbeddedContext], n_examples: int = 1, max_sim: float = 0.4, use_filter=True) -> None:
        self.anchors = []
        self.positives = []
        self.negatives = []
        self.contexts = [index_context_map[q.context_index] for q in train_questions]
        for question in train_questions:
            if use_filter:
                hard_contexts = get_top_n_hard_examples(question, self.contexts, n_examples, max_sim)
            else:
                hard_contexts = [random.choice(self.contexts) for _ in range(n_examples)]
            for hard_context in hard_contexts:
                self.anchors.append(question.embedding)
                self.positives.append(index_context_map[question.context_index].embedding)
                self.negatives.append(hard_context.embedding)

        super().__init__()


    def __len__(self):
        return len(self.anchors)

    def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return self.anchors[index], self.positives[index], self.negatives[index]

In [26]:
class CosineTripletLoss(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.cosine = nn.CosineSimilarity()

    def forward(self, x1, x2):
        cosine_scores = self.cosine(x1, x2)
        return (1 - cosine_scores)**2


In [27]:
def get_most_relevant_contexts(model, questions: list[EmbeddedQuestion], n_results: int, use_transformed = True, index_context_map: dict[int, EmbeddedContext] = {}):
    res = []
    model.eval()
    contexts = list(set([index_context_map[q.context_index] for q in questions]))
    if use_transformed:
        with torch.no_grad():
            question_embs = torch.cat([torch.unsqueeze(q.embedding, dim=0) for q in questions], dim=0).to(device)
            context_embs = torch.cat([torch.unsqueeze(c.embedding, dim=0) for c in contexts], dim=0).to(device)
            transformed_questions = model(question_embs)
            transformed_contexts = model(context_embs)
    sims = cosine_similarity(transformed_questions.to('cpu').detach().numpy(), transformed_contexts.to('cpu').detach().numpy())
    for sim in sims:
        indices = sim.argsort()[-n_results:]
        temp = []
        for index in indices:
            temp.append(contexts[index])
        res.append(temp)
    return res

def get_correct_top_n(model, questions, index_context_map):
    n_result = 1
    eval = []
    most_relevant = get_most_relevant_contexts(model, questions, n_result, True, index_context_map)
    for question, contexts in zip(questions, most_relevant):
        eval.append(index_context_map[question.context_index] in contexts)
    return np.mean(eval)

def get_avg_correct_similarity(model, questions: list[EmbeddedQuestion], index_context_map):
    sims = []
    model.eval()
    for question in questions:
        with torch.no_grad():
            transformed = model(index_context_map[question.context_index].embedding.to(device))
        index_context_map[question.context_index].transformed_embedding = transformed.to('cpu').detach().numpy()
    with torch.no_grad():
        question_emb = torch.cat([torch.unsqueeze(q.embedding, dim=0) for q in questions], dim=0).to(device)
        question_transformed = model(question_emb)
        for question, emb in zip(questions, question_transformed):
            sims.append(cosine_similarity(emb.to('cpu').detach().numpy().reshape(1,-1), index_context_map[question.context_index].transformed_embedding.reshape(1,-1)).flatten()[0])
    return np.mean(sims)

def get_avg_wrong_similarity(model, questions: list[EmbeddedQuestion], index_context_map):
    sims = []
    model.eval()
    for question in questions:
        with torch.no_grad():
            transformed = model(index_context_map[question.context_index].embedding.to(device))
        index_context_map[question.context_index].transformed_embedding = transformed.to('cpu').detach().numpy()
    contexts = set([index_context_map[q.context_index] for q in questions])
    with torch.no_grad():
        question_emb = torch.cat([torch.unsqueeze(q.embedding, dim=0) for q in questions], dim=0).to(device)
        question_transformed = model(question_emb)
        for question, emb in zip(questions, question_transformed):
            wrong_context_embs = [c.transformed_embedding for c in filter(lambda x: x.context_index != question.context_index, contexts)]
            res = cosine_similarity(emb.to('cpu').detach().numpy().reshape(1,-1), np.array(wrong_context_embs)).flatten().tolist()
            sims += res
    return np.mean(sims)            


In [28]:
from typing import Any, Callable
from torch.optim import Optimizer


def train_one_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    loss,
    optimizer: Optimizer,
    step_fn: Callable[[nn.Module, Any, Any], Any],
):
    running_loss = 0.0
    model.train()
    for i, batch in enumerate(dataloader):
        l = step_fn(model, batch, loss)
        l.backward()
        running_loss += l.item()
        optimizer.step()
    return running_loss


def evaluate(
    model: nn.Module,
    dataloader: DataLoader,
    loss,
    step_fn: Callable[[nn.Module, Any, Any], Any],
):
    running_loss = 0.0
    model.eval()
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            l = step_fn(model, batch, loss)
            running_loss += l.item()
    return running_loss


def train(
    model: nn.Module,
    n_epochs: int,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    loss,
    optimizer: Optimizer,
    step_fn: Callable[[nn.Module, Any, Any], Any],
    valid_metrics: dict[str, Callable[[nn.Module, list[EmbeddedQuestion], dict], Any]],
    index_context_map: dict[int, EmbeddedContext],
    val_questions: list[EmbeddedQuestion]
):
    EARLY_STOP_EPOCHS = 5
    best_loss = 2**32-1
    best_loss_epoch = 0
    for epoch in range(n_epochs+1):
        if epoch == 0:
            report_strings = [
            f"Epoch {epoch-1}",
        ]
            for metric, metric_fn in valid_metrics.items():
                result = metric_fn(model, val_questions, index_context_map)
                report_strings.append(f"{metric}: {result}")
            print(" | ".join(report_strings))
            continue
        optimizer.zero_grad()
        epoch_train_loss = train_one_epoch(
            model, train_dataloader, loss, optimizer, step_fn
        )
        epoch_valid_loss = evaluate(model, val_dataloader, loss, step_fn)
        if epoch_valid_loss < best_loss:
            best_loss = epoch_valid_loss
            best_loss_epoch = epoch
        report_strings = [
            f"Epoch {epoch-1}",
            f"Train loss: {epoch_train_loss/len(train_dataloader)}",
            f"Validation loss: {epoch_valid_loss/len(val_dataloader)}",
        ]
        for metric, metric_fn in valid_metrics.items():
            result = metric_fn(model, val_questions, index_context_map)
            report_strings.append(f"{metric}: {result}")
        print(" | ".join(report_strings))
        if epoch - best_loss_epoch > EARLY_STOP_EPOCHS:
            break

In [29]:
def embed_and_save(model: nn.Module, test_questions: list[EmbeddedQuestion], index_context_map: dict[int, EmbeddedContext], metadata: dict, experiment_name: str, model_name: str, dataset_name: str):
    model.eval()
    transformed_questions: list[Question] = []
    transformed_contexts: list[Context] = []
    context_indices = set([q.context_index for q in test_questions])
    with torch.no_grad():
        for question in test_questions:
            transformed = model(question.embedding.to(device))
            question.transformed_embedding = transformed.to('cpu').detach()
            transformed_questions.append(question)
        for index in context_indices:
            transformed_context = model(index_context_map[index].embedding.to(device))
            index_context_map[index].transformed_embedding = transformed_context.to('cpu').detach()
            transformed_contexts.append(index_context_map[index]) 
    collection = DataCollection(questions=transformed_questions, contexts=transformed_contexts, metadata=metadata)
    with open(f"./data/train_experiments/{experiment_name}/{model_name}/{dataset_name}.pkl", "wb") as file:
        pkl.dump(collection, file)

## Single Layer/Dimensionality reduction

In [30]:
models = ["BAAI-bge-small-en-v1.5", "Cohere-embed-english-v3.0", "text-embedding-3-large", "text-embedding-ada-002"]

In [31]:
def triplet_step_fn(model, batch, loss):
    anchor, positive, negative = batch['questions'], batch['contexts'], batch['labels']
    anchor = model(anchor)
    positive = model(positive)
    negative = model(negative)
    return loss(anchor, positive, negative)

In [34]:
import pathlib
from glob import glob

EXPERIMENT_NAME = "single-layer-squared-no-dropout-long-train-low-margin"
EPOCH_MAP = {
    "sciq":30,
    "dolly":30,
    "sciq-large":5,
    "default":200
}

for model_name in models:
    print(f"Training {model_name}")
    if not pathlib.Path(f"./data/train_experiments/{EXPERIMENT_NAME}/{model_name}").exists():
        pathlib.Path(f"./data/train_experiments/{EXPERIMENT_NAME}/{model_name}/").mkdir(parents=True)
    for path in glob(f"./data/embedded/{model_name}/*.pkl"):
        dataset_name = "-".join(pathlib.Path(path).name.split("-")[:-1])
        if dataset_name in EPOCH_MAP:
            n_epochs = EPOCH_MAP[dataset_name]
        else:
            n_epochs = EPOCH_MAP["default"]
        print(f"Tranining on {dataset_name}")
        with open(path, "rb") as file:
            collection: DataCollection = pkl.load(file)
        train_questions, val_questions, test_questions, index_context_map = preprocess(
            collection, 42, 0.15, 0.3
        )

        # datasets
        dataset = TripletDataset(
            train_questions, index_context_map, n_examples=1, max_sim=0.8, use_filter=not dataset_name == "sciq-large"
        )
        valid_dataset = TripletDataset(
            val_questions, index_context_map, n_examples=1, max_sim=0.8, use_filter=not dataset_name == "sciq-large"
        )
        test_dataset = TripletDataset(
            test_questions, index_context_map, n_examples=1, max_sim=0.8, use_filter=not dataset_name == "sciq-large"
        )

        # dataloaders
        dataloader = DataLoader(
            dataset, 128, collate_fn=dataset_collate_fn, shuffle=True
        )
        valid_dataloader = DataLoader(
            valid_dataset,
            len(valid_dataset),
            collate_fn=dataset_collate_fn,
            shuffle=True,
        )
        test_dataloader = DataLoader(
            test_dataset, len(test_dataset), collate_fn=dataset_collate_fn, shuffle=True
        )

        # model + loss
        loss = nn.TripletMarginWithDistanceLoss(
            distance_function=CosineTripletLoss(),
            margin=0.5
        )
        model = LinearTransform(1, train_questions[0].embedding.shape[0], 0.0) #adjust dropout
        model.to(device)
        optimizer = torch.optim.SGD(params=model.parameters(), lr=1e-2)

        train(
            model,
            n_epochs,
            dataloader,
            valid_dataloader,
            loss,
            optimizer,
            triplet_step_fn,
            {
                "Accuracy top 1": get_correct_top_n,
                "Average correct sim": get_avg_correct_similarity,
                "Average wrong sim": get_avg_wrong_similarity,
            },
            index_context_map,
            val_questions,
        )

        embed_and_save(model, test_questions, index_context_map, {"Experiment":"Single Layer 0.1 Dropout"}, EXPERIMENT_NAME, model_name, dataset_name)

Training BAAI-bge-small-en-v1.5
Tranining on 2008_Sichuan_earthquake
Epoch -1 | Accuracy top 1: 0.8787878787878788 | Average correct sim: 0.6866998672485352 | Average wrong sim: 0.5112523372497475
Epoch 0 | Train loss: 0.36966437101364136 | Validation loss: 0.35472673177719116 | Accuracy top 1: 0.8787878787878788 | Average correct sim: 0.6813889145851135 | Average wrong sim: 0.5027345076940086
Epoch 1 | Train loss: 0.3622588515281677 | Validation loss: 0.34927865862846375 | Accuracy top 1: 0.8787878787878788 | Average correct sim: 0.6757727861404419 | Average wrong sim: 0.49372258192994756
Epoch 2 | Train loss: 0.3578720986843109 | Validation loss: 0.3433651328086853 | Accuracy top 1: 0.8787878787878788 | Average correct sim: 0.6698119640350342 | Average wrong sim: 0.4841421703842173
Epoch 3 | Train loss: 0.35070764025052387 | Validation loss: 0.3368930220603943 | Accuracy top 1: 0.8787878787878788 | Average correct sim: 0.6634203195571899 | Average wrong sim: 0.4738675948452246
Epoch 

# Expand Transform

In [35]:
models = ["BAAI-bge-small-en-v1.5", "Cohere-embed-english-v3.0", "text-embedding-3-large", "text-embedding-ada-002"]

In [38]:
def triplet_step_fn(model, batch, loss):
    anchor, positive, negative = batch['questions'], batch['contexts'], batch['labels']
    anchor = model(anchor)
    positive = model(positive)
    negative = model(negative)
    return loss(anchor, positive, negative)

In [39]:
import pathlib
from glob import glob

EXPERIMENT_NAME = "expanded-transform"
EPOCH_MAP = {
    "sciq":30,
    "dolly":30,
    "sciq-large":5,
    "default":200
}

for model_name in models:
    print(f"Training {model_name}")
    if not pathlib.Path(f"./data/train_experiments/{EXPERIMENT_NAME}/{model_name}").exists():
        pathlib.Path(f"./data/train_experiments/{EXPERIMENT_NAME}/{model_name}/").mkdir(parents=True)
    for path in glob(f"./data/embedded/{model_name}/*.pkl"):
        dataset_name = "-".join(pathlib.Path(path).name.split("-")[:-1])
        if dataset_name in EPOCH_MAP:
            n_epochs = EPOCH_MAP[dataset_name]
        else:
            n_epochs = EPOCH_MAP["default"]
        print(f"Tranining on {dataset_name}")
        with open(path, "rb") as file:
            collection: DataCollection = pkl.load(file)
        train_questions, val_questions, test_questions, index_context_map = preprocess(
            collection, 42, 0.15, 0.3
        )

        # datasets
        dataset = TripletDataset(
            train_questions, index_context_map, n_examples=1, max_sim=0.8, use_filter=not dataset_name == "sciq-large"
        )
        valid_dataset = TripletDataset(
            val_questions, index_context_map, n_examples=1, max_sim=0.8, use_filter=not dataset_name == "sciq-large"
        )
        test_dataset = TripletDataset(
            test_questions, index_context_map, n_examples=1, max_sim=0.8, use_filter=not dataset_name == "sciq-large"
        )

        # dataloaders
        dataloader = DataLoader(
            dataset, 128, collate_fn=dataset_collate_fn, shuffle=True
        )
        valid_dataloader = DataLoader(
            valid_dataset,
            len(valid_dataset),
            collate_fn=dataset_collate_fn,
            shuffle=True,
        )
        test_dataloader = DataLoader(
            test_dataset, len(test_dataset), collate_fn=dataset_collate_fn, shuffle=True
        )

        # model + loss
        loss = nn.TripletMarginWithDistanceLoss(
            distance_function=CosineTripletLoss(),
            margin=0.5
        )
        model = ExpandTransform(train_questions[0].embedding.shape[0], 0.2)
        model.to(device)
        optimizer = torch.optim.SGD(params=model.parameters(), lr=1e-2)

        train(
            model,
            n_epochs,
            dataloader,
            valid_dataloader,
            loss,
            optimizer,
            triplet_step_fn,
            {
                "Accuracy top 1": get_correct_top_n,
                "Average correct sim": get_avg_correct_similarity,
                "Average wrong sim": get_avg_wrong_similarity,
            },
            index_context_map,
            val_questions,
        )

        embed_and_save(model, test_questions, index_context_map, {"Experiment":"Single Layer 0.1 Dropout"}, EXPERIMENT_NAME, model_name, dataset_name)

Training BAAI-bge-small-en-v1.5
Tranining on 2008_Sichuan_earthquake
Epoch -1 | Accuracy top 1: 0.8863636363636364 | Average correct sim: 0.7249703407287598 | Average wrong sim: 0.5690497552212536
Epoch 0 | Train loss: 0.4030599693457286 | Validation loss: 0.3848750591278076 | Accuracy top 1: 0.8863636363636364 | Average correct sim: 0.7108638882637024 | Average wrong sim: 0.5468414696898947
Epoch 1 | Train loss: 0.39293550451596576 | Validation loss: 0.3709593415260315 | Accuracy top 1: 0.8863636363636364 | Average correct sim: 0.6941148042678833 | Average wrong sim: 0.520457870378924
Epoch 2 | Train loss: 0.38061482707659405 | Validation loss: 0.3535289168357849 | Accuracy top 1: 0.8863636363636364 | Average correct sim: 0.674406111240387 | Average wrong sim: 0.48938953831340326
Epoch 3 | Train loss: 0.36424562335014343 | Validation loss: 0.33050018548965454 | Accuracy top 1: 0.8863636363636364 | Average correct sim: 0.6501629948616028 | Average wrong sim: 0.4511362426755341
Epoch 4 

# Dimensionality reduction

In [16]:
models = ["BAAI-bge-small-en-v1.5", "Cohere-embed-english-v3.0", "text-embedding-3-large", "text-embedding-ada-002"]

In [17]:
def triplet_step_fn(model, batch, loss):
    anchor, positive, negative = batch['questions'], batch['contexts'], batch['labels']
    anchor = model(anchor)
    positive = model(positive)
    negative = model(negative)
    return loss(anchor, positive, negative)

In [40]:
import pathlib
from glob import glob

EXPERIMENT_NAME = "dimensionality-reduction"
EPOCH_MAP = {
    "sciq":30,
    "dolly":30,
    "sciq-large":5,
    "default":200
}

for model_name in models:
    print(f"Training {model_name}")
    if not pathlib.Path(f"./data/train_experiments/{EXPERIMENT_NAME}/{model_name}").exists():
        pathlib.Path(f"./data/train_experiments/{EXPERIMENT_NAME}/{model_name}/").mkdir(parents=True)
    for path in glob(f"./data/embedded/{model_name}/*.pkl"):
        dataset_name = "-".join(pathlib.Path(path).name.split("-")[:-1])
        if dataset_name in EPOCH_MAP:
            n_epochs = EPOCH_MAP[dataset_name]
        else:
            n_epochs = EPOCH_MAP["default"]
        print(f"Tranining on {dataset_name}")
        with open(path, "rb") as file:
            collection: DataCollection = pkl.load(file)
        train_questions, val_questions, test_questions, index_context_map = preprocess(
            collection, 42, 0.15, 0.3
        )

        # datasets
        dataset = TripletDataset(
            train_questions, index_context_map, n_examples=1, max_sim=0.8, use_filter=not dataset_name == "sciq-large"
        )
        valid_dataset = TripletDataset(
            val_questions, index_context_map, n_examples=1, max_sim=0.8, use_filter=not dataset_name == "sciq-large"
        )
        test_dataset = TripletDataset(
            test_questions, index_context_map, n_examples=1, max_sim=0.8, use_filter=not dataset_name == "sciq-large"
        )

        # dataloaders
        dataloader = DataLoader(
            dataset, 128, collate_fn=dataset_collate_fn, shuffle=True
        )
        valid_dataloader = DataLoader(
            valid_dataset,
            len(valid_dataset),
            collate_fn=dataset_collate_fn,
            shuffle=True,
        )
        test_dataloader = DataLoader(
            test_dataset, len(test_dataset), collate_fn=dataset_collate_fn, shuffle=True
        )

        # model + loss
        loss = nn.TripletMarginWithDistanceLoss(
            distance_function=CosineTripletLoss(),
            margin=0.5
        )
        model = DimensionalityReductionTransform(train_questions[0].embedding.shape[0], 0.2)
        model.to(device)
        optimizer = torch.optim.SGD(params=model.parameters(), lr=1e-2)

        train(
            model,
            n_epochs,
            dataloader,
            valid_dataloader,
            loss,
            optimizer,
            triplet_step_fn,
            {
                "Accuracy top 1": get_correct_top_n,
                "Average correct sim": get_avg_correct_similarity,
                "Average wrong sim": get_avg_wrong_similarity,
            },
            index_context_map,
            val_questions,
        )

        embed_and_save(model, test_questions, index_context_map, {"Experiment":"Single Layer 0.1 Dropout"}, EXPERIMENT_NAME, model_name, dataset_name)

Training BAAI-bge-small-en-v1.5
Tranining on 2008_Sichuan_earthquake
Epoch -1 | Accuracy top 1: 0.8636363636363636 | Average correct sim: 0.7055796384811401 | Average wrong sim: 0.5278719120333639
Epoch 0 | Train loss: 0.36693186561266583 | Validation loss: 0.3355693817138672 | Accuracy top 1: 0.8636363636363636 | Average correct sim: 0.6813669204711914 | Average wrong sim: 0.48764062056724916
Epoch 1 | Train loss: 0.35127390424410504 | Validation loss: 0.30206024646759033 | Accuracy top 1: 0.8636363636363636 | Average correct sim: 0.6522899866104126 | Average wrong sim: 0.4392375940014348
Epoch 2 | Train loss: 0.3182796835899353 | Validation loss: 0.25588515400886536 | Accuracy top 1: 0.8712121212121212 | Average correct sim: 0.6154924035072327 | Average wrong sim: 0.3776469526922351
Epoch 3 | Train loss: 0.2858186463514964 | Validation loss: 0.20181915163993835 | Accuracy top 1: 0.8636363636363636 | Average correct sim: 0.5735310912132263 | Average wrong sim: 0.30712183881952076
Epoc