In [1]:
import pandas as pd
import torch

from torch.utils.data import DataLoader

In [2]:
from lstm_scratch import SimpleLSTM
from helpers import CVFConfigForAnalysisDataset

In [None]:
model_name = "lstm_trained_at_2025_04_10_00_11"

graph_name = "star_graph_n15"
# graph_name = "star_graph_n7"

In [4]:
device = "cuda" 

In [5]:
# Model class must be defined somewhere
model = torch.load(f"trained_models/{model_name}.pt", weights_only=False)
model.eval()

SimpleLSTM(
  (lstm): GRU(3, 32, batch_first=True)
  (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (h2o): Linear(in_features=32, out_features=1, bias=True)
)

In [6]:
dataset = CVFConfigForAnalysisDataset(device, graph_name)

Total configs: 448.


In [7]:
data = []
# result_df = pd.DataFrame([], columns=['node', 'rank_effect'])

result_df = pd.DataFrame({
    'node': pd.Series(dtype='int'),
    'rank_effect': pd.Series(dtype='float')
})

with torch.no_grad():
    test_dataloader = DataLoader(dataset, batch_size=1)

    count = 0
    for batch in test_dataloader:
        for i in range(len(batch[0])):
            frm_idx = batch[1][i].item()
            frm_rank = model(batch[0][i].unsqueeze(0))
            for (
                position,
                to_indx,
            ) in dataset.cvf_analysis.possible_perturbed_state_frm(frm_idx):
                to = dataset[to_indx]
                to_rank = model(to[0].unsqueeze(0))
                rank_effect = torch.floor(frm_rank - to_rank + 0.5).item()  # to round off at 0.5
                data.append({'node': position, 'rank_effect': rank_effect})

        temp_df = pd.DataFrame(data, columns=['node', 'rank_effect'])
        data = []
        result_df = pd.concat([result_df, temp_df], ignore_index=True)


result_df.to_csv(f"ml_predictions/{model_name}__{graph_name}__cvf.csv")
