In [88]:
# ['T1','TH','EA','EL','PI','PR','PG','SF','SR','SA']
import torch


In [89]:
model_path = '/Users/monaabd/Desktop/EmotiBit Final Models/3S_32H_3L/model29.pt'
model = torch.load(model_path)
input_size = 10
val_loader = torch.load('/Users/monaabd/Desktop/3 sec dataloader/val3_dataloader.pt')
features_order = ['T1','TH','EA','EL','PI','PR','PG','SF','SR','SA']

In [90]:
max_value = [-float('inf')]*10  # Initialize with the smallest possible number
min_value = [float('inf')]*10   # Initialize with the largest possible number
for data in val_loader:
    # Assuming data is just a batch of tensors; adjust if your DataLoader also returns labels
    for i in range(10):
        batch_max = torch.max(data[0][:,:,i])
        batch_min = torch.min(data[0][:,:,i])
        # Update max and min values
        max_value[i] = max(max_value[i], batch_max.item())
        min_value[i] = min(min_value[i], batch_min.item())

print("Maximum value across all batches:", max_value)
print("Minimum value across all batches:", min_value)
ranges = [a - b for a, b in zip(max_value, min_value)]
print(ranges)

Maximum value across all batches: [8.556126594543457, 7.506616592407227, 7.407805919647217, 7.06660270690918, 4.0715813636779785, 7.478794097900391, 5.979840278625488, 4.783245086669922, 3.9023795127868652, 7.018608093261719]
Minimum value across all batches: [-11.803434371948242, -10.856104850769043, -6.742403030395508, -7.212161064147949, -9.146245956420898, -5.77009391784668, -3.604616403579712, -2.8141942024230957, -1.8856086730957031, -1.8956528902053833]
[20.3595609664917, 18.36272144317627, 14.150208950042725, 14.278763771057129, 13.217827320098877, 13.24888801574707, 9.5844566822052, 7.597439289093018, 5.787988185882568, 8.914260983467102]


In [91]:
def get_importances(model, noise_or_value, val_loader, input_size, ranges):
    model.eval()  # Set the model to evaluation mode
    importances = torch.zeros(input_size)  # input_size is the number of input features
    
    for i, (inputs, targets) in enumerate(val_loader):
        original_output = model(inputs)
        original_output_magnitude = torch.abs(original_output).mean()  # Calculate the mean magnitude of the original output
        
        for j in range(input_size):
            perturbed_input = inputs.clone()
            # Add random noise or a constant value to the j-th feature across all timesteps and batch items
            shift_val = ranges[j] * noise_or_value
            perturbed_input[:, :, j] += shift_val
            perturbed_output = model(perturbed_input)
            
            # Measure the relative change in the output
            change = torch.abs(perturbed_output - original_output)
            relative_change = change.mean() / original_output_magnitude
            importances[j] += relative_change.item()  # Summing up the relative changes

    # Normalize the importances by the number of validation samples and convert to percentage
    importances = (importances / len(val_loader))
    return importances

In [92]:
noise_or_value_tests = [0.01, 0.02, 0.03, 0.04, 0.05]
importances_results = []
sorted_importances = []
sorted_features = []
for val in noise_or_value_tests:
    importances = get_importances(model, val, val_loader, 10, ranges)
    feats = features_order.copy()
    importances_results.append(importances.tolist())
    paired_sorted = sorted(zip(importances.tolist(), feats), reverse=True)
    sorted_imp, sorted_feats = zip(*paired_sorted)
    sorted_imp, sorted_feats = list(sorted_imp), list(sorted_feats)
    sorted_importances.append(sorted_imp)
    sorted_features.append(sorted_feats)

In [93]:
for i in range(5):
    print(sorted_importances[i])

[0.25833773612976074, 0.2245100736618042, 0.05852538347244263, 0.03712789714336395, 0.02556931972503662, 0.02301660366356373, 0.0159677192568779, 0.013350997120141983, 0.013140418566763401, 0.010995127260684967]
[0.4095552861690521, 0.373066782951355, 0.11558151245117188, 0.07173412293195724, 0.05107859894633293, 0.045493338257074356, 0.031567081809043884, 0.02681049145758152, 0.025997398421168327, 0.02242385223507881]
[0.5138230323791504, 0.49912014603614807, 0.16905856132507324, 0.10419774055480957, 0.07623367756605148, 0.06520522385835648, 0.045852988958358765, 0.039734940975904465, 0.03851357474923134, 0.03393643721938133]
[0.599682629108429, 0.5956429243087769, 0.2273334264755249, 0.13868477940559387, 0.10236848890781403, 0.08406738936901093, 0.060597945004701614, 0.052752334624528885, 0.05085563287138939, 0.04582067206501961]
[0.6754988431930542, 0.6541175842285156, 0.2856563627719879, 0.17442677915096283, 0.12865689396858215, 0.10164093226194382, 0.07485121488571167, 0.065108634

In [94]:
for i in range(5):
    print(sorted_features[i])

['EA', 'EL', 'T1', 'TH', 'SA', 'PI', 'SF', 'PR', 'SR', 'PG']
['EA', 'EL', 'T1', 'TH', 'SA', 'PI', 'SF', 'PR', 'SR', 'PG']
['EA', 'EL', 'T1', 'TH', 'SA', 'PI', 'SF', 'PR', 'SR', 'PG']
['EL', 'EA', 'T1', 'TH', 'SA', 'PI', 'SF', 'PR', 'SR', 'PG']
['EL', 'EA', 'T1', 'TH', 'SA', 'PI', 'SF', 'PR', 'SR', 'PG']
