In [1]:
from __future__ import annotations


"""
This training script trains MLPs up to layer 5?
"""
from pathlib import Path

import torch
from safetensors.torch import load_file
from torch import nn


DATASETS = [
    # (numbers are counts for documents, there may be some longer documents -> slightly more chunks)
    "arguana",  # 10K
    "fiqa",  # 50K -> 20K
    "scidocs",  # 25K -> 20K
    "nfcorpus",  # 5K
    "hotpotqa",  # 100K -> 20K
    "trec-covid",  # too much -> 20K
]

MODEL_NAMES = [
    "WhereIsAI/UAE-Large-V1",
    "BAAI/bge-base-en-v1.5",
    "BAAI/bge-large-en-v1.5",
    "BAAI/bge-small-en-v1.5",
    "intfloat/e5-base-v2",
    "intfloat/e5-large-v2",
    "intfloat/e5-small-v2",
    "thenlper/gte-base",
    "thenlper/gte-large",
    "thenlper/gte-small",
    "sentence-transformers/gtr-t5-base",
    "sentence-transformers/gtr-t5-large",
    "mixedbread-ai/mxbai-embed-large-v1",
    "sentence-transformers/sentence-t5-base",
    "sentence-transformers/sentence-t5-large",
    "openai/text-embedding-3-large",
    "openai/text-embedding-3-small",
]


class MLP(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        hidden_dims: list[int] | None = None,
        num_layers: int = 1,
    ):
        super().__init__()
        num_layers = num_layers or len(hidden_dims)
        if num_layers is None:
            raise ValueError("Either num_layers or hidden_dims must be provided")
        if hidden_dims is None:
            hidden_dims = [max(input_dim, output_dim)] * (num_layers - 1)

        # Build layer dimensions including input and output
        layer_dims = [input_dim] + hidden_dims + [output_dim]

        # Create sequential model with linear layers and ReLU activations
        layers: list[nn.Module] = []
        for i in range(len(layer_dims) - 1):
            layers.append(nn.Linear(layer_dims[i], layer_dims[i + 1]))
            if i < len(layer_dims) - 2:  # No ReLU after final layer
                layers.append(nn.ReLU())

        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

In [2]:
"""
Simple scaling test to see how many we can train at once without running out of memory.
"""
import tqdm
from torch.utils.data import DataLoader
from torch.utils.data import Dataset


device = "cuda:3"
embeddings_path = Path("/mnt/align3_drive/adrianoh/dl_final_project_embeddings")
embeddings_path_src = embeddings_path / MODEL_NAMES[0].replace("/", "_") / DATASETS[1]
embeddings_path_dst = embeddings_path / MODEL_NAMES[1].replace("/", "_") / DATASETS[1]


# NOTE: copied from cereal.py
class EmbeddingDataset(Dataset):
    def __init__(
        self, source_embeddings: torch.Tensor, target_embeddings: torch.Tensor
    ):
        assert source_embeddings.shape[0] == target_embeddings.shape[0]
        self.source_embeddings = source_embeddings
        self.target_embeddings = target_embeddings

    def __len__(self):
        return len(self.source_embeddings)

    def __getitem__(self, idx):
        return self.source_embeddings[idx], self.target_embeddings[idx]


def get_embeddings_paths(embeddings_path: Path):
    record_type = "corpus"
    embeddings_train_path = (
        embeddings_path / f"embeddings_{record_type}_train.safetensors"
    )
    embeddings_validation_path = (
        embeddings_path / f"embeddings_{record_type}_validation.safetensors"
    )
    assert (embeddings_train_path.exists() and embeddings_validation_path.exists()) or (
        not embeddings_train_path.exists() and not embeddings_validation_path.exists()
    )
    if not embeddings_train_path.exists():
        # NOTE: that sometimes the path names are reversed, i.e. when using OpenAI models; you can observe
        # more in detail in `get_reversed_model_files` in `sanity_check_embeddings_note_equal.ipynb`
        embeddings_train_path = (
            embeddings_path / f"{record_type}_train_embeddings.safetensors"
        )
        embeddings_validation_path = (
            embeddings_path / f"{record_type}_validation_embeddings.safetensors"
        )
    assert embeddings_train_path.exists() and embeddings_validation_path.exists(), f"Files {embeddings_train_path} and {embeddings_validation_path} do not exist"  # fmt: skip
    return embeddings_train_path, embeddings_validation_path


embeddings_train_path_src, embeddings_validation_path_src = get_embeddings_paths(
    embeddings_path_src
)
embeddings_train_path_dst, embeddings_validation_path_dst = get_embeddings_paths(
    embeddings_path_dst
)
# Load corpus validation embeddings
embeddings_src = load_file(embeddings_train_path_src)["embeddings"].to(device)
embeddings_dst = load_file(embeddings_train_path_dst)["embeddings"].to(device)
print("embedding src shape", embeddings_src.shape)
print("embedding dst shape", embeddings_dst.shape)
num_layers = [2, 3, 4, 5, 6, 7, 8, 9, 10]
linears_src2dst = [
    MLP(embeddings_src.shape[1], embeddings_dst.shape[1], num_layers=n).to(device)
    for n in num_layers
]
linears_dst2src = [
    MLP(embeddings_dst.shape[1], embeddings_src.shape[1], num_layers=n).to(device)
    for n in num_layers
]

embedding src shape torch.Size([24900, 1024])
embedding dst shape torch.Size([24900, 768])


In [3]:
"""
Doing the scaling test still.
"""
import time


time_start = time.time()
train_dataset = EmbeddingDataset(embeddings_src, embeddings_dst)
loss_fn = nn.MSELoss()
# NOTE
# We will get time per epoch per layer
# TOTAL_TIME = NUM_EPOCHS * NUM_MODELS * TIME_PER_MODEL_EPOCH
num_epochs = 10
batch_sizes = [len(train_dataset)]  # [32, 64, 128, 512, 1024, 4096, len(train_dataset)]
for batch_size in batch_sizes:
    #### START ####
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    all_parameters = []
    for linears in linears_src2dst + linears_dst2src:
        all_parameters.extend(list(linears.parameters()))
    optimizer = torch.optim.Adam(all_parameters, lr=1e-3)
    for linears in linears_src2dst + linears_dst2src:
        linears.train()
    for epoch in tqdm.trange(num_epochs):
        for src_emb, dst_emb in train_loader:
            for linear_src2dst, linear_dst2src in zip(
                linears_src2dst, linears_dst2src, strict=False
            ):
                linear_src2dst.zero_grad()
                linear_dst2src.zero_grad()
                # one goes backwards, one forwards
                output_src2dst = linear_src2dst(src_emb)
                output_dst2src = linear_dst2src(dst_emb)
                loss_src2dst = loss_fn(output_src2dst, dst_emb)
                loss_dst2src = loss_fn(output_dst2src, src_emb)
                loss = loss_src2dst + loss_dst2src
                loss.backward()
                optimizer.step()
    time_end = time.time()
    time_per_model_epoch = (
        (time_end - time_start) / num_epochs / len(linears_src2dst + linears_dst2src)
    )
    print("Batch size", batch_size)
    print(f"Time taken: {time_end - time_start:.2f} seconds")
    print(f"Time per model epoch: {time_per_model_epoch:.2f} seconds")
    #### END ####

100%|██████████| 10/10 [00:05<00:00,  1.99it/s]

Batch size 24900
Time taken: 10.35 seconds
Time per model epoch: 0.06 seconds





In [4]:
"""
Try flipping the order of the for loops.
"""
num_epochs = 10
batch_sizes = [len(train_dataset)]  # [32, 64, 128, 512, 1024, 4096, len(train_dataset)]
for batch_size in batch_sizes:
    #### START ####
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    all_parameters = []
    for linears in linears_src2dst + linears_dst2src:
        all_parameters.extend(list(linears.parameters()))
    optimizer = torch.optim.Adam(all_parameters, lr=1e-3)
    for linears in linears_src2dst + linears_dst2src:
        linears.train()
    for linear_src2dst, linear_dst2src in zip(
        linears_src2dst, linears_dst2src, strict=False
    ):
        for epoch in tqdm.trange(num_epochs):
            for src_emb, dst_emb in train_loader:
                linear_src2dst.zero_grad()
                linear_dst2src.zero_grad()
                # one goes backwards, one forwards
                output_src2dst = linear_src2dst(src_emb)
                output_dst2src = linear_dst2src(dst_emb)
                loss_src2dst = loss_fn(output_src2dst, dst_emb)
                loss_dst2src = loss_fn(output_dst2src, src_emb)
                loss = loss_src2dst + loss_dst2src
                loss.backward()
                optimizer.step()
    time_end = time.time()
    time_per_model_epoch = (
        (time_end - time_start) / num_epochs / len(linears_src2dst + linears_dst2src)
    )
    print("Batch size", batch_size)
    print(f"Time taken: {time_end - time_start:.2f} seconds")
    print(f"Time per model epoch: {time_per_model_epoch:.2f} seconds")

100%|██████████| 10/10 [00:01<00:00,  5.14it/s]
100%|██████████| 10/10 [00:01<00:00,  5.27it/s]
100%|██████████| 10/10 [00:01<00:00,  5.45it/s]
100%|██████████| 10/10 [00:01<00:00,  5.44it/s]
100%|██████████| 10/10 [00:01<00:00,  5.35it/s]
100%|██████████| 10/10 [00:01<00:00,  5.37it/s]
100%|██████████| 10/10 [00:01<00:00,  5.32it/s]
100%|██████████| 10/10 [00:01<00:00,  5.34it/s]
100%|██████████| 10/10 [00:01<00:00,  5.42it/s]

Batch size 24900
Time taken: 27.23 seconds
Time per model epoch: 0.15 seconds





In [5]:
"""
Claude thinks this will be faster.
"""


def train_models(
    src_emb: torch.Tensor,
    dst_emb: torch.Tensor,
    models_src2dst: list[MLP],
    models_dst2src: list[MLP],
    optimizer: torch.optim.Optimizer,
    loss_fn: nn.Module,
):
    batch_size = src_emb.shape[0]
    num_models = len(models_src2dst)

    # Expand embeddings to match number of models
    src_emb_expanded = src_emb.unsqueeze(0).expand(
        num_models, -1, -1
    )  # [num_models, batch_size, dim]
    dst_emb_expanded = dst_emb.unsqueeze(0).expand(num_models, -1, -1)

    optimizer.zero_grad()

    # Forward pass (all models at once)
    outputs_src2dst = torch.stack([model(src_emb) for model in models_src2dst])
    outputs_dst2src = torch.stack([model(dst_emb) for model in models_dst2src])

    # Compute loss for all models simultaneously
    loss_src2dst = loss_fn(outputs_src2dst, dst_emb_expanded)
    loss_dst2src = loss_fn(outputs_dst2src, src_emb_expanded)
    loss = loss_src2dst + loss_dst2src

    # Backward pass and optimization
    loss.mean().backward()
    optimizer.step()

    return loss.mean().item()


def training_run(
    # NOTE: each element of the list gets the same dataset
    # (i.e. these should correspond to different embeddings)
    datasets: list[EmbeddingDataset],
    models_src2dst: list[list[MLP]],
    models_dst2src: list[list[MLP]],
    loss_fn: nn.Module,
    num_epochs: int,
    batch_size: int,
) -> float:
    assert (
        len(set(len(d) for d in datasets)) == 1
    ), "All datasets must have the same length"
    time_start = time.time()
    train_loaders = [
        DataLoader(d, batch_size=batch_size, shuffle=True) for d in datasets
    ]
    num_iters = len(train_loaders[0])  # NOTE: same for all of the datasets
    all_parameters = []
    # 1. Get all parameters
    for models_list in models_src2dst + models_dst2src:
        for models in models_list:
            # (TODO Adriano not sure if joining this way is OK?)
            all_parameters.extend(list(models.parameters()))
    optimizer = torch.optim.Adam(all_parameters, lr=1e-3)
    # 2. Set all models to train
    for models_list in models_src2dst + models_dst2src:
        for models in models_list:
            models.train()
    # 3. Train
    for epoch in tqdm.trange(num_epochs):
        loader_iters = [iter(loader) for loader in train_loaders]
        for i in range(num_iters):
            xys = [next(loader_iter) for loader_iter in loader_iters]
            assert not any(
                x is None for x in xys
            ), "Some loader iterators returned None"
            for (X, Y), models_src2dst_list, models_dst2src_list in zip(
                xys, models_src2dst, models_dst2src, strict=False
            ):
                train_models(
                    X, Y, models_src2dst_list, models_dst2src_list, optimizer, loss_fn
                )
    return time.time() - time_start


train_dataset = EmbeddingDataset(embeddings_src, embeddings_dst)
num_epochs = 10
for batch_size in batch_sizes:
    loss_fn = nn.MSELoss()
    time_taken = training_run(
        [train_dataset],
        [linears_src2dst],
        [linears_dst2src],
        loss_fn,
        num_epochs,
        batch_size,
    )
    print(f"Batch size {batch_size}: {time_taken:.2f} seconds")
    print(
        f"Time per model per epoch: {time_taken / num_epochs / len(linears_src2dst + linears_dst2src):.2f} seconds"
    )

100%|██████████| 10/10 [00:06<00:00,  1.53it/s]

Batch size 24900: 6.52 seconds
Time per model per epoch: 0.04 seconds





In [16]:
"""
This cell is meant to try and get a reasonable estimate for the time and memory
possible when trying to train a LOT in parallel (i.e. literally training 
all N^2 at the same time (if this is possible, it seems ideal?)

NOTE: this is not optimally memory efficient because it loads every dataset N-1 times.
"""
# del linears_src2dst, linears_dst2src

# free cuda memory
torch.cuda.empty_cache()
device = "cuda:3"
import gc
import math
import random


gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
with torch.cuda.device(device):
    torch.cuda.empty_cache()
# ---> Now make a table of models for all N^2 combinations where the indexxing is
# models[source][destination][layer_num] -> return None if src == destination else list of models
unordered_pairs_all = [
    (MODEL_NAMES[i], MODEL_NAMES[j]) for j in range(len(MODEL_NAMES)) for i in range(j)
]
random.seed(55)
random.shuffle(unordered_pairs_all)
unordered_pairs_block_size: int = math.ceil(len(unordered_pairs_all) / 3)
_is = list(range(0, len(unordered_pairs_all), unordered_pairs_block_size))
model2embeddings = {}
for model_name in MODEL_NAMES:
    embeddings_path = Path("/mnt/align3_drive/adrianoh/dl_final_project_embeddings")
    embeddings_path_src = embeddings_path / model_name.replace("/", "_") / DATASETS[1]
    model2embeddings[model_name] = load_file(embeddings_train_path_src)[
        "embeddings"
    ].to(device)
for _, i in enumerate(tqdm.tqdm(_is)):
    print(f"Processing block {_}/{len(_is)} idx={i} of {len(unordered_pairs_all)}")
    unordered_pairs = unordered_pairs_all[i : i + unordered_pairs_block_size]
    embeddings_tensors_src: list[torch.Tensor] = []
    embeddings_tensors_dst: list[torch.Tensor] = []
    for src, dst in tqdm.tqdm(unordered_pairs):
        # Append by ptr ideally?
        embeddings_tensors_src.append(model2embeddings[src])
        embeddings_tensors_dst.append(model2embeddings[dst])
    assert len(embeddings_tensors_src) == len(unordered_pairs)
    assert len(embeddings_tensors_dst) == len(unordered_pairs)
    print("creating models")
    models_src2dst: list[list[MLP]] = []
    models_dst2src: list[list[MLP]] = []
    # NOTE: maybe we just do layers <= len = 7 instead of going all the way to 10 (this is already like 30% of the models' depths)
    n_layers_list = list(range(2, 8))  # Maybe two blocks: [2,3], [4,5], [6,7] ???
    # assert len(n_layers_list) <= 3
    for (src, dst), src_emb, dst_emb in tqdm.tqdm(
        list(
            zip(
                unordered_pairs,
                embeddings_tensors_src,
                embeddings_tensors_dst,
                strict=False,
            )
        )
    ):
        src_dim, dst_dim = src_emb.shape[1], dst_emb.shape[1]
        models_src2dst.append(
            [MLP(src_dim, dst_dim, num_layers=n).to(device) for n in n_layers_list]
        )
        models_dst2src.append(
            [MLP(dst_dim, src_dim, num_layers=n).to(device) for n in n_layers_list]
        )
    assert len(models_src2dst) == len(models_dst2src) == len(unordered_pairs)

    print("training for 1 epoch")
    assert isinstance(models_src2dst, list)
    assert isinstance(models_dst2src, list)
    assert all(isinstance(models_src2dst[i], list) for i in range(len(models_src2dst)))
    assert all(isinstance(models_dst2src[i], list) for i in range(len(models_dst2src)))
    assert all(
        all(
            isinstance(models_src2dst[i][j], MLP) for j in range(len(models_src2dst[i]))
        )
        for i in range(len(models_src2dst))
    )
    assert all(
        all(
            isinstance(models_dst2src[i][j], MLP) for j in range(len(models_dst2src[i]))
        )
        for i in range(len(models_dst2src))
    )
    train_datasets = [
        EmbeddingDataset(embeddings_src, embeddings_dst)
        for embeddings_src, embeddings_dst in zip(
            embeddings_tensors_src, embeddings_tensors_dst, strict=False
        )
    ]
    assert len(train_datasets) == len(models_src2dst) == len(models_dst2src)
    num_epochs = 1
    # batch_sizes = [8192, 1024, 128]
    batch_sizes = [512]  # I'm a little paranoid too bit will ruin results
    for batch_size in batch_sizes:
        loss_fn = nn.MSELoss()
        time_taken = training_run(
            train_datasets,
            models_src2dst,
            models_dst2src,
            loss_fn,
            num_epochs,
            batch_size,
        )
        print(f"Batch size {batch_size}: {time_taken:.2f} seconds")
        time_per_layer_per_epoch = (
            time_taken / num_epochs / len(linears_src2dst + linears_dst2src)
        )
        print(f"Time per layer per epoch: {time_per_layer_per_epoch:.2f} seconds")
    print("clearing gpu memory")
    del models_src2dst, models_dst2src, train_datasets
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    with torch.cuda.device(device):
        torch.cuda.empty_cache()

  0%|          | 0/3 [00:00<?, ?it/s]

Processing block 0/3 idx=0 of 136


100%|██████████| 46/46 [00:00<00:00, 634664.42it/s]


creating models


100%|██████████| 46/46 [00:12<00:00,  3.55it/s]


training for 1 epoch


100%|██████████| 1/1 [00:38<00:00, 38.23s/it]
 33%|███▎      | 1/3 [00:51<01:42, 51.46s/it]

Batch size 512: 38.26 seconds
Time per layer per epoch: 2.13 seconds
clearing gpu memory
Processing block 1/3 idx=46 of 136


100%|██████████| 46/46 [00:00<00:00, 810663.80it/s]


creating models


100%|██████████| 46/46 [00:14<00:00,  3.26it/s]


training for 1 epoch


100%|██████████| 1/1 [00:37<00:00, 37.59s/it]
 67%|██████▋   | 2/3 [01:43<00:51, 51.79s/it]

Batch size 512: 37.62 seconds
Time per layer per epoch: 2.09 seconds
clearing gpu memory
Processing block 2/3 idx=92 of 136


100%|██████████| 44/44 [00:00<00:00, 753262.76it/s]


creating models


100%|██████████| 44/44 [00:13<00:00,  3.22it/s]


training for 1 epoch


100%|██████████| 1/1 [00:37<00:00, 37.87s/it]
100%|██████████| 3/3 [02:35<00:00, 51.82s/it]

Batch size 512: 37.89 seconds
Time per layer per epoch: 2.11 seconds
clearing gpu memory


100%|██████████| 3/3 [02:35<00:00, 51.78s/it]


In [None]:
"""
Conclusions from above:
1. Larger batch size is actually a LOT faster? Not measured fully rigorously, but seems like it COULD be up to 20X faster.
    Question: won't doing the entire dataset as a batch be bad?
2. Running the dataloader in the outside seems better than
    running it on the inside (i.e. all models "more at once")
    is generally better. This gives maybe up to 2X?
3. Batching loss is pretty good. Maybe up to 2X?
4. It is possible to do all N^2 combinations but you run out of memory if you do > 1 model per combination
5. It seems better to block by pairs more so than by models.

Doing basically blocks of the N^2 combinations seems reasonable to me, doing all layers at once. I need to be careful to make sure
I don't OOM though. The cause for the OOM seems to be that some of the pairs of models enforce some kind of bottleneck in memory at
certain times and it's pretty bad.
"""