In [42]:
import time
import torch 
import numpy as np
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
import matplotlib.pyplot as plt

# Generate sample data
seed = 42
np.random.seed(seed)

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms = True

## start 
tensor_2d = torch.rand(100, 100)
sum_list=[]
for i in range(100): 
    sum_list.append(torch.sum(tensor_2d[:, i]) # sum of ith row 
        + torch.sum(tensor_2d[i, :]) # sum of ith column
        - tensor_2d[i, i]) # delete the duplicate value



eliminate_indices = np.argsort(sum_list)[:10] 
print(eliminate_indices)


keep_indices = torch.tensor(range(100))[~torch.isin(torch.tensor(range(100)), torch.tensor(eliminate_indices))]
tensor_eliminated = tensor_2d[keep_indices][:, keep_indices]
print(tensor_eliminated.shape)


# result_array = np.where(np.arange(100) == eliminate_indices[:, None], 1, 0)

# np.ones(100)-eliminate_indices

[44  2 84 34  5  1 99 29 90 79]
torch.Size([90, 90])


In [78]:

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms = True
tensor_2d = torch.rand(100, 100)

def eliminate_frames(tensor_2d, desired_length):
    '''For least relative node/frame elimination'''
    # Sum all frames' values both horizontally and vertically
    sum_list = torch.sum(tensor_2d, dim=-1) + torch.sum(tensor_2d, dim=-2) - torch.diagonal(tensor_2d) 
    # keep_indices = torch.tensor(range(100))[~torch.isin(torch.tensor(range(100)), eliminate_indices)]
    keep_indices = torch.argsort(sum_list)[-desired_length:] 
    print("sum_list.shape", sum_list.shape)
    
    # Add : later for dimension format
    print("keepindices.shape", keep_indices.shape)
    tensor_eliminated = tensor_2d[keep_indices][:, keep_indices]
    return tensor_eliminated

eliminate_frames(tensor_2d, 90)

sum_list.shape torch.Size([100])
keepindices.shape torch.Size([90])
keepindices.shape torch.Size([90])


tensor([[0.1391, 0.6287, 0.1732,  ..., 0.2334, 0.7894, 0.3108],
        [0.1162, 0.4568, 0.4720,  ..., 0.4970, 0.6782, 0.0195],
        [0.5927, 0.9907, 0.9064,  ..., 0.5642, 0.5358, 0.2363],
        ...,
        [0.8791, 0.3773, 0.2866,  ..., 0.6580, 0.4450, 0.8967],
        [0.5484, 0.0074, 0.3120,  ..., 0.5515, 0.1742, 0.8009],
        [0.8953, 0.1670, 0.5482,  ..., 0.4421, 0.1089, 0.3625]])

In [144]:


def eliminate_frames_4d(tensor_4d, desired_length):
    '''For least relative node/frame elimination'''
    # Sum all frames' values both horizontally and vertically
    sum_list = torch.sum(tensor_4d, dim=-1) + torch.sum(tensor_4d, dim=-2) - torch.diagonal(tensor_4d, dim1=-1, dim2=-2) 
    # keep_indices = torch.tensor(range(100))[~torch.isin(torch.tensor(range(100)), eliminate_indices)]
    keep_indices = torch.argsort(sum_list)[:, :, -desired_length:]
    tensor_eliminated = torch.Tensor(tensor_4d.shape[0],tensor_4d.shape[1], desired_length,desired_length)
    # Add : later for dimension format
    # print("keepindices.shape", keep_indices)
    for i in range(tensor_4d.shape[0]):
        for j in range(tensor_4d.shape[1]):
            # print(keep_indices[i][j])
            keep_indices[i][j], _ = torch.sort(keep_indices[i][j])
            # print(keep_indices[i][j])
            
            tensor_eliminated[i][j] = tensor_4d[i][j][keep_indices[i][j]][:, keep_indices[i][j]]
    # print("tensor_eliminated.shape is ", tensor_eliminated.shape)
    return tensor_eliminated

tensor_4d = torch.rand(2, 2, 5, 5)
print(tensor_4d)
tensor_eliminated= eliminate_frames_4d(tensor_4d, 3)
print(tensor_eliminated)

tensor([[[[0.6567, 0.4544, 0.0739, 0.1156, 0.2893],
          [0.1070, 0.2347, 0.0941, 0.0122, 0.3473],
          [0.8115, 0.8089, 0.1653, 0.4264, 0.0856],
          [0.1094, 0.6539, 0.3113, 0.1869, 0.4624],
          [0.4149, 0.5288, 0.8059, 0.7772, 0.4932]],

         [[0.4842, 0.4247, 0.7233, 0.1519, 0.4604],
          [0.1028, 0.1991, 0.3812, 0.7244, 0.2655],
          [0.8306, 0.3553, 0.5770, 0.6246, 0.3571],
          [0.6791, 0.1092, 0.5737, 0.3948, 0.2869],
          [0.9413, 0.8814, 0.1266, 0.6658, 0.0695]]],


        [[[0.2900, 0.5260, 0.1316, 0.3836, 0.0215],
          [0.1216, 0.1701, 0.9455, 0.4624, 0.5055],
          [0.4775, 0.3073, 0.4956, 0.1203, 0.1343],
          [0.7602, 0.2926, 0.6492, 0.5469, 0.1954],
          [0.6509, 0.5519, 0.3704, 0.7545, 0.6853]],

         [[0.3599, 0.8620, 0.0731, 0.6629, 0.9526],
          [0.8854, 0.9016, 0.1605, 0.0291, 0.4890],
          [0.4297, 0.4537, 0.6170, 0.5924, 0.7929],
          [0.3188, 0.0508, 0.7596, 0.5502, 0.3539],
    