# High Recall Retrieval

In [1]:
!nvidia-smi

Sun Oct 13 12:55:43 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:0E:00.0 Off |                    0 |
| N/A   53C    P0            388W /  400W |   69299MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          Off |   00

## Setup environment

In [2]:
import sys
print("Python version:", sys.version)

Python version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]


In [3]:
import datasets
print("datasets version:", datasets.__version__)


datasets version: 3.0.1


In [4]:
import numpy as np
import time
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.distributions import MultivariateNormal
import random
from tqdm import tqdm
import math
import logging

random.seed(42)

In [5]:
torch.cuda.is_available()

True

In [10]:
import os
import sys

project_dir = os.getcwd() if not os.getcwd().split("/")[-1] == 'notebooks' else '/'.join(os.getcwd().split("/")[0:-1])
src_dir = os.path.join(project_dir, 'src')

os.chdir(project_dir)
print(f"Current working directory set to: {os.getcwd()}")


if src_dir not in sys.path:
    sys.path.insert(0, src_dir)  # Add it to the front of PYTHONPATH
    print(f"PYTHONPATH updated with: {src_dir}")
else:
    print(f"PYTHONPATH already contains: {src_dir}")

Current working directory set to: /home/nlp/achimoa/projects/high-recall-retrieval
PYTHONPATH already contains: /home/nlp/achimoa/projects/high-recall-retrieval/src


In [12]:
ls

[0m[01;34mcheckpoints[0m/  [01;34mlogs[0m/  [01;34mnotebooks[0m/


In [11]:
from data import *

dataset = build_dataset(dataset_name='synthesized_query_document')
dataset

ModuleNotFoundError: No module named 'data'

## Utils

In [7]:
def create_logger(filename, id=None):
    # Create a logger object
    logger = logging.getLogger(id)
    logger.setLevel(logging.DEBUG)

    # Create a file handler for writing logs to a file
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    file_handler = logging.FileHandler(filename)
    file_handler.setLevel(logging.DEBUG)
    file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(file_formatter)

    # Create a stream handler for writing logs to the console
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.DEBUG)
    console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    console_handler.setFormatter(console_formatter)

    # Add the handlers to the logger
    logger.handlers = []
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    # Define a custom exception handler
    def log_exception(exc_type, exc_value, exc_traceback):
        logger.error("Unhandled exception", exc_info=(exc_type, exc_value, exc_traceback))

    # Set the custom exception handler as the global exception handler
    sys.excepthook = log_exception

    logger.info(f"Logging to file {filename}")

    return logger


In [8]:
class SamplerEncoder(nn.Module):
    def __init__(self, model_name, latent_dim, vocab_size, max_length, encode_hidden_state=False):
        super(SamplerEncoder, self).__init__()
        self.latent_dim = latent_dim
        self.encode_hidden_state = encode_hidden_state

        # Encoder - RoBERTa model
        self.encoder = AutoModel.from_pretrained(model_name)

        # Hidden dimensions of RoBERTa's output
        hidden_dim = self.encoder.config.hidden_size

        # Map encoder outputs to latent space
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_var = nn.Linear(hidden_dim, latent_dim)

        # Decoder - can be a different architecture depending on your task
        self.decoder = nn.Linear(latent_dim, hidden_dim * max_length)
        self.reshape_layer = nn.Unflatten(1, (max_length, hidden_dim))
        self.output_layer = nn.Linear(hidden_dim, vocab_size)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, input_ids, attention_mask):
        # Encoding
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        if self.encode_hidden_state:
            # last_hidden_state = outputs.last_hidden_state.mean(dim=1)  # Mean pooling
            last_hidden_state = outputs.last_hidden_state
            pooled_outputs = last_hidden_state[:, 0]  # Use the [CLS] token representation
        else:
            pooled_outputs = outputs.pooler_output
        mu = self.fc_mu(pooled_outputs)
        log_var = self.fc_var(pooled_outputs)

        # Sampling
        z = self.reparameterize(mu, log_var + 1e-7)

        # Decoding (this part depends on your specific use case)
        h = self.decoder(z)
        h = self.reshape_layer(h)
        logits = self.output_layer(h)

        return logits, z, mu, log_var

In [9]:
class Encoder(nn.Module):
    def __init__(self, model_name, max_length):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self.reduction_layer = nn.Linear(self.model.config.hidden_size, max_length)

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        reduced = self.reduction_layer(last_hidden_state)
        return reduced

In [10]:
class InfoNCELoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(InfoNCELoss, self).__init__()
        self.temperature = temperature

    def forward(self, features_query, features_original_positive, features_sampled_positives, features_negative=None):
        # Normalize the features
        features_query = F.normalize(features_query, p=2, dim=1)
        features_original_positive = F.normalize(features_original_positive, p=2, dim=1)
        features_sampled_positives = F.normalize(features_sampled_positives, p=2, dim=1)

        # Concatenate original with sampled positive features
        features_positive = torch.cat([features_original_positive, features_sampled_positives], dim=0)

        # Concatenate the query feature such that it matches the positives dims
        features_anchor = torch.cat([features_query, features_query], dim=0)

        # Calculate dot product similarity
        similarity_matrix = torch.matmul(features_anchor.float(), features_positive.float().T) / self.temperature

        # Labels are the diagonal elements in the similarity matrix
        labels = torch.arange(similarity_matrix.size(0), dtype=torch.long, device=features_query.device)

        # Calculate cross-entropy loss
        loss = F.cross_entropy(similarity_matrix, labels)

        return loss

In [11]:
def tokenize(tokenizer, s, device, padding="max_length", truncation=True, return_tensors="pt"):
    inputs = tokenizer(s, padding=padding, truncation=truncation, return_tensors=return_tensors)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    return inputs


def encode(encoder, tokenizer, s, device, padding="max_length", truncation=True, return_tensors="pt", pooling=None, return_scalars_vector=False):
    # Tokenize and encode input sentences
    inputs = tokenize(
        tokenizer,
        s,
        device,
        padding=padding,
        truncation=truncation,
        return_tensors=return_tensors
    )
    outputs = encoder(**inputs)

    # Forward pass for model1
    outputs = encoder(**inputs)

    # Pooling of the last hidden states
    if pooling == 'max':
        # max pooling
        encoding = outputs.last_hidden_state.max(dim=1).values
    elif pooling == 'mean':
        # mean pooling
        encoding = outputs.last_hidden_state.mean(dim=1)
    else:
        # pooler output
        encoding = outputs.pooler_output

    if return_scalars_vector:
        encoding = [v.detach().cpu().numpy() for v in encoding]
        torch.cuda.empty_cache()

    return encoding


def sampler_encode(encoder, tokenizer, s, device, padding="max_length", truncation=True, return_tensors="pt", decoded=True):
    # Tokenize and encode input sentences
    inputs = tokenize(
        tokenizer,
        s,
        device,
        padding=padding,
        truncation=truncation,
        return_tensors=return_tensors
    )
    encoder.to(device)

    # Forward pass for model1
    reconstructed, z, mu, log_var = encoder(inputs['input_ids'], inputs['attention_mask'])

    # return the decoded or encoded vector
    features = reconstructed if decoded else z
    return features, mu, log_var

In [12]:
def compute_nll(mu, var, sentence):
    """
    Compute the negative log-likelihood (NLL) of the sentences under
    the multivariate normal distributions defined by mu and var.
    """
    dist = MultivariateNormal(mu, torch.diag_embed(var))
    nll = -dist.log_prob(sentence).mean()  # Compute the mean of negative log probabilities
    return nll


def reconstruction_loss_fn(s, logits, tokenizer, device):
    input_ids = tokenize(tokenizer=tokenizer, s=s, device=device)['input_ids']
    input_ids_flat = input_ids.view(-1)
    batch_size, sequence_length, vocab_size = logits.size()
    logits_flat = logits.view(batch_size * sequence_length, vocab_size)
    reconstruction_loss = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)(logits_flat, input_ids_flat)
    return reconstruction_loss


def kl_divergence_loss_fn(mu, log_var):
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return kl_loss


def vae_loss_fn(s, logits, tokenizer, device, mu, log_var, warm_up_epochs=5, epoch=0, alpha=1.0, beta=1.0):
    recon_loss = reconstruction_loss_fn(s, logits, tokenizer, device)
    kl_loss = kl_divergence_loss_fn(mu, log_var)
    warm_up_factor = min(1, epoch / warm_up_epochs) if warm_up_epochs > 0 else 1
    loss = alpha * recon_loss + beta * warm_up_factor * kl_loss
    return loss, dict(recon_loss=recon_loss, kl_loss=kl_loss)


def info_nce_loss_fn(query, positive_key, negative_keys=None, sampled_positive_key=None, temperature=0.1, reduction='mean'):
    query, positive_key, negative_keys, sampled_positive_key = normalize(query, positive_key, negative_keys, sampled_positive_key)

    if negative_keys is not None:
        pass

    if sampled_positive_key is not None:
      query = torch.cat((query, query), dim=0)
      pos_keys = torch.cat((positive_key, sampled_positive_key), dim=0)
      logits = query @ transpose(pos_keys)
      labels = torch.arange(len(query), device=query.device)
      p_loss = F.cross_entropy(logits / temperature, labels, reduction=reduction)

    else:
      # Cosine between all combinations
      logits = query @ transpose(positive_key)
      # Positive keys are the entries on the diagonal
      labels = torch.arange(len(query), device=query.device)
      p_loss = F.cross_entropy(logits / temperature, labels, reduction=reduction)

    return p_loss


def transpose(x):
    return x.transpose(-2, -1)


def normalize(*xs):
    return [None if x is None else F.normalize(x, dim=-1) for x in xs]

In [13]:
def get_dataloader(dataset, batch_size, shuffle=True):
    dataset2 = [{
        'sentence': item['sentence'],
        'good': random.choice(item['good']),
        'bad': random.choice(item['bad']),
    } for item in dataset]
    # dataset2 = [dataset2[i] for i in range(4)]
    dataloader = DataLoader(dataset2, batch_size=batch_size, shuffle=shuffle)
    return dataloader


def compute_loss(
        outputs_query,
        outputs_sentence,
        outputs_samples=None,
        mu=None,
        log_var=None,
        should_compute_info_nce_loss=True,
        should_compute_reconstruction_loss=True,
        should_compute_kl_divergence_loss=True,
    ):
    # InfoNCE loss
    # info_nce_fn = InfoNCELoss(temperature=0.5)
    # info_nce_loss = info_nce_fn(
    #     features_query=outputs_query,
    #     features_original_positive=outputs_sentence,
    #     features_sampled_positives=outputs_samples
    # )
    loss = None
    losses = {}

    # InfoNCE loss
    if should_compute_info_nce_loss:
        info_nce_loss = info_nce_loss_fn(query=outputs_query, positive_key=outputs_sentence, sampled_positive_key=outputs_samples) 
        info_nce_loss *= 1
        loss = (loss + info_nce_loss) if loss else info_nce_loss
        losses['info_nce_loss'] = info_nce_loss

    # NLL loss
    # nll_loss = F.cross_entropy(outputs_samples, outputs_sentence, reduction='sum')
    # nll_loss = compute_nll(mu, log_var, outputs_sentence)
    # nll_loss = compute_nll_loss(mu, log_var, outputs_query) ## ????

    # Reconstruction loss for VAE
    if should_compute_reconstruction_loss:
        should_compute_reconstruction_loss = should_compute_reconstruction_loss and outputs_samples is not None
        recon_loss = reconstruction_loss_fn(outputs_samples, outputs_sentence)
        recon_loss *= 0.01
        loss = (loss + recon_loss) if loss else recon_loss
        losses['recon_loss'] = recon_loss

    # KL divergence loss for VAE
    if should_compute_kl_divergence_loss:
        should_compute_kl_divergence_loss = should_compute_kl_divergence_loss and mu is not None and log_var is not None
        kl_divergence_loss = kl_divergence_loss_fn(mu, log_var)
        kl_divergence_loss *= 0.01
        loss = (loss + kl_divergence_loss) if loss else kl_divergence_loss
        losses['kl_loss'] = kl_divergence_loss

    result = (loss, losses)
    return result

In [14]:
def evaluate(data, sentence_encoder, query_encoder, tokenizer, device, batch_size=8):
    all_queries = flatten_list([item['good'] for item in data])

    all_queries_vectors = []
    for i in tqdm(range(0, len(all_queries), batch_size), desc="Encoding all queries"):
        batch = all_queries[i:i+batch_size]
        queries_batch = encode(encoder=query_encoder, tokenizer=tokenizer, s=batch, device=device, return_scalars_vector=True)
        all_queries_vectors.extend(queries_batch)
    queries = [dict(query=query, vector=vector) for query, vector in zip(all_queries, all_queries_vectors)]

    queries_true = []
    queries_pred = []
    for i in tqdm(range(0, len(data), batch_size), desc="Predict matching queries"):
        batch = data[i:i+batch_size]
        sentences_batch = [item for item in batch['sentence']]
        sentences_vector = encode(encoder=sentence_encoder, tokenizer=tokenizer, s=sentences_batch, device=device, return_scalars_vector=True)
        sentences = [dict(vector=sentence_vector, sentence=sentence, good=good) for sentence_vector, sentence, good in zip(sentences_vector, sentences_batch, batch['good'])]
        queries_true_batch = [item for item in batch['good']]
        queries_pred_batch = predict_queries(sentences, queries)

        queries_true.extend(queries_true_batch)
        queries_pred.extend(queries_pred_batch)


    precision, recall, f1 = compute_evaluation_metrics(queries_true, queries_pred)

    print(f"Precision: {precision} | Recall: {recall} | F1: {f1}")
    return precision, recall, f1


def compute_evaluation_metrics(queries_true, queries_pred):
    tp, fp, fn = 0, 0, 0
    for query_true, query_pred in zip(queries_true, queries_pred):
        true_positives = list(set(query_true).intersection(set(query_pred)))
        false_positives = set(query_pred) - set(query_true)
        false_negative = set(query_true) - set(query_pred)

        tp += len(true_positives)
        fp += len(false_positives)
        fn += len(false_negative)

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1_score = (2 * (precision * recall) / (precision + recall)) if (precision + recall) > 0 else 0

    return precision, recall, f1_score


def flatten_list(list_of_lists):
    flattened_list = [item for sublist in list_of_lists for item in sublist]
    return flattened_list


def predict_queries(sentences, queries):
    similarity_matrix = np.zeros((len(sentences), len(queries)))

    # Compute the similarity matrix
    for i, sentence in enumerate(sentences):
        for j, query in enumerate(queries):
            sentence_vec = sentence['vector']
            query_vec = query['vector']
            similarity_matrix[i, j] = compute_similarity(sentence_vec, query_vec)

    # Find the top k indices for each sentence based on the number of 'good' queries
    queries_pred_batch = []
    for i, item in enumerate(sentences):
        k = len(item['good'])
        top_k_indices = np.argsort(similarity_matrix[i])[-k:][::-1]
        queries_pred_batch.append([queries[j]['query'] for j in top_k_indices])

    return queries_pred_batch


def compute_similarity(vec1, vec2):
    similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
    return similarity


In [15]:
def train_vae(
        train_dataloader,
        val_dataloader,
        tokenizer,
        sampler_encoder,
        num_epochs = 10,
        lr = 0.25*1e-4,
        weight_decay = 1e-6,
        device = None,
        print_interval = 1000,
        should_grad_clipping = True,
        checkpoints_base_path = "checkpoints",
):
    log = create_logger(id='train_vae', filename=f'logs/train_vae.log')
    log.info(f"Training VAE for {num_epochs} epochs")
    log.info(f"Device: {device}")

    total_losses = {}
    total_loss = 0
    examples_processed = 0
    best_loss = math.inf
    timestamp = time.strftime('%Y%m%d-%H%M%S')

    optimizer_sampled = AdamW(sampler_encoder.parameters(), lr=lr, weight_decay=weight_decay)

    for epoch in range(num_epochs):  # num_epochs should be defined by you
        epoch_start_time = time.time()
        total_batch_time = 0

        sampler_encoder.train()
        sampler_encoder.to(device)

        for batch_idx, batch in enumerate(train_dataloader):
            batch_start_time = time.time()
            sentences_batch = batch['sentence']
            positives_batch = batch['good']

            logits, mu, log_var = sampler_encode(encoder=sampler_encoder, tokenizer=tokenizer, s=positives_batch, device=device) if sampler_encoder else (None, None, None)
            # loss, info_nce_loss, nll_loss, kl_divergence_loss = compute_loss(outputs_query=outputs_query, outputs_sentence=outputs_sentence, outputs_samples=outputs_samples, mu=mu, log_var=log_var)
            # loss, losses = compute_loss(outputs_query=outputs_query, outputs_sentence=outputs_sentence, outputs_samples=outputs_samples, mu=mu, log_var=log_var, should_compute_reconstruction_loss=should_compute_reconstruction_loss, should_compute_kl_divergence_loss=should_compute_kl_divergence_loss)
            loss, losses = vae_loss_fn(s=positives_batch, logits=logits, tokenizer=tokenizer, device=device, mu=mu, log_var=log_var, epoch=epoch, warm_up_epochs=0)
            # Update running loss and example count
            total_loss += loss.item()
            for key in losses.keys():
                loss_value = losses[key].item() if losses[key] else 0
                total_losses[key] = (total_losses[key] + loss_value) if (key in total_losses) else loss_value
            examples_processed += len(batch)
            
            optimizer_sampled.zero_grad()
            loss.backward()

            if should_grad_clipping:
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(sampler_encoder.parameters(), max_norm=1.0)

            # Backpropagation for models
            optimizer_sampled.step()

            batch_end_time = time.time()
            batch_duration = batch_end_time - batch_start_time
            total_batch_time += batch_duration

            if batch_idx % print_interval == 0 and batch_idx > 0:
                average_loss = total_loss / examples_processed
                log.info(f"Epoch: {epoch}, Train Batch: {batch_idx}, Average Loss: {average_loss:.4f}, {', '.join([f'Average Loss {key}: {(value / examples_processed):.4f}' for key, value in total_losses.items()])}, Avg Iteration Time = {total_batch_time / print_interval:.4f} seconds")

                # Reset counters
                total_loss = 0
                total_losses = {}
                examples_processed = 0
                total_batch_time = 0

        # validation
        sampler_encoder.eval()
        with torch.no_grad():  # Disable gradient computation
            total_losses = {}
            total_loss = 0

            for batch in val_dataloader:
                batch_start_time = time.time()
                sentences_batch = batch['sentence']
                positives_batch = batch['good']

                logits, mu, log_var = sampler_encode(encoder=sampler_encoder, tokenizer=tokenizer, s=positives_batch, device=device) if sampler_encoder else (None, None, None)

                # loss, info_nce_loss, nll_loss, kl_divergence_loss = compute_loss(outputs_query=outputs_query, outputs_sentence=outputs_sentence, outputs_samples=outputs_samples, mu=mu, log_var=log_var)
                # loss, lossses = compute_loss(outputs_query=outputs_query, outputs_sentence=outputs_sentence, outputs_samples=outputs_samples, mu=mu, log_var=log_var, should_compute_reconstruction_loss=should_compute_reconstruction_loss, should_compute_kl_divergence_loss=should_compute_kl_divergence_loss)
                loss, losses = vae_loss_fn(s=positives_batch, logits=logits, tokenizer=tokenizer, device=device, mu=mu, log_var=log_var, epoch=epoch, warm_up_epochs=0)

                # Update running loss and example count
                total_loss += loss.item()
                for key in losses.keys():
                    loss_value = losses[key].item() if losses[key] else 0
                    total_losses[key] = (total_losses[key] + loss_value) if key in total_losses else loss_value

            average_loss = total_loss / len(val_dataloader)
            log.info(f"Epoch: {epoch}, Validation, Average Loss: {average_loss:.4f}, {', '.join([f'Average Loss {key}: {(value / len(val_dataloader)):.4f}' for key, value in total_losses.items()])}, Avg Iteration Time = {total_batch_time / print_interval:.4f} seconds")

            if average_loss < best_loss:
                best_loss = average_loss
                setup_config = "vae" if sampler_encoder else "classic"
                checkpoints_path = os.path.join(checkpoints_base_path, f'loss{average_loss:.4f}_{setup_config}_epoch{epoch+1}_{timestamp}')
                os.makedirs(checkpoints_path, exist_ok=True)
                torch.save(
                    {
                        "model_name": sampler_encoder.encoder.config.name_or_path,
                        "latent_dim": sampler_encoder.latent_dim,
                        "state_dict": sampler_encoder.state_dict()
                    },
                    os.path.join(checkpoints_path, "sampler_encoder.ckpt")
                )
                log.info(f'Saved VAE model with loss {best_loss:.4f} at epoch {epoch} in folder {checkpoints_path}')

        epoch_end_time = time.time()
        epoch_duration = epoch_end_time - epoch_start_time
        log.info(f"Epoch {epoch} completed in {epoch_duration:.4f} seconds")


def load_vae_model(sampler_encoder, checkpoints_path, device):
    checkpoint_path = os.path.join(checkpoints_path, "sampler_encoder.ckpt")
    checkpoint = torch.load(checkpoint_path)

    # Load the model configuration and state_dict
    model_name = checkpoint["model_name"]
    latent_dim = checkpoint["latent_dim"]
    state_dict = checkpoint["state_dict"]

    # Load the state_dict into the model
    sampler_encoder.load_state_dict(state_dict)
    sampler_encoder.to(device)
    return sampler_encoder


#### Train models util

In [16]:
def train(
        validation_dataset,
        train_dataloader,
        val_dataloader,
        tokenizer,
        query_encoder,
        sentence_encoder,
        sampler_encoder = None,
        num_epochs = 10,
        lr = 0.25*1e-4,
        weight_decay = 1e-6,
        device = None,
        print_interval = 1000,
        should_grad_clipping = True,
        checkpoints_base_path = "checkpoints",
        should_compute_info_nce_loss=True,
        should_compute_reconstruction_loss=True,
        should_compute_kl_divergence_loss=True,
):
    log = create_logger(id='train_vae', filename=f'logs/train.log')
    log.info(f"Training VAE for {num_epochs} epochs")
    log.info(f"Device: {device}")

    total_losses = {}
    total_loss = 0
    examples_processed = 0
    best_recall = 0.0
    timestamp = time.strftime('%Y%m%d-%H%M%S')

    optimizer_query = AdamW(query_encoder.parameters(), lr=lr, weight_decay=weight_decay)
    optimizer_sentence = AdamW(sentence_encoder.parameters(), lr=lr, weight_decay=weight_decay)
    if sampler_encoder:
        optimizer_sampled = AdamW(sampler_encoder.parameters(), lr=lr, weight_decay=weight_decay)

    for epoch in range(num_epochs):  # num_epochs should be defined by you
        epoch_start_time = time.time()
        total_batch_time = 0

        query_encoder.train()
        query_encoder.to(device)
        sentence_encoder.train()
        sentence_encoder.to(device)
        if sampler_encoder:
            sampler_encoder.train()
            sampler_encoder.to(device)

        for batch_idx, batch in enumerate(train_dataloader):
            batch_start_time = time.time()

            sentences_batch = batch['sentence']
            positives_batch = batch['good']

            outputs_query = encode(encoder=sentence_encoder, tokenizer=tokenizer, s=positives_batch, device=device)
            outputs_sentence = encode(encoder=query_encoder, tokenizer=tokenizer, s=sentences_batch, device=device)
            outputs_samples, mu, log_var = sampler_encode(encoder=sampler_encoder, tokenizer=tokenizer, s=positives_batch, device=device) if sampler_encoder else (None, None, None)
            print(f"outputs_samples: {outputs_samples.shape}, outputs_query: {outputs_query.shape}, outputs_sentence: {outputs_sentence.shape}")

            # loss, info_nce_loss, nll_loss, kl_divergence_loss = compute_loss(outputs_query=outputs_query, outputs_sentence=outputs_sentence, outputs_samples=outputs_samples, mu=mu, log_var=log_var)
            loss, losses = compute_loss(outputs_query=outputs_query, outputs_sentence=outputs_sentence, outputs_samples=outputs_samples, mu=mu, log_var=log_var, should_compute_reconstruction_loss=should_compute_reconstruction_loss, should_compute_kl_divergence_loss=should_compute_kl_divergence_loss)

            optimizer_query.zero_grad()
            optimizer_sentence.zero_grad()
            if sampler_encoder:
                optimizer_sampled.zero_grad()

            loss.backward()

            if should_grad_clipping:
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(query_encoder.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(sentence_encoder.parameters(), max_norm=1.0)
                if sampler_encoder:
                    torch.nn.utils.clip_grad_norm_(sampler_encoder.parameters(), max_norm=1.0)

            # Backpropagation for models
            optimizer_query.step()
            optimizer_sentence.step()
            if sampler_encoder:
                optimizer_sampled.step()

            # Update running loss and example count
            total_loss += loss.item()
            for key in losses.keys():
                loss_value = losses[key].item() if losses[key] else 0
                total_losses[key] = (total_losses[key] + loss_value) if key in total_losses else loss_value
            examples_processed += len(batch)

            batch_end_time = time.time()
            batch_duration = batch_end_time - batch_start_time
            total_batch_time += batch_duration

            if batch_idx % print_interval == 0:
                average_loss = total_loss / examples_processed
                log.info(f"Epoch: {epoch}, Train Batch: {batch_idx}, Average Loss: {average_loss:.4f}, {', '.join([f'Average Loss {key}: {(value / examples_processed):.4f}' for key, value in total_losses.items()])}, Avg Iteration Time = {total_batch_time / print_interval:.4f} seconds")

                # Reset counters
                total_losses = {}
                total_loss = 0
                examples_processed = 0
                total_batch_time = 0

        # validation
        query_encoder.eval()
        sentence_encoder.eval()
        if sampler_encoder:
            sampler_encoder.eval()
        with torch.no_grad():  # Disable gradient computation
            total_losses = {}
            total_loss = 0

            for batch in val_dataloader:
                batch_start_time = time.time()

                outputs_query = encode(encoder=query_encoder, tokenizer=tokenizer, s=batch['good'], device=device)
                outputs_sentence = encode(encoder=sentence_encoder, tokenizer=tokenizer, s=batch['sentence'], device=device)
                outputs_samples, mu, log_var = sampler_encode(encoder=sampler_encoder, tokenizer=tokenizer, s=batch['good'], device=device) if sampler_encoder else (None, None, None)

                # loss, info_nce_loss, nll_loss, kl_divergence_loss = compute_loss(outputs_query=outputs_query, outputs_sentence=outputs_sentence, outputs_samples=outputs_samples, mu=mu, log_var=log_var)
                loss, losses = compute_loss(outputs_query=outputs_query, outputs_sentence=outputs_sentence, outputs_samples=outputs_samples, mu=mu, log_var=log_var, should_compute_reconstruction_loss=should_compute_reconstruction_loss, should_compute_kl_divergence_loss=should_compute_kl_divergence_loss)

                # Update running loss and example count
                total_loss += loss.item()
                for key in losses.keys():
                    loss_value = losses[key].item() if losses[key] else 0
                    total_losses[key] = (total_losses[key] + loss_value) if key in total_losses else loss_value

                batch_end_time = time.time()
                batch_duration = batch_end_time - batch_start_time
                total_batch_time += batch_duration

            average_loss = total_loss / len(val_dataloader)
            log.info(f"Epoch: {epoch}, Validation, Average Loss: {average_loss:.4f}, {', '.join([f'Average Loss {key}: {(value / len(val_dataloader)):.4f}' for key, value in total_losses.items()])}, Avg Iteration Time = {total_batch_time / print_interval:.4f} seconds")

            _, recall, _ = evaluate(data=validation_dataset, sentence_encoder=sentence_encoder, query_encoder=query_encoder, tokenizer=tokenizer, device=device)
            if recall > best_recall:
                best_recall = recall
                setup_config = "vae" if sampler_encoder else "classic"
                checkpoints_path = os.path.join(checkpoints_base_path, f'recall{recall:.4f}_{setup_config}_ epoch{epoch+1}_{timestamp}')
                os.makedirs(checkpoints_path, exist_ok=True)
                torch.save({"model_name": query_encoder.config.name_or_path, "state_dict": query_encoder.state_dict()}, os.path.join(checkpoints_path, "query_enoder.ckpt"))
                torch.save({"model_name": sentence_encoder.config.name_or_path, "state_dict": sentence_encoder.state_dict()}, os.path.join(checkpoints_path, "sentence_encoder.ckpt"))
                torch.save({"model_name": sampler_encoder.encoder.config.name_or_path, "latent_dim": sampler_encoder.latent_dim, "state_dict": sentence_encoder.state_dict()}, os.path.join(checkpoints_path, "sampler_encoder.ckpt"))
                print(f'Saved model with recall {recall:.4f} at epoch {epoch+1} in folder {checkpoints_path}')

        epoch_end_time = time.time()
        epoch_duration = epoch_end_time - epoch_start_time
        log.info(f"Epoch {epoch + 1} completed in {epoch_duration:.4f} seconds")


## Dataloader

In [17]:
dataset = load_dataset("biu-nlp/abstract-sim")
dataset

Repo card metadata block was not found. Setting CardData to empty.


DatasetDict({
    train: Dataset({
        features: ['sentence', 'bad', 'good'],
        num_rows: 157649
    })
    validation: Dataset({
        features: ['sentence', 'bad', 'good'],
        num_rows: 5000
    })
    test: Dataset({
        features: ['sentence', 'bad', 'good'],
        num_rows: 2955
    })
})

## Train

#### Train with Classic Setup

In [17]:
model_name = "roberta-base"
latent_dim = 128

query_encoder = AutoModel.from_pretrained(model_name)
sentence_encoder = AutoModel.from_pretrained(model_name)
sampler_encoder = None

tokenizer = AutoTokenizer.from_pretrained(model_name)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
train(
    tokenizer=tokenizer,
    query_encoder=query_encoder,
    sentence_encoder=sentence_encoder,
    sampler_encoder=sampler_encoder,
)

Epoch: 0, Batch: 0, Average Loss: 0.6932, Average InfoNCE Loss: 0.6932, Average NLL Loss: 0.0000, Average KL Divergence Loss: 0.0000
Epoch: 0, Batch: 1000, Average Loss: 0.2475, Average InfoNCE Loss: 0.2475, Average NLL Loss: 0.0000, Average KL Divergence Loss: 0.0000
Epoch: 0, Batch: 2000, Average Loss: 0.1271, Average InfoNCE Loss: 0.1271, Average NLL Loss: 0.0000, Average KL Divergence Loss: 0.0000
Epoch: 0, Batch: 3000, Average Loss: 0.1030, Average InfoNCE Loss: 0.1030, Average NLL Loss: 0.0000, Average KL Divergence Loss: 0.0000
Epoch: 0, Batch: 4000, Average Loss: 0.0923, Average InfoNCE Loss: 0.0923, Average NLL Loss: 0.0000, Average KL Divergence Loss: 0.0000
Epoch: 0, Batch: 5000, Average Loss: 0.0849, Average InfoNCE Loss: 0.0849, Average NLL Loss: 0.0000, Average KL Divergence Loss: 0.0000
Epoch: 0, Batch: 6000, Average Loss: 0.0825, Average InfoNCE Loss: 0.0825, Average NLL Loss: 0.0000, Average KL Divergence Loss: 0.0000
Epoch: 0, Batch: 7000, Average Loss: 0.0715, Averag

RuntimeError: Parent directory checkpoints/recall0.2820_epoch1_20240225-204442 does not exist.

#### Train with VAE

In [18]:
model_name = "roberta-base"
latent_dim = 128
batch_size = 8
validation_batch_size = 1

train_dataloader = get_dataloader(dataset['train'], batch_size=batch_size, shuffle=True)
val_dataloader = get_dataloader(dataset['validation'], batch_size=validation_batch_size, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(model_name)
vocab_size = tokenizer.vocab_size
max_length = tokenizer.model_max_length

query_encoder = Encoder(model_name, max_length=max_length)
sentence_encoder = Encoder(model_name, max_length=max_length)
sampler_encoder = SamplerEncoder(model_name=model_name, latent_dim=latent_dim, vocab_size=vocab_size, max_length=max_length)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [67]:
torch.cuda.empty_cache()

In [35]:
train_vae(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    tokenizer=tokenizer,
    sampler_encoder=sampler_encoder,
    num_epochs = 10,
    lr = 0.25*1e-4,
    device=device,
    should_grad_clipping=False
)

2024-03-17 00:14:01,474 - INFO - Logging to file logs/train_vae.log
2024-03-17 00:14:01,475 - INFO - Training VAE for 10 epochs
2024-03-17 00:14:01,476 - INFO - Device: cuda


logits_flat: torch.Size([6144, 50265])
input_ids_flat: torch.Size([4096])


ValueError: Expected input batch_size (6144) to match target batch_size (4096).

In [19]:
checkpoints_path = "checkpoints/loss5.4952_vae_epoch10_20240315-232205"
sampler_encoder = load_vae_model(sampler_encoder, checkpoints_path, device=device)

In [20]:
train(
    validation_dataset=dataset['validation'],
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    tokenizer=tokenizer,
    query_encoder=query_encoder,
    sentence_encoder=sentence_encoder,
    sampler_encoder=sampler_encoder,
    num_epochs = 10,
    lr = 0.25*1e-4,
    device=device,
    should_grad_clipping=False,
    should_compute_reconstruction_loss=False,
    should_compute_kl_divergence_loss=True
)

2024-03-17 00:54:49,357 - INFO - Logging to file logs/train.log
2024-03-17 00:54:49,357 - INFO - Training VAE for 10 epochs
2024-03-17 00:54:49,358 - INFO - Device: cuda


AttributeError: 'Tensor' object has no attribute 'pooler_output'

In [21]:
import socket

try:
    hostname = socket.gethostname()
    ipv4_address = socket.gethostbyname(hostname)
    print(f"Internal IPv4 Address for {hostname}: {ipv4_address}")
except socket.gaierror:
    print("There was an error resolving the hostname.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

Internal IPv4 Address for dgx01: 127.0.1.1


In [52]:
for i, x in enumerate(train_dataloader):
    print(i, x)

0 <torch.utils.data.dataloader.DataLoader object at 0x7fdad20d0280>


In [None]:
evaluate(data=dataset['validation'], sentence_encoder=sentence_encoder, query_encoder=query_encoder, tokenizer=tokenizer, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))

Encoding all queries: 100%|██████████| 3377/3377 [01:30<00:00, 37.49it/s]
Predict matching queries: 100%|██████████| 625/625 [21:29<00:00,  2.06s/it]

Precision: 0.1934383688600556 | Recall: 0.19316625328545514 | F1: 0.19330221530710529





(0.1934383688600556, 0.19316625328545514, 0.19330221530710529)

## Sandbox

In [None]:
num_epochs = 3
lr = 0.25*1e-4
weight_decay = 1e-6
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print_interval = 1000
should_grad_clipping = True
timestamp = time.strftime('%Y%m%d-%H%M%S')
checkpoints_base_path = "checkpoints"

total_loss = 0
total_info_nce_loss = 0
total_nll_loss = 0
total_kl_loss = 0
examples_processed = 0
best_recall = 0.0

optimizer_query = AdamW(query_encoder.parameters(), lr=lr, weight_decay=weight_decay)
optimizer_sentence = AdamW(sentence_encoder.parameters(), lr=lr, weight_decay=weight_decay)
optimizer_sampled = AdamW(sampler_encoder.parameters(), lr=lr, weight_decay=weight_decay)

for epoch in range(num_epochs):  # num_epochs should be defined by you
    query_encoder.train()
    sentence_encoder.train()
    sampler_encoder.train()

    query_encoder.to(device)
    sentence_encoder.to(device)
    sampler_encoder.to(device)

    for batch_idx, batch in enumerate(train_dataloader):
        sentences_batch = batch['sentence']
        positives_batch = batch['good']

        outputs_query = encode(encoder=sentence_encoder, tokenizer=tokenizer, s=sentences_batch, device=device)
        outputs_sentence = encode(encoder=query_encoder, tokenizer=tokenizer, s=positives_batch, device=device)
        outputs_samples, mu, log_var = sampler_encode(encoder=sampler_encoder, tokenizer=tokenizer, s=positives_batch, device=device)

        loss, info_nce_loss, nll_loss, kl_divergence_loss = compute_loss(outputs_query=outputs_query, outputs_sentence=outputs_sentence, outputs_samples=outputs_samples, mu=mu, log_var=log_var)

        optimizer_query.zero_grad()
        optimizer_sentence.zero_grad()
        optimizer_sampled.zero_grad()

        loss.backward()

        if should_grad_clipping:
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(query_encoder.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(sentence_encoder.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(sampler_encoder.parameters(), max_norm=1.0)

        # Backpropagation for models
        optimizer_query.step()
        optimizer_sentence.step()
        optimizer_sampled.step()

        # Update running loss and example count
        total_loss += loss.item()
        total_info_nce_loss += info_nce_loss.item()
        total_nll_loss += nll_loss.item()
        total_kl_loss += kl_divergence_loss.item()
        examples_processed += len(batch)

        if batch_idx % print_interval == 0:
            average_loss = total_loss / examples_processed
            average_info_nce_loss = total_info_nce_loss / examples_processed
            average_nll_loss = total_nll_loss / examples_processed
            average_kl_loss = total_kl_loss / len(val_dataloader)
            print(f"Epoch: {epoch}, Batch: {batch_idx}, Average Loss: {average_loss:.4f}, Average InfoNCE Loss: {average_info_nce_loss:.4f}, Average NLL Loss: {average_nll_loss:.4f}, Average KL Divergence Loss: {average_kl_loss:.4f}")

            # Reset counters
            total_loss = 0
            total_info_nce_loss = 0
            total_nll_loss = 0
            examples_processed = 0

    # validation
    query_encoder.eval()
    sentence_encoder.eval()
    sampler_encoder.eval()
    with torch.no_grad():  # Disable gradient computation
        total_loss = 0
        total_info_nce_loss = 0
        total_nll_loss = 0
        total_kl_loss = 0

        for batch in val_dataloader:
            # Validation step
            outputs_query = encode(encoder=query_encoder, tokenizer=tokenizer, s=batch['sentence'], device=device)
            outputs_sentence = encode(encoder=sentence_encoder, tokenizer=tokenizer, s=batch['good'], device=device)
            outputs_samples, mu, log_var = sampler_encode(encoder=sampler_encoder, tokenizer=tokenizer, s=batch['good'], device=device)

            loss, info_nce_loss, nll_loss, kl_divergence_loss = compute_loss(outputs_query=outputs_query, outputs_sentence=outputs_sentence, outputs_samples=outputs_samples, mu=mu, log_var=log_var)

            # Update running loss and example count
            total_loss += loss.item()
            total_info_nce_loss += info_nce_loss.item()
            total_nll_loss += nll_loss.item()
            total_kl_loss += kl_divergence_loss.item()

        average_loss = total_loss / len(val_dataloader)
        average_info_nce_loss = total_info_nce_loss / len(val_dataloader)
        average_nll_loss = total_nll_loss / len(val_dataloader)
        average_kl_loss = total_kl_loss / len(val_dataloader)

        print(f"Epoch: {epoch}, Average Validation Loss: {average_loss:.4f}, Average Validation InfoNCE Loss: {average_info_nce_loss:.4f}, Average Validation NLL Loss: {average_nll_loss:.4f}, Average Validation KL Divergence Loss: {average_kl_loss:.4f}")

        _, recall, _ = evaluate(data=dataset['validation'], sentence_encoder=sentence_encoder, query_encoder=query_encoder, tokenizer=tokenizer, device=device)
        if recall > best_recall:
            best_recall = recall
            os.makedirs(checkpoints_base_path, exist_ok=True)
            checkpoints_path = os.path.join(checkpoints_base_path, f'recall{recall:.4f}_epoch{epoch+1}_{timestamp}')
            torch.save({"model_name": query_encoder.config.name_or_path, "state_dict": query_encoder.state_dict()}, os.path.join(checkpoints_path, "query_enoder.ckpt"))
            torch.save({"model_name": sentence_encoder.config.name_or_path, "state_dict": sentence_encoder.state_dict()}, os.path.join(checkpoints_path, "sentence_encoder.ckpt"))
            torch.save({"model_name": sampler_encoder.encoder.config.name_or_path, "latent_dim": sampler_encoder.latent_dim, "state_dict": sentence_encoder.state_dict()}, os.path.join(checkpoints_path, "sampler_encoder.ckpt"))
            print(f'Saved model with recall {recall:.4f} at epoch {epoch+1} in folder {checkpoints_path}')


Epoch: 0, Batch: 0, Average Loss: 0.9791, Average InfoNCE Loss: 0.9256, Average NLL Loss: 0.0526, Average KL Divergence Loss: 0.0000
Epoch: 0, Batch: 1000, Average Loss: 0.9407, Average InfoNCE Loss: 0.9242, Average NLL Loss: 0.0134, Average KL Divergence Loss: 0.0149
Epoch: 0, Batch: 2000, Average Loss: 0.9368, Average InfoNCE Loss: 0.9242, Average NLL Loss: 0.0102, Average KL Divergence Loss: 0.0262
Epoch: 0, Batch: 3000, Average Loss: 0.9388, Average InfoNCE Loss: 0.9242, Average NLL Loss: 0.0117, Average KL Divergence Loss: 0.0400
Epoch: 0, Batch: 4000, Average Loss: 0.9371, Average InfoNCE Loss: 0.9242, Average NLL Loss: 0.0111, Average KL Divergence Loss: 0.0487
Epoch: 0, Batch: 5000, Average Loss: 0.9378, Average InfoNCE Loss: 0.9242, Average NLL Loss: 0.0112, Average KL Divergence Loss: 0.0600
Epoch: 0, Batch: 6000, Average Loss: 0.9373, Average InfoNCE Loss: 0.9242, Average NLL Loss: 0.0116, Average KL Divergence Loss: 0.0673
Epoch: 0, Batch: 7000, Average Loss: 0.9359, Averag

RuntimeError: Parent directory checkpoints/recall0.0002_epoch1_20240225-132429 does not exist.

In [None]:
recall > best_recall

False

In [None]:
if recall > best_recall or True:
        best_recall = recall
        checkpoints_path = os.path.join(checkpoints_base_path, f'recall{recall:.4f}_epoch{epoch+1}_{timestamp}')
        os.makedirs(checkpoints_path, exist_ok=True)
        torch.save({"model_name": query_encoder.config.name_or_path, "state_dict": query_encoder.state_dict()}, os.path.join(checkpoints_path, "query_enoder.ckpt"))
        torch.save({"model_name": sentence_encoder.config.name_or_path, "state_dict": sentence_encoder.state_dict()}, os.path.join(checkpoints_path, "sentence_encoder.ckpt"))
        torch.save({"model_name": sampler_encoder.encoder.config.name_or_path, "latent_dim": sampler_encoder.latent_dim, "state_dict": sentence_encoder.state_dict()}, os.path.join(checkpoints_path, "sampler_encoder.ckpt"))
        print(f'Saved model with recall {recall:.4f} at epoch {epoch+1} in folder {checkpoints_path}')

Saved model with recall 0.0002 at epoch 1 in folder checkpoints/recall0.0002_epoch1_20240225-132429


In [None]:
evaluate(data=dataset['validation'], sentence_encoder=sentence_encoder, query_encoder=query_encoder, tokenizer=tokenizer, device=device)

Precision: 0.00018508236165093466 | Recall: 0.00018509606485766112 | F1: 0.00018508921300066633


(0.00018508236165093466, 0.00018509606485766112, 0.00018508921300066633)

#### Train classic setup

In [None]:
num_epochs = 4
lr = 0.25*1e-4
weight_decay = 1e-6
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print_interval = 1000
should_grad_clipping = True
timestamp = time.strftime('%Y%m%d-%H%M%S')
checkpoints_base_path = "checkpoints"

total_loss = 0
total_info_nce_loss = 0
total_nll_loss = 0
total_kl_loss = 0
examples_processed = 0
best_recall = 0.0

optimizer_query = AdamW(query_encoder.parameters(), lr=lr, weight_decay=weight_decay)
optimizer_sentence = AdamW(sentence_encoder.parameters(), lr=lr, weight_decay=weight_decay)

for epoch in range(num_epochs):  # num_epochs should be defined by you
    query_encoder.train()
    sentence_encoder.train()

    query_encoder.to(device)
    sentence_encoder.to(device)

    for batch_idx, batch in enumerate(train_dataloader):
        sentences_batch = batch['sentence']
        positives_batch = batch['good']

        outputs_query = encode(encoder=sentence_encoder, tokenizer=tokenizer, s=sentences_batch, device=device)
        outputs_sentence = encode(encoder=query_encoder, tokenizer=tokenizer, s=positives_batch, device=device)

        loss, info_nce_loss, nll_loss, _ = compute_loss(outputs_query=outputs_query, outputs_sentence=outputs_sentence, outputs_samples=outputs_samples, mu=None, log_var=None)

        optimizer_query.zero_grad()
        optimizer_sentence.zero_grad()

        loss.backward()

        if should_grad_clipping:
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(query_encoder.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(sentence_encoder.parameters(), max_norm=1.0)

        # Backpropagation for models
        optimizer_query.step()
        optimizer_sentence.step()

        # Update running loss and example count
        total_loss += loss.item()
        total_info_nce_loss += info_nce_loss.item()
        total_nll_loss += nll_loss.item()
        total_kl_loss += kl_divergence_loss.item() if (mu and log_var) else 0
        examples_processed += len(batch)

        if batch_idx % print_interval == 0:
            average_loss = total_loss / examples_processed
            average_info_nce_loss = total_info_nce_loss / examples_processed
            average_nll_loss = total_nll_loss / examples_processed
            average_kl_loss = total_kl_loss / len(val_dataloader)
            print(f"Epoch: {epoch}, Batch: {batch_idx}, Average Loss: {average_loss:.4f}, Average InfoNCE Loss: {average_info_nce_loss:.4f}, Average NLL Loss: {average_nll_loss:.4f}, Average KL Divergence Loss: {average_kl_loss:.4f}")

            # Reset counters
            total_loss = 0
            total_info_nce_loss = 0
            total_nll_loss = 0
            examples_processed = 0

    # validation
    query_encoder.eval()
    sentence_encoder.eval()
    with torch.no_grad():  # Disable gradient computation
        total_loss = 0
        total_info_nce_loss = 0
        total_nll_loss = 0
        total_kl_loss = 0

        for batch in val_dataloader:
            # Validation step
            outputs_query = encode(encoder=query_encoder, tokenizer=tokenizer, s=batch['sentence'], device=device)
            outputs_sentence = encode(encoder=sentence_encoder, tokenizer=tokenizer, s=batch['good'], device=device)

            loss, info_nce_loss, nll_loss, _ = compute_loss(outputs_query=outputs_query, outputs_sentence=outputs_sentence, outputs_samples=outputs_samples, mu=None, log_var=None)

            # Update running loss and example count
            total_loss += loss.item()
            total_info_nce_loss += info_nce_loss.item()
            total_nll_loss += nll_loss.item()
            total_kl_loss += 0

        average_loss = total_loss / len(val_dataloader)
        average_info_nce_loss = total_info_nce_loss / len(val_dataloader)
        average_nll_loss = total_nll_loss / len(val_dataloader)
        average_kl_loss = total_kl_loss / len(val_dataloader)

        print(f"Epoch: {epoch}, Average Validation Loss: {average_loss:.4f}, Average Validation InfoNCE Loss: {average_info_nce_loss:.4f}, Average Validation NLL Loss: {average_nll_loss:.4f}, Average Validation KL Divergence Loss: {average_kl_loss:.4f}")

        _, recall, _ = evaluate(data=dataset['validation'], sentence_encoder=sentence_encoder, query_encoder=query_encoder, tokenizer=tokenizer, device=device)
        if recall > best_recall:
            best_recall = recall
            checkpoints_path = os.path.join(checkpoints_base_path, f'recall{recall:.4f}_epoch{epoch+1}_{timestamp}')
            torch.save({"model_name": query_encoder.config.name_or_path, "state_dict": query_encoder.state_dict()}, os.path.join(checkpoints_path, "query_enoder.ckpt"))
            torch.save({"model_name": sentence_encoder.config.name_or_path, "state_dict": sentence_encoder.state_dict()}, os.path.join(checkpoints_path, "sentence_encoder.ckpt"))
            torch.save({"model_name": sampler_encoder.encoder.config.name_or_path, "latent_dim": sampler_encoder.latent_dim, "state_dict": sentence_encoder.state_dict()}, os.path.join(checkpoints_path, "sampler_encoder.ckpt"))
            print(f'Saved model with recall {recall:.4f} at epoch {epoch+1} in folder {checkpoints_path}')


In [None]:
lr = 0.25*1e-4
weight_decay = 1e-6
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

optimizer_query = AdamW(query_encoder.parameters(), lr=lr, weight_decay=weight_decay)
optimizer_sentence = AdamW(sentence_encoder.parameters(), lr=lr, weight_decay=weight_decay)
optimizer_sampled = AdamW(sampler_encoder.parameters(), lr=lr, weight_decay=weight_decay)

for batch in train_dataloader:
        # Move batch to GPU
        # batch = {k: v.to(device) for k, v in batch.items() if k in ['sentence', 'good']}
        outputs_query = encode(encoder=query_encoder, tokenizer=tokenizer, s=batch['sentence'], device=device)
        outputs_sentence = encode(encoder=sentence_encoder, tokenizer=tokenizer, s=batch['good'], device=device)