In [34]:
import torch
from torch import nn
import numpy as np
import random

torch.manual_seed(42)
np.random.seed(42)
random.seed(44)

class DeepNN(nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'special'):
        super().__init__()

        torch.set_default_dtype(torch.float32)

        layers = []
        prev_dim = d
        for _ in range(depth):
            linear = nn.Linear(prev_dim, hidden_size)

            if mode == 'special':
                # Special initialization as in original code
                gain = nn.init.calculate_gain('relu')
                std = gain / np.sqrt(prev_dim)
                nn.init.normal_(linear.weight, mean=0.0, std=std)
                nn.init.zeros_(linear.bias)
            else:
                # Standard PyTorch initialization
                nn.init.xavier_uniform_(linear.weight)
                nn.init.zeros_(linear.bias)

            layers.extend([
                linear,
                nn.ReLU()
            ])
            prev_dim = hidden_size

        final_layer = nn.Linear(prev_dim, 1)
        if mode == 'special':
            nn.init.normal_(final_layer.weight, std=0.01)
        else:
            nn.init.xavier_uniform_(final_layer.weight)
        nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)

        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def evaluate_model(model, X, y):
    model.eval()
    with torch.no_grad():
        y_pred = model(X)
        mse = torch.mean((y_pred - y) ** 2).item()
    return mse

In [121]:
model = DeepNN(30, 800, 4)
n_train = 20000

state_dict = torch.load(f'stair_function/results/msp_NN_gpu_test/final_model_h800_d4_n{n_train}.pt')
model.load_state_dict(state_dict)

  state_dict = torch.load(f'stair_function/results/msp_NN_gpu_test/final_model_h800_d4_n{n_train}.pt')


<All keys matched successfully>

In [122]:
def count_matching_rows(tensor_a: torch.Tensor, tensor_b: torch.Tensor) -> int:
    """
    Counts the number of rows in tensor_a that exist in tensor_b.

    Args:
        tensor_a (torch.Tensor): Tensor of shape (m, d).
        tensor_b (torch.Tensor): Tensor of shape (n, d).

    Returns:
        int: Number of matching rows.
    """

    # Validate that both tensors are 2D
    if tensor_a.dim() != 2 or tensor_b.dim() != 2:
        raise ValueError("Both tensors must be 2D (matrices).")

    # Validate that both tensors have the same number of columns
    if tensor_a.size(1) != tensor_b.size(1):
        raise ValueError("Both tensors must have the same number of columns (features).")

    # Validate that both tensors have the same data type
    if tensor_a.dtype != tensor_b.dtype:
        raise ValueError("Both tensors must have the same data type.")

    # Move tensors to CPU and convert to NumPy arrays
    a_np = tensor_a.cpu().numpy()
    b_np = tensor_b.cpu().numpy()

    # View each row as a single entity by creating a structured dtype
    # This allows us to treat each row as a unique record
    a_view = a_np.view([('', a_np.dtype)] * a_np.shape[1])
    b_view = b_np.view([('', b_np.dtype)] * b_np.shape[1])

    # Use NumPy's in1d to check for each row in a_view if it exists in b_view
    matches = np.isin(a_view, b_view)

    # Count the number of matches
    count = np.sum(matches)

    return count

In [123]:
import pickle

with open(f'stair_function/results/no_overlap/train_data_h800_d4_n{n_train}_lr0.001_standard.pkl', 'rb') as f:
    X_train, y_train = pickle.load(f)

with open('stair_function/results/no_overlap/test_data.pkl', 'rb') as f:
    X_test, y_test = pickle.load(f)

int(count_matching_rows(X_train, X_test))

0

In [105]:
train_error = evaluate_model(model, X_train, y_train)
test_error = evaluate_model(model, X_test, y_test)

print(f'Train error: {train_error:.4f}')
print(f'Test error: {test_error:.4f}')

Train error: 0.0001
Test error: 0.0000


In [106]:
import torch
import torch.nn as nn
import copy

def pad_final_linear_layer(model, target_output_features=16):
    """
    Pads the final nn.Linear layer's output features to the target number by adding dummy outputs with zero weights.

    Args:
        model (nn.Module): The trained PyTorch model.
        target_output_features (int): Desired number of output features (must be divisible by 16).

    Returns:
        nn.Module: The modified model with the padded final linear layer.
    """
    # Deep copy the model to avoid in-place modifications
    model_padded = copy.deepcopy(model)

    # Find the last nn.Linear layer
    linear_layers = [module for module in model_padded.modules() if isinstance(module, nn.Linear)]
    if not linear_layers:
        raise ValueError("No nn.Linear layer found in the model.")

    final_layer = linear_layers[-1]

    current_out_features = final_layer.out_features
    in_features = final_layer.in_features

    if current_out_features >= target_output_features:
        raise ValueError(f"Current output features ({current_out_features}) >= target ({target_output_features}).")

    # Calculate number of dummy outputs to add
    num_to_add = target_output_features - current_out_features

    # Pad the weights with zeros
    with torch.no_grad():
        # Create a tensor of zeros for the new weights
        dummy_weights = torch.zeros((num_to_add, in_features), dtype=final_layer.weight.dtype, device=final_layer.weight.device)
        # Concatenate the existing weights with the dummy weights
        final_layer.weight = nn.Parameter(torch.cat([final_layer.weight, dummy_weights], dim=0))

        if final_layer.bias is not None:
            # Create a tensor of zeros for the new biases
            dummy_bias = torch.zeros(num_to_add, dtype=final_layer.bias.dtype, device=final_layer.bias.device)
            # Concatenate the existing biases with the dummy biases
            final_layer.bias = nn.Parameter(torch.cat([final_layer.bias, dummy_bias], dim=0))

    return model_padded

padded_model = pad_final_linear_layer(model, target_output_features=16)
padded_model.state_dict()['network.8.weight'].shape

torch.Size([16, 800])

In [107]:
import copy

from torchao.quantization import (
    quantize_,
    int8_dynamic_activation_int8_weight,
    int8_weight_only,
    int4_weight_only,
)
from torchao.quantization import DEFAULT_INT4_AUTOQUANT_CLASS_LIST
from torchao.quantization.autoquant import AUTOQUANT_CACHE

def quantize_fp16(model):
    model_fp16 = copy.deepcopy(model)
    model_fp16.half()
    return model_fp16

def quantize_int8_weight_only(model):
    model_int8_wo = copy.deepcopy(model)
    quantize_(model_int8_wo, int8_weight_only())
    return model_int8_wo

def quantize_int8_dynamic(model):
    model_int8 = copy.deepcopy(model)
    quantize_(model_int8, int8_dynamic_activation_int8_weight())
    return model_int8

def quantize_int4_weight_only(model):
    model_int4_wo = copy.deepcopy(padded_model).to(torch.bfloat16)
    quantize_(model_int4_wo, int4_weight_only(group_size=32, use_hqq=False))  # Adjust parameters as needed
    return model_int4_wo

fp16_model = quantize_fp16(model)
int8_wo_model = quantize_int8_weight_only(model)
int8_model = quantize_int8_dynamic(model)
int4_wo_model = quantize_int4_weight_only(model)


In [108]:
import os

def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    print("model: ",label,' \t','Size (KB):', size/1e3)
    os.remove('temp.p')

print_size_of_model(model, "fp32")
print_size_of_model(fp16_model, "fp16")
print_size_of_model(int8_wo_model, "int8_wo")
print_size_of_model(int8_model, "int8")
print_size_of_model(int4_wo_model, "int4_wo")

model:  fp32  	 Size (KB): 7795.278
model:  fp16  	 Size (KB): 3899.214
model:  int8_wo  	 Size (KB): 2002.706
model:  int8  	 Size (KB): 2003.474
model:  int4_wo  	 Size (KB): 1606.006


In [109]:
models = {'fp32': model, 'int8_wo': int8_wo_model, 'int8': int8_model}

for k in models:
    print(f'Model: {k}')
    m = models[k]
    train_error = evaluate_model(m, X_train, y_train)
    test_error = evaluate_model(m, X_test, y_test)
    print(f'Train error: {train_error:.4f}')
    print(f'Test error: {test_error:.4f}')

Model: fp32
Train error: 0.0001
Test error: 0.0000
Model: int8_wo
Train error: 0.0005
Test error: 0.0005
Model: int8
Train error: 0.0017
Test error: 0.0018


In [110]:
print('Model: fp16')
train_error = evaluate_model(fp16_model, X_train.half(), y_train.half())
test_error = evaluate_model(fp16_model, X_test.half(), y_test.half())
print(f'Train error: {train_error:.4f}')
print(f'Test error: {test_error:.4f}')

Model: fp16
Train error: 0.0001
Test error: 0.0000


In [111]:
def evaluate_int4(model, X, y):
    model.eval()
    with torch.no_grad():
        y_pred = model(X)[:, 0]
        mse = torch.mean((y_pred - y) ** 2).item()
    return mse


print('Model: int4_wo')
train_error = evaluate_int4(int4_wo_model, X_train.to(torch.bfloat16), y_train.to(torch.bfloat16))
test_error = evaluate_int4(int4_wo_model, X_test.to(torch.bfloat16), y_test.to(torch.bfloat16))
print(f'Train error: {train_error:.4f}')
print(f'Test error: {test_error:.4f}')

Model: int4_wo
Train error: 0.0245
Test error: 0.0236
