In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
from torch.utils.data import DataLoader
import numpy as np
import torch.nn.functional as F
import math
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
class CIFAR10VGG(nn.Module):
    def __init__(self):
        super(CIFAR10VGG, self).__init__()
        self.num_classes = 10
        self.weight_decay = 0.0005

        self.conv1 = nn.Conv2d(3, 10, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(10)
        self.dropout1 = nn.Dropout(0.3)
        self.conv2 = nn.Conv2d(10, 64, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(128)
        self.dropout2 = nn.Dropout(0.4)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False)
        self.bn5 = nn.BatchNorm2d(256)
        self.dropout3 = nn.Dropout(0.4)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False)
        self.bn6 = nn.BatchNorm2d(256)
        self.dropout4 = nn.Dropout(0.4)
        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False)
        self.bn7 = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv8 = nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False)
        self.bn8 = nn.BatchNorm2d(512)
        self.dropout5 = nn.Dropout(0.4)
        self.conv9 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False)
        self.bn9 = nn.BatchNorm2d(512)
        self.dropout6 = nn.Dropout(0.4)
        self.conv10 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False)
        self.bn10 = nn.BatchNorm2d(512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv11 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False)
        self.bn11 = nn.BatchNorm2d(512)
        self.dropout7 = nn.Dropout(0.4)
        self.conv12 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False)
        self.bn12 = nn.BatchNorm2d(512)
        self.dropout8 = nn.Dropout(0.4)
        self.conv13 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False)
        self.bn13 = nn.BatchNorm2d(512)
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout9 = nn.Dropout(0.5)

        self.fc1 = nn.Linear(512, 512)
        self.bn_fc1 = nn.BatchNorm1d(512)
        self.dropout_fc1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, self.num_classes)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.dropout1(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool1(x)

        x = F.relu(self.bn3(self.conv3(x)))
        x = self.dropout2(x)
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool2(x)

        x = F.relu(self.bn5(self.conv5(x)))
        x = self.dropout3(x)
        x = F.relu(self.bn6(self.conv6(x)))
        x = self.dropout4(x)
        x = F.relu(self.bn7(self.conv7(x)))
        x = self.pool3(x)

        x = F.relu(self.bn8(self.conv8(x)))
        x = self.dropout5(x)
        x = F.relu(self.bn9(self.conv9(x)))
        x = self.dropout6(x)
        x = F.relu(self.bn10(self.conv10(x)))
        x = self.pool4(x)

        x = F.relu(self.bn11(self.conv11(x)))
        x = self.dropout7(x)
        x = F.relu(self.bn12(self.conv12(x)))
        x = self.dropout8(x)
        x = F.relu(self.bn13(self.conv13(x)))
        x = self.pool5(x)
        x = self.dropout9(x)

        x = x.view(x.size(0), -1)
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.dropout_fc1(x)
        x = self.fc2(x)
        return x

In [3]:
# class Network():
#     def __init__(self, model):
#         self.model = model
#         self.layer_outputs = {}

#     def get_layer_output(self, x, labels, layer_name):
#         self.model.eval()
#         with torch.no_grad():
#             layer_idx = list(model._modules).index(layer_name)
#             intermediate_model = nn.Sequential(*list(model.children())[:layer_idx+1]).to(device)
#             output = intermediate_model(x)

#         num_samples = labels.shape[0]
#         label_matrix = torch.zeros((num_samples, num_samples))

#         for i in range(num_samples):
#             for j in range(num_samples):
#                 if torch.equal(labels[i], labels[j]):
#                     label_matrix[i, j] = 1
#                 else:
#                     label_matrix[i, j] = math.sqrt(2)

#         # labels_expanded = labels[:, None] == labels[None, :]
#         # label_matrix = torch.where(labels_expanded, 1.0, math.sqrt(2)).to(device)

#         return label_matrix, output

#     def dist_mat(self, x):
#         if len(x.shape) == 4:
#             x = x.view(x.shape[0], -1)

#         dist = torch.norm(x[:, None] - x, dim=2)
#         # dist = torch.cdist(x, x, p=2)
#         return dist

#     def diff_and_forb(self, data):

#         num_images = data.shape[0]
#         num_filters = data.shape[1]

#         results = np.zeros((num_images, num_images, num_filters))
#         for i in range(num_images):
#             for j in range(num_images):
#                 diff = data[i] - data[j]
#                 norm = np.linalg.norm(diff.cpu().numpy(), axis=(1, 2))
#                 results[i, j] = norm

#         # diff_matrix = data.unsqueeze(0) - data.unsqueeze(1)
#         # results = torch.norm(diff_matrix, dim=(3, 4))

#         return results

#     def compute_gaussian_kernel(self, frobenius_norm_matrices, sigma):
#         s = frobenius_norm_matrices.shape[0]
#         num_filters = frobenius_norm_matrices.shape[2]
#         gaussian_kernels = np.zeros((s, s, num_filters))

#         for k in range(num_filters):
#             for i in range(s):
#                 for j in range(s):
#                     diff = frobenius_norm_matrices[i, j, k]
#                     gaussian_kernels[i, j, k] = np.exp(- diff ** 2 / (2 * sigma ** 2))

#         # sigma = torch.tensor(sigma, device=frobenius_norm_matrices.device)
#         # gaussian_kernels = torch.exp(-frobenius_norm_matrices ** 2 / (2 * sigma ** 2))
#         # gaussian_kernels = torch.exp(-frobenius_norm_matrices ** 2 / (2 * sigma ** 2))

#         return gaussian_kernels

#     def normalize_matrix(self, matrix):
#         assert matrix.shape[0] == matrix.shape[1], "Input matrix must be square"

#         # matrix = matrix.cpu()
#         normalized_matrix = np.zeros_like(matrix, dtype=np.float64)

#         for i in range(matrix.shape[0]):
#             for j in range(matrix.shape[1]):
#                 if i == j:
#                     normalized_matrix[i, j] = 1.0
#                 else:
#                     normalized_matrix[i, j] = matrix[i, j] / np.sqrt(matrix[i, i] * matrix[j, j])

#         # matrix = matrix.to(torch.float64)

#         # if len(matrix.shape) == 3:
#         #   matrix = matrix.view(matrix.shape[0], -1)  # Flatten the last two dimensions

#         # # Create a copy to preserve the original structure
#         # normalized_matrix = matrix.clone()

#         # # Calculate normalization factors for all non-diagonal elements
#         # diag_elements = torch.sqrt(torch.diag(matrix))  # Shape: (N,)
#         # norm_factors = diag_elements.unsqueeze(0) * diag_elements.unsqueeze(1)  # Shape: (N, N)

#         # # Normalize all elements except the diagonal
#         # normalized_matrix = torch.where(
#         #     torch.eye(matrix.size(0), device=matrix.device, dtype=torch.bool),
#         #     torch.tensor(1.0, dtype=torch.float64, device=matrix.device),  # Set diagonal to 1
#         #     matrix / norm_factors
#         # )

#         # assert matrix.shape[0] == matrix.shape[1], "Input matrix must be square"

#         # matrix = matrix.to(torch.float64)

#         # # Flatten 3D matrices (batch_size, channels, height, width) to 2D
#         # if len(matrix.shape) == 3:
#         #     matrix = matrix.view(matrix.shape[0], -1)  # Flatten last two dimensions

#         # # Create a copy to preserve the original structure
#         # normalized_matrix = matrix.clone()

#         # # Calculate normalization factors for all non-diagonal elements
#         # diag_elements = torch.sqrt(torch.diag(matrix))  # Diagonal elements (1D)
#         # norm_factors = diag_elements.unsqueeze(0) * diag_elements.unsqueeze(1)  # (N, N) normalization factors

#         # # Ensure matrix and norm_factors have the same dimensions
#         # if matrix.shape[0] != norm_factors.shape[0]:
#         #     raise ValueError(f"Shape mismatch: matrix shape {matrix.shape} and norm_factors shape {norm_factors.shape}")

#         # # Normalize all elements except the diagonal
#         # normalized_matrix = torch.where(
#         #     torch.eye(matrix.size(0), device=matrix.device, dtype=torch.bool),
#         #     torch.tensor(1.0, dtype=torch.float64, device=matrix.device),  # Set diagonal to 1
#         #     matrix / norm_factors
#         # )

#         # return normalized_matrix
#         return normalized_matrix

#     def hadamard_mult(self, data, labels):
#         if isinstance(labels, np.ndarray):
#             labels = torch.from_numpy(labels)
#         if isinstance(data, np.ndarray):
#             data = torch.from_numpy(data)
#         broadcasted_labels = labels.unsqueeze(-1).expand_as(data)
#         result = data * broadcasted_labels
#         return result

#     def entropy(self, result):
#         entropies = np.zeros(result.shape[-1])
#         epsilon = np.finfo(float).eps

#         for i in range(result.shape[-1]):
#             eigenvalues = np.linalg.eigvals(result[:, :, i])
#             eigenvalues = np.abs(eigenvalues)
#             eigenvalues[eigenvalues == 0] = epsilon

#             entropy = -np.sum(eigenvalues * np.log2(eigenvalues))

#             entropies[i] = entropy


#         # epsilon = torch.finfo(result.dtype).eps

#         # # Compute eigenvalues along the last dimension for each 2D slice in result
#         # # Note: PyTorch's `linalg.eigvals` doesn't exist, so use `linalg.eig` and take only the real parts.
#         # eigenvalues = torch.linalg.eigvals(result)  # Shape: (*, num_filters)

#         # # Taking the absolute value to handle potential negative small values due to floating-point errors
#         # eigenvalues = torch.abs(eigenvalues)

#         # # Replace zeros with epsilon to prevent log(0)
#         # eigenvalues = torch.where(eigenvalues == 0, epsilon, eigenvalues)

#         # # Calculate entropy across each set of eigenvalues
#         # entropies = -torch.sum(eigenvalues * torch.log2(eigenvalues), dim=0)


#         return entropies

#     def dist_mat(self, x):
#         if len(x.shape) == 4:
#             x = x.view(x.shape[0], -1)

#         distances = torch.norm(x[:, None] - x, dim=2)
#         # x_squared = (x ** 2).sum(dim=1, keepdim=True)  # Shape: (N, 1)
#         # distances = torch.sqrt(x_squared - 2 * (x @ x.T) + x_squared.T)
#         return distances


#     def cal_mi(self, x, labels, conv_layers):
#         layer_entropies = {}
#         # for layer in conv_layers:
#         #     print(f"Layer: {layer}")
#         #     layer_name = layer
#         #     label_matrix, layer_output = self.get_layer_output(x, labels, layer_name)
#         #     frobenius_norm = self.diff_and_forb(layer_output)
#         #     gaussian_kernel = self.compute_gaussian_kernel(frobenius_norm, 2)
#         #     normalized_gauss = self.normalize_matrix(gaussian_kernel)
#         #     normalized_labels = self.normalize_matrix(label_matrix)
#         #     hadamard_mult = self.hadamard_mult(normalized_gauss, normalized_labels)
#         #     entropy = self.entropy(hadamard_mult)
#         #     label_matrix = self.normalize_matrix(label_matrix)

#         #     layer_entropies[layer_name] = entropy

#         #     # print("Layer output shape :",layer_output.shape)
#         #     # print("label_matrix :",label_matrix.shape)
#         #     # print("frobenius_norm shape: ",frobenius_norm.shape)
#         #     # print("gaussian_kernel shape: ",gaussian_kernel.shape)
#         #     # print("normalized_gauss shape: ",normalized_gauss.shape)
#         #     # print("normalized_labels: ",normalized_labels.shape)
#         #     # print("hadamard_mult shape: ",hadamard_mult.shape)
#         #     print("entropy shape: ",entropy.shape)

#         for layer in conv_layers:
#           # print(f"Layer: {layer}")

#           # start_time = time.time()

#           # # Getting layer output
#           # layer_name = layer
#           # label_matrix, layer_output = self.get_layer_output(x, labels, layer_name)
#           # print(f"Time for get_layer_output for {layer}: {time.time() - start_time:.4f} seconds")

#           # start_time = time.time()

#           # # Calculating Frobenius norm
#           # frobenius_norm = self.diff_and_forb(layer_output)
#           # print(f"Time for diff_and_forb for {layer}: {time.time() - start_time:.4f} seconds")

#           # start_time = time.time()

#           # # Computing Gaussian kernel
#           # gaussian_kernel = self.compute_gaussian_kernel(frobenius_norm, 2)
#           # print(f"Time for compute_gaussian_kernel for {layer}: {time.time() - start_time:.4f} seconds")

#           # start_time = time.time()

#           # # Normalizing Gaussian kernel
#           # normalized_gauss = self.normalize_matrix(gaussian_kernel)
#           # print(f"Time for normalize_matrix (gaussian_kernel) for {layer}: {time.time() - start_time:.4f} seconds")

#           # start_time = time.time()

#           # # Normalizing label matrix
#           # normalized_labels = self.normalize_matrix(label_matrix)
#           # print(f"Time for normalize_matrix (label_matrix) for {layer}: {time.time() - start_time:.4f} seconds")

#           # start_time = time.time()

#           # # Hadamard product
#           # hadamard_mult = self.hadamard_mult(normalized_gauss, normalized_labels)
#           # print(f"Time for hadamard_mult for {layer}: {time.time() - start_time:.4f} seconds")

#           # start_time = time.time()

#           # # Calculating entropy
#           # entropy = self.entropy(hadamard_mult)
#           # print(f"Time for entropy calculation for {layer}: {time.time() - start_time:.4f} seconds")

#           # start_time = time.time()

#           # # Normalizing label matrix again (if required)
#           # label_matrix = self.normalize_matrix(label_matrix)
#           # print(f"Time for second normalize_matrix (label_matrix) for {layer}: {time.time() - start_time:.4f} seconds")

#           # # Storing entropy for the layer
#           # layer_entropies[layer_name] = entropy
#           # print(len(layer_entropies[layer_name]))
#           # print(f"Total time for layer {layer}: {time.time() - start_time:.4f} seconds")

#           print(f"Layer: {layer}")

#           start_time = time.time()

#           # Getting layer output
#           layer_name = layer
#           label_matrix, layer_output = self.get_layer_output(x, labels, layer_name)
#           print(f"Shape of layer_output for {layer}: {layer_output.shape}")
#           print(f"Shape of label_matrix for {layer}: {label_matrix.shape}")
#           print(f"Time for get_layer_output for {layer}: {time.time() - start_time:.4f} seconds")

#           start_time = time.time()

#           # Calculating Frobenius norm
#           frobenius_norm = self.diff_and_forb(layer_output)
#           print(f"Shape of frobenius_norm for {layer}: {frobenius_norm.shape}")
#           print(f"Time for diff_and_forb for {layer}: {time.time() - start_time:.4f} seconds")

#           start_time = time.time()

#           # Computing Gaussian kernel
#           gaussian_kernel = self.compute_gaussian_kernel(frobenius_norm, 2)
#           print(f"Shape of gaussian_kernel for {layer}: {gaussian_kernel.shape}")
#           print(f"Time for compute_gaussian_kernel for {layer}: {time.time() - start_time:.4f} seconds")

#           start_time = time.time()

#           # Normalizing Gaussian kernel
#           normalized_gauss = self.normalize_matrix(gaussian_kernel)
#           print(f"Shape of normalized_gauss for {layer}: {normalized_gauss.shape}")
#           print(f"Time for normalize_matrix (gaussian_kernel) for {layer}: {time.time() - start_time:.4f} seconds")

#           start_time = time.time()

#           # Normalizing label matrix
#           normalized_labels = self.normalize_matrix(label_matrix)
#           print(f"Shape of normalized_labels for {layer}: {normalized_labels.shape}")
#           print(f"Time for normalize_matrix (label_matrix) for {layer}: {time.time() - start_time:.4f} seconds")

#           start_time = time.time()

#           # Hadamard product
#           hadamard_mult = self.hadamard_mult(normalized_gauss, normalized_labels)
#           print(f"Shape of hadamard_mult for {layer}: {hadamard_mult.shape}")
#           print(f"Time for hadamard_mult for {layer}: {time.time() - start_time:.4f} seconds")

#           start_time = time.time()

#           # Calculating entropy
#           entropy = self.entropy(hadamard_mult)
#           print(f"Shape of entropy for {layer}: {entropy.shape}")
#           print(f"Time for entropy calculation for {layer}: {time.time() - start_time:.4f} seconds")

#           start_time = time.time()

#           # Normalizing label matrix again (if required)
#           label_matrix = self.normalize_matrix(label_matrix)
#           print(f"Shape of label_matrix after second normalization for {layer}: {label_matrix.shape}")
#           print(f"Time for second normalize_matrix (label_matrix) for {layer}: {time.time() - start_time:.4f} seconds")

#           # Storing entropy for the layer
#           layer_entropies[layer_name] = entropy
#           print(f"Length of stored entropy for {layer_name}: {len(layer_entropies[layer_name])}")
#           print(f"Total time for layer {layer}: {time.time() - start_time:.4f} seconds")

#           print("\n")


#         return layer_entropies


In [4]:
class Network():
    def __init__(self, model):
        self.model = model
        self.layer_outputs = {}

    def get_layer_output(self, x, labels, layer_name):
        self.model.eval()
        with torch.no_grad():
            layer_idx = list(model._modules).index(layer_name)
            intermediate_model = nn.Sequential(*list(model.children())[:layer_idx+1]).to(device)
            output = intermediate_model(x)

        # num_samples = labels.shape[0]
        # label_matrix = torch.zeros((num_samples, num_samples))

        # for i in range(num_samples):
        #     for j in range(num_samples):
        #         if torch.equal(labels[i], labels[j]):
        #             label_matrix[i, j] = 1
        #         else:
        #             label_matrix[i, j] = math.sqrt(2)

        labels_expanded = labels[:, None] == labels[None, :]
        label_matrix = torch.where(labels_expanded, 1.0, math.sqrt(2)).to(device)

        return label_matrix, output

    def dist_mat(self, x):
        if len(x.shape) == 4:
            x = x.view(x.shape[0], -1)

        # dist = torch.norm(x[:, None] - x, dim=2)
        dist = torch.cdist(x, x, p=2)
        return dist

    def diff_and_forb(self, data):

        num_images = data.shape[0]
        num_filters = data.shape[1]

        # results = np.zeros((num_images, num_images, num_filters))
        # for i in range(num_images):
        #     for j in range(num_images):
        #         diff = data[i] - data[j]
        #         norm = np.linalg.norm(diff.cpu().numpy(), axis=(1, 2))
        #         results[i, j] = norm

        diff_matrix = data.unsqueeze(0) - data.unsqueeze(1)
        results = torch.norm(diff_matrix, dim=(3, 4))

        return results

    def compute_gaussian_kernel(self, frobenius_norm_matrices, sigma):
        # s = frobenius_norm_matrices.shape[0]
        # num_filters = frobenius_norm_matrices.shape[2]
        # gaussian_kernels = np.zeros((s, s, num_filters))

        # for k in range(num_filters):
        #     for i in range(s):
        #         for j in range(s):
        #             diff = frobenius_norm_matrices[i, j, k]
        #             gaussian_kernels[i, j, k] = np.exp(- diff ** 2 / (2 * sigma ** 2))

        sigma = torch.tensor(sigma, device=frobenius_norm_matrices.device)
        gaussian_kernels = torch.exp(-frobenius_norm_matrices ** 2 / (2 * sigma ** 2))
        gaussian_kernels = torch.exp(-frobenius_norm_matrices ** 2 / (2 * sigma ** 2))

        return gaussian_kernels

    def normalize_matrix(self, matrix):
        assert matrix.shape[0] == matrix.shape[1], "Input matrix must be square"

        matrix = matrix.cpu()
        normalized_matrix = np.zeros_like(matrix.numpy(), dtype=np.float64)

        for i in range(matrix.shape[0]):
            for j in range(matrix.shape[1]):
                if i == j:
                    normalized_matrix[i, j] = 1.0
                else:
                    normalized_matrix[i, j] = matrix[i, j] / np.sqrt(matrix[i, i] * matrix[j, j])

        # matrix = matrix.to(torch.float64)

        # if len(matrix.shape) == 3:
        #   matrix = matrix.view(matrix.shape[0], -1)  # Flatten the last two dimensions

        # # Create a copy to preserve the original structure
        # normalized_matrix = matrix.clone()

        # # Calculate normalization factors for all non-diagonal elements
        # diag_elements = torch.sqrt(torch.diag(matrix))  # Shape: (N,)
        # norm_factors = diag_elements.unsqueeze(0) * diag_elements.unsqueeze(1)  # Shape: (N, N)

        # # Normalize all elements except the diagonal
        # normalized_matrix = torch.where(
        #     torch.eye(matrix.size(0), device=matrix.device, dtype=torch.bool),
        #     torch.tensor(1.0, dtype=torch.float64, device=matrix.device),  # Set diagonal to 1
        #     matrix / norm_factors
        # )

        # assert matrix.shape[0] == matrix.shape[1], "Input matrix must be square"

        # matrix = matrix.to(torch.float64)

        # # Flatten 3D matrices (batch_size, channels, height, width) to 2D
        # if len(matrix.shape) == 3:
        #     matrix = matrix.view(matrix.shape[0], -1)  # Flatten last two dimensions

        # # Create a copy to preserve the original structure
        # normalized_matrix = matrix.clone()

        # # Calculate normalization factors for all non-diagonal elements
        # diag_elements = torch.sqrt(torch.diag(matrix))  # Diagonal elements (1D)
        # norm_factors = diag_elements.unsqueeze(0) * diag_elements.unsqueeze(1)  # (N, N) normalization factors

        # # Ensure matrix and norm_factors have the same dimensions
        # if matrix.shape[0] != norm_factors.shape[0]:
        #     raise ValueError(f"Shape mismatch: matrix shape {matrix.shape} and norm_factors shape {norm_factors.shape}")

        # # Normalize all elements except the diagonal
        # normalized_matrix = torch.where(
        #     torch.eye(matrix.size(0), device=matrix.device, dtype=torch.bool),
        #     torch.tensor(1.0, dtype=torch.float64, device=matrix.device),  # Set diagonal to 1
        #     matrix / norm_factors
        # )

        # return normalized_matrix
        return torch.tensor(normalized_matrix, device=matrix.device)

    def hadamard_mult(self, data, labels):
        if isinstance(labels, np.ndarray):
            labels = torch.from_numpy(labels)
        if isinstance(data, np.ndarray):
            data = torch.from_numpy(data)
        broadcasted_labels = labels.unsqueeze(-1).expand_as(data)
        result = data * broadcasted_labels
        return result

    def entropy(self, result):
        entropies = np.zeros(result.shape[-1])
        epsilon = torch.finfo(float).eps

        for i in range(result.shape[-1]):
            eigenvalues = np.linalg.eigvals(result[:, :, i])
            eigenvalues = np.abs(eigenvalues)
            eigenvalues[eigenvalues == 0] = epsilon

            entropy = -np.sum(eigenvalues * np.log2(eigenvalues))

            entropies[i] = entropy


        # epsilon = torch.finfo(result.dtype).eps

        # # Compute eigenvalues along the last dimension for each 2D slice in result
        # # Note: PyTorch's `linalg.eigvals` doesn't exist, so use `linalg.eig` and take only the real parts.
        # eigenvalues = torch.linalg.eigvals(result)  # Shape: (*, num_filters)

        # # Taking the absolute value to handle potential negative small values due to floating-point errors
        # eigenvalues = torch.abs(eigenvalues)

        # # Replace zeros with epsilon to prevent log(0)
        # eigenvalues = torch.where(eigenvalues == 0, epsilon, eigenvalues)

        # # Calculate entropy across each set of eigenvalues
        # entropies = -torch.sum(eigenvalues * torch.log2(eigenvalues), dim=0)

        # epsilon = torch.finfo(result.dtype).eps

        # # Check if result is 3D or 4D and unpack accordingly
        # if len(result.shape) == 4:  # (batch_size, num_filters, height, width)
        #     batch_size, num_filters, height, width = result.shape
        # elif len(result.shape) == 3:  # (batch_size, num_filters, spatial_dim)
        #     batch_size, num_filters, spatial_dim = result.shape
        #     height = 1  # In case there's only one spatial dimension
        #     width = spatial_dim
        # else:
        #     raise ValueError("Expected 3D or 4D tensor, but got shape: {}".format(result.shape))

        # # Initialize entropy values for each filter
        # entropies = torch.zeros(num_filters, device=result.device)

        # # Loop over each filter
        # for i in range(num_filters):
        #     # Extract the 2D slice for the i-th filter (all samples in the batch)
        #     filter_matrix = result[:, i, :, :] if len(result.shape) == 4 else result[:, i, :].unsqueeze(2)  # Ensure it has two spatial dimensions

        #     # Compute Frobenius norm for each filter matrix (alternative to eigenvalue computation)
        #     frobenius_norm = torch.norm(filter_matrix, p='fro', dim=(1, 2))  # Sum of squared entries (frobenius norm)

        #     # Add a small epsilon to prevent division by zero
        #     frobenius_norm = torch.where(frobenius_norm == 0, epsilon, frobenius_norm)

        #     # Calculate entropy-like measure using Frobenius norm
        #     entropy = -torch.sum(frobenius_norm * torch.log2(frobenius_norm), dim=0)

        #     # Store the entropy for this filter
        #     entropies[i] = torch.mean(entropy)  # Average entropy across the batch

        return entropies

    def dist_mat(self, x):
        if len(x.shape) == 4:
            x = x.view(x.shape[0], -1)

        # dist = torch.norm(x[:, None] - x, dim=2)
        x_squared = (x ** 2).sum(dim=1, keepdim=True)  # Shape: (N, 1)
        distances = torch.sqrt(x_squared - 2 * (x @ x.T) + x_squared.T)
        return distances


    def cal_mi(self, x, labels, conv_layers):
        layer_entropies = {}
        # for layer in conv_layers:
        #     print(f"Layer: {layer}")
        #     layer_name = layer
        #     label_matrix, layer_output = self.get_layer_output(x, labels, layer_name)
        #     frobenius_norm = self.diff_and_forb(layer_output)
        #     gaussian_kernel = self.compute_gaussian_kernel(frobenius_norm, 2)
        #     normalized_gauss = self.normalize_matrix(gaussian_kernel)
        #     normalized_labels = self.normalize_matrix(label_matrix)
        #     hadamard_mult = self.hadamard_mult(normalized_gauss, normalized_labels)
        #     entropy = self.entropy(hadamard_mult)
        #     label_matrix = self.normalize_matrix(label_matrix)

        #     layer_entropies[layer_name] = entropy

        #     # print("Layer output shape :",layer_output.shape)
        #     # print("label_matrix :",label_matrix.shape)
        #     # print("frobenius_norm shape: ",frobenius_norm.shape)
        #     # print("gaussian_kernel shape: ",gaussian_kernel.shape)
        #     # print("normalized_gauss shape: ",normalized_gauss.shape)
        #     # print("normalized_labels: ",normalized_labels.shape)
        #     # print("hadamard_mult shape: ",hadamard_mult.shape)
        #     print("entropy shape: ",entropy.shape)

        for layer in conv_layers:
          # print(f"Layer: {layer}")

          # starting_time = time.time()

          # # Getting layer output
          # layer_name = layer
          # label_matrix, layer_output = self.get_layer_output(x, labels, layer_name)
          # # print(f"Time for get_layer_output for {layer}: {time.time() - starting_time:.4f} seconds")

          # start_time = time.time()

          # # Calculating Frobenius norm
          # frobenius_norm = self.diff_and_forb(layer_output)
          # # print(f"Time for diff_and_forb for {layer}: {time.time() - start_time:.4f} seconds")

          # start_time = time.time()

          # # Computing Gaussian kernel
          # gaussian_kernel = self.compute_gaussian_kernel(frobenius_norm, 2)
          # # print(f"Time for compute_gaussian_kernel for {layer}: {time.time() - start_time:.4f} seconds")

          # start_time = time.time()

          # # Normalizing Gaussian kernel
          # normalized_gauss = self.normalize_matrix(gaussian_kernel)
          # # print(f"Time for normalize_matrix (gaussian_kernel) for {layer}: {time.time() - start_time:.4f} seconds")

          # start_time = time.time()

          # # Normalizing label matrix
          # normalized_labels = self.normalize_matrix(label_matrix)
          # # print(f"Time for normalize_matrix (label_matrix) for {layer}: {time.time() - start_time:.4f} seconds")

          # start_time = time.time()

          # # Hadamard product
          # hadamard_mult = self.hadamard_mult(normalized_gauss, normalized_labels)
          # # print(f"Time for hadamard_mult for {layer}: {time.time() - start_time:.4f} seconds")

          # start_time = time.time()

          # # Calculating entropy
          # entropy = self.entropy(hadamard_mult)
          # # print(f"Time for entropy calculation for {layer}: {time.time() - start_time:.4f} seconds")

          # start_time = time.time()

          # # Normalizing label matrix again (if required)
          # label_matrix = self.normalize_matrix(label_matrix)
          # # print(f"Time for second normalize_matrix (label_matrix) for {layer}: {time.time() - start_time:.4f} seconds")

          # # Storing entropy for the layer
          # layer_entropies[layer_name] = entropy
          # print(len(layer_entropies[layer]))
          # print(layer_entropies[layer_name])
          # print(f"Total time for layer {layer}: {time.time() - starting_time:.4f} seconds")

          print(f"Layer: {layer}")

          start_time = time.time()
          starting_time = time.time()

          # Getting layer output
          layer_name = layer
          label_matrix, layer_output = self.get_layer_output(x, labels, layer_name)
          print(f"Shape of layer_output for {layer}: {layer_output.shape}")
          print(f"Shape of label_matrix for {layer}: {label_matrix.shape}")
          print(f"Time for get_layer_output for {layer}: {time.time() - start_time:.4f} seconds")

          start_time = time.time()

          # Calculating Frobenius norm
          frobenius_norm = self.diff_and_forb(layer_output)
          print(f"Shape of frobenius_norm for {layer}: {frobenius_norm.shape}")
          print(f"Time for diff_and_forb for {layer}: {time.time() - start_time:.4f} seconds")

          start_time = time.time()

          # Computing Gaussian kernel
          gaussian_kernel = self.compute_gaussian_kernel(frobenius_norm, 2)
          print(f"Shape of gaussian_kernel for {layer}: {gaussian_kernel.shape}")
          print(f"Time for compute_gaussian_kernel for {layer}: {time.time() - start_time:.4f} seconds")

          start_time = time.time()

          # Normalizing Gaussian kernel
          normalized_gauss = self.normalize_matrix(gaussian_kernel)
          print(f"Shape of normalized_gauss for {layer}: {normalized_gauss.shape}")
          print(f"Time for normalize_matrix (gaussian_kernel) for {layer}: {time.time() - start_time:.4f} seconds")

          start_time = time.time()

          # Normalizing label matrix
          normalized_labels = self.normalize_matrix(label_matrix)
          print(f"Shape of normalized_labels for {layer}: {normalized_labels.shape}")
          print(f"Time for normalize_matrix (label_matrix) for {layer}: {time.time() - start_time:.4f} seconds")

          start_time = time.time()

          # Hadamard product
          hadamard_mult = self.hadamard_mult(normalized_gauss, normalized_labels)
          print(f"Shape of hadamard_mult for {layer}: {hadamard_mult.shape}")
          print(f"Time for hadamard_mult for {layer}: {time.time() - start_time:.4f} seconds")

          start_time = time.time()

          # Calculating entropy
          entropy = self.entropy(hadamard_mult)
          print(f"Shape of entropy for {layer}: {entropy.shape}")
          print(f"Time for entropy calculation for {layer}: {time.time() - start_time:.4f} seconds")

          start_time = time.time()

          # Normalizing label matrix again (if required)
          label_matrix = self.normalize_matrix(label_matrix)
          print(f"Shape of label_matrix after second normalization for {layer}: {label_matrix.shape}")
          print(f"Time for second normalize_matrix (label_matrix) for {layer}: {time.time() - start_time:.4f} seconds")

          # Storing entropy for the layer
          layer_entropies[layer_name] = entropy
          print(f"Length of stored entropy for {layer_name}: {len(layer_entropies[layer_name])}")
          print(f"Total time for layer {layer}: {time.time() - starting_time:.4f} seconds")

          print("\n")
        return layer_entropies


In [40]:
def calculate_top_filters(base_model, train_loader, conv_layers):
    batch_size = 128  # Choose an appropriate batch size
    network = Network(base_model)

    layer_votes_batch = {}
    layer_aggregated_votes = {}

    for conv in conv_layers:
        layer_votes_batch[conv] = []

    # layer_votes_batch = {conv: [] for conv in conv_layers}
    # layer_aggregated_votes = {}

    # print(layer_votes_batch)

    for batch_images, batch_labels in train_loader:

        batch_images = batch_images.to(device)
        batch_labels = batch_labels.to(device)
        layer_votes = network.cal_mi(batch_images, batch_labels, conv_layers)

        print("\n\n\n\n")
        # print(layer_votes)

        for layer in conv_layers:
            layer_votes_batch[layer].append(layer_votes[layer])

    print("Entire layer votes batch is: ",layer_votes_batch)

    for layer in conv_layers:
        layer_aggregated_votes[layer] = np.mean(layer_votes_batch[layer], axis=0)
        # layer_aggregated_votes[layer] = torch.mean(torch.tensor(layer_votes_batch[layer]), axis=0).to(device)

    print("Final layer votes aggregated is: ",layer_aggregated_votes)

    return layer_aggregated_votes

def sort(layer_aggregated_votes, conv_layers, k):
    layer_result = {}
    for layer in conv_layers:
        agg_votes = layer_aggregated_votes[layer]
        k_indices=int((agg_votes.size) * k)
        flat_indices = np.argsort(agg_votes.ravel())[-k_indices:]
        indices = np.unravel_index(flat_indices, agg_votes.shape)
        # top_k_indices = np.argpartition(agg_votes.ravel(), -k)[-k:]  # Faster than full argsort
        result = np.zeros_like(agg_votes)
        result[indices] = 1
        # result[np.unravel_index(top_k_indices, agg_votes.shape)] = 1
        layer_result[layer] = result

    return layer_result

def normalize_cifar10(train_data, test_data):
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    return transform_train, transform_test

In [33]:
filter_calculation = None
def get_important_filters(model, train_loader):
    global filter_calculation
    conv_layers = []
    for name,layer in model.named_children():
        if "conv" in name:
            conv_layers.append(str(name))
    # print(conv_layers)
    most_important_filters = calculate_top_filters(model, train_loader, conv_layers)
    filter_calculation = most_important_filters
    print("\n\n", most_important_filters)
    print("-----------------------------------------------------------------------")
    sorted_filters = sort(most_important_filters, conv_layers, 0.1)
    print(sorted_filters)
    return sorted_filters

In [34]:
from torch.utils.data import DataLoader, Subset
import random

transform_train, transform_test = normalize_cifar10(None, None)
train_dataset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform = transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform = transform_test)

# train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

model = CIFAR10VGG().to(device)
network = Network(model)

Files already downloaded and verified
Files already downloaded and verified


In [19]:
train_size = int(0.01 * len(train_dataset))
test_size = int(0.01 * len(test_dataset))

train_indices = random.sample(range(len(train_dataset)), train_size)
test_indices = random.sample(range(len(test_dataset)), test_size)

train_subset = Subset(train_dataset, train_indices)
test_subset = Subset(test_dataset, test_indices)

train_loader = DataLoader(train_subset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=128, shuffle=False)


In [20]:
len(train_loader)

4

In [36]:
conv_layers = []
for name,layer in model.named_children():
        if "conv" in name:
            conv_layers.append(str(name))

In [37]:
top_filters_vgg16 = get_important_filters(model, train_loader)
print(f"Top filters selected: {top_filters_vgg16}")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
       -1423.74467795, -1423.70117534, -1423.72062035, -1423.7001453 ,
       -1423.68103497, -1423.67534345, -1423.68099055, -1423.73102552,
       -1423.67192827, -1423.7117706 , -1423.69991067, -1423.74799611,
       -1423.69638688, -1423.67738393, -1423.7030034 , -1423.66083425,
       -1423.71992104, -1423.67970851, -1423.75872168, -1423.6618921 ,
       -1423.65777457, -1423.70293943, -1423.7411142 , -1423.65646559,
       -1423.72474939, -1423.7671299 , -1423.75919456, -1423.74194309,
       -1423.64345388, -1423.65569081, -1423.75851936, -1423.7665019 ,
       -1423.69655921, -1423.65983928, -1423.69003218, -1423.67597018,
       -1423.73164721, -1423.69956367, -1423.72515419, -1423.56023555,
       -1423.72700628, -1423.58982639, -1423.67488342, -1423.70240845,
       -1423.77901826, -1423.65564129, -1423.70079112, -1423.75387126,
       -1423.68845447, -1423.69207425, -1423.752583  , -1423.65235286,
       -1423

In [54]:
sorted_filters = sort(filter_calculation, conv_layers, 0.5)

In [55]:
sorted_filters

{'conv1': array([0., 0., 1., 0., 1., 0., 0., 1., 1., 1.]),
 'conv2': array([0., 0., 1., 1., 1., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1.,
        0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 1., 1., 1., 0., 1., 0., 0.,
        0., 1., 0., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1.,
        0., 1., 1., 1., 0., 0., 0., 1., 0., 0., 1., 1., 1.]),
 'conv3': array([1., 1., 1., 1., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
        1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 1.,
        0., 0., 0., 1., 0., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0.,
        1., 0., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1., 0., 1.,
        1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0.,
        1., 0., 0., 0., 0., 0., 1., 1., 0., 1., 1., 0., 1., 0., 0., 1., 1.,
        0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0.,
        1., 1., 1., 1., 1., 0., 1., 1., 1.]),
 'conv4': array([0., 0., 1., 1., 1., 0., 0., 1., 0., 0.

In [None]:
# def sort_filters(conv_layer, k):
#     k=int((conv_layer.size) * k)
#     flat_indices = np.argsort(conv_layer.ravel())[-k:]
#     indices = np.unravel_index(flat_indices, conv_layer.shape)
#     result = np.zeros_like(conv_layer)
#     result[indices] = 1

#     return result

# def sort(layer_aggregated_votes, conv_layers, k):
#     layer_result = {}
#     for layer in conv_layers:
#       agg_votes = layer_aggregated_votes[layer]
#       filters = int((agg_votes.shape[0]) * k)
#       print(filters)
#       flat_indices = np.argsort(agg_votes.ravel())[-filters:]
#       indices = np.unravel_index(flat_indices, agg_votes.shape)
#       result = np.zeros_like(agg_votes)
#       result[indices] = 1
#       layer_result[layer] = result

#     return layer_result

In [None]:
# chosen_filters = sort_filters(filter_calculation["conv1"], 0.1)
# print(chosen_filters)

In [None]:
model = CIFAR10VGG().to(device)
conv_layers = []
for name,layer in model.named_children():
  if "conv" in name:
    conv_layers.append(str(name))
sorted_filters = sort(filter_calculation, conv_layers, 0.5)

In [None]:
sorted_filters

In [None]:
top_filters_vgg16
for layer in top_filters_vgg16:
    print(layer)
    print(len(top_filters_vgg16[layer]))

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
from torch.utils.data import DataLoader
import numpy as np
import torch.nn.functional as F
import math

In [None]:
for layer in top_filters_vgg16:
    print(layer)

In [57]:

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

for batch_idx, (inputs, targets) in enumerate(train_loader):
  inputs, targets = inputs.to(device), targets.to(device)
  optimizer.zero_grad()

  outputs = model(inputs)

  loss = criterion(outputs, targets)

  loss.backward()

  for layer_name in conv_layers:

    conv_layer = getattr(model, layer_name)
    print(conv_layer.weight.grad)
    filters_state = list(sorted_filters[layer_name])
    print(filters_state)
    filter_mask = torch.tensor(filters_state, dtype=torch.float32).view(-1, 1, 1, 1).to(device)

    with torch.no_grad():
      conv_layer.weight.grad *= filter_mask
    print(conv_layer.weight.grad)
    print("-"*50)

    optimizer.step()

    break


tensor([[[[ 1.2463e-09,  3.0188e-08,  5.2798e-08],
          [-4.4866e-09,  1.6343e-08,  3.8193e-08],
          [-3.5192e-08, -3.3329e-09,  7.6782e-09]],

         [[-2.1403e-08,  1.0225e-08,  3.5387e-08],
          [-2.9945e-08, -4.6645e-09,  2.1201e-08],
          [-5.8288e-08, -1.9462e-08, -2.1913e-09]],

         [[-6.8084e-08, -3.3555e-08, -6.4623e-09],
          [-7.2097e-08, -4.7521e-08, -1.6515e-08],
          [-9.2417e-08, -5.2755e-08, -3.0510e-08]]],


        [[[ 1.3014e-08,  1.5068e-08,  2.6020e-09],
          [ 1.8489e-08,  2.6129e-08,  5.2114e-09],
          [-1.2913e-09,  2.5092e-08,  8.8069e-09]],

         [[-1.1579e-08, -1.1404e-08, -1.6787e-08],
          [-4.7390e-09, -1.8116e-10, -1.3554e-08],
          [-2.6306e-08, -2.8438e-09, -1.1286e-08]],

         [[-6.5961e-09, -7.0161e-09, -1.2177e-08],
          [ 2.6798e-10,  5.0438e-09, -7.3919e-09],
          [-2.4465e-08, -1.0493e-10, -1.2538e-08]]],


        [[[ 6.7487e-08,  6.4198e-08,  3.8097e-08],
          [ 7.8

In [49]:
model.conv1.weight.grad

tensor([[[[ 4.4923e-08,  4.5557e-08,  2.6472e-08],
          [ 5.5847e-08,  4.7527e-08,  2.4037e-08],
          [ 5.8222e-08,  3.4506e-08, -5.4721e-09]],

         [[ 6.6519e-09,  5.7927e-09, -1.8712e-08],
          [ 1.9885e-08,  7.6635e-09, -1.6464e-08],
          [ 2.5769e-08, -3.6385e-09, -4.2959e-08]],

         [[-3.5160e-08, -2.8007e-08, -6.1323e-08],
          [-2.5058e-08, -2.7136e-08, -5.7919e-08],
          [-2.4984e-08, -4.8081e-08, -9.4239e-08]]],


        [[[-1.8590e-08, -5.3639e-09, -5.3247e-09],
          [-7.3579e-08, -4.1566e-08, -1.4217e-08],
          [-8.5084e-08, -6.3152e-08, -2.5315e-08]],

         [[ 1.5966e-08,  2.3649e-08,  2.0625e-08],
          [-4.8633e-08, -1.7029e-08,  1.5264e-08],
          [-6.7171e-08, -4.1292e-08,  3.9166e-09]],

         [[ 1.8149e-08,  3.5924e-08,  3.6997e-08],
          [-4.0390e-08, -5.1916e-10,  3.7265e-08],
          [-4.7376e-08, -1.3312e-08,  3.4783e-08]]],


        [[[ 7.4533e-08,  2.2517e-08,  2.5936e-08],
          [ 6.9

In [None]:
def train_model(model, train_loader, test_loader, epochs=250, lr=0.1, lr_drop=20, base=False, filter_data=None, conv_layers=[]):
    global pruning
    best_accuracy = 0
    best_model = None
    pruning += 1
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_drop, gamma=0.5)

    for epoch in range(epochs):
        model.train()
        train_loss, correct, total = 0, 0, 0

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            if base:
                loss = criterion(outputs, targets)
            else:
                loss = custom_loss(outputs, targets, model, criterion, params["lambda_l1"])

            loss.backward()

            for layer_name in conv_layers:
                conv_layer = getattr(model, layer_name)
                filters_state = filter_data[layer_name]
                filter_mask = torch.tensor(filters_state, dtype=torch.float32).view(-1, 1, 1, 1).to(device)

                with torch.no_grad():
                    conv_layer.weight.grad *= filter_mask

            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        train_accuracy = 100. * correct / total
        avg_train_loss = train_loss / total

        test_loss, test_accuracy = test_model(model, test_loader, criterion)

        print(f'Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

        writer.add_scalar(f'Loss/Train {pruning}Prune', avg_train_loss, epoch)
        writer.add_scalar(f'Accuracy/Train {pruning}Prune', train_accuracy, epoch)
        writer.add_scalar(f'Loss/Test {pruning}Prune', test_loss, epoch)
        writer.add_scalar(f'Accuracy/Test {pruning}Prune', test_accuracy, epoch)

        scheduler.step()

        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            best_model = deepcopy(model)

    return best_model

In [46]:
class SimpleModel(nn.Module):
  def __init__(self):
    super(SimpleModel, self).__init__()
    self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 10, kernel_size = (3,3), stride = 1, padding = 1)

  def forward(self, x):
    return self.conv1(x)

# model=SimpleModel()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

input_tensor = torch.randn(1, 3, 32, 32).to(device)
target = torch.randn(1, 1, 32, 32).to(device)

# Forward pass
output = model(input_tensor)
loss = criterion(output, target)

# Backward pass to calculate gradients
loss.backward()
conv1 = getattr(model,"conv1")
print(conv1.weight.grad)
# filter_mask = torch.tensor([1, 0, 0, 1], dtype = torch.float32).view(-1,1,1,1)

# with torch.no_grad():
  # model.conv1.weight.grad *= filter_mask
  # optimizer.state[model.conv1.weight]['momentum_buffer'] *= filter_mask


print("Filtered Gradients")
print(model.conv1.weight.grad)

RuntimeError: 0D or 1D target tensor expected, multi-target not supported

In [None]:
def update_gradients(model, filter_data, criterion, optimizer, conv_layers):
    input_tensor = torch.randn(1, 3, 32, 32)
    target = torch.randn(1, 4, 32, 32)


    output = model(input_tensor)
    loss = criterion(output, target)

    loss.backward()

    for layer in conv_layers:
        conv_layer = getattr(model, layer)
        filters_state = list(filter_data["layer"])
        filter_mask = torch.tensor(filters_state, dtype = torch.float32).view(-1,1,1,1)

        with torch.no_grad():
          conv_layer.weight.grad *= filter_mask
          # optimizer.state[model.conv1.weight]['momentum_buffer'] *= filter_mask




In [None]:
def train_model(model, train_loader, test_loader, epochs=250, lr=0.1, lr_drop=20, base=False, filter_data, conv_layers):
    global pruning
    best_accuracy=0
    best_model=None
    pruning += 1
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_drop, gamma=0.5)

    for epoch in range(epochs):
        model.train()
        train_loss, correct, total = 0, 0, 0
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            if base:
                loss = criterion(outputs, targets)
            if not base:
                loss=custom_loss(outputs, targets, model, criterion, params["lambda_l1"])
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)  # Multiply by batch size to sum all loss
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        train_accuracy = 100. * correct / total
        avg_train_loss = train_loss / total

        test_loss, test_accuracy = test_model(model, test_loader, criterion)  # Test at the end of each epoch

        print(f'Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
        writer.add_scalar(f'Loss/Train {pruning}Prune', train_loss, epoch)
        writer.add_scalar(f'Accuracy/Train {pruning}Prune', train_accuracy, epoch)
        writer.add_scalar(f'Loss/Test {pruning}Prune', test_loss, epoch)
        writer.add_scalar(f'Accuracy/Test {pruning}Prune', test_accuracy, epoch)

        scheduler.step()


        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            best_model = deepcopy(model)


    return best_model

In [None]:
conv1 = getattr(model,'conv1')
conv1.weight.grad