In [None]:
import torch, torch.nn as nn, torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms, models
from datetime import datetime
import os

In [None]:
def accuracy(output, target):
  with torch.no_grad(): #we don't need to compute gradients for backpropagation-->just checking accuracy
    preds = output.argmax(dim=-1)
    return (preds == target).float().mean().item()

In [None]:
AVERAGE_RED = 0.4914
AVERAGE_GREEN = 0.4822
AVERAGE_BLUE = 0.4465
STDDEV_RED = 0.2470
STDDEV_GREEN = 0.2435
STDDEV_BLUE = 0.2616


In [None]:
def get_dataloaders(data_dir: str, batch_size: int, num_workers: int = 2):
  mean = (AVERAGE_RED, AVERAGE_GREEN, AVERAGE_BLUE)
  std = (STDDEV_RED, STDDEV_GREEN, STDDEV_BLUE)

  train_pipeline = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize(mean, std)

  ])

  test_pipeline = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize(mean, std)
  ])

  train = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_pipeline)
  test = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=test_pipeline)

  train_loader = torch.utils.data.DataLoader(dataset= train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
  test_loader = torch.utils.data.DataLoader(dataset= test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

  return train_loader, test_loader

In [None]:
#Create a ResNet-18 network architecture


def build_model(num_classes:int=10):
  model = models.resnet18(weights = None)
  model.fc = nn.Linear(model.fc.in_features, num_classes)
  return model

In [None]:
from contextlib import nullcontext

def train_epoch(model,loader,device,criterion, optimizer, scaler=None):
  model.train() #put model into training mode
  loss_sum = 0.0
  acc_sum = 0.0
  n = 0

  use_cuda_amp = (device.type == "cuda") and (scaler is not None)
  autocast_ctx = amp.autocast(device_type="cuda", enabled=True) if use_cuda_amp else nullcontext()

  #loop through batches
  for x, y in loader: #x = images, y = labels
    x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
    #.to(device) moves data to GPU
    #non_blocking=True => lets transfers overlap with compute

    with autocast_ctx:
      out = model(x)
      loss = criterion(out,y)

    if use_cuda_amp:
      scaler.scale(loss).backward()
      scaler.step(optimizer)
      scaler.update()
    else:
      loss.backward()
      optimizer.step()

    bs = y.size(0)
    n += bs
    loss_sum += loss.item() * bs
    acc_sum += accuracy(out, y) * bs

  return loss_sum/n, acc_sum/n

#return loss_sum/n, acc_sum/n so the averages are weighted
#by how many n were in each batch
#this gives true overall loss and accuracy of the epoch

In [None]:
@torch.no_grad()
def evaluate(model, loader, device, criterion):
  model.eval()
  loss_sum = 0.0
  acc_sum = 0.0
  n = 0

  for x, y in loader:
    x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)

    outputs = model(x)
    loss = criterion(outputs, y)

    bs = y.size(0)
    n += bs
    loss_sum += loss.item() * bs
    acc_sum += accuracy(outputs,y) * bs

  epoch_loss = loss_sum / n
  epoch_acc = acc_sum / n

  return epoch_loss, epoch_acc

In [None]:
def save_model(model, out_dir):
  #defines a helper function to save trained model in PyTorch and TorchScript
  os.makedirs(out_dir, exist_ok=True)
  model_path = os.path.join

  stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  ckpt = os.path.join(out_dir, f"resnet18_cifar10_{stamp}.pth")
  #had to search up how to create a unique timestamp and bui  ld checkpoint path

  torch.save(model.state_dict(), ckpt)
  #saves model's state_dict (all learned weights/biases)

  example = torch.randn(1,3,32,32,
                        device=next(model.parameters()).device)
  #make a dummy input tensor shaped like CIFAR-10 Data (1 image, 3 channels, 32x32)

  ts = torch.jit.trace(model.eval(), example)
    #convert model into TorchScript
    #TorchScript = serialized, optimized version of model
    #makes it portable to environments w/o Python (C++)
  ts_path = ckpt.replace(".pth", ".ts.pt")
    #build a filename for TorchScript export
  ts.save(ts_path)
    #save TorchScript model to disk
  return ckpt, ts_path
    #returns both file paths



In [None]:
from torch import amp
#Time to train the ResNet-18 on CIFAR-10


EPOCHS = 60
 #passes through 50k images in training set 20 times
BATCH_SIZE = 128 #start w/ 64 or 128 rule of thumb. How many images go into model before weights updated
LR = 0.1 # learning-rate (step size for weight updates)
WEIGHT_DECAY = 5e-4 #technique to prevent overfitting model regularization parameter (lambda)

#'cuda' = GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"using device: {device}")

train_loader, test_loader = get_dataloaders("/content/data", batch_size = BATCH_SIZE, num_workers = 2)
model = build_model()
model.to(device) #moves all model weights onto that device

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
#loss function for multi-classification

#Stochastic Gradient Descent: unlike traditional gradient descent
#which uses entire dataset to compute the gradient, SGD updates the model
#parameters using only a single randomly selected training example of batch
#allows for randomness into the optimization process ==> "Stochastic"
optimizer = optim.SGD(model.parameters(),lr=LR,momentum =0.9,weight_decay=WEIGHT_DECAY)
#model.parameters() = all the numbers the model is allowed to change during training

scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
#cosine shape makes smooth decay in scheduling learning rate (big -> small)

#Other Schedules:
  # - StepLR: drops LR by a factor every N epochs (piecewise constant)
  # - ExponentialLR: decays LR exponentially
  # - ReduceLROnPlateau: lowers LR when validation accuracy stops improving
  # - CosineAnnealingLR: smooth,gradual, widely used for ResNets on CIFAR/ImageNet

scaler = amp.GradScaler(device='cuda', enabled=(device.type == "cuda"))


best_accuracy = 0.0
best_paths = (None, None)
for epoch in range(1, EPOCHS+1):
  tr_loss, tr_acc = train_epoch(model, train_loader, device,
                                    criterion, optimizer, scaler)
  #run one full training pass: forward -> loss -> backward -> optimizer step
  va_loss, va_acc = evaluate(model, test_loader, device, criterion)
  #run eval on test loader: no gradients, returns avg validation loss/accuracy
  scheduler.step()
  #updates (reduces) learning rate for next epoch


  if va_acc > best_accuracy:
    best_accuracy = va_acc
    ckpt, ts_path = save_model(model, "/content/artifacts")  # <-- ensure function name matches your helpers
    best_paths = (ckpt, ts_path)
    print("Saved BEST:", ckpt)

  #formatted progress log
  print(f"Epoch {epoch:03d} | train_loss {tr_loss:.4f} acc {tr_acc:.4f} | val_loss {va_loss:.4f} acc {va_acc:.4f} | best {best_accuracy:.4f}")


print("Best val_acc:", best_accuracy)
print("Best artifacts:", best_paths)

using device: cuda


100%|██████████| 170M/170M [00:04<00:00, 39.3MB/s]


Saved BEST: /content/artifacts/resnet18_cifar10_20260115_041418.pth
Epoch 001 | train_loss 2.2011 acc 0.2919 | val_loss 1.8362 acc 0.4067 | best 0.4067
Saved BEST: /content/artifacts/resnet18_cifar10_20260115_041442.pth
Epoch 002 | train_loss 1.6982 acc 0.4406 | val_loss 1.5922 acc 0.4892 | best 0.4892
Saved BEST: /content/artifacts/resnet18_cifar10_20260115_041504.pth
Epoch 003 | train_loss 1.5261 acc 0.5254 | val_loss 1.3748 acc 0.6015 | best 0.6015
Saved BEST: /content/artifacts/resnet18_cifar10_20260115_041527.pth
Epoch 004 | train_loss 1.3949 acc 0.5941 | val_loss 1.2924 acc 0.6411 | best 0.6411
Saved BEST: /content/artifacts/resnet18_cifar10_20260115_041550.pth
Epoch 005 | train_loss 1.3210 acc 0.6304 | val_loss 1.2534 acc 0.6582 | best 0.6582
Epoch 006 | train_loss 1.2648 acc 0.6561 | val_loss 1.2758 acc 0.6564 | best 0.6582
Saved BEST: /content/artifacts/resnet18_cifar10_20260115_041637.pth
Epoch 007 | train_loss 1.2195 acc 0.6793 | val_loss 1.2303 acc 0.6742 | best 0.6742
Save

In [None]:
import io, copy, time
import torch.nn.utils.prune as prune

# Use this for pruning fine-tune and training too

def train_epoch_fixed(model, loader, device, criterion, optimizer, scaler=None):
    model.train()
    loss_sum, acc_sum, n = 0.0, 0.0, 0

    use_cuda_amp = (device.type == "cuda") and (scaler is not None)
    autocast_ctx = amp.autocast(device_type="cuda", enabled=True) if use_cuda_amp else nullcontext()

    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with autocast_ctx:
            out = model(x)
            loss = criterion(out, y)

        if use_cuda_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        bs = y.size(0)
        n += bs
        loss_sum += loss.item() * bs
        # assumes your accuracy(output, target) exists
        acc_sum += accuracy(out, y) * bs

    return loss_sum / n, acc_sum / n


# -----------------------------
# Utilities: size + latency
# -----------------------------
def model_size_mb_via_torchsave(model) -> float:
    buf = io.BytesIO()
    torch.save(model.state_dict(), buf)
    return len(buf.getbuffer()) / (1024 ** 2)

@torch.inference_mode()
def benchmark_cpu_latency_ms(model, iters=200, warmup=50, batch_size=1, num_threads=None):
    model = model.eval().cpu()
    if num_threads is not None:
        torch.set_num_threads(int(num_threads))

    x = torch.randn(batch_size, 3, 32, 32)  # CIFAR-10 shape
    # warmup
    for _ in range(warmup):
        _ = model(x)

    t0 = time.perf_counter()
    for _ in range(iters):
        _ = model(x)
    t1 = time.perf_counter()

    return (t1 - t0) * 1000.0 / iters


# Load best FP32 checkpoint

BEST_CKPT_PATH = '/content/artifacts/resnet18_cifar10_20260115_042156.pth'

device_train = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_cpu = torch.device("cpu")

# Use your existing loader helper
train_loader, test_loader = get_dataloaders("/content/data", batch_size=128, num_workers=2)

# Rebuild + load weights
model_fp32 = build_model()
model_fp32.load_state_dict(torch.load(BEST_CKPT_PATH, map_location="cpu"))
model_fp32 = model_fp32.to(device_train)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# Baseline eval (on CPU for apples-to-apples with PTQ)
model_fp32_cpu = copy.deepcopy(model_fp32).eval().cpu()
base_loss, base_acc = evaluate(model_fp32_cpu, test_loader, device_cpu, criterion)

base_size = model_size_mb_via_torchsave(model_fp32_cpu)
base_lat = benchmark_cpu_latency_ms(model_fp32_cpu, iters=200, warmup=50, batch_size=1, num_threads=1)

print(f"[BASE FP32] acc={base_acc:.4f} loss={base_loss:.4f} size={base_size:.1f}MB latency={base_lat:.2f}ms (1 thread)")


# A) PTQ INT8 (FX graph mode) + accuracy + size + latency
# Notes:
# - Static PTQ runs on CPU.
# - For best results, calibrate on a few thousand real images (no grads).

# Quantization imports (handles minor API moves)
try:
    from torch.ao.quantization import QConfigMapping, get_default_qconfig
    from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
except Exception:
    from torch.ao.quantization import QConfigMapping, get_default_qconfig
    from torch.quantization.quantize_fx import prepare_fx, convert_fx

# Pick quantized engine
supported = torch.backends.quantized.supported_engines
engine = "fbgemm" if "fbgemm" in supported else ("qnnpack" if "qnnpack" in supported else supported[0])
torch.backends.quantized.engine = engine
print("Quant engine:", engine)

model_to_quantize = copy.deepcopy(model_fp32_cpu).eval()

# QConfigMapping for PTQ
qconfig = get_default_qconfig(engine)
qconfig_mapping = QConfigMapping().set_global(qconfig)

example_inputs = (torch.randn(1, 3, 32, 32),)

prepared = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)

@torch.inference_mode()
def calibrate(prepared_model, loader, num_batches=100):
    prepared_model.eval()
    for i, (x, _) in enumerate(loader):
        x = x.to("cpu")
        _ = prepared_model(x)
        if i + 1 >= num_batches:
            break

calibrate(prepared, test_loader, num_batches=100)  # ~12.8k images if batch=128

model_int8 = convert_fx(prepared).eval().cpu()

int8_loss, int8_acc = evaluate(model_int8, test_loader, device_cpu, criterion)
int8_size = model_size_mb_via_torchsave(model_int8)
int8_lat = benchmark_cpu_latency_ms(model_int8, iters=200, warmup=50, batch_size=1, num_threads=1)

print(f"[PTQ INT8 FX] acc={int8_acc:.4f} loss={int8_loss:.4f} size={int8_size:.1f}MB latency={int8_lat:.2f}ms (1 thread)")
print(f"  Size shrink: {base_size/int8_size:.2f}x  |  Latency speedup: {base_lat/int8_lat:.2f}x  |  Acc drop: {(base_acc-int8_acc)*100:.2f} pts")

# Optional: save artifacts
os.makedirs("/content/artifacts_ptq", exist_ok=True)
torch.save(model_int8.state_dict(), "/content/artifacts_ptq/resnet18_cifar10_int8_fx_state_dict.pth")


# B) Structured pruning + fine-tune
# NOTE: torch.nn.utils.prune creates masks (structured sparsity).
# "Effective nonzero params" drop ~30%, but tensor shapes do NOT shrink unless you do channel-slimming surgery.

def apply_structured_pruning_resnet(model, amount=0.30):
    # Prune output channels (dim=0) for Conv2d and output features for Linear
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            prune.ln_structured(m, name="weight", amount=amount, n=2, dim=0)
        elif isinstance(m, nn.Linear):
            prune.ln_structured(m, name="weight", amount=amount, n=2, dim=0)
    return model

def effective_nonzero_params(model):
    # counts weights AFTER applying masks (if present)
    total, nonzero = 0, 0
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            if hasattr(m, "weight_mask") and hasattr(m, "weight_orig"):
                w = (m.weight_orig.detach() * m.weight_mask.detach())
            else:
                w = m.weight.detach()
            total += w.numel()
            nonzero += torch.count_nonzero(w).item()
            if m.bias is not None:
                b = m.bias.detach()
                total += b.numel()
                nonzero += torch.count_nonzero(b).item()
    return nonzero, total

# Start from FP32 weights on training device
model_pruned = copy.deepcopy(model_fp32).to(device_train)

# Apply pruning
model_pruned = apply_structured_pruning_resnet(model_pruned, amount=0.30)

nz, tot = effective_nonzero_params(model_pruned)
print(f"[PRUNE] effective nonzero params: {nz:,}/{tot:,}  (sparsity={(1-nz/tot)*100:.2f}%)")

# Fine-tune (small LR, few epochs)
FT_EPOCHS = 20
FT_LR = 0.01

optimizer_ft = optim.SGD(model_pruned.parameters(), lr=FT_LR, momentum=0.9, weight_decay=5e-4)
scheduler_ft = CosineAnnealingLR(optimizer_ft, T_max=FT_EPOCHS)
scaler_ft = amp.GradScaler(device='cuda', enabled=(device_train.type == "cuda"))

best_acc = 0.0
best_sd = None

for epoch in range(1, FT_EPOCHS + 1):
    tr_loss, tr_acc = train_epoch_fixed(model_pruned, train_loader, device_train, criterion, optimizer_ft, scaler_ft)
    va_loss, va_acc = evaluate(model_pruned.eval(), test_loader, device_train, criterion)
    scheduler_ft.step()

    if va_acc > best_acc:
        best_acc = va_acc
        best_sd = copy.deepcopy(model_pruned.state_dict())

    print(f"[FT {epoch:02d}/{FT_EPOCHS}] train_acc={tr_acc:.4f} val_acc={va_acc:.4f} best_val_acc={best_acc:.4f}")

# Load best fine-tuned pruned weights
if best_sd is not None:
    model_pruned.load_state_dict(best_sd)

# OPTIONAL: make pruning permanent (removes reparam, keeps zeros)
for m in model_pruned.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)) and hasattr(m, "weight_mask"):
        prune.remove(m, "weight")

# Final pruned accuracy (CPU eval if you want to compare like PTQ)
model_pruned_cpu = copy.deepcopy(model_pruned).eval().cpu()
pr_loss, pr_acc = evaluate(model_pruned_cpu, test_loader, device_cpu, criterion)

nz2, tot2 = effective_nonzero_params(model_pruned_cpu)  # after prune.remove, should still reflect zeros
print(f"[PRUNED+FT] acc={pr_acc:.4f} loss={pr_loss:.4f} | effective sparsity={(1-nz2/tot2)*100:.2f}% | acc drop={(base_acc-pr_acc)*100:.2f} pts")

# Optional save
os.makedirs("/content/artifacts_pruned", exist_ok=True)
torch.save(model_pruned_cpu.state_dict(), "/content/artifacts_pruned/resnet18_cifar10_pruned_ft_state_dict.pth")


[BASE FP32] acc=0.7866 loss=0.9974 size=42.7MB latency=8.69ms (1 thread)
Quant engine: fbgemm


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  prepared = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torc

[PTQ INT8 FX] acc=0.7891 loss=0.9969 size=10.8MB latency=2.59ms (1 thread)
  Size shrink: 3.96x  |  Latency speedup: 3.36x  |  Acc drop: -0.25 pts
[PRUNE] effective nonzero params: 7,814,369/11,172,042  (sparsity=30.05%)
[FT 01/20] train_acc=0.6876 val_acc=0.6910 best_val_acc=0.6910
[FT 02/20] train_acc=0.7016 val_acc=0.7029 best_val_acc=0.7029
[FT 03/20] train_acc=0.7080 val_acc=0.7050 best_val_acc=0.7050
[FT 04/20] train_acc=0.7138 val_acc=0.7003 best_val_acc=0.7050
[FT 05/20] train_acc=0.7161 val_acc=0.6972 best_val_acc=0.7050
[FT 06/20] train_acc=0.7202 val_acc=0.7002 best_val_acc=0.7050
[FT 07/20] train_acc=0.7207 val_acc=0.7110 best_val_acc=0.7110
[FT 08/20] train_acc=0.7212 val_acc=0.7072 best_val_acc=0.7110
[FT 09/20] train_acc=0.7254 val_acc=0.7120 best_val_acc=0.7120
[FT 10/20] train_acc=0.7309 val_acc=0.7157 best_val_acc=0.7157
[FT 11/20] train_acc=0.7305 val_acc=0.7180 best_val_acc=0.7180
[FT 12/20] train_acc=0.7390 val_acc=0.7112 best_val_acc=0.7180
[FT 13/20] train_acc=0.