In [1]:
import torch
from tqdm import tqdm
from model import Vocaburary, TextGCN
from time import time
from datetime import timedelta
import os
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
EPOCH = 300
HIDDEN_DIM = 200

In [3]:
dataset_names = {
    "20NewsGroup": "20NG",
    "MR":"mr",
    "Ohsumed":"ohsumed_single_23",
    "R52":"R52",
    "R8":"R8"
}

result = {k : {} for k in dataset_names.keys()}

In [4]:
SAVE_PATH = './result/all_traning.result'

In [5]:
if not os.path.exists(SAVE_PATH):
    for key in dataset_names.keys():
        print("===================================")
        print(key)
        dir_name = dataset_names[key]
        # init to memory
        start_time = time()
        dict_data = torch.load(f'./ProcessedData/{dir_name}/WholeGraphDict.gh')
        voc : Vocaburary = dict_data['voc']
        whole_graph = dict_data['whole_graph'].cuda()
        word_num = dict_data['W']
        label_num = dict_data['L']
        doc_num = dict_data['D']
        train_mask = dict_data['train_mask'].cuda()
        doc_Y : torch.Tensor = dict_data['doc_Y'].cuda()
        word_Y : torch.Tensor = dict_data['word_Y'].T.cuda()
        label_Y : torch.Tensor = dict_data['label_Y'].cuda()
        train_words = list(dict_data['train_word'])
        test_words = list(dict_data['test_word'])
        train_words.sort()
        test_words.sort()
        train_num = train_mask.count_nonzero().cpu().item()
        test_num = doc_num - train_num

        result[key]['statistic'] = {
        "#DOC":doc_num,
        "#Word":word_num,
        "#Class":label_num,
        "#Train" : train_num,
        "#Test" : test_num,
        "#NODE" : word_num + doc_num + label_num
        }
        result[key]['init_time'] = time() - start_time
        # end of init

        model = TextGCN(whole_graph.shape[0], HIDDEN_DIM, label_num).cuda()
        optim = torch.optim.Adam(model.parameters(), lr=1e-3)
        loss_fn = torch.nn.CrossEntropyLoss()
        trainingProcess = tqdm(range(EPOCH))
        result[key]['testing_time_per_epoch'] = []
        result[key]['training_time_per_epoch'] = []
        result[key]['loss'] = []
        result[key]['test_accuracy'] = []
        for epoch in trainingProcess:
            # start of training
            training_start_time_per_epoch = time()
            total_loss = 0.
            optim.zero_grad()
            y_hat = model(whole_graph)
            doc_Y_hat = y_hat[:doc_num]
            word_Y_hat = y_hat[doc_num:-label_num]
            label_Y_hat = y_hat[doc_num+word_num :]
            doc_loss = loss_fn(doc_Y_hat[train_mask], doc_Y[train_mask])
            word_loss = loss_fn(word_Y_hat[train_words], word_Y[train_words])
            label_loss = loss_fn(label_Y_hat, label_Y)
            loss = 1.0 * doc_loss + 1.0 * word_loss  + 1.0 * label_loss
            loss.backward()
            optim.step()
            result[key]['training_time_per_epoch'].append(time() - training_start_time_per_epoch)
            # end of training_per_epoch
            loss_val = loss.item()
            # start of testing_per_epoch
            testing_start_time_per_epoch = time()
            with torch.no_grad():
                acc_val = ((doc_Y_hat.argmax(1)[~train_mask] == doc_Y.cuda()[~train_mask]).sum() / (~train_mask).sum()).item()
            result[key]['testing_time_per_epoch'].append(time() - testing_start_time_per_epoch)
            # end of testing_per_epoch
            result[key]['loss'].append(loss_val)
            result[key]['test_accuracy'].append(acc_val)
            trainingProcess.set_postfix({"LOSS": loss_val, "ACC" : acc_val})
    with open(SAVE_PATH, 'wb') as f:
        pickle.dump(result, f)
else:
    with open(SAVE_PATH, 'rb') as f:
        result = pickle.load(f)

In [24]:
for key in result.keys():
    print(f"===================")
    print(f"Result of {key}")
    # statistic
    print(result[key]['statistic'])
    print(f"Init Time = {result[key]['init_time']}")
    print(f"Total Training Time = {sum(result[key]['training_time_per_epoch'])}")
    print(f"Total Testing Time = {sum(result[key]['testing_time_per_epoch'])}")
    print(f"Excution Time = {sum(result[key]['training_time_per_epoch']) + sum(result[key]['testing_time_per_epoch']) + result[key]['init_time']}")

    print(f"BEST_EPOCH Accuracy = {max(result[key]['test_accuracy'])}")

Result of 20NewsGroup
{'#DOC': 18846, '#Word': 42757, '#Class': 20, '#Train': 11314, '#Test': 7532, '#NODE': 61623}
Init Time = 0.7369239330291748
Total Training Time = 96.07711815834045
Total Testing Time = 0.12513041496276855
Excution Time = 96.9391725063324
BEST_EPOCH Accuracy = 0.9492830634117126
Result of MR
{'#DOC': 10662, '#Word': 18764, '#Class': 2, '#Train': 7108, '#Test': 3554, '#NODE': 29428}
Init Time = 0.03851199150085449
Total Training Time = 7.527169227600098
Total Testing Time = 0.1331028938293457
Excution Time = 7.698784112930298
BEST_EPOCH Accuracy = 0.9079909920692444
Result of Ohsumed
{'#DOC': 7400, '#Word': 14157, '#Class': 23, '#Train': 3357, '#Test': 4043, '#NODE': 21580}
Init Time = 0.08857846260070801
Total Training Time = 28.773300886154175
Total Testing Time = 0.14446735382080078
Excution Time = 29.006346702575684
BEST_EPOCH Accuracy = 0.8864704370498657
Result of R52
{'#DOC': 9100, '#Word': 8892, '#Class': 52, '#Train': 6532, '#Test': 2568, '#NODE': 18044}
I