In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from scipy.spatial.distance import cdist
from tqdm import tqdm

from PSO import PSO
from utils import sammon_error

In [2]:
df = pd.read_csv('radon_metrics.csv').drop('Unnamed: 0', axis=1)

In [3]:
df

Unnamed: 0,HCPL,HDIF,HEFF,HNDB,HPL,HPV,HTRP,HVOL,MI
0,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,24.545304
1,0.000000,0.5,1.000000,0.000667,2.0,2.0,0.055556,2.000000,24.545304
2,44.828921,2.5,156.275624,0.020837,16.0,15.0,8.681979,62.510250,24.545304
3,0.000000,0.5,1.000000,0.000667,2.0,2.0,0.055556,2.000000,24.545304
4,6.754888,1.0,11.609640,0.003870,5.0,5.0,0.644980,11.609640,24.545304
...,...,...,...,...,...,...,...,...,...
50816,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,72.224158
50817,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,72.224158
50818,16.364528,1.5,36.000000,0.008000,8.0,8.0,2.000000,24.000000,79.069071
50819,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,77.709124


In [4]:
data = torch.from_numpy(df.to_numpy())
data = data[:5000, :]
data.shape

torch.Size([5000, 9])

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [6]:
cache = {"Sammon": {}}

In [7]:
feat_idx, error_log = PSO(data, sammon_error, 5, device=device, batch_size=5000, batch_threshold=20000, cache=cache["Sammon"])

100%|██████████████████████████████████████████████████████████████████████████████████| 30/30 [00:02<00:00, 10.69it/s]


In [8]:
df.columns[feat_idx]

array(['HCPL', 'HEFF', 'HVOL', 'MI', 'HPL'], dtype=object)

In [9]:
error_log

[0.00034244993003085256,
 0.00024594453861936927,
 6.578013562830165e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05,
 1.4396099686564412e-05]

In [10]:
cache

{'Sammon': {419: 0.5578544872663692,
  426: 0.590123813778799,
  342: 0.00427335723395819,
  433: 0.553269569723618,
  218: 0.629381193614853,
  62: 0.05007400823778433,
  242: 0.6261443581268455,
  484: 0.00034244994298180623,
  313: 0.7172641820335784,
  151: 0.04048992698654887,
  454: 0.00044414890933690884,
  186: 0.639599958603808,
  369: 0.6918277899334592,
  465: 0.5431633619384753,
  179: 0.6049099058584627,
  348: 0.004282864410663792,
  460: 0.0004476054479980257,
  115: 0.7576614315173906,
  299: 0.7292772222013919,
  91: 0.7635963093333871,
  358: 0.004544971388056145,
  362: 0.7889076045686535,
  118: 0.04995442065657446,
  341: 0.002392135747698292,
  310: 0.004025990601952433,
  307: 0.7171230952216927,
  496: 0.5715028405693255,
  482: 0.5766989392003954,
  468: 0.0002982357338148633,
  229: 0.040658021143709415,
  436: 0.0002459445364574924,
  244: 0.04121232732742903,
  199: 0.04141589060256334,
  227: 0.5986901942257651,
  481: 0.545930901194443,
  213: 0.0404966207