# <center> Train a Reduction Head on a SentenceTransformer Model </center>

In [1]:
BASE_MODEL = "Linq-AI-Research/Linq-Embed-Mistral"
MAX_SEQ_LENGTH = 4096
MEDI_DATA_PATH = "/home/cayjobla/ReducedEncoders/medi-data/medi-data.json"
EVAL_SIZE = 8000
LR = 1e-4
BATCH_SIZE = 4
EPOCHS = 1
OUTPUT_DIR = "/home/cayjobla/ReducedEncoders/Linq-Embed-Mistral-reduced"
LOGGING_STEPS = 1
EVAL_STEPS = 1000
SAVE_STEPS = 1000
DEVICE = "cuda:0"

## Define the Reduction Head

In [2]:
from torch import Tensor, nn

class Resize(nn.Sequential):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        activation_function = nn.SiLU(),
        dropout: float = 0.1,
    ):
        super().__init__(
            nn.Linear(in_features, out_features, bias=bias),
            activation_function,
            nn.Dropout(dropout),
        )
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.activation_function = activation_function.__class__.__name__
        self.dropout = dropout
        self.apply(self._init_weights)

    @staticmethod
    def _init_weights(module, initializer_range=0.02):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        
    def get_config_dict(self):
        return {
            "in_features": self.in_features,
            "out_features": self.out_features,
            "bias": self.bias,
            "activation_function": self.activation_function,
            "dropout": self.dropout
        }
        
    def __repr__(self):
        return f"Resize({self.get_config_dict()})"

class MultiResize(nn.Sequential):
    def __init__(self, 
        sizes, 
        bias: bool = True, 
        activation_function = nn.SiLU(), 
        dropout: float = 0.1
    ):
        super().__init__()
        for i in range(len(sizes) - 1):
            in_features = sizes[i]
            out_features = sizes[i+1]
            self.append(Resize(in_features, out_features, bias=bias, 
                                activation_function=activation_function, 
                                dropout=dropout))
        self.sizes = sizes

    def get_config_dict(self):
        return {
            "sizes": self.sizes
        }
    
    def __repr__(self):
        return f"MultiResize({self.get_config_dict()})"

In [3]:
from sentence_transformers import SentenceTransformer

if "model" in locals():     # Avoid CUDA out of memory error
    del model

# Load the model
model = SentenceTransformer(BASE_MODEL, device=DEVICE)
model.max_seq_length = MAX_SEQ_LENGTH

# Prepare the model for inference (no training on full sentence transformer)
for param in model.parameters():
    param.requires_grad = False
model.eval()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

SentenceTransformer(
  (0): Transformer({'max_seq_length': 4096, 'do_lower_case': False}) with Transformer model: MistralModel 
  (1): Pooling({'word_embedding_dimension': 4096, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': True, 'include_prompt': True})
)

In [4]:
# Initialize the reduction and reconstruction layers
reduce = MultiResize(
    sizes=[4096,3584,2048,1536,1024,768,512,384,256,128,64], 
    bias=True, 
    activation_function=nn.SiLU(),
    dropout=0.1
)
expand = MultiResize(
    sizes=[64,128,256,384,512,768,1024,1536,2048,3584,4096], 
    bias=True, 
    activation_function=nn.SiLU(),
    dropout=0.1
)

## Load the Dataset

In [5]:
from datasets import Dataset
import json

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'

def load_medi_dataset(data_filepath):
    with open(data_filepath, "r") as f:
        json_data = json.load(f)
    return Dataset.from_dict({
        "query": [get_detailed_instruct(*item["query"]) for item in json_data],
        "pos": [item["pos"][1] for item in json_data],  # No instruction for documents
        "neg": [item["neg"][1] for item in json_data],
        "task_name": [item["task_name"] for item in json_data]
    })

In [6]:
dataset = load_medi_dataset(MEDI_DATA_PATH)
dataset

Dataset({
    features: ['query', 'pos', 'neg', 'task_name'],
    num_rows: 1435000
})

### Basic Analysis

In [7]:
len(set(dataset["task_name"]))

330

In [8]:
index = 329817
print(dataset[index]["query"], end="\n\n")
print(dataset[index]["pos"], end="\n\n")
print(dataset[index]["neg"], end="\n\n")

Instruct: Represent the Google question for retrieving answers;
Query: what are the effects of blood pressure medication?

['Cough.', 'Diarrhea or constipation.', 'Dizziness or lightheadedness.', 'Erection problems.', 'Feeling nervous.', 'Feeling tired, weak, drowsy, or a lack of energy.', 'Headache.', 'Nausea or vomiting.']

['difficulty sleeping.', 'headaches.', 'feeling dizzy.', 'blurred vision.', 'constipation or diarrhoea.', 'feeling or being sick (nausea or vomiting)', 'dry mouth.', 'sweating.']



### Split

In [7]:
from sklearn.model_selection import train_test_split
from datasets import DatasetDict

train_indices, eval_indices = train_test_split(
    list(range(len(dataset))),      # Indices of dataset
    test_size=EVAL_SIZE,
    stratify=dataset["task_name"],  # Ensures proportional split of tasks
)
dataset = DatasetDict({
    "train": dataset.select(train_indices),
    "eval": dataset.select(eval_indices)
})
dataset

DatasetDict({
    train: Dataset({
        features: ['query', 'pos', 'neg', 'task_name'],
        num_rows: 1427000
    })
    eval: Dataset({
        features: ['query', 'pos', 'neg', 'task_name'],
        num_rows: 8000
    })
})

In [8]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset["train"], batch_size=BATCH_SIZE, shuffle=True)
eval_dataloader = DataLoader(dataset["eval"], batch_size=BATCH_SIZE, shuffle=False)

In [9]:
# TESTING: Downsample the dataset
train_dataloader = DataLoader(dataset["train"].select(list(range(0, 250))), batch_size=BATCH_SIZE, shuffle=True)
eval_dataloader = DataLoader(dataset["eval"].select(list(range(0, 25))), batch_size=BATCH_SIZE, shuffle=False)

## Define Loss Function

In [10]:
import torch
from torch import nn

class UnsupervisedCosineSimilarityLoss(nn.Module):
    def __init__(self):     # Could theoretically allow changing similarity / loss functions
        super().__init__()
        self.similarity_fn = nn.CosineSimilarity(dim=-1, eps=1e-6)
        self.loss_fn = nn.MSELoss(reduction="mean")

    def get_similarity_scores(self, sentence_embeddings: Tensor):
        similarity_matrix = self.similarity_fn(
            sentence_embeddings.unsqueeze(0),
            sentence_embeddings.unsqueeze(1)
        )
        indices = torch.triu_indices(*similarity_matrix.shape, offset=1)
        return similarity_matrix[indices[0], indices[1]] 
        
    def forward(self, sentence_embeddings: Tensor, labels: Tensor):
        true_similarities = self.get_similarity_scores(labels)
        predicted_similarities = self.get_similarity_scores(sentence_embeddings)
        return self.loss_fn(predicted_similarities, true_similarities)

class DimensionalReductionLoss(nn.Module):
    def __init__(self, weights: Tensor = torch.ones(2)):
        super().__init__()
        self.cosine_similarity_loss = UnsupervisedCosineSimilarityLoss()
        self.reconstruction_loss = nn.MSELoss(reduction="mean")
        self.weights = nn.Parameter(weights, requires_grad=True)

    def forward(self, 
        sentence_embeddings: Tensor, 
        reconstructed_embeddings: Tensor, 
        labels: Tensor
    ):
        l_c = self.cosine_similarity_loss(sentence_embeddings, labels) * self.weights[0]
        l_r = self.reconstruction_loss(reconstructed_embeddings, labels) * self.weights[1]
        loss = l_c + l_r - .5*torch.log(self.weights[0] * self.weights[1])
        return loss, l_c, l_r

In [11]:
loss_fn = DimensionalReductionLoss(weights=torch.tensor([1., 1.]))

### Test the loss

In [14]:
# Run tests on the loss functions
full_embeddings = torch.randn(BATCH_SIZE, 4096)
reduced_embeddings = reduce(full_embeddings)
reconstructed_embeddings = expand(reduced_embeddings)

In [15]:
UnsupervisedCosineSimilarityLoss()(reduced_embeddings, full_embeddings)

tensor(0.0278, grad_fn=<MseLossBackward0>)

In [16]:
loss_fn(reduced_embeddings, reconstructed_embeddings, full_embeddings)

(tensor(1.0317, grad_fn=<SubBackward0>),
 tensor(0.0278, grad_fn=<MulBackward0>),
 tensor(1.0039, grad_fn=<MulBackward0>))

## Initialize optimizer

In [12]:
from torch import optim

trainable_params = list(reduce.parameters()) + list(expand.parameters()) + list(loss_fn.parameters())
optimizer = optim.AdamW(trainable_params, lr=LR)

In [13]:
# Initialize lr scheduler
num_training_steps = len(train_dataloader) * EPOCHS
lambda_lr = lambda step: max(0, (num_training_steps - step) / num_training_steps)
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda_lr)

## Train the Reduction

In [22]:
from accelerate import Accelerator

accelerator = Accelerator()
reduce, expand, loss_fn, optimizer, lr_scheduler, train_dataloader, eval_dataloader = accelerator.prepare(
    reduce, expand, loss_fn, optimizer, lr_scheduler, train_dataloader, eval_dataloader
)

In [15]:
def forward(batch):
    with torch.no_grad():
        full_embeddings = model.encode(
            batch["query"] + batch["pos"] + batch["neg"],
            convert_to_tensor=True,
            # show_progress_bar=False,
            # device=DEVICE
        ).detach()
        
        # full_embeddings = full_embeddings.to("cpu")
    # torch.cuda.empty_cache()

    # reduced_embeddings = reduce(full_embeddings.to(accelerator.device))
    reduced_embeddings = reduce(full_embeddings)
    reconstructed_embeddings = expand(reduced_embeddings)

    # losses = loss_fn(reduced_embeddings, reconstructed_embeddings, full_embeddings.to(accelerator.device))
    losses = loss_fn(reduced_embeddings, reconstructed_embeddings, full_embeddings)
    # del full_embeddings, reduced_embeddings, reconstructed_embeddings

    return losses

In [None]:
import os

reduce.train()
expand.train()
loss_fn.train()

os.makedirs(OUTPUT_DIR, exist_ok=True)
CLEAR_CUDA_CACHE_STEPS = 10

global_step = 0
for epoch in range(EPOCHS):
    for batch in train_dataloader:
        # Forward pass
        loss, l_c, l_r = forward(batch)

        # Backward pass
        optimizer.zero_grad()
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()    # Update learning rate
        global_step += 1

        # Clear cuda cache
        if global_step % CLEAR_CUDA_CACHE_STEPS == 0:
            torch.cuda.empty_cache()

        # Training logs
        if global_step % LOGGING_STEPS == 0:
            print(f"Epoch: {epoch}, Step: {global_step}, Train Loss: {loss.item()}")
            accelerator.log({
                "train/epoch": epoch,
                "train/global_step": global_step,
                "train/loss": loss.item(),
                "train/cosine_sim_loss": l_c.item(),
                "train/reconstruction_loss": l_r.item(),
                "train/cosine_sim_weight": loss_fn.weights[0].item(),
                "train/reconstruction_weight": loss_fn.weights[1].item(),
                "train/learning_rate": optimizer.param_groups[0]["lr"],
            })

        # Evaluation
        if global_step % EVAL_STEPS == 0:
            reduce.eval()
            expand.eval()
            loss_fn.eval()

            eval_loss = 0
            eval_l_c = 0
            eval_l_r = 0
            for eval_batch in eval_dataloader:
                with torch.no_grad():
                    loss, l_c, l_r = forward(eval_batch)
                    eval_loss += loss.item()
                    eval_l_c += l_c.item()
                    eval_l_r += l_r.item()

            eval_loss /= len(eval_dataloader)
            eval_l_c /= len(eval_dataloader)
            eval_l_r /= len(eval_dataloader)

            print(f"Epoch: {epoch}, Step: {global_step}, Eval Loss: {eval_loss}")
            accelerator.log({
                "eval/loss": eval_loss,
                "eval/cosine_sim_loss": eval_l_c,
                "eval/reconstruction_loss": eval_l_r,
            })

            reduce.train()
            expand.train()
            loss_fn.train()

        if global_step % SAVE_STEPS == 0:
            accelerator.wait_for_everyone()
            reduce_unwrapped = accelerator.unwrap_model(reduce)
            accelerator.save(reduce_unwrapped.state_dict(), os.path.join(OUTPUT_DIR, "reduce.pth"))
            expand_unwrapped = accelerator.unwrap_model(expand)
            accelerator.save(expand_unwrapped.state_dict(), os.path.join(OUTPUT_DIR, "expand.pth"))
            loss_fn_unwrapped = accelerator.unwrap_model(loss_fn)
            accelerator.save(loss_fn_unwrapped.state_dict(), os.path.join(OUTPUT_DIR, "loss_fn.pth"))

# Final evaluation
reduce.eval()
expand.eval()
loss_fn.eval()

eval_loss = 0
eval_l_c = 0
eval_l_r = 0
for eval_batch in eval_dataloader:
    with torch.no_grad():
        loss, l_c, l_r = forward(eval_batch)
        eval_loss += loss.item()
        eval_l_c += l_c.item()
        eval_l_r += l_r.item()

eval_loss /= len(eval_dataloader)
eval_l_c /= len(eval_dataloader)
eval_l_r /= len(eval_dataloader)

print(f"Epoch: {epoch}, Step: {global_step}, Eval Loss: {eval_loss}")
accelerator.log({
    "eval/loss": eval_loss,
    "eval/cosine_sim_loss": eval_l_c,
    "eval/reconstruction_loss": eval_l_r,
})

# Save the model
accelerator.wait_for_everyone()
reduce_unwrapped = accelerator.unwrap_model(reduce)
accelerator.save(reduce_unwrapped.state_dict(), os.path.join(OUTPUT_DIR, "reduce.pth"))
expand_unwrapped = accelerator.unwrap_model(expand)
accelerator.save(expand_unwrapped.state_dict(), os.path.join(OUTPUT_DIR, "expand.pth"))
loss_fn_unwrapped = accelerator.unwrap_model(loss_fn)
accelerator.save(loss_fn_unwrapped.state_dict(), os.path.join(OUTPUT_DIR, "loss_fn.pth"))

Epoch: 0, Step: 1, Train Loss: 4.071497440338135
Epoch: 0, Step: 2, Train Loss: 4.1828203201293945
Epoch: 0, Step: 3, Train Loss: 4.451157093048096
Epoch: 0, Step: 4, Train Loss: 4.141313076019287
Epoch: 0, Step: 5, Train Loss: 4.224001884460449
Epoch: 0, Step: 6, Train Loss: 3.6943228244781494
Epoch: 0, Step: 7, Train Loss: 4.0532989501953125
Epoch: 0, Step: 8, Train Loss: 4.022622108459473
Epoch: 0, Step: 9, Train Loss: 4.049502372741699
Epoch: 0, Step: 10, Train Loss: 3.809208869934082
Epoch: 0, Step: 11, Train Loss: 3.7141005992889404
Epoch: 0, Step: 12, Train Loss: 3.728238821029663
Epoch: 0, Step: 13, Train Loss: 4.212258338928223
Epoch: 0, Step: 14, Train Loss: 4.07533073425293
Epoch: 0, Step: 15, Train Loss: 4.232998371124268
Epoch: 0, Step: 16, Train Loss: 4.016861915588379
Epoch: 0, Step: 17, Train Loss: 4.096328258514404
Epoch: 0, Step: 18, Train Loss: 4.021500110626221
Epoch: 0, Step: 19, Train Loss: 4.002612113952637
Epoch: 0, Step: 20, Train Loss: 4.289522171020508
Epoch: