### This Notebook serves as an overview of the general methodology to track and help reproduce the core contribution. 

#### Preservation Set Construction


In [22]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchcam.methods import GradCAM
from sklearn.cluster import KMeans
import numpy as np
import os  
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import math
from tqdm import trange
import os
from PIL import Image
import torchvision  
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import os
from tqdm import trange


In [2]:
# Load the model or define and train one  
#''' model '''

In [3]:
def load_model(model_class, model_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model_class().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    return model


### Grad-CAM Intensity Calculation

Grad-CAM (Gradient-weighted Class Activation Mapping) helps in visualizing which parts of the input image are influencing the model's prediction the most. By applying Grad-CAM, we compute activation maps that highlight important regions in the image for a given class. This allows us to rank and select images based on the intensity of these activation maps, ensuring we focus on the most critical examples.



In [4]:
def compute_gradcam_intensity(images, labels, model, cam_layer='specify the layer to cam at'):
    device = next(model.parameters()).device
    activations = []
    
    with GradCAM(model, target_layer=cam_layer) as cam_extractor:
        model.train()
        for image, label in zip(images, labels):
            image = image.to(device).requires_grad_(True)
            with torch.set_grad_enabled(True):
                output = model(image.unsqueeze(0))
                prediction = output.argmax(dim=1).item()
                cam = cam_extractor(prediction, output)
                intensity = cam[0].sum().item()
                activations.append((image.cpu().detach(), intensity, label))
    
    return activations


### Uncertainty Sampling

Uncertainty sampling is a technique used to select examples for which the model is least confident in its predictions. This is crucial for creating a robust dataset, as it focuses on examples where the model might make mistakes. By selecting the images with high uncertainty (low confidence), we gather examples that help refine and improve the model's performance.


In [5]:
def get_uncertain_examples(images, labels, model, threshold="hyperparameter to search"): 
    device = next(model.parameters()).device
    images = images.to(device)
    labels = labels.to(device)
    uncertain_examples = []
    uncertain_labels = []
    model.eval()
    with torch.no_grad():
        outputs = model(images)
        probabilities = F.softmax(outputs, dim=1)
        uncertainties = 1 - probabilities.max(dim=1)[0]
        for i, uncertainty in enumerate(uncertainties):
            if uncertainty > threshold:
                uncertain_examples.append(images[i].cpu())
                uncertain_labels.append(labels[i].cpu())
    return uncertain_examples, uncertain_labels


### Clustering-Based Projection

Clustering-based projection groups examples into clusters based on feature embeddings. This ensures that the selected examples represent a diverse range of images, minimizing redundancy in the dataset. By applying clustering, we can ensure that the dataset contains varied instances from different regions of the feature space, enhancing the overall quality of the selection process.


In [6]:
def get_embedding(model, image):  
    # or follow any conventianal method to extract features from the model, for example this will differ for CNN and attention blocks
    device = next(model.parameters()).device
    image = image.to(device)
    model.eval()
    with torch.no_grad():
        output = model.conv3(model.relu(model.conv2(model.relu(model.conv1(image.unsqueeze(0))))))
        return output.view(output.size(0), -1)

def get_diverse_examples(images, labels, model, num_clusters=10): 
    #sampling diverse examples from the dataset
    embeddings = []
    images_list = []
    labels_list = []
    for image, label in zip(images, labels):
        embedding = get_embedding(model, image)
        embeddings.append(embedding.squeeze().cpu().numpy())
        images_list.append(image.cpu())
        labels_list.append(label.cpu())
    
    embeddings = np.array(embeddings)
    kmeans = KMeans(n_clusters=min(num_clusters, len(embeddings)))
    clusters = kmeans.fit_predict(embeddings)
    
    selected_images = []
    selected_labels = []
    for cluster in range(num_clusters):
        cluster_indices = [i for i, c in enumerate(clusters) if c == cluster]
        if cluster_indices:
            selected_images.append(images_list[cluster_indices[0]])
            selected_labels.append(labels_list[cluster_indices[0]])
    
    return selected_images, selected_labels


#### Sort and Select Top examples 

In [7]:
def select_top_examples(testloader, model, num_examples="hyperparameter to search"):
    all_activations = []
    all_uncertain = []
    all_diverse = []
    all_labels = []

    for images, labels in testloader:
        activations = compute_gradcam_intensity(images, labels, model)
        all_activations.extend(activations)
        
        uncertain_examples, uncertain_labels = get_uncertain_examples(images, labels, model)
        all_uncertain.extend(uncertain_examples)
        all_labels.extend(uncertain_labels)
        
        diverse_examples, diverse_labels = get_diverse_examples(images, labels, model)
        all_diverse.extend(diverse_examples)
        all_labels.extend(diverse_labels)
    
    sorted_activations = sorted(all_activations, key=lambda x: x[1], reverse=True)
    top_activations = sorted_activations[:num_examples // 2]
    top_examples_by_gradcam = [img for img, _, _ in top_activations]
    labels_by_gradcam = [label for _, _, label in top_activations]
    
    combined_examples = top_examples_by_gradcam + all_uncertain + all_diverse
    combined_labels = labels_by_gradcam + all_labels
    
    unique_dict = {}
    for img, label in zip(combined_examples, combined_labels):
        key = tuple(img.numpy().flatten())
        if key not in unique_dict:
            unique_dict[key] = (img, label)
    
    unique_items = list(unique_dict.values())
    unique_images = [item[0] for item in unique_items]
    unique_labels = [item[1] for item in unique_items]
    
    if len(unique_images) > num_examples:
        unique_images = unique_images[:num_examples]
        unique_labels = unique_labels[:num_examples]
    
    return unique_images, unique_labels


#### CNN _ MNIST Modeling Approach

### Dataset Preparation and Transformation

The following block of code defines two sets of transformations for the MNIST dataset. The first transformation (`transform`) is applied to both the training and testing data, converting the images into tensors and normalizing them using the mean and standard deviation of the MNIST dataset.

The second transformation (`safety_transform`) is designed for the safety set and includes augmentation techniques such as random rotation and random affine translation. These augmentations ensure the diversity of the dataset by introducing slight variations in the images, which help improve the robustness of the model during training.

After defining the transformations, the MNIST dataset is loaded for both the training and testing sets. The entire dataset is then converted into tensors, `X_train` and `Y_train` for the training set, and `X_test` and `Y_test` for the test set, by using data loaders that load the entire dataset in one batch.


In [8]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

safety_transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomAffine(0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST('./data', train=False, transform=transform)

X_train, Y_train = next(iter(torch.utils.data.DataLoader(
    train_dataset, batch_size=len(train_dataset))))
X_test, Y_test = next(iter(torch.utils.data.DataLoader(
    test_dataset, batch_size=len(test_dataset))))


### Quantized Convolutional Layer

This block of code defines a custom quantized convolutional layer `QConv2d`, which applies quantization to the convolution weights. The class is designed to reduce the precision of the weights, effectively simulating a form of quantization during training. The key components are:

1. **Initialization**: The weights are initialized with a uniform distribution scaled by the inverse square root of the input size, ensuring that weight magnitudes are appropriately sized. Two additional parameters, `e` and `b`, are initialized. The parameter `e` controls the scaling of the weights, while `b` controls the number of bits used for quantization.

2. **qbits Function**: This function computes the total number of quantized bits used by the weights. It sums the `b` values (after applying a ReLU operation to ensure positivity) and multiplies by the number of elements in the first dimension of the weights. This gives a measure of the bit precision for the quantized weights.

3. **qweight Function**: This function calculates the quantized version of the weights. It:
    - Ensures the bits in `b` are non-negative using ReLU.
    - Defines the quantization range based on `b` values (`min_val` and `max_val`).
    - Scales the original weights by `2 ** -self.e` to ensure appropriate scaling for quantization.
    - Clips the scaled weights between `min_val` and `max_val`, simulating quantization.

4. **forward Function**: During the forward pass, the weights are quantized using the `qweight` function. A "straight-through estimator" is applied, which rounds the weights but allows gradient flow through the non-quantized weights, preserving differentiability. The convolution is then performed using the quantized weights scaled by `2 ** self.e`.

This layer is useful for simulating quantization during training, which can help reduce model size and improve efficiency, especially in hardware-constrained environments.


In [9]:
class QConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(QConv2d, self).__init__()
        self.kernel_size = (kernel_size, kernel_size) if isinstance(
            kernel_size, int) else tuple(kernel_size)
        scale = 1 / math.sqrt(in_channels * math.prod(self.kernel_size))
        self.weight = nn.Parameter(torch.empty(
            out_channels, in_channels, *self.kernel_size).uniform_(-scale, scale))
        self.e = nn.Parameter(torch.full((out_channels, 1, 1, 1), -8.))
        self.b = nn.Parameter(torch.full((out_channels, 1, 1, 1), 32.))  # Start with 32 bits

    def qbits(self):
        return self.b.relu().sum() * self.weight[0].numel()

    def qweight(self):
        b_rel = self.b.relu()
        min_val = torch.where(b_rel > 0, -2 ** (b_rel - 1), torch.zeros_like(b_rel))
        max_val = torch.where(b_rel > 0, 2 ** (b_rel - 1) - 1, torch.zeros_like(b_rel))
        scaled_weight = 2 ** -self.e * self.weight
        qweight = torch.max(torch.min(scaled_weight, max_val), min_val)
        return qweight

    def forward(self, x):
        qw = self.qweight()
        w = (qw.round() - qw).detach() + qw  # Straight-through estimator
        return nn.functional.conv2d(x, 2 ** self.e * w)


### CNN Model Using Quantized Convolutional Layers

This block defines a CNN model `Model` that utilizes the previously defined quantized convolutional layers (`QConv2d`). The architecture follows a standard CNN design but incorporates quantization techniques within the convolutional layers to reduce precision in the weights, improving model efficiency and size reduction.

1. **Model Architecture**:
    - The model is divided into two main parts: `features` and `classifier`.
    - The `features` part contains sequential layers of `QConv2d` (quantized convolutional layers), ReLU activations, Batch Normalization, and Max Pooling. The convolutional layers use quantization to limit precision, with ReLU as the activation function. Batch Normalization without affine parameters or running statistics is applied to normalize the feature maps.
    - The `classifier` consists of a single fully connected layer that maps the flattened output of the `features` section to 10 output classes, suitable for classification tasks.

2. **Forward Function**:
    - The input tensor `x` passes through the `features` section, which applies the sequence of quantized convolutional, activation, normalization, and pooling layers.
    - The output is then flattened using `x.view()` to prepare it for the fully connected `classifier` layer, which outputs class scores.

3. **qbits Function**:
    - The `qbits` function calculates the total number of quantized bits used by all `QConv2d` layers in the `features` section. This function iterates over the layers in `features`, checks if the layer is a quantized convolution (`QConv2d`), and sums the bit usage across all such layers. This helps track how many bits are used by the quantized layers during training.

This model is particularly useful for applications where resource efficiency is critical, such as edge devices, as it combines the power of CNNs with the memory and computational benefits of quantization.


In [10]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.features = nn.Sequential(
            QConv2d(1, 32, 5), nn.ReLU(),
            QConv2d(32, 32, 5), nn.ReLU(),
            nn.BatchNorm2d(32, affine=False, track_running_stats=False),
            nn.MaxPool2d(2),
            QConv2d(32, 64, 3), nn.ReLU(),
            QConv2d(64, 64, 3), nn.ReLU(),
            nn.BatchNorm2d(64, affine=False, track_running_stats=False),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Linear(64 * 3 * 3, 10)  

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  
        x = self.classifier(x)
        return x

    def qbits(self):
        return sum(l.qbits() for l in self.features if isinstance(l, QConv2d))


### Custom Dataset Class: SafetySetDataset

The `SafetySetDataset` class is a custom dataset designed to load images and their corresponding labels from a directory. It extends `torch.utils.data.Dataset`, which allows for easy integration with PyTorch's data loading utilities. This dataset is used to handle images saved in the safety set folder.

1. **Initialization (`__init__` Method)**:
    - The class is initialized with a `safety_set_path` (path to the folder containing images) and an optional `transform` parameter.
    - It loops through all `.png` files in the specified folder and extracts both the image paths and the labels. The label is inferred from the file name, assuming that it is in the format `image_{index}_label_{label}.png`.
    - The image paths and labels are stored in lists (`self.image_paths` and `self.labels`).

2. **Length (`__len__` Method)**:
    - This method returns the number of examples in the dataset, which is simply the length of the `self.labels` list.

3. **Getting an Item (`__getitem__` Method)**:
    - The `__getitem__` method retrieves a single image and its corresponding label based on the index `idx`.
    - The image is loaded using `PIL.Image` and converted to grayscale using `.convert('L')`.
    - If a transformation is provided (`self.transform`), it is applied to the image. If not, the default transformation is applied using `transform(image)`.
    - The method returns a tuple of the transformed image and its corresponding label.

This custom dataset class allows for easy loading and transformation of images from the safety set, making it compatible with PyTorch's `DataLoader` for batching and iterating over data during model training or evaluation, accomodate this to an algorithm to work on your dataset


In [11]:
class SafetySetDataset(torch.utils.data.Dataset):
    def __init__(self, safety_set_path, transform=None):
        self.image_paths = []
        self.labels = []
        self.transform = transform
        for file in os.listdir(safety_set_path):
            if file.endswith('.png'):
                label_str = file.split('_label_')[-1].split('.png')[0]
                label = int(label_str)
                image_path = os.path.join(safety_set_path, file)
                self.image_paths.append(image_path)
                self.labels.append(label)
                
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('L')
        if self.transform:
            image = self.transform(image)
        else:
            image = transform(image)  
        label = self.labels[idx]
        return image, label


In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model().to(device)
opt = optim.Adam(model.parameters())
weight_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
safety_set_path = "/home/mohammad/safety_set_images_d"  # Update this path if necessary
safety_dataset = SafetySetDataset(safety_set_path, transform=safety_transform)


### Training Step Function with Adjusted Compression and Safety Weights

This training step function trains the model on both the main training data and a safety dataset while balancing accuracy, compression, and safety using regularization weights.

1. **Model Training**:
    - The model is set to training mode (`model.train()`), and the optimizer gradient is reset using `opt.zero_grad()`.
    - A random batch of samples (512 - the results of the paper are averaged) is selected from the training data (`X_train`) and passed through the model to get the outputs.
    - The main loss function, cross-entropy, is calculated between the model's predictions and the actual labels (`Y_train`).

2. **Quantization Regularization**:
    - The number of quantized bits (`Q`) used by the model is computed using the `qbits()` function.
    - A compression weight (`compression_weight`) of **0.1** (hyperparameter to search for) is applied to penalize the model based on the number of bits used. This encourages the model to reduce its bit usage, balancing compression and accuracy.
    - The total loss is updated to include the compression penalty, combining both cross-entropy loss and the regularization term.

3. **Safety Set Penalty**:
    - A random batch of images and labels (64) is sampled from the safety dataset.
    - The images are transformed using augmentation, stacked into a batch, and passed through the model to obtain the predictions.
    - A separate cross-entropy loss is calculated between the model's predictions and the true labels from the safety set.
    - A **safety weight** of **0.05** (hyperparameter to search for) is applied to this safety loss. This regularization term encourages the model to maintain performance on critical safety examples while still prioritizing compression and accuracy.

4. **Loss Backpropagation**:
    - The total loss, which now includes both the main loss and the safety loss, is backpropagated through the model to compute gradients.
    - The optimizer steps to update the model's weights.

This function balances the trade-offs between model accuracy, compression (through quantization), and performance on the safety set, allowing for more efficient training while preserving key performance metrics.


In [13]:
def train_step():
    model.train()
    opt.zero_grad()
    samples = torch.randint(0, X_train.shape[0], (64,))
    outputs = model(X_train[samples].to(device))
    loss = nn.functional.cross_entropy(outputs, Y_train[samples].to(device))
    Q = model.qbits() / weight_count
    compression_weight = 0.1  
    loss = loss + compression_weight * Q  
    safety_indices = torch.randint(0, len(safety_dataset), (64,))
    safety_images_batch = []
    safety_labels_batch = []
    for idx in safety_indices:
        img, label = safety_dataset[idx]
        safety_images_batch.append(img)
        safety_labels_batch.append(label)
    safety_images_batch = torch.stack(safety_images_batch).to(device)
    safety_labels_batch = torch.tensor(safety_labels_batch).to(device)
    safety_outputs = model(safety_images_batch)
    safety_loss = nn.functional.cross_entropy(safety_outputs, safety_labels_batch)
    safety_weight = 0.05  
    loss = loss + safety_weight * safety_loss

    loss.backward()
    opt.step()
    return loss.item(), Q.item(), safety_loss.item()


### Functions to Compute Test and Safety Accuracies

These two functions are used to evaluate the model's performance on both the test dataset and the safety set, providing accuracy metrics after training.

1. **Test Accuracy (`get_test_acc`)**:
    - This function evaluates the model's accuracy on the test dataset.
    - The model is set to evaluation mode (`model.eval()`), and `torch.no_grad()` ensures that no gradients are computed, saving memory and speeding up computation.
    - The test data (`X_test`) is passed through the model, and the predictions are obtained by selecting the class with the highest score (`argmax`).
    - The predicted labels are compared with the true labels (`Y_test`), and the accuracy is calculated as the percentage of correct predictions.

2. **Safety Set Accuracy (`get_safety_acc`)**:
    - This function computes the accuracy on the safety set without augmentations, which is important for evaluating the model's performance on critical safety examples.
    - A separate data loader (`safety_loader_eval`) is used to batch the safety set for evaluation.
    - Like the test accuracy function, the model is set to evaluation mode and `torch.no_grad()` is used.
    - For each batch of images and labels in the safety set, the images are passed through the model, predictions are made, and the number of correct predictions is summed.
    - The accuracy is computed as the ratio of correct predictions to the total number of examples in the safety set.

These functions allow for quick and efficient calculation of accuracy on both the test and safety sets, ensuring that the model's performance is measured across both standard and critical safety datasets.


In [14]:
def get_test_acc():
    model.eval()
    with torch.no_grad():
        outputs = model(X_test.to(device))
        pred = outputs.argmax(dim=1)
        return (pred == Y_test.to(device)).float().mean().item() * 100

safety_dataset_eval = SafetySetDataset(safety_set_path, transform=transform)
safety_loader_eval = torch.utils.data.DataLoader(safety_dataset_eval, batch_size=64, shuffle=False)

def get_safety_acc():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in safety_loader_eval:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            pred = outputs.argmax(dim=1)    
            correct += (pred == labels).sum().item()
            total += labels.size(0)
    return (correct / total) * 100


### Functions to Check and Restore Zero-Bit Kernels

These two functions are used to manage the quantization levels of the kernels in the convolutional layers (`QConv2d`) of the model. The goal is to monitor and potentially restore kernels that have been quantized to zero bits, ensuring they remain useful during training.

1. **`check_zero_bit_kernels`**:
    - This function checks if any of the quantized convolutional layers (`QConv2d`) have kernels where the bit precision (`b`) has been reduced to zero.
    - It iterates through the `features` section of the model, specifically looking for layers of type `QConv2d`.
    - For each such layer, the function checks whether any of the bit values (`b`) are less than or equal to zero, which would indicate that the corresponding kernels are using zero bits and hence have been "discarded."
    - If any zero-bit kernels are found, the function returns `True`; otherwise, it returns `False`.

2. **`restore_zero_bit_kernels`**:
    - This function restores a fraction of the kernels that have been reduced to zero bits (`b <= 0`) in the quantized convolutional layers.
    - For each `QConv2d` layer, the function identifies the indices of the kernels that have zero bits.
    - A fraction of these zero-bit kernels, determined by the `restore_fraction` parameter, is randomly selected for restoration, then this a learnable parameter to be determined via backpropagration
    - The selected kernels have their bit values (`b`) restored to 2.0, allowing them to be used in future computations. This restoration helps ensure that the model doesn't lose valuable information due to over-quantization.

In practice, these functions can be used to control the extent of quantization in the model. During experimentation, you may initialize this process, train the model, and periodically restore a small fraction of zero-bit kernels to maintain model accuracy while benefiting from quantization.


In [15]:
def check_zero_bit_kernels():
    for layer in model.features:
        if isinstance(layer, QConv2d):
            if (layer.b.view(-1) <= 0).any():
                return True
    return False

def restore_zero_bit_kernels(restore_fraction):  #  in an experiment we intialized this and we learn it, sampled and then passed within the learnble parameters
    for layer in model.features:
        if isinstance(layer, QConv2d):
            b_flat = layer.b.view(-1)
            zero_bit_indices = (b_flat <= 0).nonzero(as_tuple=False).view(-1)
            num_restore = int(restore_fraction * len(zero_bit_indices))  # Restore a smaller fraction
            if num_restore > 0:
                restore_indices = zero_bit_indices[torch.randperm(len(zero_bit_indices))[:num_restore]]
                b_flat[restore_indices] = 2.0  # Restore bits to 2


### Training Loop with Safety Set Monitoring and Kernel Restoration

This block of code defines a training loop that not only trains the model but also monitors its performance on the safety set and dynamically restores quantized kernels to prevent excessive performance degradation. The loop runs for a set number of iterations, balancing accuracy, compression, and safety.

1. **Initialization**:
    - The `prev_safety_acc` stores the previous safety accuracy to compare it with the current one at each step.
    - The `safety_acc_drop_threshold` is set to 10, meaning that if the safety accuracy drops by more than this threshold, action is taken to restore some of the zero-bit kernels.
    - Lists `test_accs`, `bytes_used`, and `safety_losses` are initialized to store the model's performance metrics over time.
    - The `initial_safety_acc` is calculated to track the starting point for safety accuracy, and `prev_safety_acc` is set to this initial value.

2. **Training Loop**:
    - The loop runs for 10,000 iterations. In each iteration, the following steps are performed:
    - **Train Step**: The `train_step()` function is called to perform a training step, which returns the current loss, the number of bits used by the model (`Q`), and the safety loss.
    - **Model Size Calculation**: The model's size in bytes is calculated using the number of bits (`Q`) divided by 8 (to convert to bytes), multiplied by the total number of weights in the model (`weight_count`).

3. **Accuracy and Safety Monitoring**:
    - Every 10 iterations, the test accuracy and safety accuracy are computed using the `get_test_acc()` and `get_safety_acc()` functions, respectively.
    - The drop in safety accuracy (`acc_drop`) is calculated by comparing the current safety accuracy to the previous one (`prev_safety_acc`).
    - If the accuracy drop exceeds the threshold (`safety_acc_drop_threshold`), the function `check_zero_bit_kernels()` checks if any kernels have zero bits. If zero-bit kernels are found, the function `restore_zero_bit_kernels()` is called to restore a fraction (10%) of these kernels to prevent further degradation.

4. **Updating Metrics**:
    - The current test accuracy is appended to `test_accs`. If no accuracy is calculated for the current iteration, the last recorded accuracy is used.
    - The model's size in bytes is appended to `bytes_used`, and the safety loss is appended to `safety_losses` for tracking.

5. **Progress Bar Description**:
    - The progress of the loop is updated using `t.set_description()`, which displays the current loss, model size in bytes, and test accuracy. This provides a real-time view of the model's performance during training.

This loop effectively balances between model accuracy, compression, and safety by continuously monitoring and adjusting quantized kernels, ensuring that the model doesn't lose critical performance on safety-critical examples.


In [20]:
torch.cuda.empty_cache()


In [21]:

prev_safety_acc = None
safety_acc_drop_threshold = 10 # averaged 
test_accs, bytes_used, safety_losses = [], [], []
initial_safety_acc = get_safety_acc()
prev_safety_acc = initial_safety_acc
for i in (t := trange(10000)):
    loss, Q, safety_loss = train_step()
    model_bytes = Q / 8 * weight_count
    if i % 10 == 9:
        test_acc = get_test_acc()
        safety_acc = get_safety_acc()
        acc_drop = prev_safety_acc - safety_acc
        if acc_drop > safety_acc_drop_threshold:
            if check_zero_bit_kernels():
                restore_zero_bit_kernels(restore_fraction=0.1)  
        prev_safety_acc = safety_acc
    else:
        test_acc = test_accs[-1] if test_accs else 0.0
    test_accs.append(test_acc)
    bytes_used.append(model_bytes)
    safety_losses.append(safety_loss)
    t.set_description(f"loss: {loss:6.2f}  bytes: {model_bytes:.1f}  acc: {test_acc:5.2f}")



loss:   1.93  bytes: 211666.1  acc: 99.35: 100%|██████████| 10000/10000 [14:33<00:00, 11.45it/s]


### Attention Based Approach

### Hyperparameters for Transformer Model

- **batch_size**: 16  
  The number of examples processed together in one forward/backward pass during training.

- **block_size**: 32  
  The maximum length of input sequences (in tokens) that the model can process.

- **max_iters**: 10,000  
  The total number of training iterations to run.

- **eval_interval**: 100  
  The number of iterations after which the model is evaluated on the validation set.

- **learning_rate**: 1e-3  
  The rate at which the model updates its weights during training. A lower learning rate leads to more gradual updates, while a higher one accelerates learning.

- **device**: `'cuda'` if available, otherwise `'cpu'`  
  Specifies the hardware on which the model will run. If a GPU is available, the model will be trained on `'cuda'`.

- **eval_iters**: 200  
  The number of iterations used for evaluation during validation.

- **n_embd**: 64  
  The dimensionality of the embedding space, or the size of each token’s vector representation.

- **n_head**: 4  
  The number of attention heads in the multi-head self-attention mechanism of the transformer.

- **n_layer**: 4  
  The number of layers (transformer blocks) in the model.

- **dropout**: 0.0  
  The dropout rate, used to prevent overfitting by randomly setting some activations to zero during training.
Results here are very close to the ones reported, the reported ones are averaged across hyperparameters

In [23]:
batch_size = 16
block_size = 32
max_iters = 10000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0

### Data Loading and Seed Initialization

- **Seed Initialization**:
  ```python
  torch.manual_seed(1337)


In [24]:
torch.manual_seed(1337)
with open('/home/mohammad/names.txt', 'r', encoding='utf-8') as f:
    text = f.read()
with open('/home/mohammad/hardest_examples.txt', 'r', encoding='utf-8') as f:
    safety_text = f.read()


### Vocabulary Creation, Data Encoding, and Splitting

In this block, the raw text data is processed to prepare it for training a language model. The steps align with core concepts in Natural Language Processing (NLP) for sequence modeling in AI:

1. **Vocabulary Creation**: 
   The union of characters from both the main and safety datasets is used to create a vocabulary. By treating characters as the basic tokens, the model learns to predict sequences at the character level, a common approach in tasks such as text generation. The vocabulary size determines the number of unique tokens the model will need to handle, and each character is mapped to a unique integer (token).

2. **Character Encoding and Decoding**:
   Two mappings are created:
   - **Encoding** converts each character into its corresponding integer based on the vocabulary, transforming text data into sequences of integers that the model can process.
   - **Decoding** reverses this process, allowing for the generation of human-readable text from the model's predictions.
   These functions enable the model to operate in the discrete space of integers while maintaining the ability to map back to the continuous space of language.

3. **Data Encoding**:
   The main and safety datasets are encoded into sequences of integers using the vocabulary. This step converts the raw text into a format that neural networks can understand, making it possible to learn patterns over these integer sequences.

4. **Training and Validation Split**:
   The data is divided into training and validation sets (90% for training, 10% for validation). This allows the model to learn from a large portion of the data while being evaluated on unseen examples, ensuring that it generalizes well to new sequences.

This process enables the model to learn patterns in the text data by predicting the next character in a sequence, a fundamental approach in language modeling. By encoding text into integers and splitting it into training and validation sets, the groundwork is laid for training an AI model to generate or classify text.


In [25]:
chars = sorted(list(set(text + safety_text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s if c in stoi]  # encoder function
decode = lambda l: ''.join([itos[i] for i in l])       # decoder function
data = torch.tensor(encode(text), dtype=torch.long)
safety_data = torch.tensor(encode(safety_text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

### Batch Generation for Training and Validation

In this block, a function `get_batch` is defined to generate mini-batches of data for both training and validation. This function prepares input-output pairs for the model to learn from, essential for training a sequence-based model like a Transformer.

1. **Data Selection**:
   The function first selects the appropriate dataset split based on the input argument `split`. If `'train'` is passed, the function works with the training data (`train_data`); otherwise, it uses the validation data (`val_data`).

2. **Random Index Sampling**:
   A random set of starting indices (`ix`) is generated, with each index representing the starting position for a sequence. The number of indices generated is equal to the batch size, ensuring that each batch contains multiple sequences for parallel processing.

3. **Input (x) and Output (y) Sequence Creation**:
   - **Input (`x`)**: For each sampled index, a sequence of `block_size` tokens is extracted. These represent the input tokens for the model, which learns to predict the next token in the sequence.
   - **Output (`y`)**: The target sequence for each input is the same sequence shifted by one position. This is a common technique in language modeling, where the model is trained to predict the next character or token in a sequence.
   
   For example, if the input (`x`) is a sequence like "hello", the target (`y`) would be "elloh", with the model tasked with predicting the next token at each step.

4. **Moving Data to the Device**:
   Both `x` (inputs) and `y` (targets) are transferred to the device (either CPU or GPU) to ensure efficient computation. This allows the model to utilize hardware acceleration for faster training.

5. **Batch Return**:
   The function returns the mini-batch of input (`x`) and output (`y`) sequences, which can then be fed into the model for training or evaluation.

This function is crucial for training models that operate on sequential


In [26]:
def get_batch(split):
    data_split = train_data if split == 'train' else val_data
    ix = torch.randint(len(data_split) - block_size, (batch_size,))
    x = torch.stack([data_split[i:i+block_size] for i in ix])
    y = torch.stack([data_split[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [27]:
def get_safety_batch():
    ix = torch.randint(len(safety_data) - block_size, (batch_size,))
    x = torch.stack([safety_data[i:i+block_size] for i in ix])
    y = torch.stack([safety_data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


### Loss Estimation Function

This function, `estimate_loss`, estimates the average loss for both the training and validation datasets. The function is annotated with `@torch.no_grad()` to disable gradient computation, which speeds up the process and reduces memory usage since gradients are not needed during loss estimation.

1. **No Gradient Calculation**:
   ```python
   @torch.no_grad()


In [28]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

### Quantized Linear Layer (QLinear)

The `QLinear` class implements a custom fully connected (linear) layer with quantization, where the weights are quantized to a specified number of bits. This layer is designed to reduce the precision of the weights during training, offering a balance between model size and performance, particularly in memory-constrained environments.

1. **Initialization (`__init__` Method)**:
   - The linear layer takes the number of input (`in_features`) and output features (`out_features`) as parameters.
   - The weights are initialized uniformly within a range scaled by the inverse square root of the input size, ensuring appropriate magnitude for weight initialization.
   - A bias term is optionally included (`bias=True` by default).
   - Two additional parameters are introduced:
     - `e`: Controls the scaling factor of the weights, initialized with a small value (-8).
     - `b`: Controls the number of bits used for quantization, initialized to 32 bits, indicating full precision at the start.

2. **Quantization Bits Calculation (`qbits` Method)**:
   ```python
   def qbits(self):
       return self.b.relu().sum() * self.weight.shape[1]


In [32]:
class QLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(QLinear, self).__init__()
        scale = 1 / math.sqrt(in_features)
        self.weight = nn.Parameter(torch.empty(out_features, in_features).uniform_(-scale, scale))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)
        self.e = nn.Parameter(torch.full((out_features, 1), -8.))
        self.b = nn.Parameter(torch.full((out_features, 1), 32.))  # Start with 32 bits

    def qbits(self):
        return self.b.relu().sum() * self.weight.shape[1]

    def qweight(self):
        b_rel = self.b.relu()
        min_val = torch.where(b_rel > 0, -2 ** (b_rel - 1), torch.zeros_like(b_rel))
        max_val = torch.where(b_rel > 0, 2 ** (b_rel - 1) - 1, torch.zeros_like(b_rel))
        scaled_weight = 2 ** -self.e * self.weight
        qweight = torch.max(torch.min(scaled_weight, max_val), min_val)
        return qweight

    def forward(self, input):
        qw = self.qweight()
        w = (qw.round() - qw).detach() + qw  # Straight-through estimator
        output = nn.functional.linear(input, 2 ** self.e * w, self.bias)
        return output

### Self-Attention Head with Quantization (Head Class)

The `Head` class implements a single attention head in a self-attention mechanism, a core component of transformer models. This class is responsible for computing attention scores and applying them to the input to capture dependencies between tokens in a sequence. The attention mechanism is quantized using `QLinear` layers, which reduces the precision of weights for more efficient computation.

1. **Initialization (`__init__` Method)**:
   - The attention head has three main linear transformations:
     - **Key** (`self.key`): Projects the input into a "key" vector space.
     - **Query** (`self.query`): Projects the input into a "query" vector space.
     - **Value** (`self.value`): Projects the input into a "value" vector space.
   - All three transformations are quantized using the `QLinear` layer, ensuring reduced bit precision during training to save memory and computational resources.
   - **Triangular Mask (`self.tril`)**: A lower triangular matrix (`block_size x block_size`) is registered as a buffer to ensure that the attention mechanism respects the causal structure (no information from future tokens is used to predict the current token).
   - **Dropout (`self.dropout`)**: Applied to the attention weights to prevent overfitting during training.

2. **Quantization Bits Calculation (`qbits` Method)**:
   ```python
   def qbits(self):
       return self.key.qbits() + self.query.qbits() + self.value.qbits()


In [33]:
class Head(nn.Module):
    """ One head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = QLinear(n_embd, head_size, bias=False)
        self.query = QLinear(n_embd, head_size, bias=False)
        self.value = QLinear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def qbits(self):
        return self.key.qbits() + self.query.qbits() + self.value.qbits()

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)   # (B,T,head_size)
        q = self.query(x) # (B,T,head_size)
        # Compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * C ** -0.5  # (B,T,T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        # Perform the weighted aggregation of the values
        v = self.value(x)  # (B,T,head_size)
        out = wei @ v      # (B,T,head_size)
        return out


### Multi-Head Self-Attention (MultiHeadAttention Class)

The `MultiHeadAttention` class implements the multi-head self-attention mechanism, where multiple self-attention heads run in parallel to capture different aspects of the input. This is a crucial component of the transformer architecture, allowing the model to attend to information from multiple subspaces of the input sequence simultaneously. Each head operates independently but in parallel, and their outputs are combined for further processing.

1. **Initialization (`__init__` Method)**:
   - **Multiple Attention Heads (`self.heads`)**:
     ```python
     self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
     ```
     - The class initializes a list of attention heads, each represented by the `Head` class. The number of heads is controlled by the `num_heads` parameter, and each head has its own `head_size`, which determines the dimensionality of each individual attention subspace.
     - Using multiple heads allows the model to learn different types of attention patterns in parallel, enhancing the model's ability to capture complex relationships in the data.

   - **Output Projection (`self.proj`)**:
     ```python
     self.proj = QLinear(n_embd, n_embd)
     ```
     - After the attention heads have processed the input, their outputs are concatenated and passed through a quantized linear layer (`QLinear`). This layer projects the combined attention outputs back into the original embedding space (`n_embd`), ensuring that the multi-head attention integrates smoothly with the rest of the model.

   - **Dropout (`self.dropout`)**:
     ```python
     self.dropout = nn.Dropout(dropout)
     ```
     - Dropout is applied to the output of the projection layer to reduce overfitting and improve generalization. This regularization technique randomly sets a fraction of the weights to zero during training, ensuring the model doesn't become overly dependent on specific attention patterns.

2. **Quantization Bits Calculation (`qbits` Method)**:
   ```python
   def qbits(self):
       return sum(h.qbits() for h in self.heads) + self.proj.qbits()


In [34]:
class MultiHeadAttention(nn.Module):
    """ Multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = QLinear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def qbits(self):
        return sum(h.qbits() for h in self.heads) + self.proj.qbits()

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


### FeedForward Network with Quantization (FeedForward Class)

The `FeedForward` class implements a basic fully connected neural network layer with quantization applied to its linear layers. This is commonly used in transformer models after the self-attention mechanism to further process the information. The purpose of this layer is to transform the attention outputs into a more expressive representation by applying non-linear transformations.

1. **Initialization (`__init__` Method)**:
   - **Linear Layers (`QLinear`)**:
     ```python
     self.net = nn.Sequential(
         QLinear(n_embd, 4 * n_embd),
         nn.ReLU(),
         QLinear(4 * n_embd, n_embd),
         nn.Dropout(dropout),
     )
     ```
     The feedforward network consists of two fully connected layers (quantized using `QLinear`) with a non-linear activation function (`ReLU`) in between:
     - The first layer expands the embedding dimension (`n_embd`) by a factor of 4, making the model more expressive and allowing it to learn richer representations.
     - The second layer projects this expanded representation back to the original embedding dimension (`n_embd`).
   - **ReLU Activation**: 
     The ReLU (Rectified Linear Unit) activation function introduces non-linearity to the model, enabling it to capture more complex patterns in the data.
   - **Dropout**:
     Dropout is applied to the output of the second `QLinear` layer to reduce overfitting. It randomly sets some outputs to zero during training, which improves generalization by preventing the model from relying too heavily on specific neurons.

2. **Quantization Bits Calculation (`qbits` Method)**:
   ```python
   def qbits(self):
       return sum(layer.qbits() for layer in self.net if isinstance(layer, QLinear))


In [35]:

class FeedForward(nn.Module):
    """ A simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            QLinear(n_embd, 4 * n_embd),
            nn.ReLU(),
            QLinear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def qbits(self):
        return sum(layer.qbits() for layer in self.net if isinstance(layer, QLinear))

    def forward(self, x):
        return self.net(x)


### Transformer Block (Block Class)

The `Block` class represents a single transformer block, which is a core building component of the transformer architecture. Each block consists of two main stages:
1. **Communication**: This refers to the self-attention mechanism, where tokens in the input sequence "communicate" with each other to learn dependencies.
2. **Computation**: After communication, the block uses a feedforward network to transform the learned representations into more useful features.

1. **Initialization (`__init__` Method)**:
   - **Self-Attention (`self.sa`)**:
     ```python
     self.sa = MultiHeadAttention(n_head, head_size)
     ```
     The block uses a multi-head self-attention mechanism (`MultiHeadAttention`). The input embedding size (`n_embd`) is divided among the attention heads, where each head independently learns different aspects of the input sequence. This enables the model to capture a variety of relationships between tokens.
   
   - **FeedForward Network (`self.ffwd`)**:
     ```python
     self.ffwd = FeedForward(n_embd)
     ```
     After the self-attention layer, a fully connected feedforward network is applied to further process the representations learned through attention. This network helps the model extract more abstract features from the data.
   
   - **Layer Normalization (`self.ln1` and `self.ln2`)**:
     ```python
     self.ln1 = nn.LayerNorm(n_embd)
     self.ln2 = nn.LayerNorm(n_embd)
     ```
     Each transformer block uses **layer normalization**, which helps stabilize and improve training by normalizing the inputs to each sub-layer (self-attention and feedforward). This prevents extreme values in the activations and ensures smoother training.

2. **Quantization Bits Calculation (`qbits` Method)**:
   ```python
   def qbits(self):
       return self.sa.qbits() + self.ffwd.qbits()


In [36]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def qbits(self):
        return self.sa.qbits() + self.ffwd.qbits()

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


### Bigram Language Model (BigramLanguageModel Class)

The `BigramLanguageModel` class implements a language model that predicts the next token in a sequence based on the preceding context. It uses the transformer architecture to model long-range dependencies between tokens, and quantized layers for efficient memory usage. The model is designed to generate new sequences of text as well.

1. **Initialization (`__init__` Method)**:
   - **Token Embedding (`self.token_embedding_table`)**:
     ```python
     self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
     ```
     The model uses an embedding layer to map each token (integer) in the vocabulary to a continuous vector of size `n_embd`. This vector represents the semantic meaning of each token in a dense space.
   
   - **Position Embedding (`self.position_embedding_table`)**:
     ```python
     self.position_embedding_table = nn.Embedding(block_size, n_embd)
     ```
     Since transformers are permutation-invariant, positional embeddings are introduced to encode the position of each token in the sequence. This allows the model to understand the order of tokens.

   - **Transformer Blocks (`self.blocks`)**:
     ```python
     self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
     ```
     The model contains a stack of transformer blocks, each of which applies multi-head self-attention and a feedforward network. These blocks allow the model to attend to different parts of the sequence and learn complex relationships between tokens.

   - **Layer Normalization (`self.ln_f`)**:
     ```python
     self.ln_f = nn.LayerNorm(n_embd)
     ```
     After the transformer blocks, layer normalization is applied to stabilize training and improve generalization. It normalizes the activations before passing them to the output layer.

   - **Output Layer (`self.lm_head`)**:
     ```python
     self.lm_head = QLinear(n_embd, vocab_size)
     ```
     The final output is generated through a quantized linear layer (`QLinear`) that projects the transformed input back into the vocabulary space. This produces a set of logits, where each element represents the model's confidence in predicting a specific token from the vocabulary.

2. **Quantization Bits Calculation (`qbits` Method)**:
   ```python
   def qbits(self):
       qbits = 0
       qbits += sum(b.qbits() for b in self.blocks)
       qbits += self.lm_head.qbits()
       return qbits


In [37]:
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = QLinear(n_embd, vocab_size)

    def qbits(self):
        qbits = 0
        qbits += sum(b.qbits() for b in self.blocks)
        qbits += self.lm_head.qbits()
        return qbits

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx)  # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))  # (T,C)
        x = tok_emb + pos_emb  # (B,T,C)
        x = self.blocks(x)     # (B,T,C)
        x = self.ln_f(x)       # (B,T,C)
        logits = self.lm_head(x)  # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape      
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # Crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:] 
            # Get the predictions
            logits, _ = self(idx_cond)
            # Focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # Append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

### Model Initialization, Parameter Counting, and Size Tracking

1. **Model Initialization**:
   ```python
   model = BigramLanguageModel().to(device)
   optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)


In [38]:
model = BigramLanguageModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
print(f"Total parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
total_weight_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
model_size_history = []  # To store model size at each eval interval
iteration_history = []   # To store corresponding iterations


Total parameters: 0.21M


### Functions to Manage Zero-Bit Attention Heads in Transformer Blocks

In transformer models with quantization, there is a possibility that certain attention heads (or their components such as keys, queries, and values) are reduced to zero bits, effectively discarding them. These two functions are designed to monitor and restore attention heads that have been over-quantized to zero bits, ensuring that the model does not lose critical information.

1. **Checking for Zero-Bit Heads (`check_zero_bit_heads`)**:
   ```python
   def check_zero_bit_heads():
       for block in model.blocks:
           if isinstance(block, Block):
               for h in block.sa.heads:
                   for layer in [h.key, h.query, h.value]:
                       if (layer.b.view(-1) <= 0).any():
                           return True
       return False


In [39]:
def check_zero_bit_heads():
    for block in model.blocks:
        if isinstance(block, Block):
            for h in block.sa.heads:
                for layer in [h.key, h.query, h.value]:
                    if (layer.b.view(-1) <= 0).any():
                        return True
    return False

def restore_zero_bit_heads(restore_fraction=0.1):
    for block in model.blocks:
        if isinstance(block, Block):
            for h in block.sa.heads:
                for layer in [h.key, h.query, h.value]:
                    b_flat = layer.b.view(-1)
                    zero_bit_indices = (b_flat <= 0).nonzero(as_tuple=False).view(-1)
                    num_restore = int(restore_fraction * len(zero_bit_indices))
                    if num_restore > 0:
                        restore_indices = zero_bit_indices[torch.randperm(len(zero_bit_indices))[:num_restore]]
                        b_flat[restore_indices] = 2.0  # Restore bits to 2


### Training Loop with Safety Loss Monitoring, Compression, and Model Size Tracking

This training loop trains the transformer model while tracking the model's performance, safety loss, and size in terms of quantized bits. It integrates safety mechanisms that restore attention heads when the safety loss increases beyond a defined threshold, ensuring that the model remains robust during training.

1. **Main Training Step**:
   - The model is set to training mode (`model.train()`), and the optimizer is reset with `optimizer.zero_grad()`.
   - A batch of training data is fetched using `get_batch('train')`, and the model computes the logits and main loss (`loss_main`).
   - The number of quantized bits used by the model is computed as `Q`, and a compression regularization term is added to the main loss. This encourages the model to maintain efficient use of bits during training:
     ```python
     Q = model.qbits() / total_weight_count
     compression_weight = 0.1  # Adjust as needed
     loss = loss_main + compression_weight * Q
     ```

2. **Safety Loss and Weight**:
   - A separate batch from the safety set is fetched using `get_safety_batch()`, and the safety loss is calculated. The safety loss is added to the total loss with a predefined weight:
     ```python
     safety_weight = 0.05  # Adjust as needed
     loss = loss + safety_weight * safety_loss
     ```
     This ensures that the model maintains good performance on critical safety examples, which are often the hardest cases for the model.

3. **Backpropagation and Optimization**:
   - The total loss, which includes the main loss, compression penalty, and safety loss, is backpropagated (`loss.backward()`), and the model parameters are updated (`optimizer.step()`).

4. **Model Evaluation and Size Tracking**:
   - Every `eval_interval` steps, the model is switched to evaluation mode (`model.eval()`), and the validation and safety losses are estimated using the `estimate_loss()` function.
   - The model's current size is calculated in bits, bytes, and megabytes based on the number of quantized bits (`Q`):
     ```python
     current_qbits = model.qbits()
     current_size_bytes = current_qbits / 8
     current_size_mb = current_size_bytes / 1e6
     ```
   - The current model size and iteration are appended to tracking lists (`model_size_history` and `iteration_history`), allowing the model's compression efficiency to be tracked over time.

5. **Logging**:
   - The current training loss, validation loss, safety loss, quantization value, and model size are logged at regular intervals:
     ```python
     print(f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, "
           f"safety loss {safety_loss:.4f}, Q: {Q:.4f}, "
           f"Model Size: {current_size_mb:.6f} MB")
     ```

6. **Safety Loss Monitoring and Restoration**:
   - If the safety loss increases significantly compared to the previous evaluation (`prev_safety_loss`), the model checks for zero-bit attention heads using `check_zero_bit_heads()`.
   - If zero-bit heads are found, a fraction of these heads are restored to ensure the model maintains performance on the safety set:
     ```python
     if (safety_loss - prev_safety_loss) > safety_loss_increase_threshold:
         if check_zero_bit_heads():
             restore_zero_bit_heads(restore_fraction=0.1)
     ```
   - This mechanism ensures that the model doesn't overly compress or lose key components related to safety-critical examples during training.



In [40]:
prev_safety_loss = None
safety_loss_increase_threshold = 0.1  # Adjust as needed

# Training loop
for iter in trange(max_iters):
    # Training step
    model.train()
    optimizer.zero_grad()

    # Main training batch
    xb, yb = get_batch('train')
    logits, loss_main = model(xb, yb)
    Q = model.qbits() / total_weight_count

    # Compression regularization weight
    compression_weight = 0.1  # Adjust as needed
    loss = loss_main + compression_weight * Q

    # Safety loss
    xs, ys = get_safety_batch()
    logits_safety, safety_loss = model(xs, ys)
    safety_weight = 0.05  # Adjust as needed
    loss = loss + safety_weight * safety_loss

    loss.backward()
    optimizer.step()

    # Every eval_interval steps, check validation and safety loss, and log model size
    if iter % eval_interval == 0 or iter == max_iters - 1:
        model.eval()
        losses = estimate_loss()
        
        # Calculate current model size in bits and megabytes
        current_qbits = model.qbits()
        current_size_bytes = current_qbits / 8
        current_size_mb = current_size_bytes / 1e6

        # Append to history
        model_size_history.append(current_size_mb)
        iteration_history.append(iter)

        # Log the information
        print(f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, "
              f"safety loss {safety_loss:.4f}, Q: {Q:.4f}, "
              f"Model Size: {current_size_mb:.6f} MB")

        # Restore zero-bit heads if safety loss increases too much
        if prev_safety_loss is not None and (safety_loss - prev_safety_loss) > safety_loss_increase_threshold:
            if check_zero_bit_heads():
                restore_zero_bit_heads(restore_fraction=0.1)
        prev_safety_loss = safety_loss


  0%|          | 3/10000 [00:06<4:29:35,  1.62s/it] 

Step 0: train loss 3.2454, val loss 3.2907, safety loss 3.4857, Q: 30.2963, Model Size: 0.793311 MB


  1%|          | 102/10000 [00:22<3:40:23,  1.34s/it]

Step 100: train loss 2.4618, val loss 2.5702, safety loss 3.1995, Q: 30.1714, Model Size: 0.790040 MB


  2%|▏         | 202/10000 [00:39<3:38:11,  1.34s/it]

Step 200: train loss 2.3382, val loss 2.4530, safety loss 3.0213, Q: 30.0467, Model Size: 0.786773 MB


  3%|▎         | 302/10000 [00:54<2:34:20,  1.05it/s]

Step 300: train loss 2.2336, val loss 2.3208, safety loss 2.9918, Q: 29.9220, Model Size: 0.783509 MB


  4%|▍         | 402/10000 [01:11<3:40:34,  1.38s/it]

Step 400: train loss 2.1467, val loss 2.2115, safety loss 2.8753, Q: 29.7975, Model Size: 0.780248 MB


  5%|▌         | 503/10000 [01:28<2:42:36,  1.03s/it]

Step 500: train loss 2.0827, val loss 2.1425, safety loss 2.9439, Q: 29.6731, Model Size: 0.776990 MB


  6%|▌         | 602/10000 [01:43<2:33:01,  1.02it/s]

Step 600: train loss 2.0360, val loss 2.0645, safety loss 2.9106, Q: 29.5488, Model Size: 0.773737 MB


  7%|▋         | 702/10000 [02:00<3:34:47,  1.39s/it]

Step 700: train loss 1.9991, val loss 2.0042, safety loss 2.7738, Q: 29.4247, Model Size: 0.770486 MB


  8%|▊         | 803/10000 [02:16<1:55:46,  1.32it/s]

Step 800: train loss 1.9521, val loss 1.9575, safety loss 2.7759, Q: 29.3006, Model Size: 0.767237 MB


  9%|▉         | 902/10000 [02:31<2:15:59,  1.12it/s]

Step 900: train loss 1.9293, val loss 1.9088, safety loss 2.8679, Q: 29.1767, Model Size: 0.763993 MB


 10%|█         | 1002/10000 [02:45<2:17:39,  1.09it/s]

Step 1000: train loss 1.9101, val loss 1.9286, safety loss 2.7361, Q: 29.0530, Model Size: 0.760752 MB


 11%|█         | 1102/10000 [03:00<2:12:22,  1.12it/s]

Step 1100: train loss 1.8866, val loss 1.9052, safety loss 2.6414, Q: 28.9293, Model Size: 0.757514 MB


 12%|█▏        | 1203/10000 [03:15<1:36:53,  1.51it/s]

Step 1200: train loss 1.8688, val loss 1.8671, safety loss 2.7288, Q: 28.8058, Model Size: 0.754279 MB


 13%|█▎        | 1302/10000 [03:29<2:11:22,  1.10it/s]

Step 1300: train loss 1.8534, val loss 1.8502, safety loss 2.7195, Q: 28.6824, Model Size: 0.751048 MB


 14%|█▍        | 1403/10000 [03:44<1:34:35,  1.51it/s]

Step 1400: train loss 1.8289, val loss 1.8291, safety loss 2.6482, Q: 28.5591, Model Size: 0.747819 MB


 15%|█▌        | 1502/10000 [03:59<2:06:40,  1.12it/s]

Step 1500: train loss 1.8229, val loss 1.8095, safety loss 2.6492, Q: 28.4359, Model Size: 0.744594 MB


 16%|█▌        | 1603/10000 [04:14<1:31:29,  1.53it/s]

Step 1600: train loss 1.8144, val loss 1.8216, safety loss 2.6423, Q: 28.3129, Model Size: 0.741373 MB


 17%|█▋        | 1702/10000 [04:28<2:06:04,  1.10it/s]

Step 1700: train loss 1.8004, val loss 1.7893, safety loss 2.7785, Q: 28.1900, Model Size: 0.738154 MB


 18%|█▊        | 1803/10000 [04:43<1:28:52,  1.54it/s]

Step 1800: train loss 1.7988, val loss 1.7911, safety loss 2.5352, Q: 28.0672, Model Size: 0.734938 MB


 19%|█▉        | 1902/10000 [04:58<2:00:58,  1.12it/s]

Step 1900: train loss 1.7869, val loss 1.7737, safety loss 2.7035, Q: 27.9445, Model Size: 0.731726 MB


 20%|██        | 2003/10000 [05:12<1:26:21,  1.54it/s]

Step 2000: train loss 1.7879, val loss 1.8008, safety loss 2.4634, Q: 27.8220, Model Size: 0.728517 MB


 21%|██        | 2103/10000 [05:27<1:27:19,  1.51it/s]

Step 2100: train loss 1.7793, val loss 1.7816, safety loss 2.7077, Q: 27.6995, Model Size: 0.725311 MB


 22%|██▏       | 2203/10000 [05:42<1:24:16,  1.54it/s]

Step 2200: train loss 1.7614, val loss 1.7625, safety loss 2.6938, Q: 27.5772, Model Size: 0.722109 MB


 23%|██▎       | 2302/10000 [05:57<1:53:39,  1.13it/s]

Step 2300: train loss 1.7696, val loss 1.7558, safety loss 2.6384, Q: 27.4551, Model Size: 0.718910 MB


 24%|██▍       | 2402/10000 [06:11<1:54:11,  1.11it/s]

Step 2400: train loss 1.7533, val loss 1.7784, safety loss 2.4993, Q: 27.3330, Model Size: 0.715713 MB


 25%|██▌       | 2502/10000 [06:26<1:55:21,  1.08it/s]

Step 2500: train loss 1.7408, val loss 1.7561, safety loss 2.5805, Q: 27.2111, Model Size: 0.712521 MB


 26%|██▌       | 2603/10000 [06:41<1:21:03,  1.52it/s]

Step 2600: train loss 1.7468, val loss 1.7454, safety loss 2.5647, Q: 27.0893, Model Size: 0.709331 MB


 27%|██▋       | 2702/10000 [06:56<1:50:14,  1.10it/s]

Step 2700: train loss 1.7253, val loss 1.7548, safety loss 2.5832, Q: 26.9675, Model Size: 0.706144 MB


 28%|██▊       | 2802/10000 [07:11<1:49:11,  1.10it/s]

Step 2800: train loss 1.7278, val loss 1.7509, safety loss 2.5410, Q: 26.8460, Model Size: 0.702961 MB


 29%|██▉       | 2902/10000 [07:25<1:28:50,  1.33it/s]

Step 2900: train loss 1.7201, val loss 1.7165, safety loss 2.5378, Q: 26.7246, Model Size: 0.699781 MB


 30%|███       | 3003/10000 [07:40<1:14:54,  1.56it/s]

Step 3000: train loss 1.7183, val loss 1.7009, safety loss 2.5595, Q: 26.6032, Model Size: 0.696603 MB


 31%|███       | 3103/10000 [07:54<1:14:14,  1.55it/s]

Step 3100: train loss 1.7152, val loss 1.7225, safety loss 2.6082, Q: 26.4820, Model Size: 0.693430 MB


 32%|███▏      | 3203/10000 [08:09<1:12:59,  1.55it/s]

Step 3200: train loss 1.6962, val loss 1.7236, safety loss 2.4849, Q: 26.3609, Model Size: 0.690259 MB


 33%|███▎      | 3303/10000 [08:23<1:11:06,  1.57it/s]

Step 3300: train loss 1.6906, val loss 1.7003, safety loss 2.5140, Q: 26.2400, Model Size: 0.687091 MB


 34%|███▍      | 3403/10000 [08:38<1:11:13,  1.54it/s]

Step 3400: train loss 1.7052, val loss 1.7172, safety loss 2.4949, Q: 26.1191, Model Size: 0.683928 MB


 35%|███▌      | 3503/10000 [08:53<1:09:59,  1.55it/s]

Step 3500: train loss 1.6844, val loss 1.7395, safety loss 2.3692, Q: 25.9984, Model Size: 0.680766 MB


 36%|███▌      | 3603/10000 [09:07<1:08:44,  1.55it/s]

Step 3600: train loss 1.6859, val loss 1.7332, safety loss 2.3751, Q: 25.8778, Model Size: 0.677608 MB


 37%|███▋      | 3702/10000 [09:22<1:33:10,  1.13it/s]

Step 3700: train loss 1.6695, val loss 1.7030, safety loss 2.4719, Q: 25.7573, Model Size: 0.674454 MB


 38%|███▊      | 3802/10000 [09:36<1:31:13,  1.13it/s]

Step 3800: train loss 1.6748, val loss 1.7006, safety loss 2.5898, Q: 25.6370, Model Size: 0.671302 MB


 39%|███▉      | 3903/10000 [09:51<1:05:09,  1.56it/s]

Step 3900: train loss 1.6603, val loss 1.7116, safety loss 2.4186, Q: 25.5167, Model Size: 0.668153 MB


 40%|████      | 4003/10000 [10:05<1:04:25,  1.55it/s]

Step 4000: train loss 1.6686, val loss 1.6841, safety loss 2.3875, Q: 25.3966, Model Size: 0.665008 MB


 41%|████      | 4103/10000 [10:20<1:03:04,  1.56it/s]

Step 4100: train loss 1.6609, val loss 1.6932, safety loss 2.4153, Q: 25.2766, Model Size: 0.661866 MB


 42%|████▏     | 4202/10000 [10:34<1:25:32,  1.13it/s]

Step 4200: train loss 1.6540, val loss 1.6802, safety loss 2.3972, Q: 25.1567, Model Size: 0.658726 MB


 43%|████▎     | 4303/10000 [10:49<1:01:30,  1.54it/s]

Step 4300: train loss 1.6674, val loss 1.6911, safety loss 2.4216, Q: 25.0370, Model Size: 0.655591 MB


 44%|████▍     | 4403/10000 [11:04<1:00:49,  1.53it/s]

Step 4400: train loss 1.6521, val loss 1.7037, safety loss 2.3911, Q: 24.9174, Model Size: 0.652458 MB


 45%|████▌     | 4503/10000 [11:18<58:40,  1.56it/s]  

Step 4500: train loss 1.6519, val loss 1.6808, safety loss 2.3419, Q: 24.7978, Model Size: 0.649328 MB


 46%|████▌     | 4602/10000 [11:33<1:20:47,  1.11it/s]

Step 4600: train loss 1.6403, val loss 1.7100, safety loss 2.2698, Q: 24.6785, Model Size: 0.646202 MB


 47%|████▋     | 4702/10000 [11:47<1:17:02,  1.15it/s]

Step 4700: train loss 1.6382, val loss 1.7135, safety loss 2.3146, Q: 24.5592, Model Size: 0.643079 MB


 48%|████▊     | 4802/10000 [12:02<1:16:32,  1.13it/s]

Step 4800: train loss 1.6392, val loss 1.6811, safety loss 2.3237, Q: 24.4400, Model Size: 0.639958 MB


 49%|████▉     | 4903/10000 [12:16<55:57,  1.52it/s]  

Step 4900: train loss 1.6326, val loss 1.6793, safety loss 2.5153, Q: 24.3210, Model Size: 0.636842 MB


 50%|█████     | 5002/10000 [12:31<1:12:59,  1.14it/s]

Step 5000: train loss 1.6246, val loss 1.6612, safety loss 2.3543, Q: 24.2021, Model Size: 0.633727 MB


 51%|█████     | 5103/10000 [12:46<52:30,  1.55it/s]  

Step 5100: train loss 1.6309, val loss 1.6721, safety loss 2.2656, Q: 24.0833, Model Size: 0.630616 MB


 52%|█████▏    | 5203/10000 [13:00<52:05,  1.54it/s]  

Step 5200: train loss 1.6141, val loss 1.6674, safety loss 2.3390, Q: 23.9646, Model Size: 0.627509 MB


 53%|█████▎    | 5303/10000 [13:15<49:54,  1.57it/s]  

Step 5300: train loss 1.6168, val loss 1.6893, safety loss 2.3468, Q: 23.8460, Model Size: 0.624404 MB


 54%|█████▍    | 5403/10000 [13:29<49:32,  1.55it/s]  

Step 5400: train loss 1.6122, val loss 1.6828, safety loss 2.3054, Q: 23.7276, Model Size: 0.621302 MB


 55%|█████▌    | 5503/10000 [13:44<48:13,  1.55it/s]  

Step 5500: train loss 1.6129, val loss 1.6652, safety loss 2.3669, Q: 23.6093, Model Size: 0.618204 MB


 56%|█████▌    | 5603/10000 [13:59<47:55,  1.53it/s]  

Step 5600: train loss 1.5984, val loss 1.6591, safety loss 2.2984, Q: 23.4910, Model Size: 0.615109 MB


 57%|█████▋    | 5703/10000 [14:13<46:25,  1.54it/s]  

Step 5700: train loss 1.6060, val loss 1.6597, safety loss 2.1663, Q: 23.3729, Model Size: 0.612016 MB


 58%|█████▊    | 5803/10000 [14:28<45:14,  1.55it/s]  

Step 5800: train loss 1.6062, val loss 1.6622, safety loss 2.1942, Q: 23.2550, Model Size: 0.608927 MB


 59%|█████▉    | 5903/10000 [14:42<44:04,  1.55it/s]  

Step 5900: train loss 1.5966, val loss 1.6679, safety loss 2.1832, Q: 23.1371, Model Size: 0.605841 MB


 60%|██████    | 6003/10000 [14:57<43:01,  1.55it/s]

Step 6000: train loss 1.5982, val loss 1.6625, safety loss 2.1935, Q: 23.0194, Model Size: 0.602758 MB


 61%|██████    | 6103/10000 [15:12<41:40,  1.56it/s]

Step 6100: train loss 1.5927, val loss 1.6438, safety loss 2.2206, Q: 22.9018, Model Size: 0.599679 MB


 62%|██████▏   | 6203/10000 [15:26<41:29,  1.53it/s]

Step 6200: train loss 1.5861, val loss 1.6782, safety loss 2.2119, Q: 22.7843, Model Size: 0.596601 MB


 63%|██████▎   | 6302/10000 [15:41<54:03,  1.14it/s]

Step 6300: train loss 1.5875, val loss 1.6718, safety loss 2.2422, Q: 22.6669, Model Size: 0.593528 MB


 64%|██████▍   | 6403/10000 [15:55<38:49,  1.54it/s]

Step 6400: train loss 1.5965, val loss 1.6653, safety loss 2.1818, Q: 22.5497, Model Size: 0.590457 MB


 65%|██████▌   | 6503/10000 [16:10<38:16,  1.52it/s]

Step 6500: train loss 1.5880, val loss 1.6553, safety loss 2.1581, Q: 22.4325, Model Size: 0.587389 MB


 66%|██████▌   | 6602/10000 [16:24<49:45,  1.14it/s]

Step 6600: train loss 1.5855, val loss 1.6498, safety loss 2.1013, Q: 22.3155, Model Size: 0.584325 MB


 67%|██████▋   | 6703/10000 [16:39<35:45,  1.54it/s]

Step 6700: train loss 1.5737, val loss 1.6549, safety loss 2.2664, Q: 22.1986, Model Size: 0.581264 MB


 68%|██████▊   | 6803/10000 [16:54<34:48,  1.53it/s]

Step 6800: train loss 1.5696, val loss 1.6393, safety loss 2.1264, Q: 22.0817, Model Size: 0.578205 MB


 69%|██████▉   | 6902/10000 [17:09<46:22,  1.11it/s]

Step 6900: train loss 1.5707, val loss 1.6512, safety loss 2.0692, Q: 21.9651, Model Size: 0.575150 MB


 70%|███████   | 7002/10000 [17:23<44:39,  1.12it/s]

Step 7000: train loss 1.5672, val loss 1.6468, safety loss 2.2357, Q: 21.8485, Model Size: 0.572098 MB


 71%|███████   | 7102/10000 [17:38<42:49,  1.13it/s]

Step 7100: train loss 1.5679, val loss 1.6411, safety loss 2.1125, Q: 21.7321, Model Size: 0.569048 MB


 72%|███████▏  | 7202/10000 [17:53<41:53,  1.11it/s]

Step 7200: train loss 1.5567, val loss 1.6377, safety loss 2.0745, Q: 21.6158, Model Size: 0.566003 MB


 73%|███████▎  | 7303/10000 [18:07<29:31,  1.52it/s]

Step 7300: train loss 1.5684, val loss 1.6335, safety loss 2.0755, Q: 21.4995, Model Size: 0.562959 MB


 74%|███████▍  | 7403/10000 [18:22<27:52,  1.55it/s]

Step 7400: train loss 1.5515, val loss 1.6539, safety loss 2.1115, Q: 21.3834, Model Size: 0.559919 MB


 75%|███████▌  | 7503/10000 [18:37<27:21,  1.52it/s]

Step 7500: train loss 1.5709, val loss 1.6482, safety loss 2.0199, Q: 21.2675, Model Size: 0.556882 MB


 76%|███████▌  | 7603/10000 [18:51<25:44,  1.55it/s]

Step 7600: train loss 1.5586, val loss 1.6587, safety loss 2.1131, Q: 21.1516, Model Size: 0.553848 MB


 77%|███████▋  | 7703/10000 [19:06<24:44,  1.55it/s]

Step 7700: train loss 1.5535, val loss 1.6304, safety loss 2.1496, Q: 21.0358, Model Size: 0.550817 MB


 78%|███████▊  | 7803/10000 [19:21<23:36,  1.55it/s]

Step 7800: train loss 1.5543, val loss 1.6602, safety loss 2.0554, Q: 20.9202, Model Size: 0.547790 MB


 79%|███████▉  | 7903/10000 [19:35<23:04,  1.51it/s]

Step 7900: train loss 1.5529, val loss 1.6405, safety loss 2.1341, Q: 20.8047, Model Size: 0.544764 MB


 80%|████████  | 8003/10000 [19:50<21:54,  1.52it/s]

Step 8000: train loss 1.5688, val loss 1.6260, safety loss 1.9548, Q: 20.6893, Model Size: 0.541743 MB


 81%|████████  | 8102/10000 [20:05<28:29,  1.11it/s]

Step 8100: train loss 1.5409, val loss 1.6586, safety loss 2.0466, Q: 20.5740, Model Size: 0.538724 MB


 82%|████████▏ | 8203/10000 [20:20<19:23,  1.55it/s]

Step 8200: train loss 1.5424, val loss 1.6425, safety loss 2.0751, Q: 20.4588, Model Size: 0.535708 MB


 83%|████████▎ | 8303/10000 [20:34<18:10,  1.56it/s]

Step 8300: train loss 1.5246, val loss 1.6319, safety loss 1.9700, Q: 20.3438, Model Size: 0.532695 MB


 84%|████████▍ | 8403/10000 [20:49<17:29,  1.52it/s]

Step 8400: train loss 1.5379, val loss 1.6413, safety loss 2.1305, Q: 20.2289, Model Size: 0.529686 MB


 85%|████████▌ | 8503/10000 [21:04<16:35,  1.50it/s]

Step 8500: train loss 1.5321, val loss 1.6316, safety loss 2.1489, Q: 20.1140, Model Size: 0.526678 MB


 86%|████████▌ | 8603/10000 [21:19<15:11,  1.53it/s]

Step 8600: train loss 1.5387, val loss 1.6631, safety loss 2.1267, Q: 19.9993, Model Size: 0.523675 MB


 87%|████████▋ | 8703/10000 [21:33<13:52,  1.56it/s]

Step 8700: train loss 1.5282, val loss 1.6479, safety loss 2.1197, Q: 19.8847, Model Size: 0.520674 MB


 88%|████████▊ | 8803/10000 [21:48<13:02,  1.53it/s]

Step 8800: train loss 1.5240, val loss 1.6382, safety loss 1.9172, Q: 19.7703, Model Size: 0.517676 MB


 89%|████████▉ | 8903/10000 [22:02<11:49,  1.55it/s]

Step 8900: train loss 1.5260, val loss 1.6556, safety loss 1.9717, Q: 19.6559, Model Size: 0.514682 MB


 90%|█████████ | 9003/10000 [22:17<10:36,  1.57it/s]

Step 9000: train loss 1.5280, val loss 1.6483, safety loss 2.0470, Q: 19.5416, Model Size: 0.511690 MB


 91%|█████████ | 9103/10000 [22:32<09:45,  1.53it/s]

Step 9100: train loss 1.5261, val loss 1.6514, safety loss 1.9573, Q: 19.4275, Model Size: 0.508701 MB


 92%|█████████▏| 9203/10000 [22:46<08:28,  1.57it/s]

Step 9200: train loss 1.5208, val loss 1.6439, safety loss 1.9381, Q: 19.3135, Model Size: 0.505716 MB


 93%|█████████▎| 9303/10000 [23:01<07:35,  1.53it/s]

Step 9300: train loss 1.5179, val loss 1.6624, safety loss 2.0066, Q: 19.1996, Model Size: 0.502732 MB


 94%|█████████▍| 9403/10000 [23:16<06:26,  1.54it/s]

Step 9400: train loss 1.5281, val loss 1.6545, safety loss 1.9307, Q: 19.0858, Model Size: 0.499753 MB


 95%|█████████▌| 9503/10000 [23:30<05:27,  1.52it/s]

Step 9500: train loss 1.5185, val loss 1.6579, safety loss 1.9721, Q: 18.9721, Model Size: 0.496776 MB


 96%|█████████▌| 9603/10000 [23:45<04:15,  1.55it/s]

Step 9600: train loss 1.5160, val loss 1.6462, safety loss 1.9976, Q: 18.8585, Model Size: 0.493802 MB


 97%|█████████▋| 9703/10000 [23:59<03:09,  1.57it/s]

Step 9700: train loss 1.5091, val loss 1.6556, safety loss 1.8980, Q: 18.7451, Model Size: 0.490832 MB


 98%|█████████▊| 9803/10000 [24:14<02:08,  1.53it/s]

Step 9800: train loss 1.5183, val loss 1.6721, safety loss 2.0523, Q: 18.6317, Model Size: 0.487864 MB


 99%|█████████▉| 9902/10000 [24:29<01:26,  1.13it/s]

Step 9900: train loss 1.5235, val loss 1.6506, safety loss 1.9608, Q: 18.5185, Model Size: 0.484899 MB


100%|██████████| 10000/10000 [24:43<00:00,  6.74it/s]

Step 9999: train loss 1.5138, val loss 1.6625, safety loss 1.9371, Q: 18.4066, Model Size: 0.481967 MB



