In [None]:
import matplotlib.pyplot as plt
import os
import numpy as np
import sys
from tqdm import tqdm
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import shutil

def path_link(path:str):
    sys.path.append(path)

path_link('/master/code/lib')

import dataLoading as dl
from measure import plotStdMessage
import simulation_v2 as sim
import features as ft

from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score


In [3]:
##############################
# paths 

#PATH = 'master/code/runs1'
#PATH = 'master/code/runs2'
#PATH = ['/master/code/analyze_models/exps/test_new_activation_0']
PATH = ['/master/code/analyze_models/exps/exp-test']

#DISPLAY_PATH = 'master/code/display_l1'
#DISPLAY_PATH = '/master/code/display_l1_2'
#DISPLAY_PATH = ['/master/code/analyze_models/display2/test_new_activation_0']
DISPLAY_PATH = ['/master/code/analyze_models/display2/exp-test']

MODEL_PATH = '/master/code/models/mod_base'


##############################

In [4]:
class Parameters():
    def __init__(self):
        self.dt = 0.001
        self.v0 = 60
        self.k = 70
        self.epsilon = 0.5
        self.tau = 3.5
        self.R = 1
        self.N = 200
        self.boundary = 100

        self.nbStep = 150

In [117]:
def findModels(path):
    pathLists = []
    for root, dirs, files in tqdm(os.walk(path)):
            for file in files:
                  
                  if file.endswith('.pt'):
                        pathLists.append(os.path.join(root, file))


    return pathLists


def delete_wandb_dirs(start_path):
    for root, dirs, files in os.walk(start_path, topdown=False):
        for dir_name in dirs:
            if dir_name == "wandb":
                dir_path = os.path.join(root, dir_name)
                print(f"Deleting: {dir_path}")
                shutil.rmtree(dir_path)



def getName(path):
    run_name = path.split('/')[-3]

    model_type = path.split('/')[-1].split('.')[0]

    if 'best' in model_type:
        model_type = '_best'

    else:
         model_type = '_latest'

    name = run_name + model_type

    return name


def loadModel(modelName:str, inputShape:int = 8, edges_shape = 5, path = None):
    """ 
    Function to import the model

    Args:
    -----
        - `modelName`: name of the model
        - `inputShape`: inout shape of the NN
        - `edges_shape`: edge shape of the NN
        - `path`: path where the models are
    """

    sys.path.append(path)

    loadFun = __import__(f'{modelName}', fromlist = ('loadNetwork'))

    model = loadFun.loadNetwork(inputShape, edges_shape)

    return model



def getModelName(key):

    name = ''

    #if 'simplest' in key:
    #    name = name + 'simplest'
    
    name = name + 'simplest'

    ## other possibilities

    if 'no-dropout' in key:
        name = name + '_no-dropout'
    else:
        name = name + '_dropout'

    if 'no-encoder' in key:
        name = name + '_no-encoder'
    else:
        name = name + '_encoder'

    if 'relu' in key:
        name = name + '-relu'

    return name


##############################
# Messages
#TODO linear 


def getOrderedVals(attribute, message, bins):
    bin_edges = np.linspace(np.min(attribute), np.max(attribute), bins + 1)
    bin_indices = np.digitize(attribute, bin_edges) - 1

    means = np.zeros(bins)
    stds = np.zeros(bins)

    for i in range(bins):
        bin_mask = bin_indices == i
        if np.any(bin_mask):
            means[i] = np.mean(message[bin_mask])
            stds[i] = np.std(message[bin_mask])
        else:
            means[i] = np.nan
            stds[i] = np.nan

    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2


    return bin_centers, means, stds


def findIndices(message, nb = 5):
    stdv=plotStdMessage(message)
    plt.close()

    inds = np.argsort(stdv)
    # change of the order
    return np.flip(inds[-nb:])


def plotMessage(graph, messages, i_attr, id_message):
    edges = None
    
    for i in range(len(graph)):
        if edges is None:
            edges = graph[i].edge_attr.cpu().detach().numpy().copy()
        else:
            edges = np.vstack((edges, graph[i].edge_attr.cpu().detach().numpy().copy()))

    x, mean, std = getOrderedVals(edges[:, i_attr], messages[:, id_message], 100)
    plt.scatter(edges[:, i_attr], messages[:, id_message])


def calculate_interaction(dist, rij, k, epsilon, radii = 1.0):
    """
    Given the vectors ri and rj, compute the force between them
    """


    r = dist

    # Combined radius of both particles (assume unit radii for now)
    #bij = 2.0                       # Ri + Rj 
    bij = radii + radii

    if r < bij*(1 + epsilon):
        force = k*(r - bij)*rij/r  
    elif r < bij*(1 + 2*epsilon):
        force = -k*(r - bij - 2*epsilon*bij)*rij/r
    else:
        force = torch.tensor([0.0, 0.0])
    return force

In [87]:
def getData(params):
    N = params.N
    v0 = params.v0
    k = params.k
    eps = params.epsilon
    tau = params.tau
    R = params.R
    dt = params.dt
    nbStep = params.nbStep
    boundary = params.boundary

    lim = 0.85 * 100

    xPos = np.linspace(-lim, lim, 10)
    yPos = np.linspace(-lim, lim, 10)
    gridX, gridY = np.meshgrid(xPos, yPos)
    delta = np.random.uniform(0, 7, gridX.shape + (2,))

    gridX2 = gridX + delta[:, :, 0]
    gridY2 = gridY + delta[:, :, 1]

    pos = np.stack([gridX.ravel(), gridY.ravel()], axis=1)
    pos_perturbed = np.stack([gridX2.ravel(), gridY2.ravel()], axis=1)

    pos = np.concatenate([pos, pos_perturbed], axis=0)

    angles = np.random.rand(pos.shape[0]) * 2 * np.pi

    data = sim.compute_main(N, (v0, tau, k, eps), boundary, T = nbStep, initialization = (pos, angles), dt = dt)[0]

    x, y, attr, inds = ft.processSimulation(data)

    dataList = []

    for i in range(len(x)):
        g = Data(x = x[i][:, 2:], y = y[i], edge_attr = attr[i], edge_index = inds[i])
        dataList.append(g)

    return dataList

In [88]:
def getForces(graph, k, epsilon):
    inds = graph.edge_index
    #print(inds)

    messages = []

    for i in range(inds.shape[1]):
        messages.extend(calculate_interaction(graph.edge_attr[i, 0],
                                              torch.tensor([graph.edge_attr[i, 0] * graph.edge_attr[i, 1], graph.edge_attr[i, 0] * graph.edge_attr[i, 2]]),
                                              k = k,
                                              epsilon=epsilon
                                              ).cpu().detach().numpy().tolist())
        
    return messages


def getGroundTruth(data, k, epsilon):

    res = []

    for graph in data:

        res.extend(getForces(graph, k, epsilon))


    return np.array(res)



def getPrediction(model, data, inds = None):
    res = None

    for graph in data:

        if inds is None:
            messages = model.message(graph).cpu().detach().numpy()
        else:   
            messages = model.message(graph).cpu().detach().numpy()[:, inds]
        
        
        if res is None:
            res = messages
        else:
            res = np.vstack((res, messages))

    return res

In [89]:
def get_linear_predictions(X, Y):

    resMod = []
    pred = []
    for i in range(Y.shape[1]):
        model = LinearRegression()
        model.fit(X, Y[:, i])

        y_pred = model.predict(X)

        pred.append(y_pred)
        resMod.append(model)


    return resMod, np.column_stack(pred)

In [107]:
def plotMessageEvol(modelList, pathPlot = DISPLAY_PATH):

    # create folders

    out_poss = ['distance', 'cosine', 'sine', 'radius_1', 'radius_2']
    number_messages = 5

    params = Parameters()

    r2List = []

    data = getData(params)

    gtMessage = getGroundTruth(data, params.k, params.epsilon)
    
    nb = 0
    nb_sim = 0
    with torch.no_grad():
        # get the data for the different models

        for model_path in tqdm(modelList):
            nb_sim += 1

            # path for the experiment
            p_exp = os.path.join(pathPlot, f'exp_{nb_sim}')
            if not os.path.exists(p_exp):
                os.makedirs(p_exp)
            else:
                print('WARNING: weird stuff here')

            for f in out_poss:
                p = os.path.join(p_exp, f)
                if not os.path.exists(p):
                    os.makedirs(p)

            # find name that identify the model
            name_plot = getName(model_path)

            # load model
            try:
                model = loadModel(getModelName(name_plot), path=MODEL_PATH)
                std_dict = torch.load(model_path)
                model.load_state_dict(std_dict)
                model.eval()

                # get messages
                message = getPrediction(model, data, inds = None)

                # best messages indices
                inds = findIndices(message, nb = number_messages)

                print(inds)

                plotStdMessage(message.copy())
                p_std = os.path.join(p_exp, f'{name_plot}.png')
                if os.path.exists(p_std):
                    print("WARNING >>> issue here")
                plt.savefig(p_std)
                plt.close()

                print(out_poss)
                print(number_messages)

                for i in range(len(out_poss)):      # radius, ...
                    for j in range(number_messages):    # id of the message

                        #plotMessage(data[0], message.copy(), i, inds[j])

                        
                        path = os.path.join(p_exp, out_poss[i])
                        path = os.path.join(path, f"{name_plot}_attr-{out_poss[i]}_nb-{j}.png")

                        nb_plot = 0
                        
                        while os.path.exists(path):
                            nb_plot += 1
                            path = os.path.join(p_exp, out_poss[i])
                            path = os.path.join(path, f"{name_plot}_attr-{out_poss[i]}_nb-{j}_nbPlot-{nb_plot}.png")

                        plt.savefig(path)
                        plt.close()

                print('nfjengjkfbdnkjvnksdbngv')

                messages_linear = getPrediction(model, data, inds = inds[:2])

                print(messages_linear.shape)
                print(gtMessage.shape)


                #mod, result_linear = get_linear_predictions(messages_linear.copy(), gtMessage.copy())
                score = r2_score(messages_linear,gtMessage.copy() 
                 #multioutput = 'raw_values',
                )
                print('nvdfjsghuidfnvjkbdfxwkjvnsbvkjfdsbvjkbdfkjvbkjdsbvjk dsbfjkdvbdj')
                print(score)


            except:
                print(f'Issue >>> {nb}')
                nb += 1



def main():
    if PATH is None:
        # check if listdir only outputs the last element (...)
        list_exp = [os.listdir('/master/code/analyze_models/exp/')]
        list_disp = [os.path.join('/master/code/analyze_models/display2', list_exp[i]) for i in range(len(list_exp))]

    else:
        list_exp = PATH
        list_disp = DISPLAY_PATH


    for i in range(len(list_exp)):

        exp = list_exp[i]
        disp = list_disp[i]

        model_list = findModels(exp)



        # messages analyzis

        plotMessageEvol(model_list, pathPlot = disp)

        # ...


In [34]:
out_poss = ['distance', 'cosine', 'sine', 'radius_1', 'radius_2']
number_messages = 5

params = Parameters()

r2List = []

data = getData(params)
print(len(data))

gtMessage = getGroundTruth(data, params.k, params.epsilon)



v0:60, tau:3.5, k:70, epsilon:0.5


100%|██████████| 149/149 [00:01<00:00, 107.91it/s]


139
tensor([[  0,   1,   3,   8,   9,  10,  12,  13,  16,  17,  18,  19,  21,  22,
          24,  25,  26,  31,  32,  33,  34,  35,  36,  37,  46,  47,  50,  53,
          54,  56,  57,  60,  62,  66,  67,  69,  70,  71,  72,  73,  74,  77,
          79,  80,  81,  82,  85,  86,  87,  89,  91,  92,  93,  94,  95,  96,
          99, 100, 101, 103, 108, 109, 110, 112, 113, 116, 117, 118, 119, 121,
         122, 124, 125, 126, 131, 132, 133, 134, 135, 136, 137, 146, 147, 150,
         153, 154, 156, 157, 160, 162, 166, 167, 169, 170, 171, 172, 173, 174,
         177, 179, 180, 181, 182, 185, 186, 187, 189, 191, 192, 193, 194, 195,
         196, 199],
        [100, 101, 103, 108, 109, 110, 112, 113, 116, 117, 118, 119, 121, 122,
         124, 125, 126, 131, 132, 133, 134, 135, 136, 137, 146, 147, 150, 153,
         154, 156, 157, 160, 162, 166, 167, 169, 170, 171, 172, 173, 174, 177,
         179, 180, 181, 182, 185, 186, 187, 189, 191, 192, 193, 194, 195, 196,
         199,   0,   1,   3,

In [104]:
r2_score(np.random.random((100, 2)), np.random.random((100, 2)))

-0.8434601560254499

In [111]:
if PATH is None:
    # check if listdir only outputs the last element (...)
    list_exp = [os.listdir('/master/code/analyze_models/exp/')]
    list_disp = [os.path.join('/master/code/analyze_models/display2', list_exp[i]) for i in range(len(list_exp))]

else:
    list_exp = PATH
    list_disp = DISPLAY_PATH


for i in range(len(list_exp)):

    exp = list_exp[i]
    disp = list_disp[i]

    model_list = findModels(exp)

0it [00:00, ?it/s]

4it [00:00, 33.98it/s]


In [112]:
print(model_list[0])

/master/code/analyze_models/exps/exp-test/test_new_activation/master-thesis_normal_classic_dt-0.01_lr-0.001_batch-16_encoder-no-encoder_dropout-no-dropout-exp/model_trained/simplest_no-dropout_no-encoder_aug_best_0-001-test.pt


In [120]:
out_poss = ['distance', 'cosine', 'sine', 'radius_1', 'radius_2']
pathPlot = disp
modelList = model_list
number_messages = 5

params = Parameters()

r2List = []

data = getData(params)

gtMessage = getGroundTruth(data, params.k, params.epsilon).reshape(-1, 2)

nb = 0
nb_sim = 0
with torch.no_grad():
    # get the data for the different models

    for model_path in tqdm(modelList):
        nb_sim += 1

        # path for the experiment
        p_exp = os.path.join(pathPlot, f'exp_{nb_sim}')
        if not os.path.exists(p_exp):
            os.makedirs(p_exp)
        else:
            print('WARNING: weird stuff here')

        for f in out_poss:
            p = os.path.join(p_exp, f)
            if not os.path.exists(p):
                os.makedirs(p)

        # find name that identify the model
        name_plot = getName(model_path)

        # load model
        model = loadModel(getModelName(name_plot), path=MODEL_PATH)
        std_dict = torch.load(model_path)
        model.load_state_dict(std_dict)
        model.eval()

        # get messages
        message = getPrediction(model, data, inds = None)

        # best messages indices
        inds = findIndices(message, nb = number_messages)

        print(inds)

        plotStdMessage(message.copy())
        p_std = os.path.join(p_exp, f'{name_plot}.png')
        if os.path.exists(p_std):
            print("WARNING >>> issue here")
        plt.savefig(p_std)
        plt.close()

        print(out_poss)
        print(number_messages)

        for i in range(len(out_poss)):      # radius, ...
            for j in range(number_messages):    # id of the message

                plotMessage(data, message.copy(), i, inds[j])

                
                path = os.path.join(p_exp, out_poss[i])
                path = os.path.join(path, f"{name_plot}_attr-{out_poss[i]}_nb-{j}.png")

                nb_plot = 0
                
                while os.path.exists(path):
                    nb_plot += 1
                    path = os.path.join(p_exp, out_poss[i])
                    path = os.path.join(path, f"{name_plot}_attr-{out_poss[i]}_nb-{j}_nbPlot-{nb_plot}.png")

                plt.savefig(path)
                plt.close()

        print('nfjengjkfbdnkjvnksdbngv')

        messages_linear = getPrediction(model, data, inds = inds[:2])

        print(messages_linear.shape)
        print(gtMessage.shape)

        model = LinearRegression()
        model = model.fit(messages_linear, gtMessage)

        y_pred = model.predict(messages_linear)


        #mod, result_linear = get_linear_predictions(messages_linear.copy(), gtMessage.copy())
        score = r2_score(y_pred,gtMessage.copy() 
            #multioutput = 'raw_values',
        )
        print('nvdfjsghuidfnvjkbdfxwkjvnsbvkjfdsbvjkbdfkjvbkjdsbvjk dsbfjkdvbdj')
        print(score)



v0:60, tau:3.5, k:70, epsilon:0.5


100%|██████████| 149/149 [00:17<00:00,  8.66it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

>>>> loading simplest
INFO >>> with NO encoder
INFO >>> with NO dropout
[37 16  7 23 20]
['distance', 'cosine', 'sine', 'radius_1', 'radius_2']
5
nfjengjkfbdnkjvnksdbngv


 50%|█████     | 1/2 [00:28<00:28, 28.55s/it]

(14416, 2)
(14416, 2)
nvdfjsghuidfnvjkbdfxwkjvnsbvkjfdsbvjkbdfkjvbkjdsbvjk dsbfjkdvbdj
0.9965384682730526
>>>> loading simplest
INFO >>> with NO encoder
INFO >>> with NO dropout
[37 16  7 20 23]
['distance', 'cosine', 'sine', 'radius_1', 'radius_2']
5
nfjengjkfbdnkjvnksdbngv


100%|██████████| 2/2 [00:46<00:00, 23.48s/it]

(14416, 2)
(14416, 2)
nvdfjsghuidfnvjkbdfxwkjvnsbvkjfdsbvjkbdfkjvbkjdsbvjk dsbfjkdvbdj
0.9962894975713092





In [None]:
def get_messages_model(model, data, nbMax:int = 2):
    message = getPrediction(model, data, inds = None)

    inds = findIndices(message, nb = number_messages)

    return message[:, inds]



def get_sum_messages_model(model, data):
    res = None

    for graph in data:

        if inds is None:
            messages = model.sum_message(graph).cpu().detach().numpy()
        else:   
            messages = model.sum_message(graph).cpu().detach().numpy()[:, inds]
        
        
        if res is None:
            res = messages
        else:
            res = np.vstack((res, messages))

    return res



def getDataPySr(model, data, type = 'e'):

    if type == 'e':
        return get_messages_model(model, data)
    
    elif type == 'sum':
        return get_sum_messages_model(model, data)
    
    else:
        print('No valid arg')
        return None
    