# Get Attention Weights
Scratch Code to get & do analysis on attention weights from Patient-GAT 

In [1]:
# Sets seed for Pytorch Lightning, sklearn, numpy, etc.
SEED = 1

In [2]:
## Standard libraries
import os
import os.path as osp
import numpy as np 
import pandas as pd
from dtw import *

## Imports for plotting
import matplotlib.pyplot as plt

## PyTorch
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
from torchmetrics import F1Score

# Imports for f1, AUC/ROC
from sklearn.metrics import auc, roc_auc_score, RocCurveDisplay

# Torchvision
import torchvision
from torchvision import transforms

# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from torch_geometric.data import Dataset, download_url, Data
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import RandomOverSampler, SMOTE

# torch geometric
try: 
    import torch_geometric
except ModuleNotFoundError:
    # Installing torch geometric packages with specific CUDA+PyTorch version. 
    # See https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html for details 
    TORCH = torch.__version__.split('+')[0]
    CUDA = 'cu' + torch.version.cuda.replace('.','')

    !pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install torch-geometric 
    import torch_geometric
import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../../saved_models/"

# Setting the seed for pl, numpy, sklearn
pl.seed_everything(SEED)
np.random.seed(SEED)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)



Importing the dtw module. When using in academic works please cite:
  T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.
  J. Stat. Soft., doi:10.18637/jss.v031.i07.



Global seed set to 1


cuda:0


In [3]:
# converts weights returned by model into classifications.
# NOTE: returned tensor is cpu-side! If using GPU, you
#       need to switch it back over.
# Returns: LongTensor with classifications
def to_predictions(prob):
    pred = prob.cpu().detach().numpy().flatten()
    for i in range(len(pred)):
        if pred[i] < 0:
            pred[i] = 0
        else:
            pred[i] = 1
    return torch.LongTensor(pred)

## Graph Neural Network

In [4]:
gnn_layer_by_name = {
    "GCN": geom_nn.GCNConv,
    "GAT": geom_nn.GATConv,
    "GraphConv": geom_nn.GraphConv
}

In [5]:
# class from tutorial. defines Graph Neural Network with a 
# specified type of layer.
class GNNModel(nn.Module):
    
    def __init__(self, c_in, c_hidden, c_out, num_layers=2, layer_name="GCN", dp_rate=0.1, **kwargs): # heads=6, 
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of the output features. Usually number of classes in classification
            num_layers - Number of "hidden" graph layers
            layer_name - String of the graph layer to use
            dp_rate - Dropout rate to apply throughout the network
            kwargs - Additional arguments for the graph layer (e.g. number of heads for GAT)
        """
        super().__init__()
        gnn_layer = gnn_layer_by_name[layer_name]
    
        layers = []
        in_channels, out_channels = c_in, c_hidden
        for l_idx in range(num_layers-1):
            in_channels_idx = in_channels # in_channels * heads if l_idx > 0 else in_channels
            layers += [
                gnn_layer(in_channels=in_channels_idx, 
                          out_channels=out_channels,
                          # heads=heads,
                          **kwargs),
                nn.ReLU(inplace=True),
                nn.Dropout(dp_rate)
            ]
            in_channels = c_hidden  # * heads
        layers += [gnn_layer(in_channels=in_channels, 
                             out_channels=c_out,
                             # heads=heads,
                             **kwargs)]
        self.layers = nn.ModuleList(layers)
    
    def forward(self, x, edge_index, return_attention_weights=False):
        alpha = []
        """
        Inputs:
            x - Input features per node
            edge_index - List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)
        """
        for l in self.layers:
            # For graph layers, we need to add the "edge_index" tensor as additional input
            # All PyTorch Geometric graph layer inherit the class "MessagePassing", hence
            # we can simply check the class type.
            if isinstance(l, geom_nn.MessagePassing):
                if (return_attention_weights):
                    x, temp = l(x, edge_index, return_attention_weights=return_attention_weights)
                    alpha += [temp]
                else:
                    x = l(x, edge_index)
            else:
                x = l(x)
        return x, alpha if return_attention_weights else x

In [6]:
# class from tutorial. defomes baseline MLP model.
class MLPModel(nn.Module):
    
    def __init__(self, c_in, c_hidden, c_out, num_layers=2, dp_rate=0.1):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of the output features. Usually number of classes in classification
            num_layers - Number of hidden layers
            dp_rate - Dropout rate to apply throughout the network
        """
        super().__init__()
        layers = []
        in_channels, out_channels = c_in, c_hidden
        for l_idx in range(num_layers-1):
            layers += [
                nn.Linear(in_channels, out_channels),
                nn.ReLU(inplace=True),
                nn.Dropout(dp_rate)
            ]
            in_channels = c_hidden
        layers += [nn.Linear(in_channels, c_out)]
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x, *args, **kwargs):
        """
        Inputs:
            x - Input features per node
        """
        return self.layers(x)

In [7]:
# class from tutorial. combines both classes above
# to streamline pytorch lightning train/eval.
class NodeLevelGNN(pl.LightningModule):
    
    def __init__(self, model_name, **model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()
        self.num_epochs = 0
        if model_name == "MLP":
            self.model = MLPModel(**model_kwargs)
        else:
            self.model = GNNModel(**model_kwargs)
        self.loss_module = nn.BCEWithLogitsLoss(reduction='mean') # nn.CrossEntropyLoss() 

    def forward(self, data, mode="train"):
        x, edge_index = data.x, data.edge_index
        x = self.model(x, edge_index)
        
        # Only calculate the loss on the nodes corresponding to the mask
        if mode == "train":
            mask = data.train_mask
        elif mode == "val":
            mask = data.val_mask
        elif mode == "test":
            mask = data.test_mask
        else:
            assert False, f"Unknown forward mode: {mode}"
        # print(x[mask])
        loss = self.loss_module(x[mask], data.y[mask].unsqueeze(1).float())
        if (self.num_epochs % 1000 == 0):
            print('epoch', self.num_epochs // 2, ': loss =', loss.cpu().numpy())
        self.num_epochs += 1
        pred = to_predictions(x[mask])
        # move predictions to gpu if y is also in gpu
        if data.y.get_device() != -1:
            pred = pred.to(device=device)
        acc = (pred == data.y[mask]).sum().float() / mask.sum() # (x[mask].argmax(dim=-1) == data.y[mask]).sum().float() / mask.sum()
        return loss, acc

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),
                             lr=l_rate, 
                             weight_decay=w_decay)
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, acc = self.forward(batch, mode="train")
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        return loss

    def validation_step(self, batch, batch_idx):
        _, acc = self.forward(batch, mode="val")
        self.log('val_acc', acc)

    def test_step(self, batch, batch_idx):
        _, acc = self.forward(batch, mode="test")
        self.log('test_acc', acc)

In [8]:
# function from tutorial. trains and evaluates model.
def train_node_classifier(model_name, dataset, **model_kwargs):
    pl.seed_everything(SEED)
    node_data_loader = geom_data.DataLoader(dataset, batch_size=1)
    
    # Create a PyTorch Lightning trainer with the generation callback
    root_dir = os.path.join(CHECKPOINT_PATH, "NodeLevel" + model_name)
    os.makedirs(root_dir, exist_ok=True)
    trainer = pl.Trainer(default_root_dir=root_dir,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc")],
                         gpus=1 if str(device).startswith("cuda") else 0,
                         max_epochs=max_epochs,
                         progress_bar_refresh_rate=0) # 0 because epoch size is 1
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    model = NodeLevelGNN(model_name=model_name, c_in=dataset.num_node_features, c_out=dataset.num_classes, **model_kwargs)
    trainer.fit(model, node_data_loader, node_data_loader)
    model = NodeLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    
    # Test best model on the test set
    test_result = trainer.test(model, node_data_loader, verbose=False)
    batch = next(iter(node_data_loader))
    batch = batch.to(model.device)
    _, train_acc = model.forward(batch, mode="train")
    _, val_acc = model.forward(batch, mode="val")
    result = {"train": train_acc,
              "val": val_acc,
              "test": test_result[0]['test_acc']}
    return model, result

Finally, we can train our models. First, let's train the simple MLP:

In [9]:
# function from tutorial. small function for printing the test scores
def print_results(result_dict):
    if "train" in result_dict:
        print(f"Train accuracy: {(100.0*result_dict['train']):4.2f}%")
    if "val" in result_dict:
        print(f"Val accuracy:   {(100.0*result_dict['val']):4.2f}%")
    print(f"Test accuracy:  {(100.0*result_dict['test']):4.2f}%")

In [10]:
import matplotlib.pyplot as plt

from sklearn import svm
from sklearn.metrics import auc
from sklearn.metrics import RocCurveDisplay

# creates graph of ROC curve based on k-fold list of ground truths
# and 1-dimensional array of probability. lightly modified sample code from 
# https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html
def ROC_kfoldCV(y_tests, prob_lists, fig_title="Receiver operating characteristic"):
    tprs = []
    aucs = []
    mean_fpr = np.linspace(0, 1, 100)
    fig, ax = plt.subplots(figsize=(7, 5.5), dpi=80)

    for i in range(len(y_tests)):
        viz = RocCurveDisplay.from_predictions(
            y_tests[i],
            prob_lists[i],
            name="ROC fold {}".format(i),
            alpha=0.3,
            lw=1,
            ax=ax,
        )
        interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr)
        interp_tpr[0] = 0.0
        tprs.append(interp_tpr)
        aucs.append(viz.roc_auc)

    ax.plot([0, 1], [0, 1], linestyle="--", lw=2, color="r", label="Chance", alpha=0.8)

    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = auc(mean_fpr, mean_tpr)
    std_auc = np.std(aucs)
    ax.plot(
        mean_fpr,
        mean_tpr,
        color="b",
        label=r"Mean ROC (AUC = %0.2f $\pm$ %0.2f)" % (mean_auc, std_auc),
        lw=2,
        alpha=0.8,
    )
    print('mean AUC:', mean_auc)
    
    std_tpr = np.std(tprs, axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
    ax.fill_between(
        mean_fpr,
        tprs_lower,
        tprs_upper,
        color="grey",
        alpha=0.2,
        label=r"$\pm$ 1 std. dev.",
    )

    ax.set(
        xlim=[-0.05, 1.05],
        ylim=[-0.05, 1.05],
        title=fig_title,
    )
    ax.legend(loc="lower right")
    plt.show()

In [11]:
# creates lightly modified sample code from 
# https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html
def get_mean_curve(y_tests, prob_lists, fig_title="Receiver operating characteristic", save_folder=None):
    tprs = []
    aucs = []
    mean_fpr = np.linspace(0, 1, 100)
    fig, ax = plt.subplots(figsize=(7, 5.5), dpi=80)

    for i in range(len(y_tests)):
        viz = RocCurveDisplay.from_predictions(
            y_tests[i],
            prob_lists[i],
            name="ROC fold {}".format(i),
            alpha=0.3,
            lw=1,
            # ax=ax,
        )
        interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr)
        interp_tpr[0] = 0.0
        tprs.append(interp_tpr)
        aucs.append(viz.roc_auc)

    ax.plot([0, 1], [0, 1], linestyle="--", lw=2, color="r", label="Chance", alpha=0.8)

    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = auc(mean_fpr, mean_tpr)
    print(type(mean_auc))
    std_auc = np.std(aucs)
    ax.plot(
        mean_fpr,
        mean_tpr,
        color="b",
        label=r"Mean ROC (AUC = %0.2f $\pm$ %0.2f)" % (mean_auc, std_auc),
        lw=2,
        alpha=0.8,
    )
    print('mean AUC:', mean_auc)
    
    std_tpr = np.std(tprs, axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
#     ax.fill_between(
#         mean_fpr,
#         tprs_lower,
#         tprs_upper,
#         color="grey",
#         alpha=0.2,
#         label=r"$\pm$ 1 std. dev.",
#     )

    ax.set(
        xlim=[-0.05, 1.05],
        ylim=[-0.05, 1.05],
        title=fig_title,
    )
    ax.legend(loc="lower right")
    plt.show()
    
    # save mean_auc, std_auc, mean_tpr, & mean_fpr into file
    if save_folder != None:
        np.savetxt(save_folder + '/mean_fpr.csv', mean_fpr, delimiter=',')
        np.savetxt(save_folder + '/mean_tpr.csv', mean_tpr, delimiter=',')
        np.savetxt(save_folder + '/mean_auc.csv', np.array([mean_auc]), delimiter=',')
        np.savetxt(save_folder + '/std_auc.csv', np.array([std_auc]), delimiter=',')

### Define dataset for storing tenfold CV Data objects

In [12]:
# Dataset class definition that loads the 10 data objects created 
# by tenfold_cross_val on a given root directory. 
# Cannot process / generate its own data: you must first run 
# tenfold_cross_val on the given root directory in the given notebook 
# in 'src/Data Object Generation' with the raw data files x_pd.csv, 
# y_pd.csv, and imputed_lab.csv. 
class TenfoldDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.num_classes = 1
        

    @property
    def raw_file_names(self):
        return ['x_pd.csv', 'y_pd.csv', 'imputed_lab.csv']

    @property
    def processed_file_names(self):
        return ['data_0.pt', 'data_1.pt', 'data_2.pt', 'data_3.pt', 'data_4.pt', 
                'data_5.pt', 'data_6.pt', 'data_7.pt', 'data_8.pt', 'data_9.pt']

    def download(self):
        pass

    def process(self):
        raise Exception('Please use tenfold_cross_val to generate the .pt files for this dataset.')
    
    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [13]:
# 'filters' out edge_index by a given mask, removing nodes that are not in the given mask
# returns: tensor representing 2 x (num edges) array for filtered edge index
def filter_edges(edge_index, mask):
    edge_index_np = edge_index.numpy()
    edge_index_filtered = [[], []]
    # only copy nodes to edge_index_filtered that are in mask
    for col in range(len(edge_index_np[0])):
        if mask[edge_index_np[0][col]] and mask[edge_index_np[1][col]]:
            edge_index_filtered[0] +=  [edge_index_np[0][col]]
            edge_index_filtered[1] += [edge_index_np[1][col]]
    # create array to convert indices to new, filtered indices
    index_change = np.zeros(len(mask))
    new_index = 0
    for i in range(len(index_change)):
        if (mask[i]):
            index_change[i] = new_index
            new_index += 1
    # convert indices to filtered indices
    for col in range(len(edge_index_filtered[0])):
        edge_index_filtered[0][col] = index_change[edge_index_filtered[0][col]]
        edge_index_filtered[1][col] = index_change[edge_index_filtered[1][col]]
    return torch.LongTensor(edge_index_filtered)

In [14]:
model = torch.load('./Model Train and Eval/node_gnn_model.pt')
data = torch.load('../data/lab-oversampled-mixed5,0.65/processed/data_0.pt')
result, alpha = model.model(data.x[data.test_mask], filter_edges(data.edge_index, data.test_mask), return_attention_weights=True)
# result = filter_edges(data.edge_index, data.test_mask)

  return torch.LongTensor(edge_index_filtered)


In [35]:
np.set_printoptions(threshold=sys.maxsize)


alpha[1]

(tensor([[  3,   4,   5,   8,   8,   9,  10,  10,  11,  17,  18,  19,  23,  25,
           28,  30,  31,  33,  38,  39,  40,  40,  41,  42,  43,  43,  44,  46,
           48,  48,  49,  51,  56,  57,  59,  59,  61,  62,  63,  64,  68,  70,
           71,  74,  74,  75,  76,  78,  80,  80,  90,  92,  95,  97, 101, 102,
          104, 105, 105, 106, 108, 111, 115, 116, 116, 119, 120, 123, 123, 126,
          126, 130, 132, 135, 140, 141, 142, 144, 144, 146, 148, 151, 154, 156,
          156, 157, 159, 161, 164, 166, 166, 169, 171,   0,   1,   2,   3,   4,
            5,   6,   7,   8,   9,  10,  11,  12,  13,  14,  15,  16,  17,  18,
           19,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,  30,  31,  32,
           33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,
           47,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,  60,
           61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,
           75,  76,  77,  78,  79,  80, 

In [16]:
data.edge_index.numpy()

array([[   0,    0,    0,    0,    0,    1,    1,    1,    1,    1,    2,
           2,    2,    2,    2,    3,    3,    3,    3,    3,    4,    4,
           4,    4,    4,    5,    5,    5,    5,    5,    6,    6,    6,
           6,    6,    7,    7,    7,    7,    7,    8,    8,    8,    8,
           8,    9,    9,    9,    9,    9,   10,   10,   10,   10,   10,
          11,   11,   11,   11,   11,   12,   12,   12,   12,   12,   13,
          13,   13,   13,   13,   14,   14,   14,   14,   14,   15,   15,
          15,   15,   15,   16,   16,   16,   16,   16,   17,   17,   17,
          17,   17,   18,   18,   18,   18,   18,   19,   19,   19,   19,
          19,   21,   21,   21,   21,   21,   22,   22,   22,   22,   22,
          23,   23,   23,   23,   23,   24,   24,   24,   24,   24,   25,
          25,   25,   25,   25,   26,   26,   26,   26,   26,   27,   27,
          27,   27,   27,   28,   28,   28,   28,   28,   29,   29,   29,
          29,   29,   30,   30,   30, 

In [22]:
classifications = np.zeros(len(result), dtype=np.int64)
for i in range(len(result)):
    if (result[i][0] >= 0):
        classifications[i] = 1
classifications

array([0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0,
       1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1,
       0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1,
       0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
       0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1,
       1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1,
       1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1,
       1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1],
      dtype=int64)

In [21]:
data.y[data.test_mask].numpy()

array([0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0,
       1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0,
       0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1,
       0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1,
       1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
       0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
      dtype=int64)

In [18]:
x_df = pd.DataFrame(data.x.numpy())
for index, row in x_df.iterrows():
    print(index, ' : ', row)

0  :  0      0.000000
1      0.000000
2      0.107900
3      0.428571
4      0.000000
         ...   
119    0.250688
120    0.329854
121    0.307292
122    0.329854
123    0.369836
Name: 0, Length: 124, dtype: float32
1  :  0      0.000000
1      0.000000
2      0.254335
3      0.571429
4      0.000000
         ...   
119    0.250856
120    0.359840
121    0.307290
122    0.379831
123    0.044905
Name: 1, Length: 124, dtype: float32
2  :  0      0.000000
1      0.000000
2      0.242775
3      0.571429
4      1.000000
         ...   
119    0.297011
120    0.283374
121    0.234271
122    0.255796
123    0.337628
Name: 2, Length: 124, dtype: float32
3  :  0      0.000000
1      0.000000
2      0.148362
3      0.142857
4      1.000000
         ...   
119    0.314860
120    0.307174
121    0.307490
122    0.384829
123    0.044905
Name: 3, Length: 124, dtype: float32
4  :  0      0.000000
1      0.000000
2      0.105973
3      0.142857
4      0.000000
         ...   
119    0.270936
120   

Name: 165, Length: 124, dtype: float32
166  :  0      0.000000
1      0.000000
2      0.227360
3      0.285714
4      1.000000
         ...   
119    0.250923
120    0.341569
121    0.307490
122    0.344847
123    0.044905
Name: 166, Length: 124, dtype: float32
167  :  0      0.000000
1      0.000000
2      0.418112
3      0.285714
4      0.000000
         ...   
119    0.179920
120    0.307039
121    0.234167
122    0.303842
123    0.027365
Name: 167, Length: 124, dtype: float32
168  :  0      0.000000
1      0.285714
2      0.186898
3      0.285714
4      0.000000
         ...   
119    0.239636
120    0.334851
121    0.307290
122    0.339849
123    0.044905
Name: 168, Length: 124, dtype: float32
169  :  0      0.000000
1      0.000000
2      0.285164
3      0.571429
4      0.000000
         ...   
119    0.202410
120    0.202410
121    0.204909
122    0.202410
123    0.161553
Name: 169, Length: 124, dtype: float32
170  :  0      0.000000
1      0.000000
2      0.163776
3      0.1428

Name: 355, Length: 124, dtype: float32
356  :  0      1.000000
1      0.000000
2      0.310212
3      0.571429
4      0.000000
         ...   
119    0.319858
120    0.381955
121    0.307492
122    0.394825
123    0.403472
Name: 356, Length: 124, dtype: float32
357  :  0      0.000000
1      0.000000
2      0.123314
3      0.571429
4      1.000000
         ...   
119    0.263099
120    0.306984
121    0.304865
122    0.252782
123    0.272379
Name: 357, Length: 124, dtype: float32
358  :  0      0.000000
1      0.285714
2      0.065511
3      0.285714
4      1.000000
         ...   
119    0.250610
120    0.306946
121    0.262383
122    0.319858
123    0.044905
Name: 358, Length: 124, dtype: float32
359  :  0      0.000000
1      0.000000
2      0.075145
3      0.428571
4      1.000000
         ...   
119    0.250922
120    0.307174
121    0.307489
122    0.253012
123    0.045172
Name: 359, Length: 124, dtype: float32
360  :  0      0.000000
1      0.000000
2      0.138728
3      0.5714

Name: 567, Length: 124, dtype: float32
568  :  0      0.000000
1      0.000000
2      0.125241
3      0.428571
4      1.000000
         ...   
119    0.250923
120    0.307174
121    0.307490
122    0.253012
123    0.045172
Name: 568, Length: 124, dtype: float32
569  :  0      1.000000
1      0.000000
2      0.144509
3      0.571429
4      0.000000
         ...   
119    0.250923
120    0.307175
121    0.307492
122    0.253035
123    0.302366
Name: 569, Length: 124, dtype: float32
570  :  0      0.000000
1      0.000000
2      0.206166
3      0.571429
4      1.000000
         ...   
119    0.244891
120    0.280759
121    0.297708
122    0.319858
123    0.075720
Name: 570, Length: 124, dtype: float32
571  :  0      1.000000
1      0.000000
2      0.211946
3      0.714286
4      0.000000
         ...   
119    0.244324
120    0.256792
121    0.384829
122    0.349845
123    0.339849
Name: 571, Length: 124, dtype: float32
572  :  0      0.000000
1      0.000000
2      0.254335
3      0.8571

Name: 758, Length: 124, dtype: float32
759  :  0      0.000000
1      0.050325
2      0.541031
3      0.420453
4      1.000000
         ...   
119    0.307222
120    0.265469
121    0.318468
122    0.287940
123    0.293608
Name: 759, Length: 124, dtype: float32
760  :  0      0.000000
1      0.000000
2      0.214671
3      0.714286
4      0.585579
         ...   
119    0.289629
120    0.415952
121    0.447629
122    0.355789
123    0.330788
Name: 760, Length: 124, dtype: float32
761  :  0      0.000000
1      0.000000
2      0.358250
3      0.450750
4      0.000000
         ...   
119    0.253139
120    0.307037
121    0.341509
122    0.251141
123    0.045125
Name: 761, Length: 124, dtype: float32
762  :  0      0.000000
1      0.000000
2      0.332561
3      0.147626
4      0.000000
         ...   
119    0.377971
120    0.354748
121    0.380862
122    0.340466
123    0.359210
Name: 762, Length: 124, dtype: float32
763  :  0      0.000000
1      0.000000
2      0.291255
3      0.4571

967  :  0      0.000000
1      0.100652
2      0.406819
3      0.470777
4      0.000000
         ...   
119    0.287584
120    0.291727
121    0.400582
122    0.252105
123    0.282376
Name: 967, Length: 124, dtype: float32
968  :  0      0.000000
1      0.000000
2      0.463410
3      0.235591
4      0.000000
         ...   
119    0.250760
120    0.299139
121    0.309761
122    0.270484
123    0.065002
Name: 968, Length: 124, dtype: float32
969  :  0      0.791209
1      0.000000
2      0.359758
3      0.029827
4      0.000000
         ...   
119    0.289571
120    0.393489
121    0.444117
122    0.315171
123    0.361762
Name: 969, Length: 124, dtype: float32
970  :  0      0.000000
1      0.000000
2      0.241260
3      0.036504
4      0.744471
         ...   
119    0.346456
120    0.307033
121    0.285537
122    0.284705
123    0.044973
Name: 970, Length: 124, dtype: float32
971  :  0      0.022282
1      0.000000
2      0.308156
3      0.292081
4      0.000000
         ...   
119 

Name: 1195, Length: 124, dtype: float32
1196  :  0      0.000000
1      0.142857
2      0.281310
3      0.571429
4      0.000000
         ...   
119    0.354842
120    0.306984
121    0.384829
122    0.252782
123    0.276877
Name: 1196, Length: 124, dtype: float32
1197  :  0      1.000000
1      0.571429
2      0.242775
3      0.142857
4      0.000000
         ...   
119    0.250756
120    0.338141
121    0.307492
122    0.253035
123    0.414816
Name: 1197, Length: 124, dtype: float32
1198  :  0      0.000000
1      0.142857
2      0.152216
3      0.000000
4      0.000000
         ...   
119    0.250922
120    0.307175
121    0.307492
122    0.242475
123    0.299867
Name: 1198, Length: 124, dtype: float32
1199  :  0      0.000000
1      0.000000
2      0.441233
3      0.571429
4      0.000000
         ...   
119    0.394825
120    0.245808
121    0.314860
122    0.332352
123    0.344847
Name: 1199, Length: 124, dtype: float32
1200  :  0      0.000000
1      0.000000
2      0.134875
3  