<a href="https://colab.research.google.com/github/Nderwoodfrank/Prediction-of-Sparse-Network/blob/main/code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/TimDettmers/sparse_learning.git
%cd sparse_learning
!pip install -r requirements.txt
!python setup.py install

In [None]:
!pip install fvcore

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='torch.optim.lr_scheduler')

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from matplotlib import pyplot as plt
import argparse
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.datasets import CIFAR10, CIFAR100, MNIST
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau
import sparselearning
from sparselearning.core import add_sparse_args, CosineDecay, Masking
from sparselearning.funcs import no_redistribution, magnitude_prune, random_growth
import time
import math
from math import ceil
from fvcore.nn import FlopCountAnalysis, parameter_count_table
import logging
import os
import sys

In [None]:
def setup_logger():
    logger = logging.getLogger()  # Get the root logger
    if not logger.handlers:  # Check if the logger already has handlers
        logger.setLevel(logging.INFO)

        # Create formatter
        formatter = logging.Formatter(fmt='%(asctime)s: %(message)s', datefmt='%H:%M:%S')

        # Create file handler which logs even debug messages
        log_path = './logs/training.log'
        if not os.path.exists('./logs'):
            os.mkdir('./logs')
        fh = logging.FileHandler(log_path)
        fh.setLevel(logging.INFO)
        fh.setFormatter(formatter)
        logger.addHandler(fh)

        # Create console handler with a higher log level
        ch = logging.StreamHandler()
        ch.setLevel(logging.INFO)
        ch.setFormatter(formatter)
        logger.addHandler(ch)

    return logger
logger = setup_logger()
log_interval = 100


In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def ResNet34(c=10):
    model=ResNet(BasicBlock, [3,4,6,3],c)
    return model

def ResNet50(c=100):
    model=ResNet(Bottleneck, [3,4,6,3],c)
    return model


In [None]:
VGG_CONFIGS = {
    # M for MaxPool, Number for channels
    'like': [
        64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
        512, 512, 512, 'M'
    ],
    'D': [
        64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
        512, 512, 512, 'M'
    ],
    'C': [
        64, 64, 'M', 128, 128, 'M', 256, 256, (1, 256), 'M', 512, 512, (1, 512), 'M',
        512, 512, (1, 512), 'M' # tuples indicate (kernel size, output channels)
    ]
}


class VGG16(nn.Module):

    def __init__(self, config, num_classes=10, save_features=False):
        super().__init__()

        self.features = self.make_layers(VGG_CONFIGS[config], batch_norm=True)
        self.feats = []
        self.densities = []
        self.save_features = save_features


        if config == 'C' or config == 'D':
            self.classifier = nn.Sequential(
                nn.Linear((512 if config == 'D' else 2048), 512),  # 512 * 7 * 7 in the original VGG
                nn.ReLU(True),
                nn.BatchNorm1d(512),  # instead of dropout
                nn.Linear(512, 512),
                nn.ReLU(True),
                nn.BatchNorm1d(512),  # instead of dropout
                nn.Linear(512, num_classes),
            )
        else:
            self.classifier = nn.Sequential(
                nn.Linear(512, 512),  # 512 * 7 * 7 in the original VGG
                nn.ReLU(True),
                nn.BatchNorm1d(512),  # instead of dropout
                nn.Linear(512, num_classes),
            )

    @staticmethod
    def make_layers(config, batch_norm=False):
        layers = []
        in_channels = 3
        for v in config:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                kernel_size = 3
                if isinstance(v, tuple):
                    kernel_size, v = v
                conv2d = nn.Conv2d(in_channels, v, kernel_size=kernel_size, padding=1)
                if batch_norm:
                    layers += [
                        conv2d,
                        nn.BatchNorm2d(v),
                        nn.ReLU(inplace=True)
                    ]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        for layer_id, layer in enumerate(self.features):
            x = layer(x)

            if self.save_features:
                if isinstance(layer, nn.ReLU):
                    self.feats.append(x.clone().detach())
                    self.densities.append((x.data != 0.0).sum().item()/x.numel())

        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        x = F.log_softmax(x, dim=1)
        return x


In [None]:
def get_transforms(dataset_name):
    if dataset_name == 'cifar10':
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010))
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4), mode='reflect').squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])

    return transform_train, transform_test

In [None]:
def load_dataset(dataset_name, batch_size):
    transform_train, transform_test = get_transforms(dataset_name)

    if dataset_name == 'cifar10':
        full_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
        train_size = int(0.8 * len(full_trainset))
        validate_size = len(full_trainset) - train_size
        trainset, validateset = random_split(full_trainset, [train_size, validate_size])

        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    validate_loader = DataLoader(validateset, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    return train_loader, validate_loader, test_loader

In [None]:
class SelectivePruning:
    def __init__(self, model, optimizer, density, prune_rate, sparsity_ratio):
        self.model = model
        self.optimizer = optimizer
        self.density = density
        self.prune_rate = prune_rate
        self.sparsity_ratio = sparsity_ratio
        self.logger = logging.getLogger("SelectivePruning")

        self.masking = Masking(optimizer, CosineDecay(prune_rate, 100), prune_rate, growth_mode='random', prune_mode='magnitude', redistribution_mode='none', verbose=True)
        self.masking.add_module(model, density)

        self.selected_block_name = self.select_block_for_pruning()
        self.logger.info(f"Selected block for focused pruning and growth: {self.selected_block_name}")

    def apply_masks(self):
        """Apply masks to enforce sparsity after each optimizer step."""
        self.masking.apply_mask()

    def calculate_total_magnitude(self, module):
        """Calculate the total magnitude of weights in a module."""
        return sum(p.data.abs().sum().item() for p in module.parameters() if p.requires_grad)

    def select_block_for_pruning(self):
        """Select a block that is neither a shortcut in ResNet nor the first/last layer generally."""
        block_magnitudes = {}
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)) and "shortcut" not in name:
               if not name.startswith(('conv1', 'bn1')) and not name.endswith(('conv3', 'bn3')):
                  magnitude = self.calculate_total_magnitude(module)
                  block_magnitudes[name] = magnitude

        selected_block = max(block_magnitudes, key=block_magnitudes.get, default=None)
        self.logger.info(f"Eligible blocks for pruning: {list(block_magnitudes.keys())}")
        return selected_block

    def prune_and_regrow(self):
        """Apply pruning and regrowth only to the selected block."""
        for name, module in self.model.named_modules():
            if name == self.selected_block_name:
                self.masking.prune_rate = self.prune_rate
                self.masking.step()
                self.masking.apply_mask()

    def share_weights(self, enable_sharing=True):
        """Share only the sparsity pattern from the selected block to matching layers."""
        if not enable_sharing:
            self.logger.info("Weight sharing is disabled.")
            return

        source_block = next((module for name, module in self.model.named_modules() if name == self.selected_block_name), None)
        if source_block is None or not hasattr(source_block, 'weight'):
            self.logger.error(f"No source block with weights found for {self.selected_block_name}.")
            return

        source_mask = source_block.weight.data != 0

        shared = False
        for target_name, target_module in self.model.named_modules():
            if hasattr(target_module, 'weight') and target_name != self.selected_block_name and 'shortcut' not in target_name:
                if source_block.weight.size() == target_module.weight.size():
                    # Apply the sparsity pattern
                    target_module.weight.data *= source_mask.float()  # Apply mask
                    shared = True
                    self.logger.info(f"Sparsity pattern shared from {self.selected_block_name} to {target_name}.")

        if not shared:
            self.logger.info("No patterns were shared due to dimensional mismatch or other conditions.")

    def _share_sparsity_pattern(self, source_module, target_module):
        """Helper function to share the sparsity pattern from source to target module."""
        sparsity_mask = source_module.weight.data != 0
        target_module.weight.data *= sparsity_mask.float()

    def log_non_zero_weights(self, phase):
        """Log the percentage of non-zero weights for Conv2d layers after training."""
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                total_weights = module.weight.data.numel()
                non_zero_weights = torch.count_nonzero(module.weight.data).item()
                percentage_non_zero = non_zero_weights / total_weights * 100
                self.logger.info(f"{phase} - {name}: {percentage_non_zero:.2f}% non-zero weights.")


In [None]:
def train_epoch(model, train_loader, optimizer, device, epoch, selective_pruning):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        selective_pruning.apply_masks()  # Apply sparsity masks globally

        total_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        if batch_idx % log_interval == 0:
            print(f'Train Epoch: {epoch+1} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

    train_loss = total_loss / len(train_loader)
    train_accuracy = 100. * correct / total
    return train_loss, train_accuracy

def evaluate(model, test_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    logging.info(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.0f}%)')
    return accuracy

def validate(model, validate_loader, device):
    model.eval()
    validation_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in validate_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = F.cross_entropy(output, target)
            validation_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

    validation_loss /= len(validate_loader)
    validation_accuracy = 100. * correct / total
    return validation_loss, validation_accuracy

In [None]:
def select_model(model_name, num_classes):
    if model_name.startswith('resnet'):
        if model_name == 'resnet34':
            return ResNet34(c=num_classes)
        elif model_name == 'resnet50':
            return ResNet50(c=num_classes)
    elif model_name == 'vgg16':
        return VGG16(config='D', num_classes=num_classes)
    else:
        raise ValueError("Unknown model type")

In [None]:
def running_in_notebook():
    if 'ipykernel' in sys.modules:
        return True
    if any('SPYDER' in name for name in os.environ):
        return True
    return False


In [None]:
def parse_args():
    if running_in_notebook():
        print("Running in a Jupyter notebook or IPython environment.")
        class Args:
            epochs = 50
            batch_size = 128
            density = 0.5
            prune_rate = 0.5
            sparsity_ratio = 0.5
            dataset = 'cifar10'
            model = 'resnet34'
            lr = 0.1  # Default learning rate
            l2 = 5.0e-4  # Default L2 regularization strength
            disable_weight_sharing = False
        return Args()
    else:
        parser = argparse.ArgumentParser(description='Train a model with options.')
        parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.')
        parser.add_argument('--batch_size', type=int, default=128, help='Training batch size.')
        parser.add_argument('--density', type=float, default=0.5, help='Density of the network connections.')
        parser.add_argument('--prune_rate', type=float, default=0.5, help='Pruning rate for the network.')
        parser.add_argument('--sparsity_ratio', type=float, default=0.1, help='Sparsity ratio used in selective pruning.')
        parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10'], help='Dataset to use (cifar10)')
        parser.add_argument('--model', type=str, default='resnet34', choices=[ 'resnet34', 'resnet50', 'vgg16'], help='Model to use (resnet18, resnet34, resnet50, vgg16)')
        parser.add_argument('--lr', type=float, default=0.1, help='Learning rate for the optimizer.')
        parser.add_argument('--l2', type=float, default=5.0e-4, help='L2 weight decay for regularization.')
        parser.add_argument('--disable_weight_sharing', action='store_true', help='Disable the weight sharing feature.')
        return parser.parse_args()

In [None]:
def main():
    args = parse_args()
    print(f"Training with settings: epochs={args.epochs}, batch_size={args.batch_size}, density={args.density}, "
          f"prune_rate={args.prune_rate}, sparsity_ratio={args.sparsity_ratio}, dataset={args.dataset}, disable_weight_sharing={args.disable_weight_sharing}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = {'cifar10': 10}.get(args.dataset, 10)
    model = select_model(args.model, num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2)
    scheduler = StepLR(optimizer, step_size=15, gamma=0.1)
    train_loader, validate_loader, test_loader = load_dataset(args.dataset, args.batch_size)

    selective_pruning = SelectivePruning(model, optimizer, args.density, args.prune_rate, args.sparsity_ratio)
    selective_pruning.log_non_zero_weights('Before Training')
    final_accuracy = 0
    start_time = time.time()

    for epoch in range(args.epochs):
        train_loss, train_accuracy = train_epoch(model, train_loader, optimizer, device, epoch, selective_pruning)
        validation_loss, validation_accuracy = validate(model, validate_loader, device)
        scheduler.step()

        selective_pruning.prune_and_regrow()
        # Log training and validation statistics for each epoch
        logger.info(f'Epoch {epoch+1}: Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%')
        logger.info(f'Epoch {epoch+1}: Validation Loss: {validation_loss:.4f}, Validation Accuracy: {validation_accuracy:.2f}%')

        if not args.disable_weight_sharing:
            selective_pruning.share_weights(True)
        else:
            selective_pruning.share_weights(False)
        selective_pruning.log_non_zero_weights('After Training')

        final_accuracy = evaluate(model, test_loader, device)
        logger.info(f'Final Test Accuracy: {final_accuracy:.2f}%')
        elapsed_time = time.time() - start_time  # End timing
    print(f"Total training time: {elapsed_time:.2f} seconds")




if __name__ == "__main__":
    main()