In [1]:
from torchvision import datasets, transforms
import torch

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

In [2]:
mnsit = datasets.MNIST(root="./data", train=True, download=True, transform=transform)

In [3]:
# Divide into train, test and validation
train_size = int(0.8 * len(mnsit))
test_size = len(mnsit) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(
    mnsit, [train_size, test_size]
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

In [4]:
class CNNModel(torch.nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = torch.nn.Conv2d(16, 32, 3, padding=1)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.fc1 = torch.nn.Linear(32 * 7 * 7, 128)
        self.fc2 = torch.nn.Linear(128, 10)
        self.relu = torch.nn.ReLU()
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7)
        x = self.relu(self.fc1(x))
        x = self.softmax(self.fc2(x))
        return x

In [5]:
model = CNNModel()
model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[ 0.2263,  0.0876, -0.0450],
                        [-0.3229,  0.0884, -0.1193],
                        [ 0.2156, -0.1474, -0.2008]]],
              
              
                      [[[-0.2970,  0.1203,  0.2262],
                        [ 0.2009,  0.2276, -0.0481],
                        [-0.2551,  0.0271, -0.2636]]],
              
              
                      [[[ 0.1325, -0.3077,  0.0899],
                        [ 0.3202, -0.0005,  0.1568],
                        [ 0.3011, -0.0493,  0.1492]]],
              
              
                      [[[ 0.2841,  0.1373, -0.2483],
                        [ 0.0939, -0.1654, -0.0769],
                        [ 0.2735,  0.3103, -0.0568]]],
              
              
                      [[[ 0.2746, -0.2575,  0.1935],
                        [-0.3154,  0.1762,  0.1813],
                        [ 0.0395, -0.2632,  0.2744]]],
              
              
               

In [6]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

In [10]:
def train(model, train_loader, loss_fn, optimizer, epochs=5):
    for epoch in range(epochs):
        for i, (images, labels) in enumerate(train_loader):
            model.train()
            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            if i % 100 == 0:
                print(f"Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}")
                model.eval()
                correct = 0
                total = 0
                for images, labels in test_loader:
                    outputs = model(images)
                    _, predicted = torch.max(outputs, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                accuracy = correct / total
                print(f"Accuracy: {accuracy}")

In [11]:
ob = train(model, train_loader, loss_fn, optimizer, epochs=5)

torch.save(model.state_dict(), "model.pth")

Epoch: 0, Batch: 0, Loss: 2.3035407066345215
Accuracy: 0.094
Epoch: 0, Batch: 100, Loss: 1.8159127235412598
Accuracy: 0.6686666666666666
Epoch: 0, Batch: 200, Loss: 1.6802787780761719
Accuracy: 0.7594166666666666
Epoch: 0, Batch: 300, Loss: 1.6002156734466553
Accuracy: 0.8385833333333333
Epoch: 0, Batch: 400, Loss: 1.4762157201766968
Accuracy: 0.953
Epoch: 0, Batch: 500, Loss: 1.4881139993667603
Accuracy: 0.9644166666666667
Epoch: 0, Batch: 600, Loss: 1.5415012836456299
Accuracy: 0.9675833333333334
Epoch: 0, Batch: 700, Loss: 1.4800060987472534
Accuracy: 0.9675
Epoch: 1, Batch: 0, Loss: 1.5271815061569214
Accuracy: 0.9703333333333334
Epoch: 1, Batch: 100, Loss: 1.5184838771820068
Accuracy: 0.9739166666666667
Epoch: 1, Batch: 200, Loss: 1.479993462562561
Accuracy: 0.96675
Epoch: 1, Batch: 300, Loss: 1.4832701683044434
Accuracy: 0.9744166666666667
Epoch: 1, Batch: 400, Loss: 1.4967296123504639
Accuracy: 0.9765
Epoch: 1, Batch: 500, Loss: 1.508124589920044
Accuracy: 0.9730833333333333
Epo

In [19]:
model.state_dict()["conv1.weight"].shape
print("Num of weights in the first layer: ", model.state_dict()["conv1.weight"].numel())

Num of weights in the first layer:  144


In [60]:
def generate_population(pruning_percentage, model, population_size):
    population = []
    for i in range(population_size):
        num_weights = model.state_dict()["conv1.weight"].numel()
        num_pruned = int(num_weights * pruning_percentage)
        mask = torch.ones_like(model.state_dict()["conv1.weight"])

        random_indices = torch.randint(0, num_weights, (num_pruned,))
        # print(random_indices, i)
        mask.view(-1)[random_indices] = 0

        population.append(mask)
    return population

In [40]:
model = CNNModel()
model.load_state_dict(torch.load("model.pth"))
model.state_dict()

  model.load_state_dict(torch.load("model.pth"))


OrderedDict([('conv1.weight',
              tensor([[[[ 0.6052,  0.3859,  0.1449],
                        [-0.0127,  0.0975, -0.2710],
                        [ 0.1488, -0.3524, -0.4971]]],
              
              
                      [[[-0.3834,  0.1881,  0.4207],
                        [ 0.4898,  0.3829, -0.0098],
                        [-0.2248, -0.2529, -0.5597]]],
              
              
                      [[[ 0.0068, -0.7278, -0.2614],
                        [ 0.4699,  0.1245,  0.2678],
                        [ 0.4007,  0.1198,  0.1814]]],
              
              
                      [[[ 0.4538,  0.1196, -0.5974],
                        [ 0.3765, -0.1184, -0.4085],
                        [ 0.3901,  0.4503, -0.1015]]],
              
              
                      [[[ 0.1896, -0.2394,  0.3217],
                        [-0.4511,  0.0255,  0.3709],
                        [-0.0968, -0.3693,  0.5028]]],
              
              
               

In [77]:
def sparsity(mask):
    num_zeros = torch.sum(mask == 0).item()
    total = mask.numel()
    return num_zeros / total

In [61]:
model_cpy = CNNModel()
population = generate_population(0.4, model_cpy, 5)

In [80]:
def evaluate_candidate(model, mask):
    model_cpy = CNNModel()
    model_cpy.load_state_dict(model.state_dict())
    model_cpy.conv1.weight.data *= mask

    # print(model_cpy.state_dict())

    print("Sparsity: ", sparsity(mask))
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model_cpy(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return accuracy

In [81]:
for mask in generate_population(0.4, model, 5):
    print(evaluate_candidate(model, mask))

Sparsity:  0.3194444444444444
0.9716666666666667
Sparsity:  0.3333333333333333
0.98075
Sparsity:  0.3472222222222222
0.9788333333333333
Sparsity:  0.3333333333333333
0.98225
Sparsity:  0.3055555555555556
0.9593333333333334


In [63]:
def mutation(mask, mutation_rate):
    # Flip random no of ones with zeroes
    random_indices = torch.randint(0, len(mask), (int(mutation_rate * len(mask)),))
    mask[random_indices] = 1 - mask[random_indices]
    return mask


def crossover(mask1, mask2):
    # Half genes from each parent
    new_mask = torch.zeros_like(mask1)
    new_mask[: len(mask1) // 2] = mask1[: len(mask1) // 2]
    new_mask[len(mask1) // 2 :] = mask2[len(mask1) // 2 :]

    return new_mask

In [66]:
population = generate_population(0.4, model, 5)
# mutation(population[0].clone(), 0.1)
# crossover(population[0].clone(), population[1].clone())

tensor([[[[0., 1., 1.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [0., 0., 1.]]],


        [[[0., 0., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [1., 0., 0.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 0., 1.],
          [0., 0., 0.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [1., 1., 1.],
          [1., 0., 0.]]],


        [[[1., 1., 1.],
          [0., 1., 1.],
          [1., 1., 0.]]],


        [[[1., 1., 1.],
          [1., 1., 0.],
          [1., 0., 1.]]],


        [[[0., 1., 0.],
          [0., 1., 0.],
          [0., 1., 0.]]],


        [[[1., 1., 0.],
          [1., 1., 1.],
          [1., 0., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [1., 0., 0.],
          [0., 1., 0.]]],


        [[[1

In [70]:
def apply_genetic(
    model, mutation_rate, pruning_percentage, population_size, generations
):
    population = generate_population(pruning_percentage, model, population_size)
    for generation in range(generations):
        new_population = []
        for i in range(population_size):
            mask1 = population[i]
            mask2 = population[(i + 1) % population_size]
            new_mask = crossover(mask1, mask2)
            new_mask = mutation(new_mask, mutation_rate)
            accuracy = evaluate_candidate(model, new_mask)
            new_population.append((new_mask, accuracy))
        new_population = sorted(new_population, key=lambda x: x[1], reverse=True)
        population = [x[0] for x in new_population]
        print(f"Generation: {generation}, Best accuracy: {new_population[0][1]}")
    return population

In [82]:
population = apply_genetic(model, 0.1, 0.4, 5, 5)

Sparsity:  0.3888888888888889
Sparsity:  0.3472222222222222
Sparsity:  0.2777777777777778
Sparsity:  0.4166666666666667
Sparsity:  0.3333333333333333
Generation: 0, Best accuracy: 0.9830833333333333
Sparsity:  0.3402777777777778
Sparsity:  0.3402777777777778
Sparsity:  0.3055555555555556
Sparsity:  0.3958333333333333
Sparsity:  0.4166666666666667
Generation: 1, Best accuracy: 0.9826666666666667
Sparsity:  0.3472222222222222
Sparsity:  0.3680555555555556
Sparsity:  0.3194444444444444
Sparsity:  0.4166666666666667
Sparsity:  0.4375
Generation: 2, Best accuracy: 0.982
Sparsity:  0.3888888888888889
Sparsity:  0.3611111111111111
Sparsity:  0.4444444444444444
Sparsity:  0.3888888888888889
Sparsity:  0.4652777777777778
Generation: 3, Best accuracy: 0.9794166666666667
Sparsity:  0.4166666666666667
Sparsity:  0.3958333333333333
Sparsity:  0.3888888888888889
Sparsity:  0.4375
Sparsity:  0.4305555555555556
Generation: 4, Best accuracy: 0.9685833333333334


In [83]:
population

[tensor([[[[1., 1., 0.],
           [1., 1., 0.],
           [1., 1., 1.]]],
 
 
         [[[0., 0., 1.],
           [1., 1., 0.],
           [1., 1., 0.]]],
 
 
         [[[1., 0., 1.],
           [1., 1., 0.],
           [1., 1., 0.]]],
 
 
         [[[1., 1., 1.],
           [0., 1., 1.],
           [1., 1., 0.]]],
 
 
         [[[1., 0., 0.],
           [1., 1., 0.],
           [1., 0., 0.]]],
 
 
         [[[0., 0., 0.],
           [0., 1., 0.],
           [0., 1., 0.]]],
 
 
         [[[0., 1., 0.],
           [1., 1., 1.],
           [1., 0., 1.]]],
 
 
         [[[1., 0., 0.],
           [1., 0., 0.],
           [0., 0., 0.]]],
 
 
         [[[0., 0., 1.],
           [0., 1., 0.],
           [0., 1., 0.]]],
 
 
         [[[0., 1., 1.],
           [0., 1., 0.],
           [0., 0., 1.]]],
 
 
         [[[0., 1., 1.],
           [1., 0., 0.],
           [1., 0., 0.]]],
 
 
         [[[0., 1., 0.],
           [1., 1., 0.],
           [1., 1., 1.]]],
 
 
         [[[1., 1., 1.],
   

In [84]:
pruned_model = CNNModel()
pruned_model.load_state_dict(model.state_dict())

pruned_model.conv1.weight.data *= population[0]

torch.save(pruned_model.state_dict(), "pruned_model.pth")