In [3]:
import matplotlib.pyplot as plt
import os
import numpy as np
import sys
from tqdm import tqdm
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

def path_link(path:str):
    sys.path.append(path)

path_link('master/code/lib')

import dataLoading as dl
from measure import plotStdMessage

yessss


In [4]:
#PATH = 'master/code/runs1'
#PATH = 'master/code/runs2'
PATH = 'master/code/test_new_activation'

#DISPLAY_PATH = 'master/code/display_l1'
#DISPLAY_PATH = '/master/code/display_l1_2'
DISPLAY_PATH = 'master/code/display_l1_test_new_activation'

MODEL_PATH = 'master/code/models/mod_base'

In [5]:
def findModels(path):
    pathLists = []
    for root, dirs, files in tqdm(os.walk(path)):
            for file in files:
                  
                  if file.endswith('.pt'):
                        pathLists.append(os.path.join(root, file))


    return pathLists


def delete_wandb_dirs(start_path):
    for root, dirs, files in os.walk(start_path, topdown=False):
        for dir_name in dirs:
            if dir_name == "wandb":
                dir_path = os.path.join(root, dir_name)
                print(f"Deleting: {dir_path}")
                shutil.rmtree(dir_path)


def getLoader(path, batch_size = 32, shuffleBool = True, root = None, jsonFile = None, mode = 'training'):
    datasetTraining = dl.DataLoader2(root, path = path, jsonFile = jsonFile, mode = mode)
    loader = DataLoader(datasetTraining, batch_size=batch_size, shuffle = shuffleBool)
    
    return loader


def getName(path):
    run_name = path.split('/')[-3]

    model_type = path.split('/')[-1].split('.')[0]

    if 'best' in model_type:
        model_type = '_best'

    else:
         model_type = '_latest'

    name = run_name + model_type

    return name


def getModelName(key):

    name = ''

    #if 'simplest' in key:
    #    name = name + 'simplest'
    
    name = name + 'simplest'

    ## other possibilities

    if 'no-dropout' in key:
        name = name + '_no-dropout'
    else:
        name = name + '_dropout'

    if 'no-encoder' in key:
        name = name + '_no-encoder'
    else:
        name = name + '_encoder'

    if 'relu' in key:
        name = name + '-relu'

    return name
         

def loadModel(modelName:str, inputShape:int = 8, edges_shape = 5, path = None):
    """ 
    Function to import the model

    Args:
    -----
        - `modelName`: name of the model
        - `inputShape`: inout shape of the NN
        - `edges_shape`: edge shape of the NN
        - `path`: path where the models are
    """

    sys.path.append(path)

    loadFun = __import__(f'{modelName}', fromlist = ('loadNetwork'))

    model = loadFun.loadNetwork(inputShape, edges_shape)

    return model



def findLoader(key):

    # remains initial conditions to consider
    if 'normal' in key:
        if '0_01' in key:
            dataloader_path = '/master/code/simulation/path/mew_0_01_normal.json'

        else:       # 0.001
            dataloader_path = '/master/code/simulation/path/mew_0_001_normal.json'

    else:       # noisy
        if '0_01' in key:
            dataloader_path = '/master/code/simulation/path/mew_0_01_noisy.json'

        else:       # 0.001
            dataloader_path = '/master/code/simulation/path/mew_0_001_noisy.json'

    print(dataloader_path)


    return getLoader(None, 128, True, None, dataloader_path, 'test')


def getOrderedVals(attribute, message, bins):
    bin_edges = np.linspace(np.min(attribute), np.max(attribute), bins + 1)
    bin_indices = np.digitize(attribute, bin_edges) - 1

    means = np.zeros(bins)
    stds = np.zeros(bins)

    for i in range(bins):
        bin_mask = bin_indices == i
        if np.any(bin_mask):
            means[i] = np.mean(message[bin_mask])
            stds[i] = np.std(message[bin_mask])
        else:
            means[i] = np.nan
            stds[i] = np.nan

    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2


    return bin_centers, means, stds



def findIndices(message, nb = 5):
    stdv=plotStdMessage(message)
    plt.close()

    inds = np.argsort(stdv)

    return inds[-nb:]




def plotMessage(graph, messages, i_attr, id_message):
    edges = graph.edge_attr.cpu().detach().numpy().copy()
    x, mean, std = getOrderedVals(edges[:, i_attr], messages[:, id_message], 100)
    plt.scatter(edges[:, i_attr], messages[:, id_message])
    #plt.plot(x, mean, 'green')
    #plt.fill_between(x, mean-std, mean+ std, color = 'red', alpha = 0.4)
    #plt.vlines(x = 2, ymin = np.min(messages[:, id]), ymax = np.max(messages[:, id]))
    #plt.vlines(x = 4, ymin = np.min(messages[:, id]), ymax = np.max(messages[:, id]))


def plotMessageEvol(modelList, pathPlot = DISPLAY_PATH):

    # create folders

    out_poss = ['distance', 'cosine', 'sine', 'radius_1', 'radius_2']
    number_messages = 5
    
    nb = 0
    nb_sim = 0
    with torch.no_grad():
        # get the data for the different models

        for model_path in tqdm(modelList):
            nb_sim += 1

            # path for the experiment
            p_exp = os.path.join(pathPlot, f'exp_{nb_sim}')
            if not os.path.exists(p_exp):
                os.makedirs(p_exp)
            else:
                print('WARNING: weird stuff here')

            for f in out_poss:
                p = os.path.join(p_exp, f)
                if not os.path.exists(p):
                    os.makedirs(p)

            # find name that identify the model
            name_plot = getName(model_path)

            # load model
            try:
                model = loadModel(getModelName(name_plot), path=MODEL_PATH)
                std_dict = torch.load(model_path)
                model.load_state_dict(std_dict)
                model.eval()

                # condition loader on the dt
                loader = findLoader(model_path)

                # get data
                for data, _ in loader:
                    break

                # get messages
                message = model.message(data).cpu().detach().numpy()



                # best messages indices
                inds = findIndices(message, nb = number_messages)

                plotStdMessage(message.copy())
                p_std = os.path.join(p_exp, f'{name_plot}.png')
                if os.path.exists(p_std):
                    print("WARNING >>> issue here")
                plt.savefig(p_std)
                plt.close()

                for i in range(len(out_poss)):      # radius, ...
                    for j in range(number_messages):    # id of the message

                        plotMessage(data, message.copy(), i, inds[j])
                        
                        path = os.path.join(p_exp, out_poss[i])
                        path = os.path.join(path, f"{name_plot}_attr-{out_poss[i]}_nb-{j}.png")

                        nb_plot = 0
                        while os.path.exists(path):
                            nb_plot += 1
                            path = os.path.join(p_exp, out_poss[i])
                            path = os.path.join(path, f"{name_plot}_attr-{out_poss[i]}_nb-{j}_nbPlot-{nb_plot}.png")

                        plt.savefig(path)
                        plt.close()

            except:
                print(f'Issue >>> {nb}')
                nb += 1

In [6]:
model_list = findModels('/master/code/analyze_models/exp/test_new_activation_0')

0it [00:00, ?it/s]


In [9]:
print(os.path.exists('/master/code/analyze_models/exps/test_new_activation_0'))

True


In [8]:
plotMessageEvol(model_list)

  0%|          | 0/6 [00:00<?, ?it/s]

>>>> loading simplest
INFO >>> relu end of message MLP
INFO >>> with DROPOUT
INFO >>> with NO encoder
/master/code/simulation/path/mew_0_001_normal.json


 17%|█▋        | 1/6 [00:13<01:09, 13.92s/it]

>>>> loading simplest
INFO >>> relu end of message MLP
INFO >>> with DROPOUT
INFO >>> with NO encoder
/master/code/simulation/path/mew_0_001_normal.json


 33%|███▎      | 2/6 [00:30<01:02, 15.72s/it]

>>>> loading simplest
INFO >>> relu end of message MLP
INFO >>> with DROPOUT
INFO >>> with NO encoder
/master/code/simulation/path/mew_0_001_normal.json


 50%|█████     | 3/6 [00:44<00:44, 14.68s/it]

>>>> loading simplest
INFO >>> relu end of message MLP
INFO >>> with DROPOUT
INFO >>> with NO encoder
/master/code/simulation/path/mew_0_001_normal.json


 67%|██████▋   | 4/6 [01:01<00:31, 15.69s/it]

>>>> loading simplest
INFO >>> relu end of message MLP
INFO >>> with NO encoder
INFO >>> with NO dropout
/master/code/simulation/path/mew_0_001_normal.json


 83%|████████▎ | 5/6 [01:14<00:14, 14.63s/it]

>>>> loading simplest
INFO >>> relu end of message MLP
INFO >>> with NO encoder
INFO >>> with NO dropout
/master/code/simulation/path/mew_0_001_normal.json


100%|██████████| 6/6 [01:31<00:00, 15.24s/it]


### Linear regression analysis

In [10]:
from simulation import calculate_interaction

EPSILON = ...
K = ...
RADII = 1


def getGTMessages(graph, epsilon = EPSILON, k = K, radii = RADII):

    ri = graph.x[graph.edge_inds[0, :], :2].cpu().detach().numpy()
    rj = graph.x[graph.edge_inds[1, :], :2].cpu().detach().numpy()

    N = graph.x.shape[0]

    forces = calculate_interaction(N, ri, rj, k, epsilon, radii = 1.0)

    return forces

In [11]:
def getMessage(graph, model):
    message = model.message(graph).cpu().detach().numpy()
    return message

## analyser les dimensions pour savoir quel message correspond a quoi

In [None]:
def getSim():
    lim = 0.85 * 100

    xPos = np.linspace(-lim, lim, 10)
    yPos = np.linspace(-lim, lim, 10)
    gridX, gridY = np.meshgrid(xPos, yPos)
    delta = np.random.uniform(0, 7, gridX.shape + (2,))

    gridX2 = gridX + delta[:, :, 0]
    gridY2 = gridY + delta[:, :, 1]

    pos = np.stack([gridX.ravel(), gridY.ravel()], axis=1)
    pos_perturbed = np.stack([gridX2.ravel(), gridY2.ravel()], axis=1)

    pos = np.concatenate([pos, pos_perturbed], axis=0)

    angles = np.random.rand(pos.shape[0]) * 2 * np.pi

    data = sim.compute_main(200, (60, 3.5, 70, 0.5), 120, T = 200, initialization = (pos, angles), dt = 0.001)[0]

    x, y, attr, inds = ft.processSimulation(data)

    return x, y, attr, inds

In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

def getLinearComp(model):

    # get data
    data = getSim()
    inds = np.arange(data.shape[0])
    random.shuffle(inds)
    inds = inds[:5]
    graphs = data[inds]

    gt = None
    messages = None

    for graph in graphs:

        graph = None
        # get the gt messages and the predicted messages

        if gt is None:
            gt = getGTMessages(graph)
            messages = getMessage(graph, model)
        else:
            gt = np.vstack((gt, getGTMessages(graph)))
            messages = np.vstack((messages, getMessage(graph, model)))


    # fit a linear model
    # MAKE 2 models if need to dissociate the two (should be fine though)
    mod = LinearRegression()
    mod.fit(messages, gt)


    # get new data
    data = getSim()
    inds = np.arange(data.shape[0])
    random.shuffle(inds)
    inds = inds[:5]
    graphs = data[inds]

    graph = None

    gt = None
    preds = None

    for graph in graphs:
        # get the new gt and predicted messages

        if gt is None:
            gt = getGTMessages(graph)
            messages = getMessage(graph, model)
            preds = model.predict(messages)

        else:
            gt = np.vstack((gt, getGTMessages(graph)))
            messages = getMessage(graph, model)

            preds = np.vstack((preds, model.predict(messages)))

    r2 = r2_score(gt, preds)

    return r2, preds, gt

### Get data into a csv dataframe for the errors

In [20]:
def combinePandas(df1, df2):

    combined_columns = df1.columns.union(df2.columns)
    df1 = df1.reindex(columns=combined_columns, fill_value=0)
    df2 = df2.reindex(columns=combined_columns, fill_value=0)
    
    combined_df = pd.concat([df1, df2], ignore_index=True)
    
    return combined_df

In [21]:
import pandas as pd

# Example DataFrames
df1 = pd.DataFrame({
    'A': [1, 2],
    'B': [3, 4]
})

df2 = pd.DataFrame({
    'B': [5, 6],
    'C': [7, 8]
})

# Combine the DataFrames
combined_df = combinePandas(df1, df2)
print(df1)
print('\n')
print(df2)
print('\n')
print(combined_df)

   A  B
0  1  3
1  2  4


   B  C
0  5  7
1  6  8


   A  B  C
0  1  3  0
1  2  4  0
2  0  5  7
3  0  6  8
