In [None]:
from safetensors.torch import save_file
import torch
from typing import List

def save_logits(logits : List[torch.tensor], file_path : str) -> None:
    """
    Save a list of output logits

    Parameters:
    - logits (List[torch.tensor]): list of the logits you want to save
    - file_path (str): location to store the logits
    """
    data = {}
    for i, logit in enumerate(logits):
        data[f"Question {i+1}"] = logit.to('cpu')

    save_file(data, file_path)

In [None]:
from safetensors.torch import safe_open

def load_list_of_logits_safetensor(file_path: str):
    """
    Load a list of saved torch logits

    Parameters:
    - file_path (str): file path for where to look for the logits file
    """
    with safe_open(file_path, framework="pt") as f:
        logits_list = []
        for key in f.keys():
            logits_list.append(f.get_tensor(key))
    
    return logits_list

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class KnowledgeDistillationLoss(nn.Module):
    """
    Object used to calculate KD loss using a mix of hard loss (cross-entropy) and soft loss (KL-Divergence).
    """
    def __init__(self, temperature=1.0, alpha=0.5):
        """
        Parameters:
        - temperature (float): Temperature for softening logits before KL-Divergence.
        - alpha (float): Weight for combining hard and soft losses.
        """
        super(KnowledgeDistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):

        # Hard Loss: Cross-Entropy between student predictions and true labels
        loss_hard = self.criterion(student_logits, labels)

        # Soft Loss: KL-Divergence between soft targets from teacher and student
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)
        student_log_probs = F.log_softmax(student_logits / self.temperature, dim=1)
        loss_soft = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean', log_target=False) * (self.temperature ** 2)

        # Combine the losses
        loss = self.alpha * loss_hard + (1.0 - self.alpha) * loss_soft
        return loss

In [None]:
import os
from safetensors.torch import save_file
from transformers import AutoModelForCausalLM

def save_model_safetensors_sharded(model: AutoModelForCausalLM, save_directory: str, max_shard_size: int = 5 * 1024 * 1024 * 1024, dtype: str = "FP32"):
    """
    Save a model in a sharded SafeTensors format with a maximum shard size.

    Parameters:
    - model (PreTrainedModel): The model to save.
    - save_directory (str): The directory where the shards will be saved.
    - max_shard_size (int): Maximum shard size in bytes. Default is 5GB.
    - dtype (str): The data type to use when saving the tensors. Options are 'FP32', 'BF16', 'FP16'. Default is 'FP32'.
    """
    # Ensure the save directory exists
    os.makedirs(save_directory, exist_ok=True)
    
    # Convert model to state dict
    state_dict = model.state_dict()

    # Initialize shard variables
    shard = {}
    current_shard_size = 0
    shard_index = 0

    # Map the dtype string to a PyTorch dtype
    dtype_map = {
        "FP32": torch.float32,
        "BF16": torch.bfloat16,
        "FP16": torch.float16,
    }

    selected_dtype = dtype_map[dtype]

    for key, tensor in state_dict.items():
        # Convert tensor to the selected dtype
        tensor = tensor.to(selected_dtype)

        tensor_size = tensor.numel() * tensor.element_size()
        
        # If adding this tensor would exceed the max_shard_size, save the current shard
        if current_shard_size + tensor_size > max_shard_size and shard:
            shard_file = os.path.join(save_directory, f"model-{shard_index:05d}-of-unknown.safetensors")
            save_file(shard, shard_file)
            shard_index += 1
            shard = {}
            current_shard_size = 0
        
        # Add tensor to current shard
        shard[key] = tensor
        current_shard_size += tensor_size

    # Save the final shard
    if shard:
        shard_file = os.path.join(save_directory, f"model-{shard_index:05d}-of-unknown.safetensors")
        save_file(shard, shard_file)
        shard_index += 1

    # Rename files to reflect the total number of shards
    total_shards = shard_index
    for i in range(total_shards):
        old_name = os.path.join(save_directory, f"model-{i:05d}-of-unknown.safetensors")
        new_name = os.path.join(save_directory, f"model-{i+1:05d}-of-{total_shards:05d}.safetensors")
        os.rename(old_name, new_name)