In [1]:
import torch
from tqdm import tqdm
from model import Vocaburary, TextGCN

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
EPOCH = 300
HIDDEN_DIM = 200

In [4]:
dataset_names = {
    "20NewsGroup": "20NG",
    "MR":"mr",
    "Ohsumed":"ohsumed_single_23",
    "R52":"R52",
    "R8":"R8"
}

In [5]:
for key in dataset_names.keys():
    print(key)
    dir_name = dataset_names[key]
    dict_data = torch.load(f'./ProcessedData/{dir_name}/WholeGraphDict_w20_wihout_c.gh')
    voc : Vocaburary = dict_data['voc']
    whole_graph = dict_data['whole_graph'].to(device)
    word_num = dict_data['W']
    label_num = dict_data['L']
    doc_num = dict_data['D']
    train_mask = dict_data['train_mask'].to(device)
    doc_Y : torch.Tensor = dict_data['doc_Y'].to(device)
    word_Y : torch.Tensor = dict_data['word_Y'].T.to(device)
    train_words = list(dict_data['train_word'])
    test_words = list(dict_data['test_word'])
    train_words.sort()
    test_words.sort()
    train_num = train_mask.count_nonzero().cpu().item()
    test_num = doc_num - train_num
    print({
    "#DOC":doc_num,
    "#Word":word_num,
    "#Class":label_num,
    "#Train" : train_num,
    "#Test" : test_num,
    "#NODE" : word_num + doc_num
    })
    log = {"ACC": [], "LOSS": []}
    model = TextGCN(whole_graph.shape[0], HIDDEN_DIM, label_num).to(device)
    optim = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()
    trainingProcess = tqdm(range(EPOCH))
    for epoch in trainingProcess:
        total_loss = 0.
        optim.zero_grad()
        y_hat = model(whole_graph)
        doc_Y_hat = y_hat[:doc_num]
        word_Y_hat = y_hat[doc_num:]
        doc_loss = loss_fn(doc_Y_hat[train_mask], doc_Y[train_mask])
        word_loss = loss_fn(word_Y_hat[train_words], word_Y[train_words])
        loss = 1.0 * doc_loss + 1.0 * word_loss
        loss.backward()
        optim.step()
        loss_val = loss.item()
        with torch.no_grad():
            acc_val = ((doc_Y_hat.argmax(1)[~train_mask] == doc_Y.to(device)[~train_mask]).sum() / (~train_mask).sum()).item()
        trainingProcess.set_postfix({"LOSS": loss_val, "Accuracy":acc_val})
        log['ACC'].append(acc_val)
        log["LOSS"].append(loss_val)
    print(max(log['ACC']))

20NewsGroup
{'#DOC': 18846, '#Word': 42757, '#Class': 20, '#Train': 11314, '#Test': 7532, '#NODE': 61603}
