In [10]:
import torch
import torchvision
import torch_pruning as tp
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader

In [11]:
import torch
import torchvision
import torch_pruning as tp
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
from torch.nn import Module
from torch.optim import Optimizer
from torch.nn import CrossEntropyLoss


def load_data(batch_size: int = 128) -> tuple[DataLoader, DataLoader]:
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )

    trainset = datasets.CIFAR10(
        root="./data", train=True, download=True, transform=transform
    )
    trainloader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=2
    )

    testset = datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform
    )
    testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

    return trainloader, testloader


def initialize_model(device: torch.device) -> Module:
    weights = models.ResNet18_Weights.DEFAULT
    model = models.resnet18(weights=weights).to(device)
    return model


def build_dependency_graph(
    model: Module, example_inputs: torch.Tensor
) -> tp.DependencyGraph:
    DG = tp.DependencyGraph()
    DG.build_dependency(model, example_inputs=example_inputs)
    return DG


def prune_model_with_depgraph(
    DG: tp.DependencyGraph, model: Module, pruning_factor: float, device: torch.device
) -> None:
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune_idxs = list(range(0, int(module.out_channels * pruning_factor)))
            group = DG.get_pruning_group(
                module, tp.prune_conv_out_channels, idxs=prune_idxs
            )
            if DG.check_pruning_group(group):
                group.prune()


def train_and_prune(
    model: Module,
    DG: tp.DependencyGraph,
    trainloader: DataLoader,
    criterion: CrossEntropyLoss,
    optimizer: Optimizer,
    device: torch.device,
    iterative_steps: int,
    pruning_factor: float,
) -> None:
    for step in range(iterative_steps):
        model.train()
        for inputs, targets in trainloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        prune_model_with_depgraph(DG, model, pruning_factor, device)

        macs, params = tp.utils.count_ops_and_params(model, sample_inputs)
        print(f"Step {step+1}: MACs = {macs}, Params = {params}")


def evaluate(model: Module, dataloader: DataLoader, device: torch.device) -> float:
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return 100.0 * correct / total


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    trainloader, testloader = load_data()
    model = initialize_model(device)

    # Take some data from the training set as example inputs
    sample_inputs, _ = next(iter(trainloader))
    sample_inputs = sample_inputs.to(device)

    DG = build_dependency_graph(model, sample_inputs)

    criterion = CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    train_and_prune(
        model,
        DG,
        trainloader,
        criterion,
        optimizer,
        device,
        iterative_steps=5,
        pruning_factor=0.2,
    )

    accuracy = evaluate(model, testloader, device)
    print(f"Pruned Model Accuracy: {accuracy}%")

Files already downloaded and verified
Files already downloaded and verified
Step 1: MACs = 16123732.8125, Params = 4855110
Step 2: MACs = 7124367.8125, Params = 2039184
Step 3: MACs = 3282331.8125, Params = 877649
Step 4: MACs = 1617595.8125, Params = 385111
Step 5: MACs = 795440.8125, Params = 175252
Pruned Model Accuracy: 9.41%
