In [1]:
from __future__ import absolute_import
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

In [2]:
# Check if CUDA is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available. Using GPU.")
else:
    device = torch.device("cpu")
    print("CUDA is not available. Using CPU.")

CUDA is available. Using GPU.


In [3]:
__all__ = ['preresnet']

def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

In [4]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        return out

In [5]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        return out

In [6]:
class PreResNet(nn.Module):

    def __init__(self, depth, num_classes=10):
        super(PreResNet, self).__init__()
        # Model type specifies number of layers for CIFAR-10 model
        assert (depth - 2) % 6 == 0, 'depth should be 6n+2'
        n = (depth - 2) // 9

        block = Bottleneck if depth >=44 else BasicBlock

        self.inplanes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
                               bias=False)
        self.layer1 = self._make_layer(block, 16, n)
        self.layer2 = self._make_layer(block, 32, n, stride=2)
        self.layer3 = self._make_layer(block, 64, n, stride=2)
        self.bn = nn.BatchNorm2d(64 * block.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)

        x = self.layer1(x)  # 32x32
        x = self.layer2(x)  # 16x16
        x = self.layer3(x)  # 8x8
        x = self.bn(x)
        x = self.relu(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

In [7]:
def preresnet(**kwargs):
    """
    Constructs a ResNet model.
    """
    return PreResNet(**kwargs)

In [8]:
def train(model, train_loader, criterion, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f'Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')
        # print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')


In [9]:
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Accuracy of the network on the test images: {accuracy} %')

In [10]:
def prune_weights(model, pruning_ratio):
    # Get all the weights in the model and compute the threshold.
    all_weights = []
    for name, param in model.named_parameters():
        if 'weight' in name:
            all_weights += list(param.to(device).data.abs().numpy().flatten())
    threshold = torch.tensor([all_weights[int(len(all_weights) * pruning_ratio)]])

    # Prune the weights below the threshold.
    for name, param in model.named_parameters():
        if 'weight' in name:
            mask = param.abs() > threshold
            param.data *= mask

    return model

In [11]:

# Load and preprocess CIFAR-10 data...
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
# Initialize the model, loss function, and optimizer...
model = PreResNet(depth=20).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

In [13]:
# Train the model...
epochs = 20  # Set the number of epochs to a suitable value
train(model, train_loader, criterion, optimizer, epochs)
test(model, test_loader)

Epoch 1, Loss: 1.6370, Accuracy: 40.33%
Epoch 2, Loss: 1.1633, Accuracy: 58.74%
Epoch 3, Loss: 0.9512, Accuracy: 66.73%
Epoch 4, Loss: 0.8154, Accuracy: 71.60%
Epoch 5, Loss: 0.7261, Accuracy: 74.88%
Epoch 6, Loss: 0.6625, Accuracy: 77.00%
Epoch 7, Loss: 0.6094, Accuracy: 79.02%
Epoch 8, Loss: 0.5684, Accuracy: 80.34%
Epoch 9, Loss: 0.5276, Accuracy: 81.70%
Epoch 10, Loss: 0.4985, Accuracy: 82.68%
Epoch 11, Loss: 0.4709, Accuracy: 83.67%
Epoch 12, Loss: 0.4402, Accuracy: 84.68%
Epoch 13, Loss: 0.4191, Accuracy: 85.30%
Epoch 14, Loss: 0.3937, Accuracy: 86.22%
Epoch 15, Loss: 0.3778, Accuracy: 86.90%
Epoch 16, Loss: 0.3603, Accuracy: 87.39%
Epoch 17, Loss: 0.3396, Accuracy: 88.00%
Epoch 18, Loss: 0.3235, Accuracy: 88.65%
Epoch 19, Loss: 0.3054, Accuracy: 89.41%
Epoch 20, Loss: 0.2909, Accuracy: 89.74%
Accuracy of the network on the test images: 81.9 %


In [None]:
# Prune and retrain the model...
prune_ratio = 0.5
pruned_model = prune_weights(model, prune_ratio)
print(f'Pruning ratio: {prune_ratio}')
train(pruned_model, train_loader, criterion, optimizer, epochs)
test(pruned_model, test_loader)

In [None]:
prune_ratio = 0.7
pruned_model = prune_weights(model, prune_ratio)
print(f'Pruning ratio: {prune_ratio}')
train(pruned_model, train_loader, criterion, optimizer, epochs)
test(pruned_model, test_loader)

In [None]:
prune_ratio = 0.9
pruned_model = prune_weights(model, prune_ratio)
print(f'Pruning ratio: {prune_ratio}')
train(pruned_model, train_loader, criterion, optimizer, epochs)
test(pruned_model, test_loader)