In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from torchviz import make_dot

%matplotlib inline

In [2]:
# Download and preprocess the MNIST dataset
mnist_dataset = torchvision.datasets.MNIST(root='./data', download=True)
transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
mnist_dataset = torchvision.datasets.MNIST(root='./data', transform=transforms, download=True)

# Then, create a data loader that will generate minibatches of the preprocessed MNIST dataset
mnist_data_loader = torch.utils.data.DataLoader(dataset=mnist_dataset,
                                                batch_size=32,
                                                shuffle=True)

In [3]:
target_layer_index = 99
num_layers = 2
learning_rate = .001
hidden_size = 10


# Define the model
class ResNet(nn.Module):
    def __init__(self, num_layers, hidden_size):
        super().__init__()
        # Define the layers of the model
        self.fc1 = nn.Linear(in_features=28*28, out_features=10)
        self.residual_layers = nn.ModuleList([nn.Linear(in_features=hidden_size, out_features=hidden_size) for _ in range(num_layers)])

    def forward(self, x):
        # Flatten the input tensor
        x = x.view(x.shape[0], -1)
        # Apply the first fully-connected layer
        x = self.fc1(x)
        # Apply the residual layers
        for i, layer in enumerate(self.residual_layers):
            # Skip the specified layer if its index matches the target index
            if i == target_layer_index:
                continue
        #
            x = F.relu(layer(x) + x)
        # Apply the log_softmax activation function
        return F.log_softmax(x, dim=1)
    
    def add_layer(self):
        new_layer = nn.Linear(in_features=10, out_features=10)
        new_layer.reset_parameters()
        self.residual_layers.append(new_layer)
        
        
    def remove_layer(self, index):
        # Check if the index is valid
        if index >= 0 and index < len(self.residual_layers):
        # Remove the layer at the specified index
            del self.residual_layers[index]
        else:
            print("Invalid layer index")
            
    def prune(self, threshold):
      # Create a list to store the layers that need to be pruned
      pruned_layers = []

      # Loop over the layers in the model
      for ind, layer in enumerate(self.residual_layers):
        # Compute the L2 norm of the weight tensor of the current layer
        weight_norm = torch.norm(layer.weight)

        # If the L2 norm is less than the specified threshold,
        # add the layer to the list of pruned layers
        if weight_norm < threshold:
            remove_layer(ind)
            pruned_layers.append(ind)

      # Return the list of pruned layers
      return pruned_layers
        
            

In [10]:

in_index = 0
out_index = 1


# Define the model
class WideNet(nn.Module):
    def __init__(self, num_layers, hidden_size):
        super().__init__()
        # Define the layers of the model
        self.fc1 = nn.Linear(in_features=28*28, out_features=10)
        self.layers = nn.ModuleList([nn.Linear(in_features=hidden_size, out_features=hidden_size) for _ in range(num_layers)])
        self.top_layers = None
        
    def forward(self, x):
        # Flatten the input tensor
        x = x.view(x.shape[0], -1)
        # Apply the first fully-connected layer
        x = self.fc1(x)
        # Apply the residual layers
        
    
        for i, layer in enumerate(self.layers):
            # Skip the specified layer if its index matches the target index
          #  if (i == in_index) and (self.top_layers is not None):
          #      x_ = x
          #      for j, wlayer in enumerate(self.top_layers):
          #          x_ = F.relu(wlayer(x_))
          #  else:
          #      x_ = torch.zeros_like(x)
          #      
          #  if i == out_index:
          #      x = x + x_
                    
        
            x = F.relu(layer(x))
        # Apply the log_softmax activation function
        return F.log_softmax(x, dim=1)
    
    def add_layer(self):
        new_layer = nn.Linear(in_features=10, out_features=10)
        new_layer.reset_parameters()
        self.layers.append(new_layer)
      

    def add_width(self, num_layers):
        # Create a new list of layers
        new_layers = nn.ModuleList([nn.Linear(in_features=hidden_size, out_features=hidden_size) for _ in range(num_layers)])
        # Loop over the layers in the new list and add them to the existing list of residual layers
        
        self.top_layers = new_layers
        
        #for layer in new_layers:
        #    self.layers.extend(layer)      
        
    def remove_layer(self, index):
        # Check if the index is valid
        if index >= 0 and index < len(self.residual_layers):
        # Remove the layer at the specified index
            del self.residual_layers[index]
        else:
            print("Invalid layer index")
            
    def prune(self, threshold):
      # Create a list to store the layers that need to be pruned
      pruned_layers = []

      # Loop over the layers in the model
      for ind, layer in enumerate(self.layers):
        # Compute the L2 norm of the weight tensor of the current layer
        weight_norm = torch.norm(layer.weight)

        # If the L2 norm is less than the specified threshold,
        # add the layer to the list of pruned layers
        if weight_norm < threshold:
            remove_layer(ind)
            pruned_layers.append(ind)

      # Return the list of pruned layers
      return pruned_layers


In [11]:
# Define the loss function and the optimizer
model = WideNet(num_layers, hidden_size)

criterion = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0)

# Define the accuracy function
def accuracy(outputs, targets):
    # Compare the model's outputs to the targets
    correct = (outputs == targets).sum()*100
    # Return the accuracy
    return (correct).float()

In [12]:


# Train the model
for epoch in range(10):
    avg_loss = 0
    avg_acc = 0

    for minibatch in mnist_data_loader:
        # Extract the data and the labels from the minibatch
        images, labels = minibatch

        # Forward pass: compute the predicted log-probabilities
        log_probs = model(images)

        # Compute the predictions by applying the log_softmax function to the log-probabilities
        # and taking the index of the maximum value along the log_probabilities' second dimension
        preds = log_probs.argmax(dim=1)

        # Compute the loss
        loss = criterion(log_probs, labels)

        # Backward pass: compute the gradients of the loss with respect to the model's parameters
        loss.backward()

        # Update the model's parameters using the optimizer
        optimizer.step()

        # Zero the gradients
        optimizer.zero_grad()

        # Compute the loss
        # Update the average loss
        avg_loss += loss.item() / len(mnist_dataset)
        avg_acc += accuracy(preds, labels) / len(mnist_dataset)
        #plt.plot(avg_loss, 'b')
        #plt.draw()
        #print(model.)

        # Print the average loss
    print("Average loss for epoch", epoch + 1, ":", avg_loss)
    print("Average acc for epoch", epoch + 1, ":", avg_acc.numpy())


Average loss for epoch 1 : 0.07066851193507498
Average acc for epoch 1 : 23.946783
Average loss for epoch 2 : 0.06791916280984878
Average acc for epoch 2 : 33.380127
Average loss for epoch 3 : 0.06390713739395139
Average acc for epoch 3 : 33.688442
Average loss for epoch 4 : 0.05860562411944079
Average acc for epoch 4 : 39.791943


KeyboardInterrupt: 