In [489]:
from torch import nn
import torch


class MyModel(nn.Module):
    def __init__(self, all_sizes, activation=nn.ReLU):
        super(MyModel, self).__init__()

        input_size, *hidden_sizes, output_size = all_sizes

        layers = []
        current_size = input_size

        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(current_size, hidden_size))
            layers.append(activation())
            current_size = hidden_size

        self.layers = nn.Sequential(*layers)
        self.output_layer = nn.Linear(current_size, output_size)

        self.frozen_edges = set()

    def forward(self, x):
        x = self.layers(x)
        x = self.output_layer(x)
        return x

    def calculate_edge_gradients(
            self, dataloader, loss_fn, threshold,
            top_percent, auto_freeze=True, device='cpu'
    ):
        accumulated_grads = None
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)

            self.zero_grad()
            outputs = self.forward(data)
            loss = loss_fn(outputs, target)
            loss.backward()

            grads = self.layers[-2].weight.grad.abs()

            if accumulated_grads is None:
                accumulated_grads = torch.zeros_like(grads).to(device)

            accumulated_grads += grads

        avg_grads = accumulated_grads / len(dataloader)

        frozen_mask = torch.zeros(avg_grads.size(), dtype=torch.bool, device=device)
        if self.frozen_edges:
            frozen_edges_tensor = torch.tensor(list(self.frozen_edges), device=device).T
            frozen_mask[frozen_edges_tensor[0], frozen_edges_tensor[1]] = True

        mask = (avg_grads < threshold) & ~frozen_mask

        candidate_grads = avg_grads[mask]
        candidate_indices = torch.nonzero(mask, as_tuple=False)

        top_k = max(1, int(len(candidate_indices) * top_percent / 100))

        selected_edges_to_ignore = []
        if candidate_grads.numel() > 0:
            top_weights, top_indices = torch.topk(candidate_grads, top_k, largest=False)
            selected_edges_to_ignore = candidate_indices[top_indices].tolist()

        if auto_freeze and selected_edges_to_ignore:
            self.freeze_edges(selected_edges_to_ignore)

        return selected_edges_to_ignore

    def freeze_edges(self, edges):
        self.frozen_edges.update(tuple(edge) for edge in edges)


In [490]:
sizes = [10, 20, 40, 20, 1]
model = MyModel(sizes)

In [491]:
model

MyModel(
  (layers): Sequential(
    (0): Linear(in_features=10, out_features=20, bias=True)
    (1): ReLU()
    (2): Linear(in_features=20, out_features=40, bias=True)
    (3): ReLU()
    (4): Linear(in_features=40, out_features=20, bias=True)
    (5): ReLU()
  )
  (output_layer): Linear(in_features=20, out_features=1, bias=True)
)

In [492]:
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_regression

X, y = make_regression(n_samples=1000, n_features=sizes[0], n_informative=sizes[0], random_state=42)
X = torch.from_numpy(X).float()
y = torch.from_numpy(y).float()

dataset = list(zip(X, y))
train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

criterion = nn.MSELoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

MyModel(
  (layers): Sequential(
    (0): Linear(in_features=10, out_features=20, bias=True)
    (1): ReLU()
    (2): Linear(in_features=20, out_features=40, bias=True)
    (3): ReLU()
    (4): Linear(in_features=40, out_features=20, bias=True)
    (5): ReLU()
  )
  (output_layer): Linear(in_features=20, out_features=1, bias=True)
)

In [493]:
from torch import nn

selected_edges = model.calculate_edge_gradients(
    train_loader, criterion,
    threshold=0.001, top_percent=20,
    auto_freeze=True, device=device)

print(selected_edges)
# model.freeze_edges(selected_edges)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.])
[[14, 4], [0, 19], [2, 4], [2, 19], [18, 23], [18, 25], [18, 26], [18, 27], [5, 35], [18, 28], [9, 4], [9, 5], [9, 19], [9, 37], [18, 29], [18, 30], [14, 36], [18, 31], [18, 32], [18, 33], [18, 34], [18, 35], [12, 19], [18, 36], [18, 37]]


In [494]:
model.frozen_edges

{(0, 19),
 (2, 4),
 (2, 19),
 (5, 35),
 (9, 4),
 (9, 5),
 (9, 19),
 (9, 37),
 (12, 19),
 (14, 4),
 (14, 36),
 (18, 23),
 (18, 25),
 (18, 26),
 (18, 27),
 (18, 28),
 (18, 29),
 (18, 30),
 (18, 31),
 (18, 32),
 (18, 33),
 (18, 34),
 (18, 35),
 (18, 36),
 (18, 37)}