In [45]:
import torch
import safetensors
import pydantic
import safetensors.torch
from typing import Literal, Optional, List, Dict, Tuple
import json
from pathlib import Path
DATASETS_PATH = Path("/mnt/align3_drive/adrianoh/dl_final_project_embeddings")
NFCORPUS_PATH = DATASETS_PATH / "nfcorpus-real-other-is-scidocs"
SCIDOCS_PATH = DATASETS_PATH / "nfcorpus"
FIQA_PATH = DATASETS_PATH / "fiqa"
ARGUANA_PATH = DATASETS_PATH / "arguana"

class EmbeddingDataset(pydantic.BaseModel):
    config: pydantic.ConfigDict = pydantic.ConfigDict(arbitrary_types_allowed=True)
    embeddings: torch.Tensor
    ids: List[str]
    documents: List[str]

class EmbeddingLoader:
    def __init__(self, root_path: Path):
        self.root_path = root_path

    def fetch_embeddings(
        self,
        corpus_folder: Literal["corpus", "query"] = "corpus",
        model_folders: Optional[List[str]] = None
    ) -> List[EmbeddingDataset]:
        """
        Fetch the set of embeddings for the given corpus and subfolders.
        If it's `None` for `model_folders`, it will fetch all the models' embeddings.
        """
        if model_folders is None:
            model_folders = [model_folder.name for model_folder in (self.root_path / corpus_folder).iterdir() if model_folder.is_dir()] # fmt: skip
        subpaths = [(self.root_path / corpus_folder / model_folder) for model_folder in model_folders]
        assert len(set(subpath.name for subpath in subpaths)) == len(subpaths)
        # The files we want
        embeddings_paths = [subpath / "embeddings.safetensors" for subpath in subpaths]
        ids_paths = [subpath / "ids.jsonl" for subpath in subpaths]
        documents_paths = [subpath / "documents.jsonl" for subpath in subpaths]

        # They should all have the correct files present
        assert all(embeddings_path.exists() for embeddings_path in embeddings_paths), ("Not all embeddings paths exist:\n  " + '\n  '.join(f'{p.as_posix()}: {p.exists()}' for p in embeddings_paths)) # fmt: skip
        assert all(ids_path.exists() for ids_path in ids_paths), f"Not all ids paths exist: {ids_paths}"
        assert all(documents_path.exists() for documents_path in documents_paths), f"Not all documents paths exist: {documents_paths}" # fmt: skip

        _embeddings_sets: List[Dict[str, torch.Tensor]] = [safetensors.torch.load_file(embeddings_path, device="cpu") for embeddings_path in embeddings_paths] # fmt: skip
        _ids_sets: List[Dict[str, List[str]]] = [json.load(open(ids_path, "r")) for ids_path in ids_paths]
        _documents_sets: List[Dict[str, List[str]]] = [json.load(open(documents_path, "r")) for documents_path in documents_paths] # fmt: skip

        # Extrat the jsons
        assert all(isinstance(ids_set, dict) and "ids" in ids_set and len(ids_set) == 1 for ids_set in _ids_sets)
        assert all(isinstance(documents_set, dict) and "documents" in documents_set and len(documents_set) == 1 for documents_set in _documents_sets) # fmt: skip
        assert all(isinstance(embeddings_set, dict) and "embeddings" in embeddings_set and len(embeddings_set) == 1 for embeddings_set in _embeddings_sets) # fmt: skip
        ids_sets: List[List[str]] = [ids_set["ids"] for ids_set in _ids_sets]
        documents_sets: List[List[str]] = [documents_set["documents"] for documents_set in _documents_sets] # fmt: skip
        embeddings_sets: List[torch.Tensor] = [embeddings_set["embeddings"] for embeddings_set in _embeddings_sets] # fmt: skip

        # Make sure that all these have the same lengths
        _ids_lengths = [len(ids_set) for ids_set in ids_sets]
        _documents_lengths = [len(documents_set) for documents_set in documents_sets] # fmt: skip
        _embeddings_lengths = [len(embeddings_set) for embeddings_set in embeddings_sets] # fmt: skip
        assert len(set(_ids_lengths)) == 1
        assert len(set(_documents_lengths)) == 1
        assert len(set(_embeddings_lengths)) == 1

        # Make sure the types are OK => they come from the same dataset but different models embedded so they should ONLY differ in that
        assert all(isinstance(ids_set, list) and all(isinstance(id, str) for id in ids_set) for ids_set in ids_sets) # fmt: skip
        assert all(ids_set1 == ids_sets[0] for ids_set1 in ids_sets) # star pattern for equality
        # ...
        assert all(isinstance(documents_set, list) and all(isinstance(doc, str) for doc in documents_set) for documents_set in documents_sets) # fmt: skip
        assert all(documents_set1 == documents_sets[0] for documents_set1 in documents_sets) # star pattern for equality
        # ...
        assert all(isinstance(embeddings_set, torch.Tensor) and embeddings_set.shape == embeddings_sets[0].shape for embeddings_set in embeddings_sets) # fmt: skip
        # Cartesian product: all most be unique
        # avoid double-counting and self-comparison
        is_close: Dict[str, bool] = {}
        for i in range(len(embeddings_sets)):
            for j in range(i + 1, len(embeddings_sets)):
                path1, path2 = embeddings_paths[i].parent.name, embeddings_paths[j].parent.name
                path_key = f"{path1} <|> {path2}"
                is_close[path_key] = torch.allclose(embeddings_sets[i], embeddings_sets[j])
        is_close = {x : y for x, y in is_close.items() if y}
        assert len(is_close) == 0, f"Got {len(is_close)} embeddings (total n_embeddings={len(embeddings_sets)}) that are the same:\n\n{json.dumps(is_close, indent=2)}"
        return [
            EmbeddingDataset(
                embeddings=embeddings_set,
                ids=ids_set,
                documents=documents_set
            )
            for embeddings_set, ids_set, documents_set in zip(embeddings_sets, ids_sets, documents_sets)
        ]

loader = EmbeddingLoader(root_path=NFCORPUS_PATH)
embeddings_list: List[EmbeddingDataset] = loader.fetch_embeddings(corpus_folder="corpus", model_folders=None) # get all the document/corpus embeddings
print("Embeddings shape is", embeddings_list[0].embeddings.shape, "=", embeddings_list[1].embeddings.shape, "= ...") # axis 0 is dataset, axis 1 is embedding
print("Embeddings devices:", "embeddings1.device", embeddings_list[0].embeddings.device, "embeddings2.device", embeddings_list[1].embeddings.device, "etc...") # shoulds be cpu cpu


AssertionError: Got 120 embeddings (total n_embeddings=16) that are the same:

{
  "intfloat_e5-small-v2 <|> intfloat_e5-large-v2": true,
  "intfloat_e5-small-v2 <|> sentence-transformers_sentence-t5-large": true,
  "intfloat_e5-small-v2 <|> thenlper_gte-small": true,
  "intfloat_e5-small-v2 <|> Salesforce_SFR-Embedding-Mistral": true,
  "intfloat_e5-small-v2 <|> BAAI_bge-large-en-v1.5": true,
  "intfloat_e5-small-v2 <|> intfloat_e5-base-v2": true,
  "intfloat_e5-small-v2 <|> BAAI_bge-base-en-v1.5": true,
  "intfloat_e5-small-v2 <|> sentence-transformers_gtr-t5-large": true,
  "intfloat_e5-small-v2 <|> thenlper_gte-base": true,
  "intfloat_e5-small-v2 <|> sentence-transformers_sentence-t5-base": true,
  "intfloat_e5-small-v2 <|> thenlper_gte-large": true,
  "intfloat_e5-small-v2 <|> WhereIsAI_UAE-Large-V1": true,
  "intfloat_e5-small-v2 <|> BAAI_bge-small-en-v1.5": true,
  "intfloat_e5-small-v2 <|> sentence-transformers_gtr-t5-base": true,
  "intfloat_e5-small-v2 <|> mixedbread-ai_mxbai-embed-large-v1": true,
  "intfloat_e5-large-v2 <|> sentence-transformers_sentence-t5-large": true,
  "intfloat_e5-large-v2 <|> thenlper_gte-small": true,
  "intfloat_e5-large-v2 <|> Salesforce_SFR-Embedding-Mistral": true,
  "intfloat_e5-large-v2 <|> BAAI_bge-large-en-v1.5": true,
  "intfloat_e5-large-v2 <|> intfloat_e5-base-v2": true,
  "intfloat_e5-large-v2 <|> BAAI_bge-base-en-v1.5": true,
  "intfloat_e5-large-v2 <|> sentence-transformers_gtr-t5-large": true,
  "intfloat_e5-large-v2 <|> thenlper_gte-base": true,
  "intfloat_e5-large-v2 <|> sentence-transformers_sentence-t5-base": true,
  "intfloat_e5-large-v2 <|> thenlper_gte-large": true,
  "intfloat_e5-large-v2 <|> WhereIsAI_UAE-Large-V1": true,
  "intfloat_e5-large-v2 <|> BAAI_bge-small-en-v1.5": true,
  "intfloat_e5-large-v2 <|> sentence-transformers_gtr-t5-base": true,
  "intfloat_e5-large-v2 <|> mixedbread-ai_mxbai-embed-large-v1": true,
  "sentence-transformers_sentence-t5-large <|> thenlper_gte-small": true,
  "sentence-transformers_sentence-t5-large <|> Salesforce_SFR-Embedding-Mistral": true,
  "sentence-transformers_sentence-t5-large <|> BAAI_bge-large-en-v1.5": true,
  "sentence-transformers_sentence-t5-large <|> intfloat_e5-base-v2": true,
  "sentence-transformers_sentence-t5-large <|> BAAI_bge-base-en-v1.5": true,
  "sentence-transformers_sentence-t5-large <|> sentence-transformers_gtr-t5-large": true,
  "sentence-transformers_sentence-t5-large <|> thenlper_gte-base": true,
  "sentence-transformers_sentence-t5-large <|> sentence-transformers_sentence-t5-base": true,
  "sentence-transformers_sentence-t5-large <|> thenlper_gte-large": true,
  "sentence-transformers_sentence-t5-large <|> WhereIsAI_UAE-Large-V1": true,
  "sentence-transformers_sentence-t5-large <|> BAAI_bge-small-en-v1.5": true,
  "sentence-transformers_sentence-t5-large <|> sentence-transformers_gtr-t5-base": true,
  "sentence-transformers_sentence-t5-large <|> mixedbread-ai_mxbai-embed-large-v1": true,
  "thenlper_gte-small <|> Salesforce_SFR-Embedding-Mistral": true,
  "thenlper_gte-small <|> BAAI_bge-large-en-v1.5": true,
  "thenlper_gte-small <|> intfloat_e5-base-v2": true,
  "thenlper_gte-small <|> BAAI_bge-base-en-v1.5": true,
  "thenlper_gte-small <|> sentence-transformers_gtr-t5-large": true,
  "thenlper_gte-small <|> thenlper_gte-base": true,
  "thenlper_gte-small <|> sentence-transformers_sentence-t5-base": true,
  "thenlper_gte-small <|> thenlper_gte-large": true,
  "thenlper_gte-small <|> WhereIsAI_UAE-Large-V1": true,
  "thenlper_gte-small <|> BAAI_bge-small-en-v1.5": true,
  "thenlper_gte-small <|> sentence-transformers_gtr-t5-base": true,
  "thenlper_gte-small <|> mixedbread-ai_mxbai-embed-large-v1": true,
  "Salesforce_SFR-Embedding-Mistral <|> BAAI_bge-large-en-v1.5": true,
  "Salesforce_SFR-Embedding-Mistral <|> intfloat_e5-base-v2": true,
  "Salesforce_SFR-Embedding-Mistral <|> BAAI_bge-base-en-v1.5": true,
  "Salesforce_SFR-Embedding-Mistral <|> sentence-transformers_gtr-t5-large": true,
  "Salesforce_SFR-Embedding-Mistral <|> thenlper_gte-base": true,
  "Salesforce_SFR-Embedding-Mistral <|> sentence-transformers_sentence-t5-base": true,
  "Salesforce_SFR-Embedding-Mistral <|> thenlper_gte-large": true,
  "Salesforce_SFR-Embedding-Mistral <|> WhereIsAI_UAE-Large-V1": true,
  "Salesforce_SFR-Embedding-Mistral <|> BAAI_bge-small-en-v1.5": true,
  "Salesforce_SFR-Embedding-Mistral <|> sentence-transformers_gtr-t5-base": true,
  "Salesforce_SFR-Embedding-Mistral <|> mixedbread-ai_mxbai-embed-large-v1": true,
  "BAAI_bge-large-en-v1.5 <|> intfloat_e5-base-v2": true,
  "BAAI_bge-large-en-v1.5 <|> BAAI_bge-base-en-v1.5": true,
  "BAAI_bge-large-en-v1.5 <|> sentence-transformers_gtr-t5-large": true,
  "BAAI_bge-large-en-v1.5 <|> thenlper_gte-base": true,
  "BAAI_bge-large-en-v1.5 <|> sentence-transformers_sentence-t5-base": true,
  "BAAI_bge-large-en-v1.5 <|> thenlper_gte-large": true,
  "BAAI_bge-large-en-v1.5 <|> WhereIsAI_UAE-Large-V1": true,
  "BAAI_bge-large-en-v1.5 <|> BAAI_bge-small-en-v1.5": true,
  "BAAI_bge-large-en-v1.5 <|> sentence-transformers_gtr-t5-base": true,
  "BAAI_bge-large-en-v1.5 <|> mixedbread-ai_mxbai-embed-large-v1": true,
  "intfloat_e5-base-v2 <|> BAAI_bge-base-en-v1.5": true,
  "intfloat_e5-base-v2 <|> sentence-transformers_gtr-t5-large": true,
  "intfloat_e5-base-v2 <|> thenlper_gte-base": true,
  "intfloat_e5-base-v2 <|> sentence-transformers_sentence-t5-base": true,
  "intfloat_e5-base-v2 <|> thenlper_gte-large": true,
  "intfloat_e5-base-v2 <|> WhereIsAI_UAE-Large-V1": true,
  "intfloat_e5-base-v2 <|> BAAI_bge-small-en-v1.5": true,
  "intfloat_e5-base-v2 <|> sentence-transformers_gtr-t5-base": true,
  "intfloat_e5-base-v2 <|> mixedbread-ai_mxbai-embed-large-v1": true,
  "BAAI_bge-base-en-v1.5 <|> sentence-transformers_gtr-t5-large": true,
  "BAAI_bge-base-en-v1.5 <|> thenlper_gte-base": true,
  "BAAI_bge-base-en-v1.5 <|> sentence-transformers_sentence-t5-base": true,
  "BAAI_bge-base-en-v1.5 <|> thenlper_gte-large": true,
  "BAAI_bge-base-en-v1.5 <|> WhereIsAI_UAE-Large-V1": true,
  "BAAI_bge-base-en-v1.5 <|> BAAI_bge-small-en-v1.5": true,
  "BAAI_bge-base-en-v1.5 <|> sentence-transformers_gtr-t5-base": true,
  "BAAI_bge-base-en-v1.5 <|> mixedbread-ai_mxbai-embed-large-v1": true,
  "sentence-transformers_gtr-t5-large <|> thenlper_gte-base": true,
  "sentence-transformers_gtr-t5-large <|> sentence-transformers_sentence-t5-base": true,
  "sentence-transformers_gtr-t5-large <|> thenlper_gte-large": true,
  "sentence-transformers_gtr-t5-large <|> WhereIsAI_UAE-Large-V1": true,
  "sentence-transformers_gtr-t5-large <|> BAAI_bge-small-en-v1.5": true,
  "sentence-transformers_gtr-t5-large <|> sentence-transformers_gtr-t5-base": true,
  "sentence-transformers_gtr-t5-large <|> mixedbread-ai_mxbai-embed-large-v1": true,
  "thenlper_gte-base <|> sentence-transformers_sentence-t5-base": true,
  "thenlper_gte-base <|> thenlper_gte-large": true,
  "thenlper_gte-base <|> WhereIsAI_UAE-Large-V1": true,
  "thenlper_gte-base <|> BAAI_bge-small-en-v1.5": true,
  "thenlper_gte-base <|> sentence-transformers_gtr-t5-base": true,
  "thenlper_gte-base <|> mixedbread-ai_mxbai-embed-large-v1": true,
  "sentence-transformers_sentence-t5-base <|> thenlper_gte-large": true,
  "sentence-transformers_sentence-t5-base <|> WhereIsAI_UAE-Large-V1": true,
  "sentence-transformers_sentence-t5-base <|> BAAI_bge-small-en-v1.5": true,
  "sentence-transformers_sentence-t5-base <|> sentence-transformers_gtr-t5-base": true,
  "sentence-transformers_sentence-t5-base <|> mixedbread-ai_mxbai-embed-large-v1": true,
  "thenlper_gte-large <|> WhereIsAI_UAE-Large-V1": true,
  "thenlper_gte-large <|> BAAI_bge-small-en-v1.5": true,
  "thenlper_gte-large <|> sentence-transformers_gtr-t5-base": true,
  "thenlper_gte-large <|> mixedbread-ai_mxbai-embed-large-v1": true,
  "WhereIsAI_UAE-Large-V1 <|> BAAI_bge-small-en-v1.5": true,
  "WhereIsAI_UAE-Large-V1 <|> sentence-transformers_gtr-t5-base": true,
  "WhereIsAI_UAE-Large-V1 <|> mixedbread-ai_mxbai-embed-large-v1": true,
  "BAAI_bge-small-en-v1.5 <|> sentence-transformers_gtr-t5-base": true,
  "BAAI_bge-small-en-v1.5 <|> mixedbread-ai_mxbai-embed-large-v1": true,
  "sentence-transformers_gtr-t5-base <|> mixedbread-ai_mxbai-embed-large-v1": true
}

In [15]:
# Now, we can try to train linear transforms between embeddings...
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import tqdm
import torch.optim as optim
from pydantic import BaseModel
import wandb
from typing import Optional
class EmbeddingDataset(Dataset):
    def __init__(self, source_embeddings: torch.Tensor, target_embeddings: torch.Tensor):
        assert source_embeddings.shape == target_embeddings.shape
        self.source_embeddings = source_embeddings
        self.target_embeddings = target_embeddings
        
    def __len__(self):
        return len(self.source_embeddings)
        
    def __getitem__(self, idx):
        return self.source_embeddings[idx], self.target_embeddings[idx]

class LinearTransformTrainerArgs(BaseModel):
    test_split: float = 0.2
    num_epochs: int = 50
    batch_size: int = 32
    learning_rate: float = 0.001
    save_every_n_epochs: int = 10
    use_tqdm: bool = True

class LinearTransformTrainer:
    def __init__(
        self,
        save_path: Path,
        linear: Optional[nn.Linear],
        source_embeddings: torch.Tensor,
        target_embeddings: torch.Tensor,
        device: torch.device | str,
        args: LinearTransformTrainerArgs = LinearTransformTrainerArgs()
    ):
        self.linear = linear
        self.source_embeddings = source_embeddings
        self.target_embeddings = target_embeddings
        self.num_epochs = args.num_epochs
        self.batch_size = args.batch_size
        self.learning_rate = args.learning_rate
        self.device = device
        self.test_split = args.test_split
        self.save_every_n_epochs = args.save_every_n_epochs
        self.save_path = save_path
        self.checkpoint_path = save_path / "checkpoints"
        self.checkpoint_path.mkdir(parents=True, exist_ok=True)
        self.logfile = save_path / "log.jsonl"
        self.use_tqdm = args.use_tqdm
        if self.linear is None:
            self.linear = self.create_linear_transform()
        self.optimizer = torch.optim.Adam(self.linear.parameters(), lr=self.learning_rate)

    def create_datasets(self):
        # Create indices and shuffle
        num_samples = len(self.source_embeddings)
        indices = torch.randperm(num_samples)
        
        # Split indices
        split_idx = int(num_samples * (1 - self.test_split))
        train_indices = indices[:split_idx]
        test_indices = indices[split_idx:]
        
        # Create datasets
        train_dataset = EmbeddingDataset(
            self.source_embeddings[train_indices],
            self.target_embeddings[train_indices]
        )
        test_dataset = EmbeddingDataset(
            self.source_embeddings[test_indices],
            self.target_embeddings[test_indices]
        )
        
        return train_dataset, test_dataset

    def create_linear_transform(self):
        return nn.Linear(self.source_embeddings.shape[1], self.target_embeddings.shape[1]).to(self.device)

    def train(self):
        train_dataset, test_dataset = self.create_datasets()
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size)

        mse_loss = nn.MSELoss()
        trange = tqdm.trange if self.use_tqdm else range
        tqdm_fn = tqdm.tqdm if self.use_tqdm else lambda *args, **kwargs: args[0]
        for epoch in trange(self.num_epochs):
            # Training
            self.linear.train()
            train_mse = 0.0
            train_mae = 0.0
            num_train_batches = 0

            for source_emb, target_emb in train_loader:
                source_emb = source_emb.to(self.device)
                target_emb = target_emb.to(self.device)

                self.optimizer.zero_grad()
                output = self.linear(source_emb)

                loss = mse_loss(output, target_emb)
                loss.backward()
                self.optimizer.step()

                train_mse += loss.detach().item()
                train_mae += (output.detach() - target_emb.detach()).abs().mean().item()
                num_train_batches += 1

            avg_train_mse = train_mse / num_train_batches
            avg_train_mae = train_mae / num_train_batches

            # Evaluation
            self.linear.eval()
            test_mse = 0.0
            test_mae = 0.0
            num_test_batches = 0

            with torch.no_grad():
                for source_emb, target_emb in tqdm_fn(test_loader):
                    source_emb = source_emb.to(self.device)
                    target_emb = target_emb.to(self.device)

                    output = self.linear(source_emb)

                    test_mse += mse_loss(output, target_emb).item()
                    test_mae += (output.detach() - target_emb.detach()).abs().mean().item()
                    num_test_batches += 1

            avg_test_mse = test_mse / num_test_batches
            avg_test_mae = test_mae / num_test_batches

            # Log metrics
            log_entry = {
                "epoch": epoch,
                "train_mse": avg_train_mse,
                "train_mae": avg_train_mae,
                "test_mse": avg_test_mse,
                "test_mae": avg_test_mae,
            }
            wandb.log(log_entry)
            with open(self.logfile, "a") as f:
                f.write(json.dumps(log_entry) + "\n")

            if epoch % self.save_every_n_epochs == 0:
                self.save_checkpoint(epoch)

    def save_checkpoint(self, epoch: int):
        checkpoint_path = self.checkpoint_path / f"checkpoint_{epoch}.safetensors"
        safetensors.torch.save_file(self.linear.state_dict(), checkpoint_path)

In [17]:
DEFAULT_ARGS = LinearTransformTrainerArgs(
    test_split=0.2,
    num_epochs=50,
    batch_size=32,
    learning_rate=0.001,
    save_every_n_epochs=10,
    use_tqdm=True
)

In [16]:
# NOTE: this shit is rlly fast!
import click
import os
import shutil
save_path_parent = Path("/mnt/align3_drive/adrianoh/git/dl_final_project_layers")
save_path = save_path_parent / "ipynb_test"
assert os.environ.get("CUDA_VISIBLE_DEVICES") is None
if save_path.exists():
    click.echo(f"Save path {save_path} already exists. Deleting it...")
    # click.confirm(f"Save path {save_path} already exists. Do you want to delete it?", abort=True)
    shutil.rmtree(save_path)
device = "cuda:0" # change this based on availability
embeddings1 = embeddings1.to(device)
embeddings2 = embeddings2.to(device)
wandb.init(project="2024_12_09_dl_project_testing_layer_train", name="ipynb_test")
trainer = LinearTransformTrainer(
    save_path=save_path, # made by the trainer
    linear=None, # trainer makes it
    source_embeddings=embeddings1,
    target_embeddings=embeddings2,
    device=device,
    args=DEFAULT_ARGS
)
trainer.train()
wandb.finish()


100%|██████████| 44/44 [00:00<00:00, 110.93it/s]
100%|██████████| 44/44 [00:00<00:00, 105.14it/s]
100%|██████████| 44/44 [00:00<00:00, 102.70it/s]
100%|██████████| 44/44 [00:00<00:00, 103.36it/s]
100%|██████████| 44/44 [00:00<00:00, 140.78it/s]
100%|██████████| 44/44 [00:00<00:00, 102.31it/s]
100%|██████████| 44/44 [00:00<00:00, 108.90it/s]
100%|██████████| 44/44 [00:00<00:00, 104.31it/s]
100%|██████████| 44/44 [00:00<00:00, 104.40it/s]
100%|██████████| 44/44 [00:00<00:00, 101.83it/s]
100%|██████████| 44/44 [00:00<00:00, 116.57it/s]
100%|██████████| 44/44 [00:00<00:00, 106.01it/s]
100%|██████████| 44/44 [00:00<00:00, 105.30it/s]
100%|██████████| 44/44 [00:00<00:00, 145.66it/s]
100%|██████████| 44/44 [00:00<00:00, 102.16it/s]
100%|██████████| 44/44 [00:00<00:00, 114.30it/s]
100%|██████████| 44/44 [00:00<00:00, 105.36it/s]
100%|██████████| 44/44 [00:00<00:00, 105.53it/s]
100%|██████████| 44/44 [00:00<00:00, 151.76it/s]
100%|██████████| 44/44 [00:00<00:00, 107.61it/s]
100%|██████████| 44/

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
test_mae,█▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_mse,█▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_mae,█▅▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_mse,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,49.0
test_mae,0.00084
test_mse,0.0
train_mae,0.00077
train_mse,0.0
