In [1]:
import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# os.environ["TORCH_LOGS"] = "output_code"
# os.environ["TORCH_LOGS"] = "inductor"
# os.environ["TORCHINDUCTOR_TRACE"] = "1"
# os.environ["TORCHINDUCTOR_VERBOSE"] = "1"
# os.environ["TORCHINDUCTOR_DEBUG"] = "1"
# os.environ["TORCHINDUCTOR_DUMP"] = "1"
# os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
# os.environ["TORCH_COMPILE_DEBUG"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
    # "backend:cudaMallocAsync,"
    "expandable_segments:True,"
    # "garbage_collection_threshold:0.6"
)

import sys
sys.path.append('/home/hice1/yyu496/kaggle/CW')
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._functorch.aot_autograd import aot_module, make_boxed_func
from torch.autograd import grad
from torch.optim.lr_scheduler import (
    LinearLR,
    CosineAnnealingLR,
    SequentialLR
)
from torch.optim.optimizer import Optimizer

import torch
from torch.profiler import profile, record_function, ProfilerActivity
import torch._dynamo as dynamo

import ACT6.cpp_extension as cpp_extension

from ACT6.controller import Controller
import ACT6.cuda_graph_utils as cuda_utils
from ACT6.layers import RMSNorm

import timm
import torchvision.models as models
from torchvision.transforms import v2

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.allow_tf32 = True


import actnn
# available choices are ["L0", "L1", "L2", "L3", "L4", "L5"]
actnn.set_optimization_level("L3")


from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
from torchmetrics import Accuracy 



import matplotlib.pyplot as plt
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size = 512
warmup_epochs = 20
num_epochs = 512
# ============= Data ==================

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]


transform_train = v2.Compose([
    # --- Convert to tensor FIRST (kills PIL early) ---
    v2.ToImage(),                          # handles PIL → Tensor safely
    v2.ToDtype(torch.float32, scale=True),

    # --- Geometry ---
    v2.RandomResizedCrop(224, scale=(0.5, 1.0), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),

    # --- Color ---
    v2.RandomApply([
        v2.ColorJitter(
            brightness=0.2,
            contrast=0.2,
            saturation=0.2,
            hue=0.1
        )
    ], p=0.5),

    # --- RandAugment (vectorized, NO PIL) ---
    v2.RandAugment(num_ops=4, magnitude=9),

    # --- Normalize ---
    v2.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),

    # --- Random Erasing (tensor-based, fast) ---
    v2.RandomErasing(p=0.5, scale=(0.02, 0.33),
                     ratio=(0.3, 3.3), value=0),
])

# download and prepare
# train_dataset = datasets.CIFAR100(root="./data", train=True,
#                                  download=True, transform=transform_train)

train_dataset = ImageFolder(
    root="/home/hice1/yyu496/scratch/data/cifar10_resized/train",
    transform=transform_train
)

# DataLoader with GPU-friendly settings
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          num_workers=45, pin_memory=True, drop_last=True,
                          persistent_workers=True, prefetch_factor=15)



val_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),

    v2.Resize(224, antialias=True),

    v2.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# === CIFAR-10 validation dataset (train=False) ===
# val_dataset = datasets.CIFAR100(
#     root="./data",
#     train=False,
#     download=True,
#     transform=val_transform
# )
val_dataset = ImageFolder(
    root="/home/hice1/yyu496/scratch/data/cifar10_resized/test",
    transform=val_transform
)

# === Validation DataLoader ===
val_loader = DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=False,   
    num_workers=45,
    pin_memory=True,
    persistent_workers=True, 
    prefetch_factor=15
)

def replace_activation_func(m):
    for name, child in m.named_children():
        if isinstance(child, nn.ReLU):
            setattr(m, name, nn.SiLU(inplace=True))
            # setattr(m, name, nn.ReLU6(inplace=True))
            # setattr(m, name, nn.LeakyReLU(inplace=True))
            # setattr(m, name, nn.GELU())
        else:
            replace_activation_func(child)


def replace_norm(m):
    for name, child in m.named_children():
        if isinstance(child, nn.LayerNorm):
            setattr(m, name, RMSNorm(dims=child.normalized_shape[-1]))
        else:
            replace_activation_func(child)


def disable_act_inplace(m):
    for module in m.modules():
        if isinstance(module, (nn.ReLU, nn.SiLU, nn.GELU)) and hasattr(module, 'inplace'):
            module.inplace = False


def init_fan_out(m):
    if isinstance(m, (nn.Conv2d,)):
        nn.init.kaiming_normal_(
            m.weight,
            mode='fan_out',
            nonlinearity='relu'
        )
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(
            m.weight,
            mode='fan_in',
            nonlinearity='relu'
        )

        if m.bias is not None:
            nn.init.zeros_(m.bias)



In [3]:
DIVISION = {
    'pool_kernel_size' : 3
}
# DIVISION = None

config = {
    "default_bits": 2,
    'auto_precision': None,
    'DIVISION' : DIVISION,
    "group_size": 256,
    'batch_size' : batch_size,
    'fp8' : False,
    'depth_point_conv' : False,
    'rms_norm' : False
}

# vit_base_patch16_224
# efficientnet_b0
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=10)
# replace_activation_func(model)
# replace_norm(model)


num_classes = 100
# model = models.resnet50(weights=None)
# model = models.resnet18(weights=None)
# model = models.efficientnet_b0(weights=None)
# model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
# model.fc = nn.Linear(model.fc.in_features, num_classes)
# replace_activation_func(model)
disable_act_inplace(model)
# model.apply(init_fan_out)
# model = actnn.QModule(model)
# model.cuda()
# model.compile(fullgraph=True)



criterion = nn.CrossEntropyLoss()

compute_stream = torch.cuda.Stream()
controller = Controller(model, config, train_loader, criterion, test=False)
controller.iterate(criterion)
controller.warp_model(graph_mode=True, quantizer=True)
# controller.warmup(train_loader, criterion, compute_stream)

In [None]:

N = len(controller.quantizer.layer_hessian_eigenvalue_spectral_density)
cols = 3
rows = math.ceil(N / cols)

fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 3*rows))
axes = axes.flatten()

for i, (ax, name) in enumerate(zip(axes, controller.quantizer.layer_hessian_eigenvalue_spectral_density.keys())):
    if i < N:
        ax.hist(controller.quantizer.layer_hessian_eigenvalue_spectral_density[name], bins=50)
        ax.set_title(f"Layer {name}")
    else:
        ax.axis("off")  # hide unused subplots

plt.tight_layout()
plt.show()

In [4]:
class SGD_LARS(Optimizer):
    def __init__(self,
                 params,
                 lr=0.1,
                 momentum=0.9,
                 weight_decay=1e-4,
                 eta=0.001,
                 eps=1e-9,
                 nesterov=False,
                 trust_coef=True):
        defaults = dict(lr=lr,
                        momentum=momentum,
                        weight_decay=weight_decay,
                        eta=eta,
                        eps=eps,
                        nesterov=nesterov,
                        trust_coef=trust_coef)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            lr = group['lr']
            momentum = group['momentum']
            weight_decay = group['weight_decay']
            eta = group['eta']
            eps = group['eps']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                # ---- decoupled weight decay ----
                if weight_decay != 0 and p.ndim > 1:
                    grad = grad.add(p, alpha=weight_decay)

                # ---- LARS trust ratio (only apply to weight tensors) ----
                if p.ndim > 1:      # don't LARS biases / norms
                    w_norm = p.norm()
                    g_norm = grad.norm()

                    if w_norm > 0 and g_norm > 0:
                        trust_ratio = eta * (w_norm / (g_norm + eps))
                    else:
                        trust_ratio = 1.0
                else:
                    trust_ratio = 0.001

                scaled_lr = lr * trust_ratio

                # ---- Momentum buffer ----
                param_state = self.state[p]
                if 'momentum_buffer' not in param_state:
                    buf = param_state['momentum_buffer'] = torch.zeros_like(p)
                else:
                    buf = param_state['momentum_buffer']

                buf.mul_(momentum).add_(grad)  # standard momentum update

                if nesterov:
                    update = grad.add(buf, alpha=momentum)
                else:
                    update = buf

                # ---- Update weights ----
                p.add_(update, alpha=-scaled_lr)


In [5]:
acc = Accuracy(task='multiclass', num_classes=10, average='macro')

opt = torch.optim.AdamW(controller.traced_model.parameters(), lr=(1e-3 * 3), fused=True, capturable=True)
# opt = SGD_LARS(
#     controller.traced_model.parameters(),
#     lr=0.5,                     # LARS needs high LR
#     momentum=0.9,
#     weight_decay=1e-7,          # decoupled weight decay
#     eta=0.001,                  # LARS coefficient (0.001–0.02)
#     nesterov=True
# )
# opt = torch.optim.AdamW(model.parameters(), lr=(1e-3 * 3), fused=True, capturable=True)
# opt = SGD_LARS(
#     model.parameters(),
#     lr=1.0,                     # LARS needs high LR
#     momentum=0.9,
#     weight_decay=1e-7,          # decoupled weight decay
#     eta=0.001,                  # LARS coefficient (0.001–0.02)
#     nesterov=True
# )


total_timer_start = torch.cuda.Event(enable_timing=True)
total_timer_end = torch.cuda.Event(enable_timing=True)
total_time = 0.0

e_timer_start = torch.cuda.Event(enable_timing=True)
e_timer_end = torch.cuda.Event(enable_timing=True)
partile_time = 0.0


# scheduler = torch.optim.lr_scheduler.OneCycleLR(
#     opt,
#     1e-3,
#     total_steps= num_epochs * len(train_loader),
#     pct_start=0.1,
#     div_factor=10,
#     final_div_factor=1000
# )

# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
#     opt,
#     T_0=10,
#     T_mult=2,
#     eta_min=1e-5
# )

warmup_sched = LinearLR(
    opt,
    start_factor=0.1,   # lr = lr * start_factor
    end_factor=1.0,    # lr = lr * end_factor = lr
    total_iters=warmup_epochs
)

# 2. Cosine decay
cosine_sched = CosineAnnealingLR(
    opt,
    T_max=num_epochs - warmup_epochs,
    eta_min=1e-4
)

# 3. Combine them
scheduler = SequentialLR(
    opt,
    schedulers=[warmup_sched, cosine_sched],
    milestones=[warmup_epochs]
)

In [6]:
cuda_graph_generator = cuda_utils.Graph(controller.traced_model,
                                        criterion,
                                        opt,
                                        train_loader,
                                        compute_stream,
                                        mode='qdrop',
                                        num_of_graph=1,
                                        device='cuda')


# cuda_graph_generator = cuda_utils.Graph(model,
#                                         criterion,
#                                         opt,
#                                         train_loader,
#                                         compute_stream,
#                                         mode='qdrop',
#                                         num_of_graph=1,
#                                         device='cuda')

In [7]:
graphs, compute_stream, static_x, static_y, logits, loss = cuda_graph_generator.capture_cuda_graph_qdrop()

total_graphs = cuda_graph_generator.num_of_graph

In [8]:
@torch.no_grad()
def lowpass_param_2d(t: torch.Tensor, k: int = 3):
    """
    Low-pass filter for parameter tensors that have 2D structure in the last two dims
    (e.g. conv weights [out_c, in_c, kH, kW] or activation-like grads [B, C, H, W]).

    For non-4D tensors or very small spatial dims, returns t unchanged.
    """
    if t.ndim != 4:
        return t

    *prefix, H, W = t.shape
    if H < k or W < k:
        return t  # nothing to do, kernel larger than spatial size

    # Merge prefix dims into batch*channels, so shape -> [N, 1, H, W]
    t_flat = t.view(-1, 1, H, W)
    t_l = F.avg_pool2d(t_flat, kernel_size=k, stride=1, padding=k // 2)
    t_l = F.interpolate(t_l, size=(H, W), mode="bilinear", align_corners=False)

    return t_l.view_as(t)


@torch.no_grad()
def inject_grad_noise_large_batch(
    model,
    step,
    batch_size,
    total_samples,
    base_std=1,
    gamma=0.55,
    lp_kernel: int = 3
):
    # Large-batch SGD noise recovery term
    gns_scale = (1.0 / batch_size - 1.0 / total_samples) ** 0.5
    decay = 1.0 / ((1 + step) ** gamma)

    for p in model.parameters():
        if p.grad is None:
            continue

        # Skip bias / norm (critical for ViT & BN stability)
        if p.ndim == 1:
            continue


        grad_std = p.grad.std().clamp(min=1e-6)

        noise = torch.randn_like(p.grad) * base_std * gns_scale * grad_std * decay

        if lp_kernel is not None and lp_kernel > 1:
            noise = lowpass_param_2d(noise, k=lp_kernel)

        p.grad.add_(noise)



In [6]:
torch.cuda.empty_cache()

G_GNS = []
G_SNR = []


def gradient_noise_scale(grad):
    tr_sigma = grad.var(0, unbiased=False).sum()
    return tr_sigma / (grad.mean(0).pow(2).sum() + 1e-12)

def signal_to_noise_ratio(grad):
    mu = grad.mean(0)
    tr_sigma = grad.var(0, unbiased=False).sum()
    return mu.norm() / (tr_sigma.sqrt() + 1e-12)


torch.cuda.synchronize()
total_timer_start.record()
for i in range(num_epochs):
    train_logits = []
    train_y = []
    

    val_logits = []
    val_y = []
    partile_time = 0
    controller.traced_model.train()
    # model.train()
    # aot_model.train()
    G = []
    torch.cuda.synchronize()
    # k = i % total_graphs
    for step, (x, y) in enumerate(train_loader):
        
        # torch.cuda.current_stream().wait_stream(warmup)  
        x, y = x.to('cuda', non_blocking=True), y.to('cuda', non_blocking=True)
        # x = transform_train(x)
        
        e_timer_start.record()
        # compute_stream.wait_stream(torch.cuda.current_stream())
        # with torch.cuda.stream(compute_stream):
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True):
                # static_x.copy_(x.to('cuda'))
                # static_y.copy_(y.to('cuda'))
                # opt.zero_grad(set_to_none=False)

                # total_steps = i * len(train_loader) + step
                # k = step % total_graphs
                # add_smoothout_noise(controller.traced_model)
                # assign_theta(controller.traced_model)
                # graphs[0].replay()
                # remove_smoothout_noise(controller.traced_model)
                # opt.step()
                # scheduler.step()   
                
                # train_logits.append(logits.detach().cpu())
                # train_y.append(static_y.detach().cpu())  

                # for m in controller.traced_model.modules():
                #     if hasattr(m, "qdrop"):
                #         temp = torch.rand(1).item()
                #         m.qdrop.copy_(temp)

            opt.zero_grad()
            logits = controller.traced_model(x)
            # logits = model(x)
            loss = criterion(logits, y)
            # inject_grad_noise_large_batch(controller.traced_model, step, batch_size, len(train_loader))
            # inject_grad_noise_large_batch(model, step, batch_size, len(train_loader))
            loss.backward()
            opt.step()
            # scheduler.step()      
            
            # if step % 20 == 0:
            #     temp = torch.nn.utils.parameters_to_vector(p.reshape(-1) for p in controller.traced_model.parameters() if p.grad is not None)
            #     # temp = torch.nn.utils.parameters_to_vector(p.reshape(-1) for p in model.parameters() if p.grad is not None)
            #     G.append(temp)     
                
            e_timer_end.record()
            torch.cuda.synchronize()
            partile_time += e_timer_start.elapsed_time(e_timer_end)

            train_logits.append(logits.detach().cpu())
            train_y.append(y.detach().cpu())

                # print("train logits mean:", logits.mean().item())
                # print("train logits std:", logits.std().item())

    # partile_time = e_timer_start.elapsed_time(e_timer_end)
    # total_time += e_timer_start.elapsed_time(e_timer_end)

    train_logits = torch.cat(train_logits)
    train_y = torch.cat(train_y)
    computed_acc = acc(train_logits, train_y)
    throughtout = len(train_dataset) / (partile_time / 1000)

    # temp_G = torch.cat(G)
    # GNS = gradient_noise_scale(temp_G)
    # SNR = signal_to_noise_ratio(temp_G)

    print(f'Epoch: {i}')
    print(f"Train Loss: {loss}")
    print(f"Learning Rate: {scheduler.get_last_lr()[0]}")
    print(f"Train Accuracy: {computed_acc}")
    print(f'Peak Mem Reserved: {torch.cuda.max_memory_reserved()}')
    print(f'Peak Mem Allocated: {torch.cuda.max_memory_allocated()}')
    print(f'Current train time: {partile_time / 1000} s')
    print(f"Throughout: {throughtout} samples per second")
    # print(f'GNS: {GNS}')
    # print(f'SNR: {SNR}')
    # G_GNS.append(GNS)
    # G_SNR.append(SNR)
    scheduler.step()

    train_logits = []
    train_y = []

    controller.traced_model.eval()
    # model.eval()
    # aot_model.eval()
    with torch.compiler.set_stance("force_eager"):
        with torch.no_grad():
            for x_val, y_val in val_loader:
                x_val, y_val = x_val.to('cuda'), y_val.to('cuda')

                with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True):
                    y_preds = controller.traced_model(x_val)
                    # y_preds = model(x_val)
                val_loss = F.cross_entropy(y_preds, y_val)

                val_logits.append(y_preds.detach().cpu())
                val_y.append(y_val.detach().cpu())

    val_logits = torch.cat(val_logits)
    val_y = torch.cat(val_y)

    computed_acc = acc(val_logits, val_y)
    print(f"Val Loss: {val_loss}")
    print(f"Val Accuracy: {computed_acc}\n\n")


total_timer_end.record()
torch.cuda.synchronize()


full_time = total_timer_start.elapsed_time(total_timer_end)

Epoch: 0
Train Loss: 2.0698013305664062
Learning Rate: 0.00030000000000000003
Train Accuracy: 0.15378844738006592
Peak Mem Reserved: 33395048448
Peak Mem Allocated: 32231807488
Current train time: 26.066883728027342 s
Throughout: 1918.1425950904734 samples per second




Val Loss: 2.296875
Val Accuracy: 0.23899999260902405


Epoch: 1
Train Loss: 2.0035934448242188
Learning Rate: 0.000435
Train Accuracy: 0.221540629863739
Peak Mem Reserved: 33416019968
Peak Mem Allocated: 32241442816
Current train time: 25.03799557495117 s
Throughout: 1996.9649667172891 samples per second
Val Loss: 2.109375
Val Accuracy: 0.31049999594688416


Epoch: 2
Train Loss: 1.8642311096191406
Learning Rate: 0.00057
Train Accuracy: 0.2683746814727783
Peak Mem Reserved: 33416019968
Peak Mem Allocated: 32241442816
Current train time: 24.93927995300293 s
Throughout: 2004.869430642063 samples per second
Val Loss: 1.8046875
Val Accuracy: 0.35919997096061707


Epoch: 3
Train Loss: 1.8989295959472656
Learning Rate: 0.000705
Train Accuracy: 0.2811632454395294
Peak Mem Reserved: 33416019968
Peak Mem Allocated: 32241442816
Current train time: 24.937581253051757 s
Throughout: 2005.0059984819582 samples per second
Val Loss: 1.71875
Val Accuracy: 0.39480000734329224


Epoch: 4
Train Loss: 1.860

KeyboardInterrupt: 

In [7]:
controller.meta

{'group_size': 256,
 'fp8': False,
 'patch_embed.proj': {'bits': 2,
  'group_size': 512,
  'act_padding': False,
  'N': 512,
  'DIVISION': {'pool_kernel_size': 3},
  'pack_only': False},
 'batch_size': 512,
 'blocks.0.attn.qkv': {'bits': 2,
  'group_size': 256,
  'act_padding': False,
  'N': 512,
  'DIVISION': None,
  'pack_only': False},
 'blocks.0.attn.proj': {'bits': 2,
  'group_size': 256,
  'act_padding': False,
  'N': 512,
  'DIVISION': None,
  'pack_only': False},
 'blocks.0.mlp.fc1': {'bits': 2,
  'group_size': 256,
  'act_padding': False,
  'N': 512,
  'DIVISION': None,
  'pack_only': False},
 'blocks.0.mlp.fc2': {'bits': 2,
  'group_size': 512,
  'act_padding': False,
  'N': 512,
  'DIVISION': {'pool_kernel_size': 3},
  'pack_only': False},
 'blocks.1.attn.qkv': {'bits': 2,
  'group_size': 256,
  'act_padding': False,
  'N': 512,
  'DIVISION': None,
  'pack_only': False},
 'blocks.1.attn.proj': {'bits': 2,
  'group_size': 256,
  'act_padding': False,
  'N': 512,
  'DIVISION':

In [None]:
x, y = next(iter(train_loader))
x, y = x.to('cuda'), y.to('cuda')

In [None]:
with profile(
    activities=[
        ProfilerActivity.CPU,
        ProfilerActivity.CUDA
    ],
    record_shapes=True,
    with_stack=True,
    profile_memory=True,
) as prof:

    for step in range(10):
        with record_function("train_step"):
            with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
                output = controller.traced_model(x)
                loss = criterion(output, y)
                loss.backward()

torch.cuda.synchronize()

print(prof.key_averages().table(
    sort_by="cuda_time_total",
    row_limit=200
))
