In [81]:
import pandas as pd
import numpy as np
from tqdm import tqdm

from sklearn import preprocessing
import random

import itertools
import os
import re
import math

import torch
from torch_geometric.data import HeteroData, Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx, subgraph, degree, from_networkx
from torch.utils.data import Dataset, dataloader
import torch_geometric.transforms as T

from torch_geometric.nn import GCNConv, summary, GraphSAGE
from torch.nn import Sequential, Linear, ReLU, Dropout
from sklearn.model_selection import ParameterGrid
from torch_geometric.nn import global_add_pool, global_mean_pool
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score, f1_score
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from sklearn.utils import class_weight

import community.community_louvain as community_louvain  # python-louvain

# visual
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from tabulate import tabulate

In [82]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

#device = torch.device('cpu')
print(device)

cuda


In [83]:
import os
cwd = os.getcwd()
print(cwd)

d:\Work\UQAM\Doctorat\Projets\oignion_GNN\cultures_GNN\carotte\Automne_2023\Script\graph_masking


In [84]:
carrot_df = pd.read_csv('../../Output/carrot_no_sensitive_data.csv', index_col=0)
meteo_df = pd.read_csv('../../Output/combined_daily_meteo.csv', index_col=0)
plant_distance_filepath = "../../Output/field_distance.txt"

In [85]:
with open(plant_distance_filepath) as file:
    distance_txt_file = [line.rstrip() for line in file]

In [86]:
carrot_df.rename(columns={'SampleDate':'Date'}, inplace=True)

In [87]:
carrot_df['cote_c_carotae'].value_counts()

0    493
1    257
2     10
3      9
4      1
Name: cote_c_carotae, dtype: int64

On retire la ferme 0 car elle debalance le jeu de donnees 

In [88]:
carrot_df = carrot_df.drop(carrot_df[carrot_df['FarmID'] == 0].index)

In [89]:
carrot_df['cote_c_carotae'].value_counts()

0    276
1    218
2      4
3      1
Name: cote_c_carotae, dtype: int64

In [90]:
carrot_df.loc[carrot_df['cote_c_carotae'] >= 1, 'cote_c_carotae'] = 1
carrot_df['cote_c_carotae'].value_counts()

0    276
1    223
Name: cote_c_carotae, dtype: int64

In [91]:
# we remove every date in the meteo dataframe where no sample have been taken
unique_sample_date = carrot_df['Date'].unique()
print(unique_sample_date)
unique_sample_date = meteo_df[meteo_df['Date'].isin(unique_sample_date)]

unique_sample_date.head()

[202 204 207 209 214 216 222 225 237 243 244 251 253 259 260 265 266 271
 272]


Unnamed: 0,FarmID,Date,Day_avg_Temp_C,Day_max_Temp_C,Day_min_Temp_C,Day_less_5_Temp_C,Day_less_13_Temp_C,Day_more_30_Temp_C,Day_more_35_Temp_C,Day_15_25_Temp_C,...,Rolling_Quot_RH_max_14D,Rolling_Quot_DewPoint_sum_14D,Rolling_Quot_SolarRadiation_sum_14D,Rolling_Quot_SolarRadiation_mean_14D,Rolling_Quot_Rain_sum_14D,Rolling_Quot_Rain_mean_14D,Rolling_Quot_more_0_Rain_14D,Rolling_Quot_WindSpeed_mean_14D,Rolling_Quot_GustSpeed_mean_14D,Rolling_DegresJours_sum_14D
37,0,202,19.633077,23.93,15.27,0,0,0,0,13,...,99.8,241.599583,2473.5,176.678571,47.2,3.371429,2.071429,0.449405,2.42381,232.5
39,0,204,22.406923,26.11,16.18,0,0,0,0,9,...,99.8,239.610833,2842.041667,203.002976,28.2,2.014286,1.285714,0.552083,2.609821,239.375
42,0,207,24.604615,28.1,16.75,0,0,0,0,6,...,99.8,243.713333,2762.166667,197.297619,39.6,2.828571,1.642857,0.703869,3.064286,244.495
44,0,209,21.417692,25.26,15.08,0,0,0,0,12,...,99.8,234.027917,2749.791667,196.41369,39.4,2.814286,1.642857,0.574405,2.685714,235.605
49,0,214,20.373077,23.62,14.96,0,0,0,0,12,...,99.8,206.191667,2762.833333,197.345238,45.2,3.228571,1.857143,0.741071,3.085119,212.38


In [92]:
carrot_df

Unnamed: 0,FarmID,GreenLeavesNum_carrots,Plant_ID,cote_c_carotae,Date,incidence_a_dauci,incidence_s_sclerotiorum,carrot_stage
520,2,4,1,0,202,0,0,0
521,2,4,2,0,202,0,0,0
522,2,4,3,0,202,0,0,0
523,2,4,4,0,202,0,0,0
524,2,4,5,0,202,0,0,0
...,...,...,...,...,...,...,...,...
515,1,6,21,1,272,0,0,2
516,1,6,22,1,272,0,0,2
517,1,6,23,1,272,0,0,2
518,1,6,24,0,272,0,0,2


In [93]:
# Initialize a dictionary to store the progression of sickness ratings
sickness_progression = {}

# Iterate through the dataset
for index, row in carrot_df.iterrows():
    farm_id = row['FarmID']
    plant_id = row['Plant_ID']
    sickness_rating = row['cote_c_carotae']
    
    # Create a unique key using farm id and plant id
    key = (int(farm_id), int(plant_id))
    
    # If the key is not already in the dictionary, create a new entry
    if key not in sickness_progression:
        sickness_progression[key] = []
    
    # Add date and sickness rating to the respective lists
    sickness_progression[key].append(int(sickness_rating))

In [94]:
for key, value in sickness_progression.items():
    print(key, value)

(2, 1) [0, 0, 1, 1, 0, 0, 0, 0, 0, 0]
(2, 2) [0, 0, 1, 0, 0, 1, 0, 0, 0, 1]
(2, 3) [0, 0, 1, 1, 0, 1, 0, 0, 0, 0]
(2, 4) [0, 0, 1, 1, 0, 1, 1, 0, 0, 0]
(2, 5) [0, 0, 1, 1, 0, 0, 1, 0, 0, 0]
(2, 6) [0, 0, 0, 1, 0, 1, 1, 1, 1, 0]
(2, 7) [0, 0, 1, 1, 0, 1, 1, 0, 0, 0]
(2, 8) [0, 0, 1, 1, 1, 1, 1, 0, 0, 0]
(2, 9) [1, 0, 0, 1, 0, 0, 0, 0, 0, 0]
(2, 10) [0, 0, 1, 0, 0, 0, 1, 0, 0, 0]
(2, 11) [0, 0, 1, 1, 0, 1, 1, 0, 0, 0]
(2, 12) [0, 0, 1, 1, 0, 1, 1, 0, 0, 0]
(2, 13) [0, 0, 1, 1, 0, 1, 0, 0, 0, 0]
(2, 14) [0, 0, 1, 1, 1, 1, 1, 0, 1, 0]
(2, 15) [0, 0, 1, 1, 0, 0, 1, 0, 0, 0]
(2, 16) [0, 1, 1, 1, 0, 1, 0, 1, 1, 0]
(2, 17) [0, 0, 1, 1, 1, 0, 1, 0, 0, 0]
(2, 18) [0, 0, 0, 1, 0, 0, 0, 1, 0, 0]
(2, 19) [0, 0, 1, 1, 0, 0, 0, 0, 0, 1]
(2, 20) [0, 0, 1, 0, 1, 1, 0, 0, 0, 0]
(2, 21) [0, 0, 0, 1, 0, 1, 1, 1, 1, 0]
(2, 22) [0, 0, 1, 0, 1, 0, 0, 1, 0, 0]
(2, 23) [0, 0, 0, 1, 1, 0, 1, 1, 0, 0]
(2, 24) [0, 0, 0, 0, 0, 1, 1, 1, 0, 0]
(2, 25) [0, 0, 0, 1, 0, 0, 1, 0, 0, 0]
(1, 1) [0, 0, 1, 1, 1, 1, 1, 1, 0,

In [95]:
def check_class_distribution(distribution, percentage_threshold=0.1):

    smallest_class = min(distribution)

    # Calculate the expected minimum and maximum class count
    max_class_count = (1 + percentage_threshold) * smallest_class

    # Check if the distribution is within the specified percentage
    within_threshold = all(count <= max_class_count for count in distribution)

    return within_threshold

In [96]:
def count_values(dictionary, keys):
    value_counter = {}

    for key in keys:
        values = dictionary.get(key, [])
        for value in values:
            if value in value_counter:
                value_counter[value] += 1
            else:
                value_counter[value] = 1

    return value_counter

In [97]:
def count_zeros_ones(binary_list):
    zeros = binary_list.count(0)
    ones = len(binary_list) - zeros
    return zeros, ones

def split_keys_by_ratio_old(dictionary, ratios):
    #sorted_keys = sorted(dictionary.keys(), key=lambda k: abs(count_zeros_ones(dictionary[k])[0] - count_zeros_ones(dictionary[k])[1]))
    sorted_keys = list(dictionary.keys())
    random.shuffle(sorted_keys)

    total_keys = len(sorted_keys)

    group_sizes = []
    for ratio_index in range(len(ratios)):
        if ratio_index == len(ratios) - 1:
            group_sizes.append(math.floor((total_keys * ratios[ratio_index])))
        else:
            group_sizes.append(math.ceil((total_keys * ratios[ratio_index])))

    #groups = [sorted_keys[i:i+group_size] for i, group_size in enumerate(group_sizes)]
    groups = []
    start_idx = 0
    for group_size in group_sizes:
        groups.append(sorted_keys[start_idx: start_idx + group_size])
        start_idx = group_size

    
    return groups


In [98]:
def split_keys_by_ratio(dictionary, ratios):
    #sorted_keys = sorted(dictionary.keys(), key=lambda k: abs(count_zeros_ones(dictionary[k])[0] - count_zeros_ones(dictionary[k])[1]))
    sorted_keys = list(dictionary.keys())

    random.shuffle(sorted_keys)

    total_length = len(sorted_keys)
    # Calculate the starting and ending indices for each segment
    start_idx_1 = 0
    end_idx_1 = int(total_length * ratios[0])

    start_idx_2 = end_idx_1
    end_idx_2 = start_idx_2 + int(total_length * ratios[1])

    start_idx_3 = end_idx_2
    end_idx_3 = total_length

    # Split the list into three segments
    part_1 = sorted_keys[start_idx_1:end_idx_1]
    part_2 = sorted_keys[start_idx_2:end_idx_2]
    part_3 = sorted_keys[start_idx_3:end_idx_3]
    
    return [part_1, part_2, part_3]

In [99]:
balanced = False
iter = 0
perc = 0.02
ratios = [0.70, 0.15, 0.15]
while not balanced:
    groups = split_keys_by_ratio(sickness_progression, ratios)

    temp_balance = True
    for idx in range(len(ratios)):
        class_ratio = count_values(sickness_progression, groups[idx])
        class_ratio_list = list(class_ratio.values())
        if check_class_distribution(class_ratio_list, percentage_threshold=perc) is False:
            temp_balance = False

    iter +=1
    if temp_balance:
          balanced = True
    if iter == 200:
        perc += 0.02
        iter = 0

for i, group in enumerate(groups):
        zeros = sum(count_zeros_ones(sickness_progression[key])[0] for key in group)
        ones = sum(count_zeros_ones(sickness_progression[key])[1] for key in group)
        print(f"Group {i+1}: {len(group)} - Zeros: {zeros}, Ones: {ones}") # {group}
print(f'final perc: {perc}')
    

Group 1: 35 - Zeros: 194, Ones: 156
Group 2: 7 - Zeros: 38, Ones: 31
Group 3: 8 - Zeros: 44, Ones: 36
final perc: 0.25999999999999995


In [100]:
# Convert floats to integers
for group in groups:
    for i, (farm_id, plant_id) in enumerate(group):
        group[i] = (int(farm_id), int(plant_id))

In [101]:
filtered_lists = [[y for x, y in sublist if x == 0] for sublist in groups]
test = list(itertools.chain.from_iterable(filtered_lists))
print(test)
print(set(test))


[]
set()


In [102]:

result = count_values(sickness_progression, groups[0])
print(result)
result = count_values(sickness_progression, groups[1])
print(result)
result = count_values(sickness_progression, groups[2])
print(result)


{0: 194, 1: 156}
{0: 38, 1: 31}
{0: 44, 1: 36}


In [103]:
groups[1]

[(2, 21), (1, 24), (2, 22), (2, 11), (1, 13), (2, 20), (1, 1)]

In [104]:
groups[2]

[(1, 10), (1, 15), (2, 10), (1, 9), (1, 14), (2, 18), (1, 2), (1, 12)]

In [105]:
print(len(sickness_progression.keys()))

50


In [None]:
import json

# Save the data to a JSON file
with open('graph_masking_carrot.json', 'w') as json_file:
    json.dump(groups, json_file)

In [None]:
import json

with open('graph_masking_carrot.json', 'r') as json_file:
    loaded_data = json.load(json_file)