In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time
from resnet_cl import SlimmableResNet18 as resnet18

# Device configuration
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
# Hyperparameters
num_epochs = 10
batch_size = 128
learning_rate = 0.001

# Data transformations (including augmentation for training)
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform_train, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform_test, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Load ResNet-18 model
model = resnet18(num_classes=10)  # Set num_classes to match CIFAR-10
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

Files already downloaded and verified
Files already downloaded and verified


In [10]:
# Training loop
def train():
    best_acc = 0.0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        start_time = time.time()
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        train_acc = 100. * correct / total
        epoch_time = time.time() - start_time
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}%, Time: {epoch_time:.2f}s')

        # Validate after each epoch
        test_acc = validate()

        # Save best model
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'best_resnet18_cifar10.pth')
            print("Model saved with accuracy: {:.2f}%".format(test_acc))

# Validation function
def validate():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    acc = 100. * correct / total
    print(f'Validation Accuracy: {acc:.2f}%')
    return acc

if __name__ == '__main__':
    train()

Epoch [1/10], Loss: 1.5500, Train Acc: 43.50%, Time: 25.16s
Validation Accuracy: 48.76%
Model saved with accuracy: 48.76%
Epoch [2/10], Loss: 1.1847, Train Acc: 57.80%, Time: 21.02s
Validation Accuracy: 62.93%
Model saved with accuracy: 62.93%
Epoch [3/10], Loss: 1.0199, Train Acc: 64.02%, Time: 20.40s
Validation Accuracy: 67.13%
Model saved with accuracy: 67.13%
Epoch [4/10], Loss: 0.9091, Train Acc: 68.01%, Time: 21.61s
Validation Accuracy: 68.93%
Model saved with accuracy: 68.93%
Epoch [5/10], Loss: 0.8361, Train Acc: 70.60%, Time: 23.58s
Validation Accuracy: 72.98%
Model saved with accuracy: 72.98%
Epoch [6/10], Loss: 0.7718, Train Acc: 73.06%, Time: 20.94s
Validation Accuracy: 75.16%
Model saved with accuracy: 75.16%
Epoch [7/10], Loss: 0.7253, Train Acc: 74.73%, Time: 20.57s
Validation Accuracy: 75.21%
Model saved with accuracy: 75.21%
Epoch [8/10], Loss: 0.6883, Train Acc: 75.85%, Time: 22.61s
Validation Accuracy: 77.20%
Model saved with accuracy: 77.20%
Epoch [9/10], Loss: 0.65

In [18]:
from tqdm import tqdm
from flags import FLAGS

def evaluate_model(model, data_loader, device):
    for width in FLAGS.width_mult_list:
        print(f"\n=== Testing width multiplier: {width} ===")
        # Switch to desired width
        model.switch_to_width(width)
        model.eval()  # Set model to evaluation mode
        correct = 0
        total = 0

        with torch.no_grad():  # Disable gradient computation for speedup
            for inputs, labels in tqdm(data_loader, desc="Evaluating", leave=True):
                inputs, labels = inputs.to(device), labels.to(device)

                # Forward pass
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)  # Get class with highest probability

                # Update metrics
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        print(f"\nAccuracy: {accuracy:.2f}%")

# Run evaluation
evaluate_model(model, test_loader, device)


=== Testing width multiplier: 0.25 ===


Evaluating: 100%|██████████| 79/79 [00:02<00:00, 33.94it/s]



Accuracy: 10.87%

=== Testing width multiplier: 0.5 ===


Evaluating: 100%|██████████| 79/79 [00:02<00:00, 26.39it/s]



Accuracy: 11.08%

=== Testing width multiplier: 0.75 ===


Evaluating: 100%|██████████| 79/79 [00:03<00:00, 24.74it/s]



Accuracy: 12.06%

=== Testing width multiplier: 1.0 ===


Evaluating: 100%|██████████| 79/79 [00:02<00:00, 33.87it/s]


Accuracy: 77.65%



