In [1]:
import torch
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import dgl.sparse as dglsp
from nltk.corpus import stopwords
from utils import clean_str, remove_stopwords, nomalize_Adj
from model import Vocaburary
from IPython.display import clear_output

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ROOT = './ProcessedData'
TRAIN_DATA = 'train_data.csv'
TEST_DATA = 'test_data.csv'
MAX_WINDOW_SIZE = 20

In [3]:
for dataset in os.listdir(ROOT):
    window_size = MAX_WINDOW_SIZE
    clear_output(wait = True)
    print(f"==================================")
    print(f'Current Dataset = {dataset}')
    print(f'**********************************')
    print(f'Current Window Size = {window_size}')
    TARGET_DATA = f'WholeGraphDict_w{window_size}_wihout_c.gh'
    file_path = os.path.join(ROOT, dataset)
    save_path = os.path.join(file_path, TARGET_DATA)
    if os.path.isfile(save_path):
        print("File has exsit, Skip it")
        continue
    train_df = pd.read_csv(os.path.join(file_path, TRAIN_DATA), index_col=False)
    test_df = pd.read_csv(os.path.join(file_path, TEST_DATA), index_col=False)
    train_df['text'] = train_df['text'].map(clean_str)
    test_df['text'] = test_df['text'].map(clean_str)
    if dataset != 'mr':
        train_df['text'] = train_df['text'].map(remove_stopwords)
        test_df['text'] = test_df['text'].map(remove_stopwords)

    word_count = {}
    for text in tqdm(train_df['text'], desc='Counting Word...'):
        for word in text.split():
            if word not in word_count:
                word_count.update({word: 0})
            word_count[word] +=1
    for text in tqdm(test_df['text'], desc='Counting Word...'):
        for word in text.split():
            if word not in word_count:
                word_count.update({word: 0})
            word_count[word] +=1
    if dataset == 'mr':
        voc = Vocaburary(word_count = word_count)
    else:
        voc = Vocaburary(word_count = word_count, min_time = 5)

    train_df['ids'] = train_df['text'].map(lambda x : voc.encode(x.split()))
    test_df['ids'] = test_df['text'].map(lambda x : voc.encode(x.split()))

    train_word_set = set()
    test_word_set = set()
    for ids in tqdm(train_df['ids'], desc='constrcuting train_word_set...'):
        id_set = set(ids)
        train_word_set.update(ids)
    for ids in tqdm(test_df['ids'], desc='constructing test_word_set...'):
        id_set = set(ids)
        test_word_set.update(ids)

    ALL_df = pd.concat([train_df, test_df], axis=0, ignore_index=True)

    label_num = ALL_df['target'].unique().max() + 1

    doc_word_graph = {"doc_node": [],
                    "word_node" : []}
    label_word_graph = {"label_node" : [],
                        "word_node" : []}
    for doc_id, (label_id, ids) in enumerate(tqdm(zip(ALL_df['target'], ALL_df['ids']))):
        data = ids
        doc_word_graph['doc_node'] += [doc_id for _ in data]
        doc_word_graph['word_node'] += data

        label_word_graph['label_node'] += [label_id for _ in data]
        label_word_graph['word_node'] += data
    doc_num = doc_id + 1

    doc_word_mat = dglsp.spmatrix(
        indices = torch.tensor([doc_word_graph['doc_node'], doc_word_graph['word_node']]),
        shape = (doc_num, len(voc))
        )

    label_word_mat = dglsp.spmatrix(
        indices = torch.tensor([label_word_graph['label_node'], label_word_graph['word_node']]),
        shape = (label_num, len(voc))
        )
    
    containMat = doc_word_mat.to_dense()
    countMat = doc_word_mat.coalesce().to_dense()
    tf = (countMat.T / (countMat.sum(dim=1) + 1e-9)).T
    idf = torch.log10(containMat.shape[0] / (containMat.sum(dim=0) + 1e-9))
    doc_word_tfidf = (tf * idf).to_sparse()

    containMat = label_word_mat.to_dense()
    countMat = label_word_mat.coalesce().to_dense()
    tf = (countMat.T / (countMat.sum(dim=1) + 1e-9)).T
    idf = torch.log10(containMat.shape[0] / (containMat.sum(dim=0) + 1e-9))
    label_word_tfidf = (tf * idf).to_sparse()
    word_Y = countMat / countMat.T.sum(dim=1)
    src_dst_nodes = {}
    total_window = 0.
    for ids in tqdm(ALL_df['ids'], desc="Constructing Graph..."):
        for w in range(max(len(ids) - window_size + 1, 1)): 
            window = set(ids[w : w + window_size])
            for i in window:
                for j in window:
                    if (i, j) not in src_dst_nodes:
                        src_dst_nodes.update({(i, j) : 0})
                    src_dst_nodes[(i,j)] += 1
            total_window+=1.
    src_nodes = []
    dst_nodes = []
    values = []
    for (i,j) in tqdm(src_dst_nodes.keys(), desc='Building Graph...'):
        src_nodes.append(i)
        dst_nodes.append(j)
        values.append(src_dst_nodes[(i,j)])

    co_occurMat = dglsp.spmatrix(
        torch.tensor([src_nodes, dst_nodes]), 
        val= torch.tensor(values, dtype=torch.float32), 
        shape=(len(voc), len(voc))
        )
    
    p_mat = (co_occurMat.to_dense() / total_window)
    p_diag = torch.diag(p_mat)
    p_diag = p_diag.unsqueeze(1) @ p_diag.unsqueeze(0)
    word_word_PMI = (torch.log10(p_mat/(p_diag + 1e-9) + 1)).to_sparse()
    word_word_PMI = nomalize_Adj(word_word_PMI)

    whole_graph = torch.zeros(size = (word_word_PMI.shape[0] + doc_word_tfidf.shape[0], word_word_PMI.shape[0] + doc_word_tfidf.shape[0]))
    D = doc_word_tfidf.shape[0]
    W = word_word_PMI.shape[0]
    L = label_word_tfidf.shape[0]
    A = whole_graph.shape[0]
    whole_graph[:D, :D] = torch.ones(size = (doc_num,)).diagflat() # doc-doc identity
    whole_graph[:D,D:] = doc_word_tfidf.to_dense() # tf-idf doc
    whole_graph[D:, :D] = doc_word_tfidf.T.to_dense() # tf-idf doc T
    whole_graph[D:, D:] = word_word_PMI.to_dense() # word-word PMI
    # whole_graph[D:A-L, D+W:] = label_word_tfidf.T.to_dense() # label-word T
    # whole_graph[D+W:, D:A-L] = label_word_tfidf.to_dense() #label-word
    # whole_graph[D+W:, D+W:] = torch.ones(size = (label_num,)).diagflat() # label-label identity
    whole_graph = whole_graph.to_sparse()
    doc_Y = torch.tensor(ALL_df['target'].to_list(), dtype=torch.int64)
    # label_Y = torch.arange(0, L, 1, dtype = torch.int64)
    train_mask = torch.zeros(size=(len(ALL_df), ), dtype=torch.bool)
    train_mask[:len(train_df)] = True
    train_mask[len(train_df):] = False
    torch.save({
        "voc":voc,
        'train_word' : train_word_set,
        'test_word': test_word_set,
        "whole_graph": whole_graph,
        "doc_Y":doc_Y,
        'word_Y': word_Y,
        # "label_Y":label_Y,
        "train_mask": train_mask,
        "D":D,
        "W":W,
        "L":L
    }, save_path)
    print(f'**********************************')
    print(f'Done ! Dataset :{dataset}')
    print(f"==================================")

Current Dataset = R8
**********************************
Current Window Size = 20


Counting Word...: 100%|██████████| 5485/5485 [00:00<00:00, 81847.72it/s]
Counting Word...: 100%|██████████| 2189/2189 [00:00<00:00, 87265.89it/s]
constrcuting train_word_set...: 100%|██████████| 5485/5485 [00:00<00:00, 456998.42it/s]
constructing test_word_set...: 100%|██████████| 2189/2189 [00:00<00:00, 437685.63it/s]
7674it [00:00, 182671.53it/s]
Constructing Graph...: 100%|██████████| 7674/7674 [00:22<00:00, 343.66it/s]
Building Graph...: 100%|██████████| 3585748/3585748 [00:01<00:00, 2284997.96it/s]


**********************************
Done ! Dataset :R8
