In [None]:
import os
import time
import pandas as pd
from sklearn.model_selection import train_test_split
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from torch.jit import trace
from tqdm import tqdm

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
NUM_WORKERS = 4
LR_HEAD = 1e-3
LR_BACKBONE = 1e-5
WEIGHT_DECAY = 1e-4
EPOCHS_LP = 10
EPOCHS_FT = 8
UNFREEZE_BLOCKS = 2

In [4]:
meta_csv = "/home/zack/11785/project/data/HAM10000_dataset/HAM10000_metadata.csv"
img_dir  = "/home/zack/11785/project/data/HAM10000_dataset"

df = pd.read_csv(meta_csv)
df['image_id'] = df['image_id'].apply(lambda x: f"{x}.jpg")
df = df[df['dx'].notna()]

train_df, val_df = train_test_split(
    df, test_size=0.2, stratify=df['dx'], random_state=42
)

In [5]:
train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])
val_tfms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

In [6]:
class HAM10000Dataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.classes = sorted(self.df['dx'].unique())
        self.class_to_idx = {c:i for i,c in enumerate(self.classes)}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = os.path.join(self.img_dir, row['image_id'])
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = self.class_to_idx[row['dx']]
        return img, label

In [7]:
t_train = HAM10000Dataset(train_df, img_dir, transform=train_tfms)
t_val   = HAM10000Dataset(val_df,   img_dir, transform=val_tfms)
train_loader = DataLoader(t_train, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS)
val_loader   = DataLoader(t_val,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
NUM_CLASSES = len(t_train.classes)

In [None]:
backbone = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14").to(device)
feat_dim = backbone.embed_dim
num_blocks = len(backbone.blocks)

head = nn.Linear(feat_dim, NUM_CLASSES).to(device)

class DinoClassifier(nn.Module):
    def __init__(self, backbone, head):
        super().__init__()
        self.backbone = backbone
        self.head = head
    def forward(self, x):
        feats = self.backbone(x)
        return self.head(feats)
    

In [9]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for x, y in tqdm(loader, desc="Train", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(loader.dataset)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    for x, y in tqdm(loader, desc="Eval", leave=False):
        x, y = x.to(device), y.to(device)
        pred = model(x).argmax(dim=1)
        correct += (pred == y).sum().item()
    return correct / len(loader.dataset)

def measure_inference_time(model, device, runs=100):
    model.eval()
    example = torch.randn(1,3,224,224).to(device)
    for _ in range(10): _ = model(example)
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end   = torch.cuda.Event(enable_timing=True)
    start.record()
    for _ in range(runs): _ = model(example)
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end) / runs

In [None]:
# ========== Linear Probing ==========

for p in backbone.parameters(): p.requires_grad = False
for p in head.parameters():    p.requires_grad = True
model_lp = DinoClassifier(backbone, head).to(device)
optimizer_lp = optim.AdamW(head.parameters(), lr=LR_HEAD, weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss()
print("=== Linear Probing Training ===")
for epoch in range(EPOCHS_LP):
    loss = train_epoch(model_lp, train_loader, optimizer_lp, criterion, device)
    acc  = evaluate(model_lp, val_loader, device)
    print(f"Epoch {epoch+1}/{EPOCHS_LP} — loss: {loss:.4f}, val_acc: {acc:.4%}")
acc_lp  = evaluate(model_lp, val_loader, device)
time_lp = measure_inference_time(model_lp, device)
torch.jit.trace(model_lp.eval(), torch.randn(1,3,224,224).to(device)).save("model_lp.ts")
size_lp = os.path.getsize("model_lp.ts")/1e6

In [None]:
# ========== Partial Fine-tuning ==========
for p in backbone.parameters():  p.requires_grad = False
for blk in backbone.blocks[-UNFREEZE_BLOCKS:]:
    for p in blk.parameters(): p.requires_grad = True
for p in head.parameters(): p.requires_grad = True
params_ft = [
    {"params": head.parameters(),                             "lr": LR_HEAD},
    {"params": [p for p in backbone.parameters() if p.requires_grad], "lr": LR_BACKBONE},
]
optimizer_ft = optim.AdamW(params_ft, weight_decay=WEIGHT_DECAY)
print("=== Partial Fine-tuning ===")
for epoch in range(EPOCHS_FT):
    loss = train_epoch(model_lp, train_loader, optimizer_ft, criterion, device)
    acc  = evaluate(model_lp, val_loader, device)
    print(f"Epoch {epoch+1}/{EPOCHS_FT} — loss: {loss:.4f}, val_acc: {acc:.4%}")
acc_ft  = evaluate(model_lp, val_loader, device)
time_ft = measure_inference_time(model_lp, device)
model_ft = torch.jit.trace(model_lp.eval(), torch.randn(1,3,224,224).to(device))
model_ft.save("model_ft.ts")
size_ft = os.path.getsize("model_ft.ts")/1e6

In [None]:
# ========== ONNX + INT8 Quantization ==========
import torch
import onnx
import tensorrt as trt
import numpy as np
import os
import time
from tqdm import tqdm

model_ft = torch.jit.load("model_ft.ts", map_location='cpu')
model_ft.eval()


dummy_input = torch.randn(1, 3, 224, 224)
onnx_path = "model_ft.onnx"


torch.onnx.export(
    model_ft,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=13,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)


onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

class CalibrationDataset:
    def __init__(self, dataloader, max_samples=100):
        self.dataloader = dataloader
        self.max_samples = max_samples
        self.current_sample = 0
        self.batches = []
        
        for x, _ in tqdm(dataloader, total=min(max_samples, len(dataloader))):
            self.batches.append(x.numpy())
            self.current_sample += 1
            if self.current_sample >= self.max_samples:
                break
        self.current_sample = 0
        
    def get_batch(self):
        if self.current_sample >= len(self.batches):
            return None
        batch = self.batches[self.current_sample]
        self.current_sample += 1
        return batch

class Int8Calibrator(trt.IInt8EntropyCalibrator2):
    def __init__(self, dataset, cache_file="calibration.cache"):
        super().__init__()
        self.dataset = dataset
        self.cache_file = cache_file
        self.device_input = None
        
    def get_batch_size(self):
        return BATCH_SIZE
    
    def get_batch(self, names):
        batch = self.dataset.get_batch()
        if batch is None:
            return None
            
        if self.device_input is None:
            self.device_input = cuda.mem_alloc(batch.nbytes)
        
        cuda.memcpy_htod(self.device_input, batch)
        return [int(self.device_input)]
    
    def read_calibration_cache(self):
        if os.path.exists(self.cache_file):
            with open(self.cache_file, "rb") as f:
                return f.read()
        return None
        
    def write_calibration_cache(self, cache):
        with open(self.cache_file, "wb") as f:
            f.write(cache)

# 4. TensorRT 
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

def build_trt_engine_int8(onnx_file_path):
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    parser = trt.OnnxParser(network, TRT_LOGGER)
    
    with open(onnx_file_path, 'rb') as model:
        if not parser.parse(model.read()):
            print("ERROR: Failed to parse the ONNX file.")
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            return None
    
    
    config = builder.create_builder_config()
    config.max_workspace_size = 1 << 30  # 1GB
    config.set_flag(trt.BuilderFlag.INT8)
    

    calibration_dataset = CalibrationDataset(val_loader)
    calibrator = Int8Calibrator(calibration_dataset)
    config.int8_calibrator = calibrator
    

    profile = builder.create_optimization_profile()
    profile.set_shape("input", (1, 3, 224, 224), (BATCH_SIZE, 3, 224, 224), (BATCH_SIZE*2, 3, 224, 224))
    config.add_optimization_profile(profile)
    
    engine = builder.build_engine(network, config)
    
    with open("model_ft_trt_int8.engine", "wb") as f:
        f.write(engine.serialize())
    
    return engine

engine = build_trt_engine_int8(onnx_path)

import pycuda.driver as cuda
import pycuda.autoinit

def load_engine(engine_path):
    with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
        return runtime.deserialize_cuda_engine(f.read())

def allocate_buffers(engine, batch_size=1):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda.Stream()
    
    for binding in engine:
        dims = engine.get_binding_shape(binding)
        if dims[0] == -1: 
            dims[0] = batch_size
        size = trt.volume(dims) * engine.max_batch_size
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        
        host_mem = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        
        bindings.append(int(device_mem))
        
        if engine.binding_is_input(binding):
            inputs.append({"host": host_mem, "device": device_mem, "name": binding})
        else:
            outputs.append({"host": host_mem, "device": device_mem, "name": binding})
    
    return inputs, outputs, bindings, stream

def infer_trt(context, bindings, inputs, outputs, stream, batch_size=1):

    for inp in inputs:
        cuda.memcpy_htod_async(inp["device"], inp["host"], stream)
    
    context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)

    for out in outputs:
        cuda.memcpy_dtoh_async(out["host"], out["device"], stream)
    
    stream.synchronize()
    
    return [out["host"] for out in outputs]


engine = load_engine("model_ft_trt_int8.engine")
context = engine.create_execution_context()


inputs, outputs, bindings, stream = allocate_buffers(engine)

def measure_inference_time_trt(context, inputs, outputs, bindings, stream, runs=100):

    dummy_input = np.random.rand(1, 3, 224, 224).astype(np.float32)
    inputs[0]["host"] = np.ascontiguousarray(dummy_input)
    
    for _ in range(10):
        infer_trt(context, bindings, inputs, outputs, stream)
    
    start = time.time()
    for _ in range(runs):
        infer_trt(context, bindings, inputs, outputs, stream)
    end = time.time()
    
    return (end - start) * 1000 / runs  

@torch.no_grad()
def evaluate_trt(engine, val_loader, device):
    context = engine.create_execution_context()
    inputs, outputs, bindings, stream = allocate_buffers(engine)
    
    correct = 0
    total = 0
    
    for x, y in tqdm(val_loader, desc="评估TensorRT"):
        batch_size = x.shape[0]
        x_np = x.cpu().numpy()
        
        inputs[0]["host"] = np.ascontiguousarray(x_np)

        context.set_binding_shape(0, (batch_size, 3, 224, 224))
        
        output = infer_trt(context, bindings, inputs, outputs, stream, batch_size)
        
        logits = output[0].reshape(batch_size, NUM_CLASSES)
        pred = np.argmax(logits, axis=1)
        
        correct += (pred == y.cpu().numpy()).sum()
        total += batch_size
    
    return correct / total

time_qt = measure_inference_time_trt(context, inputs, outputs, bindings, stream)
acc_qt = evaluate_trt(engine, val_loader, device)
size_qt = os.path.getsize("model_ft_trt_int8.engine") / 1e6  # MB


In [None]:
print("\n=== Results Comparison ===")
print(f"{'Scheme':<40} {'Top1 Acc':>10}   {'Infer(ms)':>10}   {'Size(MB)':>8}")
print("-"*68)
print(f"{'1. Linear Probing':<40} {acc_lp*100:>9.2f}%   {time_lp:>10.2f}   {size_lp:>8.2f}")
print(f"{'2. Linear + Partial FT':<40} {acc_ft*100:>9.2f}%   {time_ft:>10.2f}   {size_ft:>8.2f}")
print(f"{'3. INT8 Quantization':<40} {acc_qt*100:>9.2f}%   {time_qt:>10.2f}   {size_qt:>8.2f}")