In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [2]:
from senmodel.model.utils import convert_dense_to_sparse_network, get_model_last_layer
from senmodel.metrics.nonlinearity_metrics import GradientMeanEdgeMetric, PerturbationSensitivityEdgeMetric
from senmodel.metrics.edge_finder import EdgeFinder


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
from sklearn.preprocessing import LabelEncoder

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = [
    'age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
    'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
    'hours-per-week', 'native-country', 'income'
]
data = pd.read_csv(url, names=columns, na_values=" ?", skipinitialspace=True)
data = data.dropna()

X = data.drop('income', axis=1)
y = data['income']

for col in X.select_dtypes(include=['object']).columns:
    X[col] = LabelEncoder().fit_transform(X[col])

y = LabelEncoder().fit_transform(y)

scaler = StandardScaler()
X = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)


In [5]:
class TabularDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [6]:
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [7]:
class EnhancedFCN(nn.Module):
    def __init__(self, input_size=14, hidden_sizes=None, output_size=2, dropout_rate=0.3):
        super(EnhancedFCN, self).__init__()
        if hidden_sizes is None:
            hidden_sizes = [128, 64, 32]
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)

        self.fc3 = nn.Linear(hidden_sizes[1], hidden_sizes[2])
        self.relu3 = nn.ReLU()
        self.dropout3 = nn.Dropout(dropout_rate)

        self.output = nn.Linear(hidden_sizes[2], output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        # x = self.dropout1(x)

        x = self.fc2(x)
        x = self.relu2(x)
        # x = self.dropout2(x)

        x = self.fc3(x)
        x = self.relu3(x)
        # x = self.dropout3(x)

        x = self.output(x)
        return x


In [24]:
from copy import deepcopy
dense_model = EnhancedFCN(input_size=X.shape[1])

sparse_model = convert_dense_to_sparse_network(deepcopy( dense_model))

In [25]:
ze = torch.zeros((1, X.shape[1]))
dense_model(ze), sparse_model(ze)

(tensor([[-0.1328, -0.0299]], grad_fn=<AddmmBackward0>),
 tensor([[-0.1328, -0.0299]], grad_fn=<AsStridedBackward0>))

In [26]:
aaa = []
for batch, label in val_loader:
    # print(sparse_model(batch))
    aaa.append(sparse_model(batch))

In [27]:
layer = get_model_last_layer(sparse_model)
layer.replace_many([0],[0])

In [28]:
aaa, sparse_model(ze)

([tensor([[-0.1362, -0.0227],
          [-0.1286, -0.0784],
          [-0.1356, -0.0376],
          ...,
          [-0.2246,  0.0251],
          [-0.1107, -0.0863],
          [-0.1766, -0.0806]], grad_fn=<AsStridedBackward0>),
  tensor([[-0.1107, -0.0206],
          [-0.1299, -0.0209],
          [-0.3321, -0.0567],
          ...,
          [-0.1511, -0.0109],
          [-0.1070, -0.0416],
          [-0.1819, -0.0564]], grad_fn=<AsStridedBackward0>),
  tensor([[-0.1732, -0.0333],
          [-0.1442, -0.0279],
          [-0.0953, -0.0660],
          ...,
          [-0.1258, -0.0393],
          [-0.1331, -0.0158],
          [-0.1694, -0.0059]], grad_fn=<AsStridedBackward0>),
  tensor([[-0.2321, -0.0086],
          [-0.1885, -0.0315],
          [-0.1839, -0.0473],
          ...,
          [-0.1706, -0.0226],
          [-0.1609, -0.0700],
          [-0.1546, -0.0237]], grad_fn=<AsStridedBackward0>),
  tensor([[-0.1881, -0.0141],
          [-0.1390, -0.0453],
          [-0.1627, -0.0222],
  

In [29]:
print(aaa[0])

tensor([[-0.1362, -0.0227],
        [-0.1286, -0.0784],
        [-0.1356, -0.0376],
        ...,
        [-0.2246,  0.0251],
        [-0.1107, -0.0863],
        [-0.1766, -0.0806]], grad_fn=<AsStridedBackward0>)


In [30]:

for i, (batch, label) in enumerate(val_loader):
    # print(sparse_model(batch))
    print(sparse_model(batch) - aaa[i])

tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        ...,
        [0., 0.],
        [0., 0.],
        [0., 0.]], grad_fn=<SubBackward0>)
tensor([[ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        ...,
        [ 0.0000e+00,  0.0000e+00],
        [-7.4506e-09,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00]], grad_fn=<SubBackward0>)
tensor([[0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00],
        [7.4506e-09, 0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00]], grad_fn=<SubBackward0>)
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        ...,
        [0., 0.],
        [0., 0.],
        [0., 0.]], grad_fn=<SubBackward0>)
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        ...,
        [0., 0.],
        [0., 0.],
        [0., 0.]], grad_fn=<SubBackward0>)
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        ...,

In [1]:
 import torch
 a = torch.tensor([ 6.9022e+04,  1.6965e+06,  1.1345e+15,  5.5997e+05,  1.1029e+07,
         6.6347e+06,  6.6089e+04,  8.5784e+05,  1.5985e+06,  2.8082e+05,
         6.4803e+06,  3.4058e+04,  8.0083e+04,  1.6350e+08,  5.8992e+04,
         4.5622e+05,  5.6431e+06,  7.8509e+05,  1.4978e+06,  1.3414e+08,
         4.4167e+05,  2.8922e+10,  2.2747e+06,  1.3983e+08,  2.5155e+05,
         6.2198e+04,  9.0821e+04,  2.1378e+05,  1.6509e+05,  1.7527e+06,
         2.1345e+07,  2.1246e+06,  2.7472e+07,  1.3090e+05,  2.1024e+05,
         3.3485e+07,  1.3051e+06,  6.1822e+05,  1.0921e+05,  5.2002e+04,
         3.3412e+04,  5.9295e+05,  1.5586e+05,  2.2295e+06,  1.2166e+06,
         3.2241e+05,  1.7997e+06,  4.5363e+05,  1.6789e+05,  6.9436e+05,
         8.9839e+05,  4.3433e+04,  1.4715e+06,  1.0955e+06,  1.1203e+06,
         7.0053e+05,  2.1473e+06,  7.3261e+07,  2.2886e+06,  2.1422e+04,
        -1.0737e+09,  9.6736e+05,  3.2057e+05,  2.7164e+05,  6.0661e+04,
         3.8687e+08,  1.6877e+15,  1.3528e+06,  1.1994e+05,  1.5147e+06,
         7.8676e+04,  3.3642e+07,  2.0233e+04,  1.1023e+06,  1.8402e+06,
         6.1932e+04,  1.1434e+06,  3.1452e+05,  9.9346e+05,  4.2186e+03,
         1.0330e+05,  1.3591e+05,  4.2181e+06,  2.0733e+06,  4.9743e+05,
         1.1679e+06,  1.9500e+05,  3.6932e+09,  1.2239e+04,  4.7405e+04,
         8.7372e+04,  2.4216e+06,  2.6856e+05,  4.2048e+07,  2.3840e+05,
         3.6185e+04,  1.3817e+06,  7.2997e+05,  9.9176e+05,  1.6360e+05,
         4.2012e+05,  2.6256e+06,  3.1839e+05,  5.8632e+05,  1.9068e+07,
         3.6611e+04,  9.8132e+06,  2.1928e+05,  2.2181e+06,  8.5119e+08,
         2.4173e+04,  2.3829e+06,  7.8264e+04,  1.4169e+05,  6.0406e+05,
         1.1157e+05,  2.4950e+06,  1.1776e+05,  5.2810e+07,  4.5327e+06,
         1.8013e+06,  4.0606e+04,  4.6533e+07,  7.1381e+05, -1.0737e+09,
         4.5376e+05,  2.8130e+04,  5.6960e+07,  4.6980e+05,  6.2472e+06,
         1.0322e+15,  1.0527e+06,  1.4322e+07,  1.3890e+05,  1.9316e+05,
         6.5512e+06,  2.7770e+06,  3.7175e+05,  2.4857e+06,  2.6948e+04,
         4.5100e+03,  1.5158e+09,  1.4491e+06,  2.4482e+04,  2.5775e+05,
         1.5083e+06,  3.2840e+07,  7.5077e+06,  2.0845e+06,  1.6012e+05,
         1.2100e+06,  4.7585e+06,  6.0649e+04,  4.4847e+04,  5.7350e+05,
         1.2643e+06,  2.5817e+05,  2.6936e+06,  4.0905e+05,  9.5338e+04,
         1.0102e+06,  4.0300e+05,  3.8632e+05,  8.9918e+05,  1.7925e+05,
         1.0981e+06,  1.6795e+06,  7.1568e+04,  8.7205e+03,  7.6574e+05,
         9.1398e+05,  7.0512e+05,  1.7622e+05,  9.3289e+06,  1.0948e+06,
         1.9818e+07,  1.7306e+06,  3.2573e+06,  2.1825e+06,  5.6478e+05,
         1.0972e+05,  8.3738e+05,  4.3233e+05,  2.1176e+06,  1.8724e+05,
         1.6533e+05,  1.9857e+06,  9.4836e+04, -1.0737e+09,  5.2462e+04,
         1.7566e+05,  5.7480e+04,  4.0811e+05,  5.1872e+05,  1.7569e+15,
         1.4161e+06,  4.4382e+09,  7.2311e+04,  1.1362e+04,  2.0309e+06,
         2.3171e+06,  4.8009e+04,  2.0915e+05,  1.3529e+05,  6.7281e+05,
         1.9006e+07,  1.7175e+07,  3.4387e+05,  6.9306e+05,  7.1946e+06,
         9.1985e+06,  3.0890e+05,  7.4677e+06,  2.7651e+06,  2.8398e+05,
         1.2612e+07,  8.4599e+04,  5.1609e+05,  1.5134e+04,  3.4960e+06,
         1.2442e+08,  1.6875e+06,  1.3678e+05,  4.0722e+04,  3.0225e+04,
         2.0721e+04,  1.8493e+04,  1.7154e+05,  3.1312e+06,  3.7191e+05,
         1.2708e+08,  5.5560e+04,  4.4528e+05,  3.5258e+05,  3.1834e+05,
         3.2890e+05,  2.2338e+08,  4.7685e+07,  9.1055e+08,  8.1262e+06,
         1.5079e+05,  3.9948e+07,  1.3972e+07,  4.2602e+05,  1.0166e+05,
         3.3161e+05,  5.4861e+04,  1.8846e+04,  4.6496e+04,  1.7156e+04,
         6.1701e+06,  1.1200e+06, -1.0737e+09,  2.7737e+04,  6.1520e+04,
         3.2491e+05,  1.1640e+07,  1.4336e+06,  1.7662e+15,  6.6933e+04,
         4.6476e+07,  2.5762e+05,  2.1320e+05,  1.1276e+05,  2.2650e+06,
         1.9050e+05,  4.1473e+08,  1.7188e+03,  1.5505e+05,  4.0181e+08,
         3.9801e+05,  6.5862e+03,  1.6302e+04,  6.9644e+03,  4.4675e+07,
         2.3290e+05,  1.3027e+06,  1.5770e+05,  1.0570e+05,  2.1958e+06,
         2.4959e+05,  3.3234e+04,  8.5533e+07,  2.9434e+04,  5.3264e+05,
         5.2528e+05,  1.3188e+08,  1.8995e+06,  1.3163e+04,  2.8630e+07,
         3.1971e+05,  1.2480e+05,  1.1379e+05,  6.1255e+07,  5.2971e+05,
         5.4051e+04,  1.0784e+06,  1.1039e+06,  9.3173e+04,  4.8783e+04,
         6.4800e+04,  7.5631e+05,  5.6531e+04,  4.6193e+06,  6.5102e+05,
         3.5427e+05,  2.9322e+06,  2.6021e+06,  1.5071e+06,  9.1090e+05,
         4.7211e+05,  1.9186e+05,  9.6160e+05,  2.3102e+04,  1.3799e+06,
         4.8442e+05, -1.0737e+09,  1.2692e+04,  1.1325e+04,  6.1503e+05,
         5.8174e+04,  2.0043e+06,  1.7225e+15,  2.7384e+07,  1.3660e+07,
         1.7985e+07,  9.0011e+05,  1.0493e+04,  1.4072e+06,  2.8411e+05,
         2.9821e+05,  5.7245e+05,  1.6910e+06,  7.1180e+06,  1.3653e+04,
         8.8879e+03,  3.2289e+06,  4.5018e+06,  8.6668e+08,  3.0394e+06,
         1.5276e+09,  8.4264e+04,  1.1716e+06,  1.2263e+08,  3.0405e+05,
         7.9007e+05,  4.4359e+06,  2.2415e+08,  2.0396e+06,  1.0158e+09,
         7.5569e+06,  3.9150e+04,  2.6479e+06,  5.4375e+05,  3.3545e+05,
         1.5833e+05,  1.2243e+05,  1.9202e+06,  2.5889e+06,  9.1108e+04,
         8.3320e+03,  5.1863e+06,  9.7801e+05,  5.5139e+06,  1.6086e+06,
         5.6013e+07,  1.6992e+07,  6.4646e+07,  1.2294e+05,  1.0803e+09,
         1.0373e+07,  5.5174e+04,  1.0667e+04,  7.7791e+03,  1.0855e+05,
         2.0791e+07,  6.9263e+05,  8.3305e+03,  8.7472e+05,  8.6060e+04,
        -1.0737e+09,  9.9370e+04,  2.1399e+06,  1.3596e+06,  1.8856e+05,
         9.4710e+05,  1.7227e+15,  6.0531e+05,  2.6186e+05,  3.4315e+06,
         5.6631e+05,  2.5760e+07,  1.4048e+05,  7.7653e+04,  3.0004e+08,
         9.8352e+05,  3.1408e+06,  7.0761e+07,  1.9264e+06,  5.5613e+06,
         5.1148e+05,  6.1304e+08,  1.0187e+07,  1.7149e+06,  4.1827e+06,
         2.4639e+05,  7.7710e+04,  1.0114e+05,  1.1714e+05,  3.2073e+05,
         1.0336e+06,  9.6792e+05,  1.4767e+05,  1.3099e+06,  5.1196e+06,
         2.1578e+06,  8.3911e+05,  1.2149e+06,  6.0688e+05,  5.0724e+06,
         2.4198e+05,  2.1475e+05,  1.0435e+07,  5.5377e+04,  7.1657e+04,
         2.1288e+06,  5.5558e+05,  2.8647e+05,  9.5689e+04,  1.2884e+05,
         1.8075e+11,  3.1971e+05,  8.1447e+04,  1.1965e+07,  6.1282e+05,
         2.2299e+06,  3.6986e+04,  2.1502e+05,  2.7610e+06,  5.4668e+04,
         1.3818e+07,  2.8310e+05,  2.4442e+06,  2.1964e+05, -1.0737e+09,
         1.8199e+05,  6.9338e+05,  3.3376e+07,  9.2876e+04,  1.5306e+06,
         1.7357e+15,  1.0608e+07,  6.0825e+06,  2.0652e+08,  1.2264e+04,
         9.8798e+04,  1.9173e+06,  5.7814e+05,  3.1685e+06,  7.8763e+04,
         1.6250e+06,  1.2782e+07,  2.4862e+06,  4.1529e+05,  2.4538e+05,
         6.1900e+05,  1.1938e+05,  1.2544e+06,  1.5332e+05,  5.1154e+05,
         1.4514e+05,  3.0013e+04,  1.5999e+06,  2.1513e+07,  3.3458e+05,
         6.9581e+05,  1.1185e+06,  2.8333e+05,  1.1035e+05,  6.5994e+06,
         1.5657e+05,  2.2601e+05,  4.0251e+05,  7.9971e+05,  3.0209e+04,
         6.4070e+04,  1.4849e+07,  5.7654e+04,  1.7160e+08,  9.9640e+05,
         6.4653e+06,  8.0233e+04,  4.7267e+05,  2.5737e+07,  2.1029e+07,
         2.2886e+05,  1.3321e+07,  3.9171e+09,  1.6287e+06,  2.8563e+07,
         1.6277e+04,  3.7437e+05,  5.8582e+04,  5.9566e+05,  2.4321e+05,
         3.3402e+08,  1.7962e+10,  2.4821e+05, -1.0737e+09,  3.3052e+04,
         2.8007e+04,  2.3257e+05,  2.7617e+05,  7.4281e+06,  1.7167e+15,
         1.8180e+06,  4.9983e+05,  1.0955e+05,  4.0377e+04,  1.2621e+05,
         3.6683e+05,  6.2151e+04,  1.4191e+05,  9.2158e+04,  2.6365e+04,
         2.9282e+06,  1.9357e+08,  1.4001e+05,  1.3656e+07,  1.9241e+06,
         5.4744e+06,  7.7778e+06,  1.7003e+07,  6.5253e+05,  2.1619e+06,
         7.3109e+06,  1.7913e+05,  2.4120e+04,  1.4304e+05,  5.7588e+04,
         7.6912e+05,  5.5723e+05,  7.6723e+05,  2.0018e+04,  2.3947e+05,
         7.7393e+05,  2.1976e+04,  1.4885e+05,  2.0574e+06,  5.3609e+05,
         7.3149e+08,  2.1269e+04,  2.3472e+05,  9.4693e+05,  3.4018e+06,
         6.5202e+05,  2.4331e+05,  4.7398e+05,  4.9529e+05,  5.9774e+06,
         7.8946e+05,  1.6068e+07,  1.0209e+07,  3.0638e+06,  1.2867e+07,
         6.1751e+08,  2.0917e+05,  2.2926e+04,  2.1828e+06,  1.0832e+06,
         2.6399e+12,  6.0480e+04, -1.0737e+09,  7.7761e+04,  6.0114e+05,
         2.0353e+04,  3.2125e+08,  1.4544e+06,  1.2283e+15,  1.4444e+06,
         2.2557e+05,  4.2982e+09,  7.3784e+04,  1.2856e+09,  1.8579e+06,
         1.9405e+07,  3.9299e+06,  2.3553e+05,  3.7936e+05,  1.1273e+08,
         1.5759e+06,  8.2796e+05,  3.9179e+06,  9.4623e+05,  6.4623e+10,
         1.3429e+06,  3.4329e+06,  2.3788e+06,  7.4012e+05,  4.9648e+05,
         5.8744e+05,  2.2171e+05,  2.2823e+09,  3.4151e+07,  1.6501e+06,
         1.7440e+06,  2.8996e+06,  2.0826e+06,  1.0962e+07,  3.0342e+05,
         1.3277e+08,  1.4171e+06,  3.6114e+07,  4.4559e+06,  3.5964e+07,
         1.1699e+06,  6.2368e+06,  2.4476e+06,  5.3971e+06,  1.9589e+07,
         7.0238e+04,  1.8816e+06,  2.6660e+06,  3.3107e+06,  1.4654e+06,
         7.0771e+05,  2.5440e+09,  1.3343e+06,  3.7616e+05,  3.1955e+07,
         7.8976e+05,  2.0539e+07,  5.3754e+07,  2.1626e+08,  2.3734e+06,
         2.2390e+06, -1.0737e+09,  4.8421e+06,  7.9953e+07,  1.9528e+07,
         2.9976e+04,  2.0778e+05,  1.6097e+15,  2.0020e+06,  1.8793e+05,
         4.8044e+05,  4.1244e+05,  6.1046e+05,  4.6489e+04,  2.7531e+07,
         1.9704e+08,  5.8004e+03,  2.2890e+03,  2.6257e+07,  1.0250e+05,
         9.4852e+04,  1.0868e+06,  8.7815e+04,  2.6772e+06,  4.0043e+07,
         3.1864e+04,  3.1928e+04,  2.7496e+07,  1.9132e+06,  5.7716e+05,
         2.0775e+05,  1.4103e+06,  2.9357e+04,  6.0570e+04,  9.4551e+04,
         8.1903e+04,  2.6056e+06,  5.6959e+04,  8.1645e+04,  1.1047e+06,
         2.4534e+04,  6.6156e+05,  2.4232e+05,  7.7093e+04,  2.3483e+04,
         2.8769e+05,  2.0361e+04,  1.2382e+04,  4.1633e+04,  1.1614e+05,
         4.5655e+07,  1.2108e+06,  1.0321e+07,  1.2846e+09,  7.7650e+06,
         8.5084e+05,  7.1255e+06,  2.7860e+06,  1.4759e+05,  5.9931e+06,
         4.9695e+04,  6.1360e+05,  1.8028e+06,  1.5477e+06,  6.6570e+04,
        -1.0737e+09,  1.4013e+05,  1.9480e+03,  7.6531e+03,  1.2203e+05,
         6.3133e+05,  1.6186e+15,  1.9502e+06,  1.6707e+07,  1.4344e+05,
         7.6748e+05,  1.3662e+06,  1.3198e+07,  3.8499e+04,  6.3779e+06,
         7.4082e+04,  5.2442e+05,  2.6990e+07,  7.3961e+04,  1.0166e+06,
         5.9046e+06,  4.8680e+06,  2.1410e+07,  7.4278e+05,  6.2513e+05,
         2.9080e+07,  1.6122e+07,  7.8371e+05,  7.4335e+05,  4.9615e+06,
         7.5239e+05,  8.4515e+04,  1.7720e+05,  4.8655e+07,  1.0344e+06,
         3.1655e+05,  3.5247e+06,  7.4497e+07,  2.4600e+06,  2.0998e+07,
         1.2326e+05,  6.6976e+05,  1.5027e+06,  1.3801e+06,  7.7076e+04,
         1.3554e+09,  2.2258e+05,  1.8279e+06,  1.3333e+06,  2.7559e+07,
         1.4753e+06,  3.8020e+05,  8.2397e+06,  2.1972e+06,  5.2903e+06,
         1.5422e+05,  5.7865e+06,  5.0108e+05,  2.5681e+05,  2.7210e+05,
         4.8094e+05,  2.8220e+05,  3.1210e+06,  1.7270e+06, -1.0737e+09,
         4.1770e+05,  7.1151e+04,  3.6814e+04,  6.0023e+04,  9.8481e+06,
         1.7344e+15,  8.1547e+04,  2.0074e+05,  9.6472e+04,  1.7500e+06,
         4.9126e+06,  2.2706e+05,  1.8629e+06,  3.4255e+05,  1.6954e+05,
         6.2106e+03,  9.6489e+07,  1.3572e+04,  4.3388e+04,  7.9632e+07,
         1.5320e+05,  6.4194e+06,  5.8146e+04,  2.3669e+03,  8.4420e+05,
         5.3070e+05,  5.2215e+05,  3.8956e+05,  1.6046e+05,  8.5742e+04,
         9.0920e+05,  2.3120e+06,  2.3767e+04,  1.6787e+06,  8.5800e+04,
         1.4237e+07,  5.7784e+05,  1.4798e+05,  9.3130e+05,  1.2365e+08,
         4.1065e+04,  1.7777e+06,  4.4280e+04,  2.5382e+04,  8.2928e+04,
         1.4994e+05,  7.9756e+07,  5.1382e+06,  4.2399e+07,  2.5679e+06,
         8.5661e+05,  2.2237e+06,  1.1125e+05,  1.0724e+04,  5.6282e+05,
         1.1352e+06,  1.1549e+06,  2.9292e+05,  1.3065e+06,  7.9756e+05,
         2.0591e+04,  6.0646e+05,  2.1438e+07, -1.0737e+09,  7.9538e+07,
         4.1547e+04,  4.4193e+05,  2.2058e+06,  8.5313e+07,  1.5849e+15,
         2.7540e+05,  1.9075e+08,  1.7983e+04,  8.0696e+05,  2.8090e+08,
         1.3213e+06,  1.5375e+05,  4.5946e+04,  8.3624e+05,  5.4634e+05,
         4.2375e+09,  1.9557e+08,  3.9414e+05,  5.0170e+09,  3.1679e+04,
         2.1343e+07,  5.5658e+05,  1.1819e+07,  2.6015e+04,  9.5560e+05,
         9.4436e+06,  5.5924e+04,  6.1732e+04,  1.7511e+04,  3.1911e+05,
         1.7203e+05,  4.4412e+06,  3.1255e+06,  1.2842e+05,  1.0371e+05,
         2.1808e+05,  1.1485e+05,  1.6526e+05,  4.0723e+05,  3.9758e+05,
         6.8848e+05,  3.0073e+04,  1.5447e+04,  6.0986e+06,  7.8246e+04,
         4.1866e+05,  3.2743e+06,  2.9617e+06,  4.3631e+07,  3.2540e+05,
         3.5259e+07,  5.3797e+06,  8.2419e+05,  4.0758e+08,  6.7773e+05,
         4.5097e+05,  2.3413e+05,  9.9457e+05,  4.0782e+06,  3.1981e+06,
         4.5689e+06,  5.8190e+05, -1.0737e+09,  7.6782e+05,  1.9774e+06,
         3.9805e+04,  5.6290e+05,  2.0419e+05,  1.7399e+15,  1.7597e+05,
         9.6401e+08,  3.3126e+05,  1.1482e+05,  1.0958e+09,  8.1410e+06,
         6.8039e+04,  1.7328e+05,  8.5170e+04,  3.2466e+05,  6.5024e+08,
         2.7624e+04,  1.6515e+05,  5.6193e+05,  3.9692e+06,  1.0877e+07,
         1.2148e+07,  7.0940e+05,  3.2655e+05,  1.6925e+07,  3.3275e+05,
         2.4807e+04,  3.0758e+05,  1.3699e+06,  1.4988e+06,  2.7995e+04,
         1.4021e+06,  3.7506e+07,  4.3797e+04,  3.6877e+06,  6.6619e+04,
         2.3461e+06,  1.3804e+06,  1.2844e+05,  1.7303e+06,  2.0032e+07,
         7.3303e+04,  5.6849e+03,  5.8306e+04,  1.1375e+06,  1.8448e+06,
         6.1352e+06,  9.3832e+07,  2.2952e+05,  7.4949e+05,  9.4896e+05,
         1.7528e+06,  2.3270e+05,  3.3131e+09,  3.1651e+04,  1.7344e+07,
         2.3878e+04,  5.1868e+08,  5.3087e+04,  1.2478e+06,  2.3258e+04,
         2.8566e+03, -1.0737e+09,  1.0594e+06,  1.9259e+05,  6.4458e+04])

In [39]:
# temperature = 1e4
a_log = torch.log(torch.abs(a) + 1e-8)
# a_scaled = a_log / temperature
# print(a_scaled)
softmax = torch.softmax(a_log, dim=-1)

softmax

tensor([2.9008e-12, 7.1300e-11, 4.7680e-02, 2.3534e-11, 4.6352e-10, 2.7884e-10,
        2.7776e-12, 3.6053e-11, 6.7181e-11, 1.1802e-11, 2.7235e-10, 1.4314e-12,
        3.3657e-12, 6.8715e-09, 2.4793e-12, 1.9174e-11, 2.3717e-10, 3.2995e-11,
        6.2949e-11, 5.6376e-09, 1.8562e-11, 1.2155e-06, 9.5600e-11, 5.8767e-09,
        1.0572e-11, 2.6140e-12, 3.8170e-12, 8.9847e-12, 6.9383e-12, 7.3662e-11,
        8.9708e-10, 8.9292e-11, 1.1546e-09, 5.5014e-12, 8.8359e-12, 1.4073e-09,
        5.4850e-11, 2.5982e-11, 4.5898e-12, 2.1855e-12, 1.4042e-12, 2.4920e-11,
        6.5504e-12, 9.3700e-11, 5.1131e-11, 1.3550e-11, 7.5637e-11, 1.9065e-11,
        7.0560e-12, 2.9182e-11, 3.7757e-11, 1.8254e-12, 6.1844e-11, 4.6041e-11,
        4.7084e-11, 2.9442e-11, 9.0246e-11, 3.0790e-09, 9.6184e-11, 9.0031e-13,
        4.5125e-08, 4.0656e-11, 1.3473e-11, 1.1416e-11, 2.5494e-12, 1.6259e-08,
        7.0930e-02, 5.6855e-11, 5.0408e-12, 6.3659e-11, 3.3066e-12, 1.4139e-09,
        8.5034e-13, 4.6327e-11, 7.7339e-

In [33]:
a[a < 0]

tensor([-1.0737e+09, -1.0737e+09, -1.0737e+09, -1.0737e+09, -1.0737e+09,
        -1.0737e+09, -1.0737e+09, -1.0737e+09, -1.0737e+09, -1.0737e+09,
        -1.0737e+09, -1.0737e+09, -1.0737e+09, -1.0737e+09, -1.0737e+09])

In [27]:
a[torch.softmax(a ,dim=-1) > 0]

tensor([1.7662e+15])