In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from sklearn.preprocessing import LabelEncoder
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.nn.models import DeepGraphInfomax


In [2]:
import torch
print(torch.__version__)
print(torch.version.cuda)

# Check if CUDA (GPU support) is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU")

# Now, you can use `device` to send your tensors to the GPU or CPU.


2.0.1
11.7
Using GPU: NVIDIA A100-SXM4-40GB


In [3]:
word_net = pd.read_csv('./data/tmp_word_pair_output.csv')

In [None]:
word_net['study_a']= word_net['a_voc'].str.lower()
word_net['remember_b']= word_net['b_voc'].str.lower()

In [None]:
word_net.info()

In [None]:
word_net.head(10)

In [None]:
word_net.describe()

In [None]:
# word_net_gp1 = word_net.groupby(['study_a']).sum()
# word_net_gp2 = word_net.groupby(['remember_b']).sum()
# word_net_gp = word_net_gp1+word_net_gp2
# word_net_gp = word_net_gp/word_net_gp.shape[0]/2
# word_net_gp.head()

In [None]:
supp_thres1 = word_net.supp1.quantile(0.05)
supp_thres2 = word_net.supp2.quantile(0.05)
cond_thres1 =word_net.cond1.quantile(0.05)
cond_thres2= word_net.cond2.quantile(0.05)
lift_thres1 = 1
lift_thres2 = 1
lift_thres3 = 1
p_thres = 0.05


def filter_high_quality_dynamic_pair(df,p_thres, supp_thres1,supp_thres2,cond_thres1,cond_thres2,lift_thres1,lift_thres2):
        df_temp = df[(df.supp1>supp_thres1) &(df.cond1>cond_thres1) &(df.lift1>lift_thres1) \
                 & (df.supp2>supp_thres2) &(df.cond2>cond_thres2) &(df.lift2>lift_thres2)  &(df.lift3>lift_thres3)]
        #df_temp = df_temp[df_temp.p<p_thres]
        df_temp=df_temp[df_temp['a_voc'] !=df_temp['b_voc'] ]
        return df_temp
    
word_pair_filter = filter_high_quality_dynamic_pair(word_net, p_thres,supp_thres1,supp_thres2,cond_thres1,cond_thres2,lift_thres1,lift_thres2)

#word_pair_filter.to_csv("./data/filtered_word_pairs.csv",index=None)


In [None]:
word_pair_filter.info()

In [None]:
len(set(word_pair_filter.a_voc))

In [None]:
## add chi_square for finding low p_value 
from scipy.stats import chi2_contingency
prior_pair =word_net[word_net["a_voc"] == word_net["b_voc"]][["b_voc", "b_recall_cnt", "b_forget_cnt"]]
prior_pair = prior_pair.set_index("b_voc")
prior_pair = prior_pair.to_dict("index")

In [None]:

def chi2_p(x):
    word_b = x[0]
    recall_a = x[1]
    forget_a = x[2]
    #word_b, recall_a, forget_a = row[["b_voc", "b_recall_cnt", "b_forget_cnt"]]
    # 在单词a出现时，b记住的次数、忘记的次数
    recall_b, forget_b = list(map(prior_pair[word_b].get, ["b_recall_cnt", "b_forget_cnt"]))
    # b记住的次数、忘记的次数
    data = [[recall_a, forget_a], [recall_b - recall_a, forget_b - forget_a]]
    try:
        _, p, _, _ = chi2_contingency(data)
    except:
        p = 1
    return p 


In [None]:
tqdm.pandas()
# from pandarallel import pandarallel

batch_size = 100000
num_rows = word_pair_filter.shape[0]
num_chunks = num_rows // batch_size + 1
# Create an empty list to store processed dataframes (optional, if you want to combine later)
processed_dfs = []
print(num_chunks)
for i in range(num_chunks):
    # Slice the DataFrame to get a chunk
    print("iter:",i)
    df_chunk = word_pair_filter.iloc[i*batch_size : (i+1)*batch_size].copy()
    
    # Apply your function
    df_chunk['p'] = df_chunk[["b_voc", "b_recall_cnt", "b_forget_cnt"]].progress_apply(chi2_p, axis=1)
    
    # Store the processed chunk (optional)
    processed_dfs.append(df_chunk)

# Once all chunks have been processed, you can concatenate them
word_pair_filter = pd.concat(processed_dfs, ignore_index=True)


In [15]:
#word_pair_filter.to_csv("./data/filtered_word_pairs_unfiltered.csv",index=None)

In [16]:
word_pair_filter = word_pair_filter[['a_voc', 'b_voc','study_a', 'remember_b', 'b_recall_cnt', 'b_forget_cnt', 'supp1', 'supp2',
       'cond1', 'cond2', 'lift1', 'lift2', 'lift3','p']][word_pair_filter.p<0.05]
word_pair_filter.head()

Unnamed: 0,a_voc,b_voc,study_a,remember_b,b_recall_cnt,b_forget_cnt,supp1,supp2,cond1,cond2,lift1,lift2,lift3,p
0,April,Catholic,april,catholic,59,45,5e-05,0.5577,0.5673,0.0074,1.4964,1.471,1.7566,9.826235e-05
1,April,Christ,april,christ,98,32,7e-05,0.7462,0.7538,0.0123,1.4011,1.3868,1.656,9.175321e-07
2,April,Christian,april,christian,99,31,7e-05,0.7154,0.7615,0.0118,1.4324,1.3456,1.6068,1.775944e-07
3,April,Easter,april,easter,92,27,6e-05,0.7731,0.7731,0.0117,1.4034,1.4034,1.6758,1.305838e-06
5,April,Italian,april,italian,21,7,1e-05,0.7143,0.75,0.0025,1.6188,1.5417,1.8409,0.00410141


In [17]:
word_pair_filter.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 17236177 entries, 0 to 20404284
Data columns (total 14 columns):
 #   Column        Dtype  
---  ------        -----  
 0   a_voc         object 
 1   b_voc         object 
 2   study_a       object 
 3   remember_b    object 
 4   b_recall_cnt  int64  
 5   b_forget_cnt  int64  
 6   supp1         float64
 7   supp2         float64
 8   cond1         float64
 9   cond2         float64
 10  lift1         float64
 11  lift2         float64
 12  lift3         float64
 13  p             float64
dtypes: float64(8), int64(2), object(4)
memory usage: 1.9+ GB


In [18]:
len(set(word_pair_filter.a_voc))

6685

In [19]:
word_pair_filter.to_csv("./data/filtered_word_pairs.csv",index=None)

In [20]:
'''
PyTorch Geometric using the Deep Graph Infomax (DGI) method, which is a popular unsupervised method for learning node representations:
    Deep Graph Infomax (DGI) is a method developed for unsupervised learning on graphs. The principle behind DGI is to maximize the mutual information between patch representations and corresponding high-level summaries of graphs, thereby capturing the global semantic information.
'''
# Encode the node names to integers
le = LabelEncoder()
nodes = pd.concat([word_pair_filter['study_a'], word_pair_filter['remember_b']])
le.fit(nodes)
word_pair_filter['study_a'] = le.transform(word_pair_filter['study_a'])
word_pair_filter['remember_b'] = le.transform(word_pair_filter['remember_b'])

label_to_original = {i: label for i, label in enumerate(le.classes_)}


In [21]:
# Create edge_index tensor
edges = word_pair_filter[['study_a', 'remember_b']].to_numpy().T
edge_index = torch.tensor(edges, dtype=torch.long)


In [22]:

word_net_gp1 = word_pair_filter.groupby(['study_a']).sum()
word_net_gp2 = word_pair_filter.groupby(['remember_b']).sum()


  word_net_gp1 = word_pair_filter.groupby(['study_a']).sum()
  word_net_gp2 = word_pair_filter.groupby(['remember_b']).sum()


In [23]:
word_net_gp1.head()

Unnamed: 0_level_0,remember_b,b_recall_cnt,b_forget_cnt,supp1,supp2,cond1,cond2,lift1,lift2,lift3,p
study_a,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
0,11877210,472200,201346,0.34443,2330.3668,2343.4091,24.7388,4672.6905,4646.5549,4729.316,30.103605
1,14608426,892036,450569,0.68627,2772.2334,2886.7844,34.498,5655.1331,5453.8447,6365.5551,15.558918
2,6120048,107320,87742,0.09959,734.5056,981.0143,12.0985,2802.7785,2113.5232,6284.8763,2.660573
3,5010755,151653,124088,0.14091,559.3964,769.5583,10.0518,2253.1123,1646.3749,5196.6668,1.40993
4,11104897,425208,237386,0.33856,2022.6403,2067.8426,20.2961,4300.8359,4219.9604,4832.3144,27.313952


In [24]:
word_net_gp2.head()

Unnamed: 0_level_0,study_a,b_recall_cnt,b_forget_cnt,supp1,supp2,cond1,cond2,lift1,lift2,lift3,p
remember_b,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
1,2041549,129539,10939,0.07187,517.7436,534.6389,12.9193,624.0119,604.2944,670.5394,6.54336
2,17982069,316358,274801,0.30216,2470.5517,2865.9205,33.6287,8522.2332,7346.5565,13873.7099,8.08817
3,18413078,490757,482160,0.49737,2382.2476,2744.5355,50.185,8662.9278,7519.4101,14204.8094,7.874369
4,862304,62454,4840,0.03431,221.8084,230.289,6.2497,263.7049,253.9945,282.374,3.218194
5,1183202,92257,6599,0.05046,311.498,321.2218,9.0613,365.3824,354.3232,389.9901,3.941677


In [25]:
word_net_gp = word_net_gp1.add(word_net_gp2, fill_value=0)

word_net_gp.head()

Unnamed: 0,b_forget_cnt,b_recall_cnt,cond1,cond2,lift1,lift2,lift3,p,remember_b,study_a,supp1,supp2
0,201346.0,472200.0,2343.4091,24.7388,4672.6905,4646.5549,4729.316,30.103605,11877210.0,,0.34443,2330.3668
1,461508.0,1021575.0,3421.4233,47.4173,6279.145,6058.1391,7036.0945,22.102279,14608426.0,2041549.0,0.75814,3289.977
2,362543.0,423678.0,3846.9348,45.7272,11325.0117,9460.0797,20158.5862,10.748743,6120048.0,17982069.0,0.40175,3205.0573
3,606248.0,642410.0,3514.0938,60.2368,10916.0401,9165.785,19401.4762,9.284299,5010755.0,18413078.0,0.63828,2941.644
4,242226.0,487662.0,2298.1316,26.5458,4564.5408,4473.9549,5114.6884,30.532146,11104897.0,862304.0,0.37287,2244.4487


In [26]:
word_net_gp=word_net_gp.reset_index(drop=False)

In [27]:
# Create node feature tensor by aggregating edge features
num_nodes = nodes.nunique()
print("num_nodes",num_nodes)
node_features =word_net_gp[['supp1', 'supp2','cond1', 'cond2', 'lift1', 'lift2', 'lift3','p']].values/word_net_gp.shape[0]

x = torch.tensor(node_features, dtype=torch.float)

# Create the PyG Data object
data = Data(x=x, edge_index=edge_index)


num_nodes 6685


In [28]:

# Encoder definition
class Encoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super(Encoder, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels, cached=True)
        self.conv2 = GCNConv(hidden_channels, hidden_channels, cached=True)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

# Corruption function
def corruption(x, edge_index):
    return x[torch.randperm(x.size(0))], edge_index


In [29]:

# Model and optimizer
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DeepGraphInfomax(
    hidden_channels=64, encoder=Encoder(8, 64),  # updated number of input channels to 8
    summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
    corruption=corruption).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)


In [30]:

model.train()
for epoch in range(200):
    print('epoch:',epoch)
    data = data.to(device)
    pos_z, neg_z, summary = model(data.x, data.edge_index)
    loss = model.loss(pos_z, neg_z, summary)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

model.eval()
# To get the node embeddings after training, you can use:
with torch.no_grad():
    z = model.encoder(data.x.to(device), data.edge_index.to(device))



epoch: 0
epoch: 1
epoch: 2
epoch: 3
epoch: 4
epoch: 5
epoch: 6
epoch: 7
epoch: 8
epoch: 9
epoch: 10
epoch: 11
epoch: 12
epoch: 13
epoch: 14
epoch: 15
epoch: 16
epoch: 17
epoch: 18
epoch: 19
epoch: 20
epoch: 21
epoch: 22
epoch: 23
epoch: 24
epoch: 25
epoch: 26
epoch: 27
epoch: 28
epoch: 29
epoch: 30
epoch: 31
epoch: 32
epoch: 33
epoch: 34
epoch: 35
epoch: 36
epoch: 37
epoch: 38
epoch: 39
epoch: 40
epoch: 41
epoch: 42
epoch: 43
epoch: 44
epoch: 45
epoch: 46
epoch: 47
epoch: 48
epoch: 49
epoch: 50
epoch: 51
epoch: 52
epoch: 53
epoch: 54
epoch: 55
epoch: 56
epoch: 57
epoch: 58
epoch: 59
epoch: 60
epoch: 61
epoch: 62
epoch: 63
epoch: 64
epoch: 65
epoch: 66
epoch: 67
epoch: 68
epoch: 69
epoch: 70
epoch: 71
epoch: 72
epoch: 73
epoch: 74
epoch: 75
epoch: 76
epoch: 77
epoch: 78
epoch: 79
epoch: 80
epoch: 81
epoch: 82
epoch: 83
epoch: 84
epoch: 85
epoch: 86
epoch: 87
epoch: 88
epoch: 89
epoch: 90
epoch: 91
epoch: 92
epoch: 93
epoch: 94
epoch: 95
epoch: 96
epoch: 97
epoch: 98
epoch: 99
epoch: 100

In [31]:

z = z.cpu().numpy()

# Create DataFrame from embeddings
embeddings_df = pd.DataFrame(z, index=le.classes_)

embeddings_df.columns = ['embed'+str(i+1) for i in range(embeddings_df.shape[1])]
embeddings_df = embeddings_df.reset_index(drop=False)

embeddings_df = embeddings_df.rename(columns ={"index":"w"})

embeddings_df.to_csv('./data/KGembeddings.csv',index=False)


In [32]:
embeddings_df.head()

Unnamed: 0,w,embed1,embed2,embed3,embed4,embed5,embed6,embed7,embed8,embed9,...,embed55,embed56,embed57,embed58,embed59,embed60,embed61,embed62,embed63,embed64
0,a,-0.112005,0.175023,0.081303,0.119429,0.234358,-0.285393,0.08896,-0.084222,-0.146858,...,0.121453,0.046255,0.027163,0.253145,0.250952,0.395841,0.263606,-0.024474,-0.125878,-0.034926
1,abandon,-0.400461,0.727612,0.237638,0.407411,1.023871,-1.093308,0.371002,-0.282474,-0.60378,...,0.57236,0.213719,0.139102,0.971377,0.966684,1.702201,1.043585,-0.034144,-0.49925,-0.070479
2,abdomen,-0.485593,1.103594,0.264719,0.511833,1.521983,-1.432766,0.503473,-0.349383,-0.807522,...,0.910728,0.306264,0.270944,1.298496,1.351456,2.541386,1.477359,0.069351,-0.689769,-0.026549
3,abide,-0.479631,1.098984,0.261993,0.506807,1.512208,-1.418678,0.499232,-0.345371,-0.799597,...,0.907186,0.303954,0.270819,1.286539,1.342442,2.527811,1.467622,0.073551,-0.683922,-0.024235
4,ability,-0.257335,0.465007,0.156881,0.264888,0.649716,-0.698882,0.233898,-0.185201,-0.38173,...,0.361201,0.133533,0.09038,0.622238,0.620555,1.080868,0.667867,-0.02428,-0.319865,-0.049535


In [33]:
embeddings_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6685 entries, 0 to 6684
Data columns (total 65 columns):
 #   Column   Non-Null Count  Dtype  
---  ------   --------------  -----  
 0   w        6685 non-null   object 
 1   embed1   6685 non-null   float32
 2   embed2   6685 non-null   float32
 3   embed3   6685 non-null   float32
 4   embed4   6685 non-null   float32
 5   embed5   6685 non-null   float32
 6   embed6   6685 non-null   float32
 7   embed7   6685 non-null   float32
 8   embed8   6685 non-null   float32
 9   embed9   6685 non-null   float32
 10  embed10  6685 non-null   float32
 11  embed11  6685 non-null   float32
 12  embed12  6685 non-null   float32
 13  embed13  6685 non-null   float32
 14  embed14  6685 non-null   float32
 15  embed15  6685 non-null   float32
 16  embed16  6685 non-null   float32
 17  embed17  6685 non-null   float32
 18  embed18  6685 non-null   float32
 19  embed19  6685 non-null   float32
 20  embed20  6685 non-null   float32
 21  embed21  6685 