In [None]:
import torch
import torchvision.models as models
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from resnet_cl import SlimmableResNet34

resnet34 = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)

state_dict = resnet34.state_dict()
torch.save(state_dict, "resnet34_imagenet.pth")

In [None]:
def load_non_bn_weights(slimmable_model, resnet34_weights_path):
    """
    Loads weights excluding BatchNorm layers.
    """
    # Load pre-trained ResNet-34 weights
    resnet34_state_dict = torch.load(resnet34_weights_path)

    # Get Slimmable ResNet-34 state_dict
    slimmable_state_dict = slimmable_model.state_dict()

    # Filter out BN layers from ResNet-34
    filtered_state_dict = {}
    for name, param in resnet34_state_dict.items():
        if "bn" not in name:  # Ignore BatchNorm layers
            if name in slimmable_state_dict:  # Ensure key exists in the slimmable model
                filtered_state_dict[name] = param

    # Load the filtered weights into the Slimmable model
    slimmable_model.load_state_dict(filtered_state_dict, strict=False)

    print(f"Loaded {len(filtered_state_dict)} non-BN layers from ResNet-34.")
    
def compute_bn_statistics(model, data_loader, device, width_list):
    """
    1. Switches the model to each width.
    2. Resets BN statistics (running_mean, running_var, num_batches_tracked).
    3. Runs a forward pass over the data_loader in training mode to update BN stats.
    """
    model.to(device)
    model.eval()

    def reset_bn(module):
        if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
            module.running_mean = torch.zeros_like(module.running_mean)
            module.running_var = torch.ones_like(module.running_var)
            module.num_batches_tracked.zero_()

    for w in width_list:
        print(f"\nProcessing width={w}")

        # Switch model to the current width
        model.switch_to_width(w)

        # Reset BN statistics
        model.apply(reset_bn)

        # Run a forward pass in training mode (to update BN stats)
        model.train()
        with torch.no_grad():
            for inputs, _ in tqdm(data_loader, desc=f"Updating BN (Width={w})", leave=True):
                inputs = inputs.to(device)
                _ = model(inputs)

    # Set model back to evaluation mode
    model.eval()
    print("\nBN statistics updated for all widths.\n")

In [None]:
slimmable_resnet34 = SlimmableResNet34(num_classes=1000)

resnet34_weights_path = "resnet34_imagenet.pth"

# Load weights (excluding BN layers)
load_non_bn_weights(slimmable_resnet34, resnet34_weights_path)

Loaded 38 non-BN layers from ResNet-34.


  resnet34_state_dict = torch.load(resnet34_weights_path)


In [None]:
data_dir = "tiny_imgnet" # DELETED FROM REPO

tiny_imagenet_mean = [0.485, 0.456, 0.406]
tiny_imagenet_std = [0.229, 0.224, 0.225]

# Define data transformations
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(64),  # Resize and random crop (Tiny ImageNet is 64x64)
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=tiny_imagenet_mean, std=tiny_imagenet_std)
])

test_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=tiny_imagenet_mean, std=tiny_imagenet_std)
])

# Load train dataset
train_dataset = datasets.ImageFolder(root=f"{data_dir}/train", transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)

test_dataset = datasets.ImageFolder(root=f"{data_dir}/test", transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

# Check dataset size
print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Train samples: 100000
Test samples: 10000


In [7]:
from flags import FLAGS

device = torch.device("mps" if torch.cuda.is_available() else "cpu")

compute_bn_statistics(slimmable_resnet34, train_loader, device, FLAGS.width_mult_list)


[Recomputing BN] Processing width=0.25


Updating BN (Width=0.25): 100%|██████████| 1563/1563 [03:09<00:00,  8.23it/s]



[Recomputing BN] Processing width=0.5


Updating BN (Width=0.5): 100%|██████████| 1563/1563 [05:45<00:00,  4.52it/s]



[Recomputing BN] Processing width=0.75


Updating BN (Width=0.75): 100%|██████████| 1563/1563 [09:18<00:00,  2.80it/s]



[Recomputing BN] Processing width=1.0


Updating BN (Width=1.0): 100%|██████████| 1563/1563 [13:21<00:00,  1.95it/s]


✅ BN statistics updated for all widths.






In [None]:
def evaluate_model(model, data_loader, device):
    """
    Evaluates the model accuracy on the provided data loader with a progress bar.
    """
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():  # Disable gradient computation for speedup
        for inputs, labels in tqdm(data_loader, desc="Evaluating", leave=True):
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1) 

            # Update metrics
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"\nAccuracy: {accuracy:.2f}%")
    return accuracy

device = torch.device("mps" if torch.cuda.is_available() else "cpu")
resnet34.to(device)

evaluate_model(resnet34, test_loader, device)

Evaluating: 100%|██████████| 157/157 [01:41<00:00,  1.55it/s]


Accuracy: 0.01%





0.01