In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
    # "backend:cudaMallocAsync,"
    "expandable_segments:True,"
    # "garbage_collection_threshold:0.6"
)
os.environ["TORCH_HOME"] = "/home/hice1/yyu496/scratch/torch_cache"

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import (
    LinearLR,
    CosineAnnealingLR,
    SequentialLR
)

import torchvision.models as models
import torchvision.transforms.v2 as v2
from torchvision.datasets import ImageFolder
from torchvision import datasets

from torchmetrics import Accuracy
import timm

from freq_utils import radial_spectrum_2d, temp_kl, log_huber_loss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size = 512
warmup_epochs = 10
num_epochs = 512



# ============= Data ==================
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

transform_train = v2.Compose([
    # --- Convert to tensor FIRST (kills PIL early) ---
    v2.ToImage(),                          # handles PIL â†’ Tensor safely
    v2.ToDtype(torch.float32, scale=True),

    v2.Resize((224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomApply([
        v2.ColorJitter(
            brightness=0.2,
            contrast=0.2,
            saturation=0.2,
            hue=0.1
        )
    ], p=0.5),
    v2.RandAugment(num_ops=5, magnitude=9),
    v2.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    v2.RandomErasing(p=0.5, scale=(0.02, 0.33),
                     ratio=(0.3, 3.3), value=0),
])



val_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),

    v2.Resize((224, 224), antialias=True),
    v2.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

train_dataset = datasets.CIFAR100(root="./data", train=True,
                                 download=True, transform=transform_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          num_workers=15, pin_memory=True, drop_last=True,
                          persistent_workers=True, prefetch_factor=3)



val_dataset = datasets.CIFAR100(
    root="./data",
    train=False,
    download=True,
    transform=val_transform
)

val_loader = DataLoader(
    val_dataset,
    batch_size=512,
    shuffle=False,   
    num_workers=15,
    pin_memory=True,
    persistent_workers=True, 
    prefetch_factor=3
)

In [3]:
act_cache = {}
def forward_hook(name):
    def hook(module, input, output):
        global act_cache
        act_cache[name] = output
    return hook

In [7]:
num_classes = 100


model = models.efficientnet_b0(weights=None)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

teacher_model = timm.create_model('efficientnet_b0.ra_in1k', pretrained=True, num_classes=num_classes)
teacher_model.load_state_dict(torch.load("/home/hice1/yyu496/scratch/Model_Checkpoint/efficientnet_b0_bs_512_teacher_cifar100.pth"))

model.cuda()
teacher_model.cuda()


criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4 * 8, fused=True, capturable=True)
# optimizer = torch.optim.AdamW(teacher_model.parameters(), lr=1e-4 * 8, fused=True, capturable=True)

warmup_scheduler = LinearLR(optimizer, start_factor=0.5, total_iters=warmup_epochs)
main_scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs - warmup_epochs)
scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, main_scheduler], milestones=[warmup_epochs])

for name, module in model.named_modules():
    if isinstance(module, (nn.Conv2d, nn.Conv1d)):
        if 'features.8.0' in name or 'features.1.0.block.2.0' in name:
            m = module.register_forward_hook(forward_hook(name))
            print(module)
        # print(name)

for name, module in teacher_model.named_modules():
    if isinstance(module, (nn.Conv2d, nn.Conv1d)):
        if 'conv_head' in name or 'blocks.0.0.conv_pw' in name:
            m = module.register_forward_hook(forward_hook(name))
            print(module)
        # print(name)

# features.1.0.block.2.0
# blocks.0.0.conv_pw


Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [8]:
acc = Accuracy(task='multiclass', num_classes=num_classes, average='macro')

total_timer_start = torch.cuda.Event(enable_timing=True)
total_timer_end = torch.cuda.Event(enable_timing=True)
total_time = 0.0

e_timer_start = torch.cuda.Event(enable_timing=True)
e_timer_end = torch.cuda.Event(enable_timing=True)
partile_time = 0.0

In [9]:
def checkpoint_save_helper(model, acc, max_acc, opt):
    if acc >= max_acc:
        max_acc = acc
        torch.save(model.state_dict(), "/home/hice1/yyu496/scratch/Model_Checkpoint/efficientnet_b0_bs_512_student_cifar100.pth")
        torch.save(opt.state_dict(), '/home/hice1/yyu496/scratch/Model_Checkpoint/efficientnet_b0_bs_512_opt_student_cifar100_.pth')
        return max_acc
    return max_acc

In [10]:
x, y = next(iter(train_loader))
x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)

with torch.no_grad():
    teacher_logits = teacher_model(x)

teacher_act_spectrum1 = radial_spectrum_2d(act_cache['conv_head'], raw=False, num_rad_bins=7, mode='magnitude')
teacher_act_spectrum2 = radial_spectrum_2d(act_cache['blocks.0.0.conv_pw'], raw=False, num_rad_bins=7, mode='magnitude')


In [None]:
max_acc = 0.0
global_step = 0
torch.cuda.synchronize()
total_timer_start.record()
for i in range(num_epochs):
    train_logits = []
    train_y = []
    

    val_logits = []
    val_y = []
    partile_time = 0
    # controller.traced_model.train()
    model.train()
    # teacher_model.train()
    G = []
    torch.cuda.synchronize()
    for step, (x, y) in enumerate(train_loader):
        
        # torch.cuda.current_stream().wait_stream(warmup)  
        x, y = x.to('cuda', non_blocking=True), y.to('cuda', non_blocking=True)
        
        e_timer_start.record()
        # compute_stream.wait_stream(torch.cuda.current_stream())
        # with torch.cuda.stream(compute_stream):
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True):
                # static_x.copy_(x)
                # static_y.copy_(y)
                # opt.zero_grad(set_to_none=False)

                # graphs[0].replay()
                
                # train_logits.append(logits.detach().cpu())
                # train_y.append(static_y.detach().cpu())  

                # for m in controller.traced_model.modules():
                #     if hasattr(m, "qdrop"):
                #         temp = torch.rand(1).item()
                #         m.qdrop.copy_(temp)

            optimizer.zero_grad()
            # raw_logits = controller.traced_model(x)
            raw_logits = model(x)
            # raw_logits = teacher_model(x)

            if isinstance(raw_logits, tuple):
                logits = raw_logits[0]
                freq_maps = raw_logits[1]
            else:
                logits = raw_logits
                freq_maps = None

            CE_loss = criterion(logits, y)

            # 'features.8.0' in name or 'features.1.0.block.2.0'
            # 'features.8.0' in name or "features.6.2.block.0.0" in name or "features.5.1.block.0.0"
            model_spectrum1 = radial_spectrum_2d(act_cache['features.8.0'], raw=False, num_rad_bins=7, mode='magnitude')
            model_spectrum2 = radial_spectrum_2d(act_cache['features.1.0.block.2.0'], raw=False, num_rad_bins=7, mode='magnitude')
            # model_spectrum1 = radial_spectrum_2d(freq_maps["freq_token"], num_rad_bins=16, num_ang_bins=8)
            # model_spectrum2 = radial_spectrum_2d(freq_maps["patch_token"], num_rad_bins=16, num_ang_bins=8)
            # model_spectrum1 = radial_spectrum_2d(freq_maps, num_bins=8)
            # model_spectrum2 = radial_spectrum_2d(act_cache['features.1.0.block.2.0'], num_bins=32)
            # model_spectrum3 = radial_spectrum_2d(act_cache['features.5.1.block.0.0'], num_bins=16)
            # target_spectrum = radial_spectrum_2d(x, num_rad_bins=7, num_ang_bins=1, mode='magnitude')
            # target_spectrum2 = radial_spectrum_2d(x, num_rad_bins=32, num_ang_bins=32)
            # model_spectrum = radial_spectrum_2d_channelwise_dct(act, num_bins=16)
            # target_spectrum = radial_spectrum_2d_channelwise_dct(x, num_bins=16).detach()

            # freq_loss1 = wasserstein(model_spectrum1, target_spectrum)
            freq_loss1 = log_huber_loss(model_spectrum1, teacher_act_spectrum1, delta=0.1)
            freq_loss2 = log_huber_loss(model_spectrum2, teacher_act_spectrum2, delta=0.1)
            # freq_loss1 = temp_kl(teacher_act_spectrum, model_spectrum1)
            # freq_loss2 = wasserstein_1d(model_spectrum2, target_spectrum)
            # freq_loss3 = wasserstein_1d(model_spectrum3, target_spectrum)
            # freq_loss1 = kl_div_spectrum(target_spectrum, model_spectrum1)
            # freq_loss = temp_kl(target_spectrum, model_spectrum)
            # freq_loss = vit_spectral_slope_loss(act)
            # freq_loss = local_freq_kl_loss(act, x)
            # freq_loss = act_freq_loss(act, x, num_bins=8)
            # freq_loss1 = freq_maps["anchor_loss"]
            # freq_loss1 = freq_maps.get("anchor_loss",0) + freq_maps.get("usage_reg",0)

            # freq_loss = polar_wasserstein_rtheta(model_spectrum, target_spectrum)
            
            # global_step += 1
            # w = freq_weight(global_step, warmup_steps=100, max_w=1.0)

            # cache = PolarBinCache()

            # polar_s = polar_spectrum_2d(
            #     act, num_bins=32, num_theta_bins=64, remove_dc=True, power=True, bin_cache=cache
            # )
            # polar_t = polar_spectrum_2d(
            #     x, num_bins=32, num_theta_bins=64, remove_dc=True, power=True, bin_cache=cache
            # )

            # freq_loss = polar_kl_loss(polar_s, polar_t, T=1.5, k_low=12)       # strong + fast
            # freq_loss = lf_hf_logratio_loss(polar_s, polar_t, k_low=12)

            w = 1.0
            # w = w * (CE_loss.detach() / (freq_loss.detach() + 1e-8))  # keep terms balanced
            loss = CE_loss + w * freq_loss1 + w * freq_loss2
            # freq_loss.backward()
            # loss = CE_loss
            # loss = freq_loss1
            

            # debug_check_grads_and_step(controller.traced_model, opt)
            # inject_grad_noise_large_batch(controller.traced_model, step, batch_size, len(train_loader))
            # inject_grad_noise_large_batch(model, step, batch_size, len(train_loader))
            loss.backward()
            optimizer.step()
            # scheduler.step()     
            
            # if step % 20 == 0:
            #     temp = torch.nn.utils.parameters_to_vector(p.reshape(-1) for p in controller.traced_model.parameters() if p.grad is not None)
            #     # temp = torch.nn.utils.parameters_to_vector(p.reshape(-1) for p in model.parameters() if p.grad is not None)
            #     G.append(temp)     
                
            e_timer_end.record()
            torch.cuda.synchronize()
            partile_time += e_timer_start.elapsed_time(e_timer_end)

            train_logits.append(logits.detach().cpu())
            train_y.append(y.detach().cpu())

                # print("train logits mean:", logits.mean().item())
                # print("train logits std:", logits.std().item())

    # partile_time = e_timer_start.elapsed_time(e_timer_end)
    # total_time += e_timer_start.elapsed_time(e_timer_end)


    train_logits = torch.cat(train_logits)
    train_y = torch.cat(train_y)
    computed_acc = acc(train_logits, train_y)
    throughtout = 50000 / (partile_time / 1000)

    # temp_G = torch.cat(G)
    # GNS = gradient_noise_scale(temp_G)
    # SNR = signal_to_noise_ratio(temp_G)

    print(f'Epoch: {i}')
    print(f"Train Loss: {loss}")
    print(f"CE Loss: {CE_loss.item()}, Freq(freq) Loss1: {freq_loss1.item() * w}, Freq(freq) Loss2: {freq_loss2.item() * w} ")
    # print(f"CE Loss: {CE_loss.item()}")
    print(f"Model ACT Freq Bins: {model_spectrum1}")
    print(f"Target Freq Bins: {teacher_act_spectrum1}")
    print(f"Model ACT Freq Bins2: {model_spectrum2}")
    print(f"Target Freq Bins2: {teacher_act_spectrum2}")
    print(f"Learning Rate: {scheduler.get_last_lr()[0]}")
    print(f"Train Accuracy: {computed_acc}")
    print(f'Peak Mem Reserved: {torch.cuda.max_memory_reserved()}')
    print(f'Peak Mem Allocated: {torch.cuda.max_memory_allocated()}')
    print(f'Current train time: {partile_time / 1000} s')
    print(f"Throughout: {throughtout} samples per second")
    # for name, m in controller.traced_model.named_modules():
    #     if isinstance(m, DOBatchNormReLU2d):
    #         if not torch.isfinite(m.running_mean).all() or not torch.isfinite(m.running_var).all():
    #             print("BAD BN:", name)
    #             print("mean finite:", torch.isfinite(m.running_mean).all().item(),
    #                 "var finite:", torch.isfinite(m.running_var).all().item(),
    #                 "var min:", m.running_var.min().item())
    #             break
    # print(f'GNS: {GNS}')
    # print(f'SNR: {SNR}')
    # G_GNS.append(GNS)
    # G_SNR.append(SNR)
    scheduler.step()
    # update_lr(controller.traced_model, scheduler)

    train_logits = []
    train_y = []

    # controller.traced_model.eval()
    model.eval()
    # teacher_model.eval()
    with torch.compiler.set_stance("force_eager"):
        with torch.no_grad():
            for x_val, y_val in val_loader:
                x_val, y_val = x_val.to('cuda'), y_val.to('cuda')

                with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True):
                    # raw_y_preds  = controller.traced_model(x_val)
                    raw_y_preds = model(x_val)
                    # raw_y_preds = teacher_model(x_val)

                    if isinstance(raw_y_preds, tuple):
                        y_preds = raw_y_preds[0]
                    else:
                        y_preds = raw_y_preds

                val_loss = F.cross_entropy(y_preds, y_val)

                val_logits.append(y_preds.detach().cpu())
                val_y.append(y_val.detach().cpu())

    val_logits = torch.cat(val_logits)
    val_y = torch.cat(val_y)
    

    valid_computed_acc = acc(val_logits, val_y)
    max_acc = checkpoint_save_helper(model, valid_computed_acc, max_acc, optimizer)
    # max_acc = checkpoint_save_helper(teacher_model, valid_computed_acc, max_acc, optimizer)
    print(f"Val Loss: {val_loss}")
    print(f"Val Accuracy: {valid_computed_acc}\n\n")


total_timer_end.record()
torch.cuda.synchronize()


full_time = total_timer_start.elapsed_time(total_timer_end)

Epoch: 0
Train Loss: 4.46845817565918
CE Loss: 4.4681243896484375, Freq(freq) Loss1: 0.0002053702191915363, Freq(freq) Loss2: 0.00012830065679736435 
Model ACT Freq Bins: tensor([9.1713e-01, 6.9896e-02, 9.9915e-03, 6.0615e-04, 1.8983e-03, 3.2765e-04,
        1.4551e-04], device='cuda:0', grad_fn=<ViewBackward0>)
Target Freq Bins: tensor([8.4898e-01, 9.8057e-02, 3.9212e-02, 9.0733e-03, 4.2394e-03, 3.1231e-04,
        1.2754e-04], device='cuda:0')
Model ACT Freq Bins2: tensor([9.3732e-01, 2.7976e-02, 1.5269e-02, 1.3102e-02, 5.7089e-03, 6.1556e-04,
        3.3476e-06], device='cuda:0', grad_fn=<ViewBackward0>)
Target Freq Bins2: tensor([9.9367e-01, 2.4138e-03, 1.7728e-03, 1.3765e-03, 6.5403e-04, 1.1379e-04,
        8.0846e-07], device='cuda:0')
Learning Rate: 0.0004
Train Accuracy: 0.015784112736582756
Peak Mem Reserved: 25943867392
Peak Mem Allocated: 25752589312
Current train time: 27.74788525390625 s
Throughout: 1801.939122296218 samples per second
Val Loss: 4.46875
Val Accuracy: 0.025



Val Loss: 1.78125
Val Accuracy: 0.5157000422477722


Epoch: 10
Train Loss: 2.1031758785247803
CE Loss: 2.102389097213745, Freq(freq) Loss1: 3.629889033618383e-05, Freq(freq) Loss2: 0.0007506452966481447 
Model ACT Freq Bins: tensor([8.5551e-01, 1.1466e-01, 2.3856e-02, 2.0590e-03, 3.7270e-03, 1.4750e-04,
        3.9772e-05], device='cuda:0', grad_fn=<ViewBackward0>)
Target Freq Bins: tensor([8.4898e-01, 9.8057e-02, 3.9212e-02, 9.0733e-03, 4.2394e-03, 3.1231e-04,
        1.2754e-04], device='cuda:0')
Model ACT Freq Bins2: tensor([8.6797e-01, 7.7040e-02, 3.1310e-02, 1.6893e-02, 6.2042e-03, 5.8184e-04,
        3.3855e-06], device='cuda:0', grad_fn=<ViewBackward0>)
Target Freq Bins2: tensor([9.9367e-01, 2.4138e-03, 1.7728e-03, 1.3765e-03, 6.5403e-04, 1.1379e-04,
        8.0846e-07], device='cuda:0')
Learning Rate: 0.0008
Train Accuracy: 0.4246404767036438
Peak Mem Reserved: 26132611072
Peak Mem Allocated: 25916421632
Current train time: 26.88254458618164 s
Throughout: 1859.9429767411734 sam