In [1]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
token = user_secrets.get_secret("github_repos_wildcard")

In [2]:
repo_url = f"https://{token}@github.com/gaserSami/panther.git"
branch = "autotuner"

In [3]:
!git clone -b {branch} {repo_url}

fatal: destination path 'panther' already exists and is not an empty directory.


In [4]:
# First uninstall existing torch, torchvision, torchaudio
!pip uninstall -y torch torchvision torchaudio

# Install the specified versions from PyTorch's official CUDA 12.4 wheels
!pip install torch==2.6.0+cu124 torchvision==0.21.0+cu124 torchaudio==2.6.0+cu124 --index-url https://download.pytorch.org/whl/cu124

Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
Looking in indexes: https://download.pytorch.org/whl/cu124
Collecting torch==2.6.0+cu124
  Using cached https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp311-cp311-linux_x86_64.whl.metadata (28 kB)
Collecting torchvision==0.21.0+cu124
  Using cached https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp311-cp311-linux_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio==2.6.0+cu124
  Using cached https://download.pytorch.org/whl/cu124/torchaudio-2.6.0%2Bcu124-cp311-cp311-linux_x86_64.whl.metadata (6.6 kB)
Using cached https://download.pytorch.org/whl/cu124/torch

In [5]:
!mv panther Panther

In [6]:
import os

src = '/kaggle/working/Panther'
dst = '/kaggle/working/panther'

# Simple rename
os.rename(src, dst)

In [7]:
%%writefile /kaggle/working/panther/pawX/setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name="pawX",
    ext_modules=[
        CUDAExtension(
            name="pawX",
            sources=[
                "skops.cpp",
                "bindings.cpp",
                "linear.cpp",
                "linear_cuda.cu",
                "cqrrpt.cpp",
                "rsvd.cpp",
                "attention.cpp",
            ],
            # Use system includes and libraries
            include_dirs=["/usr/include/x86_64-linux-gnu"],
            library_dirs=[],
            libraries=["openblas"],
            extra_compile_args={"cxx": ["-O2", "-fopenmp"], "nvcc": ["-O2"]},
            extra_link_args=["-llapacke", "-lopenblas"]
        )
    ],
    cmdclass={"build_ext": BuildExtension},
)

Overwriting /kaggle/working/panther/pawX/setup.py


In [8]:
!sudo apt-get install liblapacke-dev

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
liblapacke-dev is already the newest version (3.10.0-2ubuntu1).
0 upgraded, 0 newly installed, 0 to remove and 122 not upgraded.


In [9]:
!cd /kaggle/working/panther/pawX; python setup.py install
!cd /kaggle/working/panther/pawX; pip install --no-build-isolation -e .

!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
!!

        ********************************************************************************
        Please avoid running ``setup.py`` and ``easy_install``.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://github.com/pypa/setuptools/issues/917 for details.
        ********************************************************************************

!!
  self.initialize_options()
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Emitting ninja build file /kaggle/working/pan

In [10]:
import torch
print(torch.__version__)
import triton
print(triton.__version__)

2.6.0+cu124
3.2.0


In [11]:
import os
os.chdir("/kaggle/working/panther")

In [12]:
!pwd

/kaggle/working/panther


In [13]:
%%writefile /kaggle/working/panther/panther/nn/conv2d.py
import math
from typing import Any, Tuple

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.nn import functional as F
from torch.nn import init

from panther.random import scaled_sign_sketch as gen_U


def mode4_unfold(tensor: torch.Tensor) -> torch.Tensor:
    """Computes mode-4 matricization (unfolding along the last dimension)."""
    return tensor.reshape(-1, tensor.shape[-1])  # (I4, I1 * I2 * I3)


class SketchedConv2dFunction(Function):
    # Note that forward, setup_context, and backward are @staticmethods
    @staticmethod
    def forward(
        input: torch.Tensor,
        S1s: torch.Tensor,
        S2s: torch.Tensor,
        U1s: torch.Tensor,
        U2s: torch.Tensor,
        stride: Tuple[int, int],
        padding: Tuple[int, int],
        kernelSize: Tuple[int, int],
        inshape,
        bias: torch.Tensor,
    ):
        # in_channels, height, width = input.shape
        _, dout = U1s[0].shape
        hout = (inshape[2] + 2 * padding[0] - kernelSize[0]) // stride[0] + 1
        wout = (inshape[3] + 2 * padding[1] - kernelSize[1]) // stride[1] + 1
        input.transpose_(1, 2)
        t = (
            torch.einsum("nab,lbc,lcd->nlad", input, S1s, U1s)
            + torch.einsum("nab,lbc,lcd->nlad", input, U2s.transpose(1, 2), S2s)
        ).mean(dim=1)
        t = t.view(inshape[0], dout, hout, wout)
        return t + bias.view(1, dout, 1, 1)

    @staticmethod
    # inputs is a Tuple of all of the inputs passed to forward.
    # output is the output of the forward().
    def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any):
        input, S1s, S2s, U1s, U2s, stride, padding, kernelSize, inshap, bias = inputs
        ctx.save_for_backward(
            input,
            S1s,
            S2s,
            U1s,
            U2s,
            torch.tensor(stride),
            torch.tensor(padding),
            torch.tensor(kernelSize),
            torch.tensor(inshap),
            bias,
        )

    @staticmethod
    def backward(ctx: Any, *grad_output: Any) -> Any:
        input, S1s, S2s, U1s, U2s, stride, padding, kernelSize, inshape, bias = (
            ctx.saved_tensors
        )
        input.transpose_(1, 2)
        num_terms, _, __ = S2s.shape
        hout = grad_output[0].shape[2]
        wout = grad_output[0].shape[3]
        g_bias = grad_output[0].sum(dim=(0, 2, 3))
        grad_output = grad_output[0].view(
            grad_output[0].shape[0],
            hout * wout,
            grad_output[0].shape[1],
        )
        grad_output /= 2 * num_terms
        g_S1s = torch.zeros_like(S1s)
        g_S2s = torch.zeros_like(S2s)
        g_S1s = torch.einsum(
            "nab,nbc,lcd->lad", input, grad_output, U1s.transpose(1, 2)
        )
        g_S2s = torch.einsum("lab,nbc,ncd->lad", U2s, input, grad_output)
        gout = torch.einsum(
            "nab,lbc,lcd->nad", grad_output, U1s.transpose(1, 2), S1s.transpose(1, 2)
        ) + torch.einsum("nab,lbc,lcd->nad", grad_output, S2s.transpose(1, 2), U2s)
        fold = nn.Fold(
            output_size=(inshape[2], inshape[3]),
            kernel_size=(kernelSize[0], kernelSize[1]),
            stride=stride,
            padding=padding,
        )
        gout = gout.transpose(1, 2)
        gout = fold(gout)

        return (gout, g_S1s, g_S2s, None, None, None, None, None, None, g_bias)


class SKConv2d(nn.Module):
    __constants__ = ["in_features", "out_features", "num_terms", "low_rank"]
    in_features: int
    out_features: int
    num_terms: int
    low_rank: int
    S1s: torch.Tensor
    S2s: torch.Tensor
    U1s: torch.Tensor
    U2s: torch.Tensor

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Tuple = (3, 3),
        stride: Tuple = (1, 1),
        padding: Tuple = (1, 1),
        num_terms: int = 6,
        low_rank: int = 8,
        dtype=None,
        device=None,
    ):
        factory_kwargs = {"dtype": dtype, "device": device}
        super(SKConv2d, self).__init__()
        self.num_terms = num_terms
        self.low_rank = low_rank
        self.out_channels = out_channels
        self.in_channels = in_channels
        self.stride = stride if isinstance(stride, tuple) else (stride, stride)
        self.padding = padding if isinstance(padding, tuple) else (padding, padding)
        self.kernel_size = (
            kernel_size
            if isinstance(kernel_size, tuple)
            else (kernel_size, kernel_size)
        )
        self.register_buffer(
            "U1s",
            torch.stack(
                [
                    gen_U(low_rank, out_channels, **factory_kwargs)
                    for _ in range(num_terms)
                ]
            ),
        )  # kxd1
        self.register_buffer(
            "U2s",
            torch.stack(
                [
                    gen_U(
                        low_rank * self.kernel_size[0] * self.kernel_size[1],
                        in_channels * self.kernel_size[0] * self.kernel_size[1],
                        **factory_kwargs,
                    )
                    for _ in range(num_terms)
                ]
            ),
        )  # k h w x d2 h w
        kernels = nn.Parameter(
            torch.empty(
                (in_channels, *self.kernel_size, out_channels), **factory_kwargs
            )
        )  # doutxdinxhxw
        init.kaiming_uniform_(kernels, a=math.sqrt(5))
        self.S1s = nn.Parameter(
            torch.stack(
                [
                    mode4_unfold(torch.matmul(kernels, self.U1s[i].T))
                    for i in range(num_terms)
                ]
            )
        )  # d2xk
        K_mat4 = kernels.view(
            in_channels * self.kernel_size[0] * self.kernel_size[1], out_channels
        )
        self.S2s = nn.Parameter(
            torch.stack(
                [
                    mode4_unfold(
                        torch.matmul(self.U2s[i], K_mat4).view(
                            low_rank, *self.kernel_size, out_channels
                        )
                    )
                    for i in range(num_terms)
                ]
            )
        )  #
        self.bias = nn.Parameter(torch.empty(out_channels, **factory_kwargs))
        fan_in, _ = init._calculate_fan_in_and_fan_out(kernels)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        init.uniform_(self.bias, -bound, bound)

        # Register U1s and U2s as buffers since they are not learnable

    def forward(self, x):
        """Forward pass of the SKConv2d layer."""
        # padd x
        B, C, H, W = x.shape
        if self.padding[0] > 0 or self.padding[1] > 0:
            x = F.pad(
                x, (self.padding[1], self.padding[1], self.padding[0], self.padding[0])
            )
        H_out = (x.shape[2] - self.kernel_size[0]) // self.stride[0] + 1
        W_out = (x.shape[3] - self.kernel_size[1]) // self.stride[1] + 1
        x_strided = x.as_strided(
            size=(
                x.shape[0],
                x.shape[1],
                H_out,
                W_out,
                self.kernel_size[0],
                self.kernel_size[1],
            ),
            stride=(
                x.stride(0),
                x.stride(1),
                x.stride(2) * self.stride[0],
                x.stride(3) * self.stride[1],
                x.stride(2),
                x.stride(3),
            ),
        )
        x_windows = x_strided.permute(0, 2, 3, 1, 4, 5)

        x_windows = x_windows.reshape(
            -1, self.kernel_size[0] * self.kernel_size[1] * self.in_channels
        )
        out1 = (
            torch.einsum("nd,tdr,tro->no", x_windows, self.S1s, self.U1s)
            / self.num_terms
        ) + self.bias
        out2 = (
            torch.einsum(
                "nd,tdr,tro->no", x_windows, self.U2s.transpose(1, 2), self.S2s
            )
            / self.num_terms
        )
        return (
            (out1 + out2 + self.bias)
            .view(B, H_out, W_out, self.out_channels)
            .permute(0, 3, 1, 2)
        )

Overwriting /kaggle/working/panther/panther/nn/conv2d.py


In [14]:
import torch
from transformers import BertForMaskedLM, BertTokenizer

# 1. Load pretrained tokenizer & model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model     = BertForMaskedLM.from_pretrained("bert-base-uncased")
model.eval()

# 2. Tokenize a sentence with a mask
text = "Machine learning is the future of [MASK]."
inputs = tokenizer(text, return_tensors="pt")

# 3. Forward pass: yields logits over the full vocab for each position
with torch.no_grad():
    outputs = model(**inputs)
logits = outputs.logits  # shape: (1, seq_len, vocab_size)

# 4. Locate the [MASK] position
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)
batch_idx, token_idx = mask_token_index

# 5. Extract the logits for that position and pick the highest-scoring token
mask_logits = logits[batch_idx, token_idx, :]
predicted_token_id = mask_logits.argmax(dim=-1).item()
predicted_token = tokenizer.decode([predicted_token_id])

print(f"Filled mask: {predicted_token}")  # e.g. “technology”

2025-05-06 15:20:26.369298: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746544826.393841     337 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746544826.400921     337 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForP

Filled mask: science


In [15]:
!pip install botorch



In [16]:
%%writefile /kaggle/working/panther/panther/random.py
import torch

# DISCLAIMER: THIS FILE NEEDS TO BE CHECKED FOR CORRECTNESS


def uniform_dense_sketch(m, n, device=None, dtype=None):
    factory_kwargs = {"device": device, "dtype": dtype}
    return torch.empty(m, n, **factory_kwargs).uniform_(-1, 1)


def gaussian_dense_sketch(m, n, device=None, dtype=None):
    factory_kwargs = {"device": device, "dtype": dtype}
    return torch.randn(m, n, **factory_kwargs)


def hadamard_sketch(m, device=None, dtype=None):
    factory_kwargs = {"device": device, "dtype": dtype}
    if m & (m - 1) != 0:
        raise ValueError("m must be a power of 2")

    H = torch.tensor([[1.0]])
    while H.shape[0] < m:
        H = torch.cat((torch.cat((H, H), dim=1), torch.cat((H, -H), dim=1)), dim=0)

    return H / torch.sqrt(torch.tensor(m, **factory_kwargs))


def gaussian_orthonormal_sketch(m, n, device=None, dtype=None):
    factory_kwargs = {"device": device, "dtype": dtype}
    return torch.qr(torch.randn(m, n, **factory_kwargs))[0]


def scaled_sign_sketch(m, n, device=None, dtype=None):
    factory_kwargs = {"device": device, "dtype": dtype}
    return (torch.randint(0, 2, (m, n), **factory_kwargs) * 2 - 1) / torch.sqrt(
        torch.tensor(m, **factory_kwargs)
    )


def clarkson_woodruff_sketch(m, n, device=None, dtype=None):
    factory_kwargs = {"device": device, "dtype": dtype}
    indices = torch.randint(0, m, (n,), **factory_kwargs)
    signs = torch.randint(0, 2, (n,), **factory_kwargs) * 2 - 1
    sketch = torch.zeros(m, n, **factory_kwargs)
    sketch[indices, torch.arange(n)] = signs
    return sketch


def sparse_sign_embeddings_sketch(m, n, sparsity=0.1):
    mask = torch.rand(m, n) < sparsity
    signs = torch.randint(0, 2, (m, n)) * 2 - 1
    return mask.float() * signs.float()


Overwriting /kaggle/working/panther/panther/random.py


In [17]:
!pip install botorch



In [18]:
# Import components
from panther.tuner.SkAutoTuner import *



In [19]:
# ModelVisualizer.print_module_tree(model)

# the normal without changing bert

In [20]:
# import os
# import time
# import copy
# import torch
# import numpy as np
# from torch.utils.data import DataLoader, Dataset
# from tqdm import tqdm
# from transformers import BertForMaskedLM, BertTokenizer

# # Import components
# from panther.tuner.SkAutoTuner import (
#     SKAutoTuner, 
#     LayerConfig, 
#     TuningConfigs,
#     GridSearch,
#     RandomSearch, 
#     ModelVisualizer
# )

# # Setting up device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")

# ##################################### HELPERS #######################################

# def dump_tensor_info(tensor, name="Tensor"):
#     """Print details about a tensor"""
#     print(f"{name}: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}")
#     print(f"  - Values: min={tensor.min().item():.4f}, max={tensor.max().item():.4f}, mean={tensor.mean().item():.4f}")
#     print(f"  - First few values: {tensor.flatten()[:5]}")

# def measure_time(func, *args, n_runs=20, warmup=5):
#     """Measure execution time of a function"""
#     # Warmup
#     for _ in range(warmup):
#         func(*args)
    
#     # Timed runs
#     torch.cuda.synchronize() if torch.cuda.is_available() else None
#     start = time.time()
#     for _ in range(n_runs):
#         func(*args)
#         torch.cuda.synchronize() if torch.cuda.is_available() else None
#     end = time.time()
    
#     return (end - start) / n_runs

# def measure_memory(model, input_tensor):
#     """Measure peak memory usage of a model during inference"""
#     if not torch.cuda.is_available():
#         return 0  # Cannot measure CUDA memory on CPU
    
#     # Clear cache
#     torch.cuda.empty_cache()
#     torch.cuda.reset_peak_memory_stats()
    
#     # Run inference
#     with torch.no_grad():
#         model(**input_tensor)
    
#     # Get peak memory
#     return torch.cuda.max_memory_allocated() / (1024 * 1024)  # Convert to MB

# class MaskedTextDataset(Dataset):
#     """Dataset for masked language modeling"""
#     def __init__(self, texts, tokenizer, max_length=128):
#         self.texts = texts
#         self.tokenizer = tokenizer
#         self.max_length = max_length
        
#     def __len__(self):
#         return len(self.texts)
    
#     def __getitem__(self, idx):
#         text = self.texts[idx]
#         encoding = self.tokenizer(
#             text,
#             return_special_tokens_mask=True,
#             max_length=self.max_length,
#             padding="max_length",
#             truncation=True,
#             return_tensors="pt"
#         )
        
#         # Create input_ids with masks
#         input_ids = encoding.input_ids.clone().squeeze(0)
#         special_tokens_mask = encoding.special_tokens_mask.squeeze(0).bool()
        
#         # Create labels (clone of input_ids)
#         labels = input_ids.clone()
        
#         # Find positions eligible for masking (not special tokens)
#         mask_positions = (~special_tokens_mask).nonzero(as_tuple=True)[0]
        
#         # Randomly mask 15% of eligible tokens
#         num_to_mask = max(1, int(0.15 * len(mask_positions)))
#         mask_indices = np.random.choice(mask_positions.tolist(), size=num_to_mask, replace=False)
#         input_ids[mask_indices] = self.tokenizer.mask_token_id
        
#         # Create attention mask
#         attention_mask = encoding.attention_mask.squeeze(0)
        
#         # Create return dictionary
#         batch = {
#             "input_ids": input_ids,
#             "attention_mask": attention_mask,
#             "labels": labels
#         }
        
#         return batch

# def evaluate_model(model, dataloader):
#     """Evaluate model accuracy on a dataset"""
#     model.eval()
#     total_loss = 0
#     total_samples = 0
    
#     with torch.no_grad():
#         for batch in dataloader:
#             # Move batch to device
#             batch = {k: v.to(device) for k, v in batch.items()}
            
#             # Forward pass
#             outputs = model(**batch)
#             loss = outputs.loss
            
#             # Accumulate statistics
#             batch_size = batch["input_ids"].size(0)
#             total_loss += loss.item() * batch_size
#             total_samples += batch_size
    
#     return total_loss / total_samples

# def get_data():
#     """Prepare dataset for BERT testing"""
#     print("Preparing BERT test dataset...")
    
#     # Sample texts for testing
#     texts = [
#         "Machine learning is the study of computer algorithms that improve automatically through experience.",
#         "Deep learning is part of a broader family of machine learning methods based on artificial neural networks.",
#         "Natural language processing is a subfield of linguistics, computer science, and artificial intelligence.",
#         "Transformers have emerged as a powerful deep learning architecture for natural language processing tasks.",
#         "BERT is a transformer-based machine learning technique for natural language processing pre-training."
#     ]
    
#     # Add more texts to the dataset for more robust testing
#     more_texts = [
#         "The transformer architecture uses self-attention mechanisms to process sequential data effectively.",
#         "Pre-trained language models can be fine-tuned on specific downstream tasks with less data.",
#         "Language model pre-training has resulted in significant advances in many natural language tasks.",
#         "Transfer learning enables models to leverage knowledge from one domain to perform well in another.",
#         "Masked language modeling is a self-supervised technique to train language models."
#     ]
#     texts.extend(more_texts)
    
#     # Tokenizer
#     tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    
#     # Create dataset
#     dataset = MaskedTextDataset(texts, tokenizer)
    
#     # Create data loader
#     dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    
#     # Create a single batch for memory testing
#     memory_batch = {k: v.to(device) for k, v in next(iter(dataloader)).items()}
#     memory_batch = {k: v.repeat(4, 1) for k, v in memory_batch.items()}  # Make batch size 4
    
#     return tokenizer, dataloader, memory_batch

# def fill_mask_test(model, tokenizer, text="The capital of France is [MASK]."):
#     """Test mask filling capability"""
#     # Replace [MASK] with actual mask token if needed
#     if "[MASK]" in text:
#         text = text.replace("[MASK]", tokenizer.mask_token)
    
#     # Tokenize
#     inputs = tokenizer(text, return_tensors="pt").to(device)
    
#     # Find mask token position
#     mask_token_index = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)
    
#     # Forward pass
#     with torch.no_grad():
#         outputs = model(**inputs)
#     logits = outputs.logits
    
#     # Get predictions for mask position
#     if len(mask_token_index[0]) > 0:
#         batch_idx, token_idx = mask_token_index
#         mask_logits = logits[batch_idx, token_idx, :]
        
#         # Get top 5 predictions
#         topk_values, topk_indices = torch.topk(mask_logits, 5, dim=1)
        
#         # Convert to tokens
#         topk_tokens = [tokenizer.convert_ids_to_tokens(idx.item()) for idx in topk_indices[0]]
        
#         return topk_tokens
#     else:
#         return ["No mask token found"]

# def test_bert_optimization():
#     """Test SKAutoTuner on BERT model's linear layers"""
    
#     # Load pre-trained model
#     model = BertForMaskedLM.from_pretrained("bert-base-uncased").to(device)
#     model.eval()
    
#     # Create a copy of the model for reference
#     orig_model = copy.deepcopy(model)
    
#     # Get data for testing
#     tokenizer, val_loader, memory_batch = get_data()
    
#     print("\n===== Original Model Structure =====")
#     ModelVisualizer.print_module_tree(model)
    
#     # Create an evaluation function for the model
#     def acc_eval_func(model):
#         """Evaluation function based on MLM loss"""
#         loss = evaluate_model(model, val_loader)
#         # Convert loss to accuracy-like score (higher is better)
#         # Original model's loss is used as reference
#         orig_loss = evaluate_model(orig_model, val_loader)
        
#         # Score based on relative loss difference (higher is better)
#         score = (orig_loss - loss) / orig_loss
        
#         print(f"MLM Loss: {loss:.4f} (original: {orig_loss:.4f}, relative diff: {score:.4f})")
#         return score
    
#     # Create a separate speed evaluation function
#     def speed_eval_func(model):
#         """Speed evaluation function"""
#         def infer(model, inputs):
#             with torch.no_grad():
#                 return model(**inputs)
        
#         # Higher is better (inverse of time)
#         throughput = 1.0 / measure_time(infer, model, memory_batch, n_runs=10)
#         print(f"Inference speed: {throughput:.2f} samples/sec")
#         return throughput
    
#     # Test original model performance
#     print("\nBaseline BERT model:")
#     baseline_loss = evaluate_model(model, val_loader)
#     baseline_speed = speed_eval_func(orig_model)
#     baseline_memory = measure_memory(orig_model, memory_batch)
    
#     print(f"MLM Loss: {baseline_loss:.4f}")
#     print(f"Baseline model memory usage: {baseline_memory:.2f} MB")
#     print(f"Baseline model speed: {baseline_speed:.2f} samples/sec")
    
#     # Test mask filling capability
#     print("\nTesting mask filling on original model:")
#     test_sentence = "The capital of France is [MASK]."
#     predictions = fill_mask_test(model, tokenizer, test_sentence)
#     print(f"Sentence: {test_sentence}")
#     print(f"Top 5 predictions: {predictions}")
    
#     # Strategy 1: Optimizing decoder linear layer in BertMLMHead
#     print("\n===== Strategy 1: Optimizing decoder linear layer =====")
    
#     # Create configs to tune the decoder layer in MLM head
#     configs_strategy1 = TuningConfigs([
#         LayerConfig(
#             # Target the decoder linear layer in the MLM head
#             layer_names={"pattern": "cls.predictions.decoder"},
#             params={
#                 "num_terms": [1, 2, 3],
#                 "low_rank": [16, 32, 64, 128],
#             },
#             separate=True
#         ),
#     ])
    
#     # Calculate accuracy threshold
#     accuracy_threshold = -0.05  # Allow 5% reduction in accuracy
#     print(f"Setting accuracy threshold to {accuracy_threshold:.4f}")
    
#     # Create tuner focusing on the decoder linear layer
#     tuner_strategy1 = SKAutoTuner(
#         model=copy.deepcopy(model),
#         configs=configs_strategy1,
#         accuracy_eval_func=acc_eval_func,
#         search_algorithm=GridSearch(),
#         verbose=True,
#         accuracy_threshold=accuracy_threshold,
#         optmization_eval_func=speed_eval_func
#     )
    
#     # Run tuning
#     print("\nRunning decoder layer tuning...")
#     best_params = tuner_strategy1.tune()
#     print(f"Best parameters: {best_params}")
    
#     # Apply best parameters
#     tuned_model_strategy1 = tuner_strategy1.apply_best_params()
    
#     print("\n===== Tuned Model Structure (Strategy 1) =====")
#     ModelVisualizer.print_module_tree(tuned_model_strategy1)
    
#     # Test the tuned model
#     print("\nEvaluating decoder-tuned model:")
#     final_loss = evaluate_model(tuned_model_strategy1, val_loader)
#     final_speed = speed_eval_func(tuned_model_strategy1)
#     final_memory = measure_memory(tuned_model_strategy1, memory_batch)
    
#     print(f"MLM Loss: {final_loss:.4f} (original: {baseline_loss:.4f})")
#     print(f"Speed: {final_speed:.2f} samples/sec (original: {baseline_speed:.2f})")
#     print(f"Memory: {final_memory:.2f} MB (original: {baseline_memory:.2f})")
    
#     # Strategy 2: Optimizing transform dense layer in BertMLMHead
#     print("\n===== Strategy 2: Optimizing transform dense layer =====")
    
#     # Create configs to tune the dense layer in MLM head's transform module
#     configs_strategy2 = TuningConfigs([
#         LayerConfig(
#             # Target the dense linear layer in the transform module
#             layer_names={"pattern": "cls.predictions.transform.dense"},
#             params={
#                 "num_terms": [1, 2],
#                 "low_rank": [16, 32, 64, 128],
#             },
#             separate=True
#         ),
#     ])
    
#     # Create tuner focusing on the transform dense layer
#     tuner_strategy2 = SKAutoTuner(
#         model=copy.deepcopy(model),
#         configs=configs_strategy2,
#         accuracy_eval_func=acc_eval_func,
#         search_algorithm=GridSearch(),
#         verbose=True,
#         accuracy_threshold=accuracy_threshold,
#         optmization_eval_func=speed_eval_func
#     )
    
#     # Run tuning
#     print("\nRunning transform dense layer tuning...")
#     best_params = tuner_strategy2.tune()
#     print(f"Best parameters: {best_params}")
    
#     # Apply best parameters
#     tuned_model_strategy2 = tuner_strategy2.apply_best_params()
    
#     print("\n===== Tuned Model Structure (Strategy 2) =====")
#     ModelVisualizer.print_module_tree(tuned_model_strategy2)
    
#     # Test the tuned model
#     print("\nEvaluating transform-tuned model:")
#     final_loss = evaluate_model(tuned_model_strategy2, val_loader)
#     final_speed = speed_eval_func(tuned_model_strategy2)
#     final_memory = measure_memory(tuned_model_strategy2, memory_batch)
    
#     print(f"MLM Loss: {final_loss:.4f} (original: {baseline_loss:.4f})")
#     print(f"Speed: {final_speed:.2f} samples/sec (original: {baseline_speed:.2f})")
#     print(f"Memory: {final_memory:.2f} MB (original: {baseline_memory:.2f})")
    
#     # Strategy 3: Optimizing both linear layers in the MLM head
#     print("\n===== Strategy 3: Optimizing both MLM head linear layers =====")
    
#     # Create configs to tune both linear layers together
#     configs_strategy3 = TuningConfigs([
#         LayerConfig(
#             # Target both linear layers in the MLM head
#             layer_names=[
#                 "cls.predictions.decoder",
#                 "cls.predictions.transform.dense"
#             ],
#             params={
#                 "num_terms": [1, 2],
#                 "low_rank": [16, 32, 64, 128],
#             },
#             separate=False  # Tune as a group
#         ),
#     ])
    
#     # Create tuner for both layers together
#     tuner_strategy3 = SKAutoTuner(
#         model=copy.deepcopy(model),
#         configs=configs_strategy3,
#         accuracy_eval_func=acc_eval_func,
#         search_algorithm=GridSearch(),  # Use random search with limited trials
#         verbose=True,
#         accuracy_threshold=accuracy_threshold,
#         optmization_eval_func=speed_eval_func
#     )
    
#     # Run tuning
#     print("\nRunning combined MLM head layers tuning...")
#     best_params = tuner_strategy3.tune()
#     print(f"Best parameters: {best_params}")
    
#     # Apply best parameters
#     tuned_model_strategy3 = tuner_strategy3.apply_best_params()
    
#     print("\n===== Tuned Model Structure (Strategy 3) =====")
#     ModelVisualizer.print_module_tree(tuned_model_strategy3)
    
#     # Test the tuned model
#     print("\nEvaluating combined-tuning model:")
#     final_loss = evaluate_model(tuned_model_strategy3, val_loader)
#     final_speed = speed_eval_func(tuned_model_strategy3)
#     final_memory = measure_memory(tuned_model_strategy3, memory_batch)
    
#     print(f"MLM Loss: {final_loss:.4f} (original: {baseline_loss:.4f})")
#     print(f"Speed: {final_speed:.2f} samples/sec (original: {baseline_speed:.2f})")
#     print(f"Memory: {final_memory:.2f} MB (original: {baseline_memory:.2f})")
    
#     # Test mask filling capability on the final tuned model
#     print("\nTesting mask filling on tuned model:")
#     predictions = fill_mask_test(tuned_model_strategy3, tokenizer, test_sentence)
#     print(f"Sentence: {test_sentence}")
#     print(f"Top 5 predictions: {predictions}")
    
#     # Final comparison table
#     print("\n===== Performance Comparison =====")
#     print("Model Version | MLM Loss | Speed (samples/sec) | Memory (MB)")
#     print("-" * 65)
#     print(f"Original     | {baseline_loss:.4f} | {baseline_speed:.2f} | {baseline_memory:.2f}")
    
#     # Get metrics for all tuned models
#     models = [tuned_model_strategy1, tuned_model_strategy2, tuned_model_strategy3]
#     names = ["Decoder-Tuned", "Transform-Tuned", "Combined-Tuned"]
    
#     for name, tuned_model in zip(names, models):
#         loss = evaluate_model(tuned_model, val_loader)
#         speed = speed_eval_func(tuned_model)
#         memory = measure_memory(tuned_model, memory_batch)
#         print(f"{name:13} | {loss:.4f} | {speed:.2f} | {memory:.2f}")

# if __name__ == "__main__":
#     import copy  # Used for deep copying models
    
#     # Run the BERT optimization test
#     print("\nRunning BERT optimization test with SKAutoTuner...")
#     test_bert_optimization()
    
#     print("\nTest completed.")

In [30]:
import os
import time
import copy
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import BertForMaskedLM, BertTokenizer
import random
import torch.nn.functional as F

# Import components
from panther.tuner.SkAutoTuner import (
    SKAutoTuner, 
    LayerConfig, 
    TuningConfigs,
    GridSearch,
    RandomSearch, 
    ModelVisualizer
)

# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    print(f"Random seed set to {seed} for reproducibility")

# Setting up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Call set_seed early in the script
set_seed(42)

##################################### HELPERS #######################################

def count_parameters(model):
    """Count trainable parameters in the model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def model_size_info(model):
    """Get detailed size information about the model"""
    total_params = count_parameters(model)
    
    # Get layer-wise parameter counts for important components
    layer_params = {}
    
    # Check BERT layers
    if hasattr(model, 'bert') and hasattr(model.bert, 'encoder'):
        for i, layer in enumerate(model.bert.encoder.layer):
            layer_params[f'bert.encoder.layer.{i}'] = sum(p.numel() for p in layer.parameters() if p.requires_grad)
    
    # Check MLM head
    if hasattr(model, 'cls'):
        if hasattr(model.cls, 'predictions'):
            if hasattr(model.cls.predictions, 'transform'):
                layer_params['cls.predictions.transform'] = sum(
                    p.numel() for p in model.cls.predictions.transform.parameters() if p.requires_grad)
            if hasattr(model.cls.predictions, 'decoder'):
                layer_params['cls.predictions.decoder'] = sum(
                    p.numel() for p in model.cls.predictions.decoder.parameters() if p.requires_grad)
    
    return {
        "total_params": total_params,
        "total_params_millions": total_params / 1e6,
        "layer_params": layer_params
    }

def dump_tensor_info(tensor, name="Tensor"):
    """Print details about a tensor"""
    print(f"{name}: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}")
    print(f"  - Values: min={tensor.min().item():.4f}, max={tensor.max().item():.4f}, mean={tensor.mean().item():.4f}")
    print(f"  - First few values: {tensor.flatten()[:5]}")

def measure_time_with_stats(func, *args, n_runs=20, warmup=5):
    """Measure execution time of a function with proper GPU synchronization and report statistics"""
    # Clear cache first
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
    
    # Warmup
    for _ in range(warmup):
        result = func(*args)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
    
    # Timed runs
    times = []
    for _ in range(n_runs):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        
        start = time.time()
        result = func(*args)
        
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        
        end = time.time()
        times.append(end - start)
    
    # Calculate statistics
    times = np.array(times)
    mean_time = np.mean(times)
    std_time = np.std(times)
    
    return {
        "mean": mean_time,
        "std": std_time,
        "min": np.min(times),
        "max": np.max(times),
        "samples_per_sec": 1.0 / mean_time,
        "samples_per_sec_std": std_time / (mean_time * mean_time)
    }

def measure_memory(model, input_tensor):
    """Measure peak memory usage of a model during inference"""
    if not torch.cuda.is_available():
        return 0  # Cannot measure CUDA memory on CPU
    
    # Clear cache
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    # Run inference
    with torch.no_grad():
        model(**input_tensor)
    
    # Get peak memory
    return torch.cuda.max_memory_allocated() / (1024 * 1024)  # Convert to MB

class MaskedTextDataset(Dataset):
    """Dataset for masked language modeling"""
    def __init__(self, texts, tokenizer, max_length=128):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        encoding = self.tokenizer(
            text,
            return_special_tokens_mask=True,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        # Create input_ids with masks
        input_ids = encoding.input_ids.clone().squeeze(0)
        special_tokens_mask = encoding.special_tokens_mask.squeeze(0).bool()
        
        # Create labels (clone of input_ids)
        labels = input_ids.clone()
        
        # Find positions eligible for masking (not special tokens)
        mask_positions = (~special_tokens_mask).nonzero(as_tuple=True)[0]
        
        # Randomly mask 15% of eligible tokens
        num_to_mask = max(1, int(0.15 * len(mask_positions)))
        mask_indices = np.random.choice(mask_positions.tolist(), size=num_to_mask, replace=False)
        input_ids[mask_indices] = self.tokenizer.mask_token_id
        
        # Create attention mask
        attention_mask = encoding.attention_mask.squeeze(0)
        
        # Create return dictionary
        batch = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }
        
        return batch

def evaluate_model_with_stats(model, dataloader, tokenizer=None, n_runs=3):
    """Evaluate model accuracy and loss on a dataset with multiple runs for statistics"""
    all_results = []
    
    for run in range(n_runs):
        model.eval()
        total_loss = 0
        total_correct = 0
        total_predictions = 0
        total_samples = 0
        
        with torch.no_grad():
            for batch in dataloader:
                # Move batch to device
                batch = {k: v.to(device) for k, v in batch.items()}
                
                # Forward pass
                outputs = model(**batch)
                loss = outputs.loss
                
                # Calculate accuracy if tokenizer is provided
                if tokenizer is not None:
                    predictions = outputs.logits.argmax(dim=-1)
                    labels = batch["labels"]
                    
                    # Only consider masked positions for accuracy calculation
                    masked_positions = (batch["input_ids"] == tokenizer.mask_token_id)
                    if masked_positions.sum() > 0:
                        masked_predictions = predictions[masked_positions]
                        masked_labels = labels[masked_positions]
                        correct = (masked_predictions == masked_labels).sum().item()
                        total_correct += correct
                        total_predictions += masked_positions.sum().item()
                
                # Accumulate loss statistics
                batch_size = batch["input_ids"].size(0)
                total_loss += loss.item() * batch_size
                total_samples += batch_size
        
        avg_loss = total_loss / total_samples
        mlm_accuracy = total_correct / total_predictions if total_predictions > 0 else 0
        
        all_results.append({
            "loss": avg_loss,
            "accuracy": mlm_accuracy if tokenizer is not None else None
        })
    
    # Compute statistics across runs
    losses = [res["loss"] for res in all_results]
    accuracies = [res["accuracy"] for res in all_results] if tokenizer is not None else None
    
    results = {
        "loss_mean": np.mean(losses),
        "loss_std": np.std(losses),
        "accuracy_mean": np.mean(accuracies) if accuracies else None,
        "accuracy_std": np.std(accuracies) if accuracies else None,
        "runs": all_results
    }
    
    return results

def get_data():
    """Prepare dataset for BERT testing"""
    print("Preparing BERT test dataset...")
    
    # Sample texts for testing (original set)
    texts = [
        "Machine learning is the study of computer algorithms that improve automatically through experience.",
        "Deep learning is part of a broader family of machine learning methods based on artificial neural networks.",
        "Natural language processing is a subfield of linguistics, computer science, and artificial intelligence.",
        "Transformers have emerged as a powerful deep learning architecture for natural language processing tasks.",
        "BERT is a transformer-based machine learning technique for natural language processing pre-training."
    ]
    
    # Add more texts to the dataset for more robust testing
    more_texts = [
        "The transformer architecture uses self-attention mechanisms to process sequential data effectively.",
        "Pre-trained language models can be fine-tuned on specific downstream tasks with less data.",
        "Language model pre-training has resulted in significant advances in many natural language tasks.",
        "Transfer learning enables models to leverage knowledge from one domain to perform well in another.",
        "Masked language modeling is a self-supervised technique to train language models."
    ]
    texts.extend(more_texts)
    
    # Add more realistic data for better evaluation
    additional_texts = []
    
    # Use WikiText-2 chunks if available, or more synthetic examples
    try:
        from datasets import load_dataset
        print("Loading WikiText dataset for more realistic evaluation...")
        wiki_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
        # Take chunks of 100-200 characters, clean them, and add to texts
        for item in wiki_dataset:
            text = item['text'].strip()
            if 100 <= len(text) <= 200 and len(text.split()) > 10:
                additional_texts.append(text)
                if len(additional_texts) >= 90:  # Add 90 more examples to get 100 total
                    break
        print(f"Added {len(additional_texts)} examples from WikiText dataset")
    except Exception as e:
        # If datasets package is not available, generate synthetic text
        print(f"WikiText dataset not available ({str(e)}), using synthetic examples")
        additional_sentences = [
            "The transformer model has revolutionized natural language processing with its attention mechanism.",
            "Neural networks can learn complex patterns from large amounts of training data.",
            "Word embeddings map words to vectors in a high-dimensional space to capture semantic meaning.",
            "Recurrent neural networks process sequential data by maintaining a hidden state.",
            "Attention mechanisms allow models to focus on relevant parts of the input sequence.",
            "Backpropagation is an algorithm for efficiently computing gradients in neural networks.",
            "Gradient descent optimizes neural network parameters by iteratively updating weights.",
            "Convolutional neural networks are particularly effective for image recognition tasks.",
            "Regularization techniques help prevent overfitting in machine learning models.",
            "Transfer learning leverages knowledge from pre-trained models to improve performance.",
            "Long Short-Term Memory networks are designed to handle the vanishing gradient problem.",
            "Encoder-decoder architectures are commonly used for sequence-to-sequence tasks.",
            "Tokenization is the process of converting text into discrete tokens for processing.",
            "Fine-tuning adapts pre-trained models to specific downstream tasks with less data.",
            "The softmax function converts a vector of values into a probability distribution.",
            "Cross-entropy loss is commonly used for training classification models.",
            "Dropout randomly deactivates neurons during training to prevent co-adaptation.",
            "Batch normalization stabilizes and accelerates training of deep neural networks.",
            "Reinforcement learning trains agents through interaction with an environment.",
            "Semi-supervised learning combines labeled and unlabeled data for training.",
            "One-hot encoding represents categorical variables as binary vectors.",
            "Dimensionality reduction techniques compress data while preserving information.",
            "Ensemble methods combine multiple models to improve prediction accuracy.",
            "Generative adversarial networks consist of generator and discriminator components.",
            "Autoencoders learn efficient representations by reconstructing their input data.",
            "The curse of dimensionality refers to challenges in high-dimensional spaces.",
            "Decision trees recursively partition data based on feature values.",
            "Random forests are ensembles of decision trees with randomized feature selection.",
            "Support vector machines find the optimal hyperplane to separate data classes.",
            "K-means clustering groups data points based on similarity metrics.",
            "Principal component analysis identifies directions of maximum variance in data.",
            "Bias-variance tradeoff is a fundamental concept in machine learning generalization.",
            "Feature engineering transforms raw data into features suitable for models.",
            "Hyperparameter tuning optimizes model configuration for best performance.",
            "Cross-validation assesses model performance on different data subsets.",
            "The Internet of Things connects everyday devices to the global network.",
            "Cloud computing provides on-demand access to computing resources over the internet.",
            "Blockchain technology enables secure, decentralized transaction records.",
            "Quantum computing leverages quantum mechanics for computational tasks.",
            "Edge computing processes data near its source rather than in a centralized location."
        ]
        # Take up to 90 synthetic examples
        additional_texts.extend(additional_sentences[:90])
    
    texts.extend(additional_texts[:90])  # Add up to 90 more texts
    print(f"Created dataset with {len(texts)} examples")
    
    # Tokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    
    # Create dataset
    dataset = MaskedTextDataset(texts, tokenizer)
    
    # Create data loader with batch size that's a multiple of 16 for Tensor Core optimization
    batch_size = 16  # Changed from 2 to 16 to enable Tensor Core optimizations
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Create a single batch for memory testing (also multiple of 16)
    memory_batch = {k: v.to(device) for k, v in next(iter(dataloader)).items()}
    # No need to expand since we already have batch size 16
    
    return tokenizer, dataloader, memory_batch

def get_data_varied_lengths(seq_lengths=[128, 256, 384, 512]):
    """Prepare datasets with varying sequence lengths for scaling tests"""
    print(f"Preparing BERT test datasets with varying lengths: {seq_lengths}")
    
    # Sample texts (use longer texts for this test)
    texts = []
    
    # Try to load more complex texts from WikiText
    try:
        from datasets import load_dataset
        print("Loading WikiText dataset for sequence length tests...")
        wiki_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
        
        # Take chunks of various lengths
        for item in wiki_dataset:
            text = item['text'].strip()
            if len(text) > 50 and text not in texts:  # Ensure it's a substantive text
                texts.append(text)
                if len(texts) >= 100:  # Get 100 examples
                    break
    except Exception as e:
        # Fallback to synthetic data
        print(f"WikiText dataset not available ({str(e)}), using synthetic examples")
        # Generate longer synthetic texts by repeating and combining existing ones
        base_texts = [
            "Machine learning is the study of computer algorithms that improve automatically through experience.",
            "Deep learning is part of a broader family of machine learning methods based on artificial neural networks.",
            "Natural language processing is a subfield of linguistics, computer science, and artificial intelligence.",
            "Transformers have emerged as a powerful deep learning architecture for natural language processing tasks.",
            "BERT is a transformer-based machine learning technique for natural language processing pre-training."
        ]
        
        # Create longer texts by combining shorter ones
        for _ in range(100):
            num_sentences = random.randint(3, 10)
            combined_text = " ".join(random.choices(base_texts, k=num_sentences))
            texts.append(combined_text)
    
    print(f"Created dataset with {len(texts)} examples")
    
    # Tokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    
    # Create datasets and dataloaders for each sequence length
    datasets = {}
    dataloaders = {}
    memory_batches = {}
    
    for max_length in seq_lengths:
        # Create dataset with this specific max_length
        dataset = MaskedTextDataset(texts, tokenizer, max_length=max_length)
        
        # Create data loader with batch size that's a multiple of 16 for Tensor Core optimization
        # Use smaller batches for longer sequences to prevent OOM
        batch_size = 16 if max_length <= 128 else 8 if max_length <= 256 else 4 if max_length <= 384 else 2
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        # Create a memory test batch
        memory_batch = {k: v.to(device) for k, v in next(iter(dataloader)).items()}
        
        datasets[max_length] = dataset
        dataloaders[max_length] = dataloader
        memory_batches[max_length] = memory_batch
    
    return tokenizer, datasets, dataloaders, memory_batches

def test_sequence_scaling(orig_model, tuned_model, tokenizer):
    """Test how performance improvements scale with sequence length"""
    print("\n===== Testing Performance Scaling with Sequence Length =====")
    
    # Get datasets with varying sequence lengths
    tokenizer, datasets, dataloaders, memory_batches = get_data_varied_lengths()
    
    # Results table
    results = []
    
    # Test each sequence length
    for seq_length in sorted(dataloaders.keys()):
        print(f"\nTesting with sequence length: {seq_length}")
        dataloader = dataloaders[seq_length]
        memory_batch = memory_batches[seq_length]
        
        # Function for inference
        def infer(model, inputs):
            with torch.no_grad():
                return model(**inputs)
        
        # Test models
        for model_name, model in [("Original", orig_model), ("Tuned", tuned_model)]:
            model.eval()
            torch.cuda.empty_cache()
            
            # Measure accuracy
            print(f"Evaluating {model_name} model accuracy...")
            eval_results = evaluate_model_with_stats(model, dataloader, tokenizer, n_runs=3)
            
            # Measure speed
            print(f"Measuring {model_name} model speed...")
            time_results = measure_time_with_stats(infer, model, memory_batch, n_runs=10, warmup=3)
            
            # Measure memory
            memory_used = measure_memory(model, memory_batch)
            
            # Store results
            results.append({
                "seq_length": seq_length,
                "model": model_name,
                "loss_mean": eval_results["loss_mean"],
                "loss_std": eval_results["loss_std"],
                "accuracy_mean": eval_results["accuracy_mean"],
                "accuracy_std": eval_results["accuracy_std"],
                "speed_mean": time_results["samples_per_sec"],
                "speed_std": time_results["samples_per_sec_std"],
                "memory": memory_used
            })
    
    # Print results table
    print("\n===== Sequence Length Scaling Results =====")
    print("| Seq Length | Model | MLM Loss | MLM Accuracy | Speed (samples/sec) | Memory (MB) | Speedup |")
    print("|------------|-------|----------|--------------|---------------------|-------------|---------|")
    
    for seq_length in sorted(dataloaders.keys()):
        # Extract results for this sequence length
        orig_result = next(r for r in results if r["seq_length"] == seq_length and r["model"] == "Original")
        tuned_result = next(r for r in results if r["seq_length"] == seq_length and r["model"] == "Tuned")
        
        # Calculate speedup
        speedup = tuned_result["speed_mean"] / orig_result["speed_mean"]
        
        # Print original model results
        print(f"| {seq_length:10d} | Original | {orig_result['loss_mean']:.4f}±{orig_result['loss_std']:.4f} | "
              f"{orig_result['accuracy_mean']:.4f}±{orig_result['accuracy_std']:.4f} | "
              f"{orig_result['speed_mean']:.2f}±{orig_result['speed_std']:.2f} | "
              f"{orig_result['memory']:.2f} | 1.00x |")
        
        # Print tuned model results
        print(f"| {seq_length:10d} | Tuned | {tuned_result['loss_mean']:.4f}±{tuned_result['loss_std']:.4f} | "
              f"{tuned_result['accuracy_mean']:.4f}±{tuned_result['accuracy_std']:.4f} | "
              f"{tuned_result['speed_mean']:.2f}±{tuned_result['speed_std']:.2f} | "
              f"{tuned_result['memory']:.2f} | {speedup:.2f}x |")
    
    return results

def fill_mask_test(model, tokenizer, text="The capital of France is [MASK]."):
    """Test mask filling capability"""
    # Replace [MASK] with actual mask token if needed
    if "[MASK]" in text:
        text = text.replace("[MASK]", tokenizer.mask_token)
    
    # Tokenize
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    # Find mask token position
    mask_token_index = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)
    
    # Forward pass
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    
    # Get predictions for mask position
    if len(mask_token_index[0]) > 0:
        batch_idx, token_idx = mask_token_index
        mask_logits = logits[batch_idx, token_idx, :]
        
        # Get top 5 predictions
        topk_values, topk_indices = torch.topk(mask_logits, 5, dim=1)
        
        # Convert to tokens
        topk_tokens = [tokenizer.convert_ids_to_tokens(idx.item()) for idx in topk_indices[0]]
        
        return topk_tokens
    else:
        return ["No mask token found"]

def test_bert_optimization():
    """Test SKAutoTuner on BERT model's linear layers"""
    
    # Set seed for reproducibility
    set_seed(42)
    
    # Create reference copy before any modifications to ensure identical initial states
    model = BertForMaskedLM.from_pretrained("bert-base-uncased").to(device)
    model.eval()
    orig_model = copy.deepcopy(model)  # Create copy before any modifications
    
    # Get data for testing (do this before modifying the models)
    tokenizer, val_loader, memory_batch = get_data()
    
    # Get parameter counts before optimization
    print("\n===== Model Parameter Counts Before Optimization =====")
    orig_params = model_size_info(model)
    print(f"Total parameters: {orig_params['total_params_millions']:.2f}M")
    print("Parameters by layer:")
    for layer_name, param_count in orig_params['layer_params'].items():
        print(f"  - {layer_name}: {param_count/1e6:.2f}M parameters")
    
    # Apply identical vocab size modifications to both models
    orig_out_features = model.cls.predictions.decoder.weight.size(0)
    new_out_features = ((orig_out_features + 15) // 16) * 16
    
    # Store the true original forward methods before any wrapping
    true_orig_forward = model.forward
    true_orig_ref_forward = orig_model.forward
    
    # Define the post-processing function
    def post_process_outputs(model_outputs, orig_size=orig_out_features):
        """Trim any padded outputs back to original vocabulary size"""
        if hasattr(model_outputs, 'logits') and model_outputs.logits is not None:
            if model_outputs.logits.size(-1) > orig_size:
                # Trim to original vocabulary size
                model_outputs.logits = model_outputs.logits[..., :orig_size]
        return model_outputs
    
    # Create a wrapper factory for forward methods
    def create_wrapped_forward(original_forward_fn):
        def wrapped_forward(*args, **kwargs):
            outputs = original_forward_fn(*args, **kwargs)
            return post_process_outputs(outputs)
        return wrapped_forward
    
    # Apply identical modifications to both models
    for m in [model, orig_model]:
        if orig_out_features != new_out_features:
            # Create padded weights and bias
            orig_weight = m.cls.predictions.decoder.weight
            orig_bias = m.cls.predictions.decoder.bias
            
            new_weight = torch.zeros(new_out_features, orig_weight.size(1), 
                                    device=orig_weight.device, dtype=orig_weight.dtype)
            new_bias = torch.zeros(new_out_features, 
                                  device=orig_bias.device, dtype=orig_bias.dtype)
            
            # Copy the original values
            new_weight[:orig_out_features, :] = orig_weight
            new_bias[:orig_out_features] = orig_bias
            
            # Replace the decoder
            new_decoder = torch.nn.Linear(orig_weight.size(1), new_out_features, bias=True)
            new_decoder.weight = torch.nn.Parameter(new_weight)
            new_decoder.bias = torch.nn.Parameter(new_bias)
            
            m.cls.predictions.decoder = new_decoder
            
        # Update the config's vocab_size
        m.config.vocab_size = new_out_features
    
    # Apply a single wrapping to each model
    model.forward = create_wrapped_forward(true_orig_forward)
    orig_model.forward = create_wrapped_forward(true_orig_ref_forward)
    
    # First evaluate the original model before any modifications
    print("\nBaseline BERT model (before any modifications):")
    baseline_results = evaluate_model_with_stats(model, val_loader, tokenizer)
    
    # Measure performance metrics of original model
    def infer(model, inputs):
        with torch.no_grad():
            return model(**inputs)
    
    baseline_time_stats = measure_time_with_stats(infer, model, memory_batch, n_runs=10)
    baseline_speed = baseline_time_stats["samples_per_sec"]
    baseline_memory = measure_memory(model, memory_batch)
    
    print(f"MLM Loss: {baseline_results['loss_mean']:.4f}±{baseline_results['loss_std']:.4f}")
    print(f"MLM Accuracy: {baseline_results['accuracy_mean']:.4f}±{baseline_results['accuracy_std']:.4f}")
    print(f"Baseline model memory usage: {baseline_memory:.2f} MB")
    print(f"Baseline model speed: {baseline_speed:.2f}±{baseline_time_stats['samples_per_sec_std']:.2f} samples/sec")
    
    print("\n===== Original Model Structure =====")
    ModelVisualizer.print_module_tree(model)
    
    # Create an evaluation function for the model
    def acc_eval_func(model):
        """Evaluation function based on true MLM accuracy"""
        results = evaluate_model_with_stats(model, val_loader, tokenizer)
        print(f"MLM Loss: {results['loss_mean']:.4f}±{results['loss_std']:.4f}, MLM Accuracy: {results['accuracy_mean']:.4f}±{results['accuracy_std']:.4f}")
        return results['accuracy_mean']  # Return accuracy (higher is better)
    
    # Create a separate speed evaluation function
    def speed_eval_func(model):
        """Speed evaluation function"""
        def infer(model, inputs):
            with torch.no_grad():
                return model(**inputs)
        
        # Higher is better (inverse of time)
        time_stats = measure_time_with_stats(infer, model, memory_batch, n_runs=10)
        throughput = time_stats["samples_per_sec"]
        print(f"Inference speed: {throughput:.2f}±{time_stats['samples_per_sec_std']:.2f} samples/sec")
        return throughput
    
    # Calculate accuracy threshold
    accuracy_threshold = -0.05  # Allow 5% reduction in accuracy
    print(f"Setting accuracy threshold to {accuracy_threshold:.4f}")
    
    # Strategy: Optimizing both linear layers in the MLM head
    print("\n===== Optimizing both MLM head linear layers =====")
    
    # Create configs to tune both linear layers together with Tensor Core friendly dimensions
    configs = TuningConfigs([
        LayerConfig(
            # Target both linear layers in the MLM head
            layer_names={
                "pattern": "cls.predictions.*",
                "type": "Linear",
            },
            params={
                "num_terms": [1, 2, 3],
                "low_rank": [16, 32, 64],  # All values are multiples of 16 for Tensor Core
            },
            separate=False  # Tune as a group
        ),
    ])
    
    # Create tuner for both layers together
    tuner = SKAutoTuner(
        model=copy.deepcopy(model),
        configs=configs,
        accuracy_eval_func=acc_eval_func,
        search_algorithm=GridSearch(),
        verbose=True,
        accuracy_threshold=accuracy_threshold,
        optmization_eval_func=speed_eval_func
    )
    
    # Run tuning
    print("\nRunning combined MLM head layers tuning...")
    best_params = tuner.tune()
    print(f"Best parameters: {best_params}")
    
    # Apply best parameters
    tuned_model = tuner.apply_best_params()
    
    print("\n===== Tuned Model Structure =====")
    ModelVisualizer.print_module_tree(tuned_model)
    
    # Test the tuned model
    print("\nEvaluating models with identical conditions:")
    
    # Ensure both models are in the same state for fair comparison
    for m in [orig_model, tuned_model]:
        m.eval()
        torch.cuda.empty_cache()

    # Use identical test conditions
    def test_model(model_name, model):
        torch.cuda.empty_cache()
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
        
        # Run standardized tests
        results = evaluate_model_with_stats(model, val_loader, tokenizer)
        
        def infer(model, inputs):
            with torch.no_grad():
                return model(**inputs)
        
        time_result = measure_time_with_stats(infer, model, memory_batch, n_runs=10)
        speed = time_result["samples_per_sec"]
        speed_std = time_result["samples_per_sec_std"]
        memory_used = measure_memory(model, memory_batch)
        
        return {
            "name": model_name,
            "loss": results["loss_mean"],
            "loss_std": results["loss_std"],
            "accuracy": results["accuracy_mean"],
            "accuracy_std": results["accuracy_std"],
            "speed": speed,
            "speed_std": speed_std,
            "memory": memory_used
        }

    # Test both models under identical conditions
    baseline_results = test_model("Original", orig_model)
    tuned_results = test_model("Tuned", tuned_model)
    
    # Extract results for the comparison table
    baseline_loss = baseline_results["loss"]
    baseline_accuracy = baseline_results["accuracy"]
    baseline_speed = baseline_results["speed"]
    baseline_memory = baseline_results["memory"]
    
    final_loss = tuned_results["loss"]
    final_accuracy = tuned_results["accuracy"]
    final_speed = tuned_results["speed"]
    final_memory = tuned_results["memory"]
    
    # After optimization, get new parameter counts
    print("\n===== Model Parameter Counts After Optimization =====")
    tuned_params = model_size_info(tuned_model)
    print(f"Original model: {orig_params['total_params_millions']:.2f}M parameters")
    print(f"Tuned model: {tuned_params['total_params_millions']:.2f}M parameters")
    print(f"Reduction: {(1 - tuned_params['total_params_millions']/orig_params['total_params_millions'])*100:.2f}%")
    
    print("\nParameters by layer:")
    for layer_name in sorted(set(list(orig_params['layer_params'].keys()) + list(tuned_params['layer_params'].keys()))):
        orig_count = orig_params['layer_params'].get(layer_name, 0) / 1e6
        tuned_count = tuned_params['layer_params'].get(layer_name, 0) / 1e6
        
        if orig_count > 0 and tuned_count > 0:
            reduction = (1 - tuned_count/orig_count) * 100
            print(f"  - {layer_name}: {orig_count:.2f}M → {tuned_count:.2f}M ({reduction:.2f}% reduction)")
    
    print(f"MLM Loss: {final_loss:.4f}±{tuned_results['loss_std']:.4f} (original: {baseline_loss:.4f}±{baseline_results['loss_std']:.4f})")
    print(f"MLM Accuracy: {final_accuracy:.4f}±{tuned_results['accuracy_std']:.4f} (original: {baseline_accuracy:.4f}±{baseline_results['accuracy_std']:.4f})")
    print(f"Speed: {final_speed:.2f}±{tuned_results['speed_std']:.2f} samples/sec (original: {baseline_speed:.2f}±{baseline_results['speed_std']:.2f})")
    print(f"Memory: {final_memory:.2f} MB (original: {baseline_memory:.2f})")
    
    # Enhanced performance comparison table
    print("\n===== Performance Comparison =====")
    print("| Model Version | MLM Loss | MLM Accuracy | Speed (samples/sec) | Memory (MB) | Speed Improvement |")
    print("|--------------|----------|--------------|---------------------|-------------|-------------------|")
    print(f"| Original     | {baseline_loss:.4f}±{baseline_results['loss_std']:.4f} | {baseline_accuracy:.4f}±{baseline_results['accuracy_std']:.4f} | {baseline_speed:.2f}±{baseline_results['speed_std']:.2f} | {baseline_memory:.2f} | 1.00x |")
    print(f"| Tuned        | {final_loss:.4f}±{tuned_results['loss_std']:.4f} | {final_accuracy:.4f}±{tuned_results['accuracy_std']:.4f} | {final_speed:.2f}±{tuned_results['speed_std']:.2f} | {final_memory:.2f} | {final_speed/baseline_speed:.2f}x |")
    
    # Additional comparison tests with real examples
    test_examples = [
        "The capital of France is [MASK].",
        "Machine learning models [MASK] data to make predictions.",
        "Transformers use [MASK] attention to process sequences.",
        "The [MASK] language model was developed by Google researchers."
    ]
    
    print("\n===== Qualitative Comparison: Mask Filling =====")
    for test_sentence in test_examples:
        print(f"\nSentence: {test_sentence}")
        
        # Original model predictions
        orig_predictions = fill_mask_test(orig_model, tokenizer, test_sentence)
        print(f"Original model predictions: {', '.join(orig_predictions)}")
        
        # Tuned model predictions
        tuned_predictions = fill_mask_test(tuned_model, tokenizer, test_sentence)
        print(f"Tuned model predictions:    {', '.join(tuned_predictions)}")
    
    # Run sequence length scaling test
    test_sequence_scaling(orig_model, tuned_model, tokenizer)
    
    return tuned_model

if __name__ == "__main__":
    import copy  # Used for deep copying models
    
    # Run the BERT optimization test
    print("\nRunning BERT optimization test with SKAutoTuner...")
    test_bert_optimization()
    
    print("\nTest completed.")

Using device: cuda
Random seed set to 42 for reproducibility

Running BERT optimization test with SKAutoTuner...
Random seed set to 42 for reproducibility


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Preparing BERT test dataset...
Loading WikiText dataset for more realistic evaluation...
Added 90 examples from WikiText dataset
Created dataset with 100 examples

===== Model Parameter Counts Before Optimization =====
Total parameters: 109.51M
Parameters by layer:
  - bert.encoder.layer.0: 7.09M parameters
  - bert.encoder.layer.1: 7.09M parameters
  - bert.encoder.layer.2: 7.09M parameters
  - bert.encoder.layer.3: 7.09M parameters
  - bert.encoder.layer.4: 7.09M parameters
  - bert.encoder.layer.5: 7.09M parameters
  - bert.encoder.layer.6: 7.09M parameters
  - bert.encoder.layer.7: 7.09M parameters
  - bert.encoder.layer.8: 7.09M parameters
  - bert.encoder.layer.9: 7.09M parameters
  - bert.encoder.layer.10: 7.09M parameters
  - bert.encoder.layer.11: 7.09M parameters
  - cls.predictions.transform: 0.59M parameters
  - cls.predictions.decoder: 23.47M parameters

Baseline BERT model (before any modifications):
MLM Loss: 14.3385±0.0361
MLM Accuracy: 0.5525±0.0211
Baseline model memo