In [1]:
import torch
from tqdm import tqdm
from model import Vocaburary, TextGCN

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
EPOCH = 300
HIDDEN_DIM = 200

In [3]:
dataset_names = {
    "20NewsGroup": "20NG",
    "MR":"mr",
    "Ohsumed":"ohsumed_single_23",
    "R52":"R52",
    "R8":"R8"
}
dir_name = dataset_names['MR']

In [4]:
dict_data = torch.load(f'./ProcessedData/{dir_name}/WholeGraphDict.gh')
dict_data.keys()

dict_keys(['voc', 'train_word', 'test_word', 'whole_graph', 'doc_Y', 'word_Y', 'label_Y', 'train_mask', 'D', 'W', 'L'])

In [5]:
voc : Vocaburary = dict_data['voc']
word_num = dict_data['W']
label_num = dict_data['L']
doc_num = dict_data['D']
train_mask = dict_data['train_mask'].cuda()
doc_Y : torch.Tensor = dict_data['doc_Y'].cuda()
word_Y : torch.Tensor = dict_data['word_Y'].T.cuda()
whole_graph = dict_data['whole_graph'].to_dense()[:-label_num, :-label_num].to_sparse().cuda()
train_words = list(dict_data['train_word'])
test_words = list(dict_data['test_word'])

In [6]:
train_words.sort()
test_words.sort()

In [7]:
train_num = train_mask.count_nonzero().cpu().item()
test_num = doc_num - train_num

In [8]:
{
    "#DOC":doc_num,
    "#Word":word_num,
    "#Class":label_num,
    "#Train" : train_num,
    "#Test" : test_num,
    "#NODE" : word_num + doc_num
}

{'#DOC': 10662,
 '#Word': 18764,
 '#Class': 2,
 '#Train': 7108,
 '#Test': 3554,
 '#NODE': 29426}

In [9]:
log = {"ACC": [], "LOSS": []}
model = TextGCN(whole_graph.shape[0], HIDDEN_DIM, label_num).cuda()
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
trainingProcess = tqdm(range(EPOCH))
for epoch in trainingProcess:
    total_loss = 0.
    optim.zero_grad()
    y_hat = model(whole_graph)
    doc_Y_hat = y_hat[:doc_num]
    word_Y_hat = y_hat[doc_num:]
    doc_loss = loss_fn(doc_Y_hat[train_mask], doc_Y[train_mask])
    word_loss = loss_fn(word_Y_hat[train_words], word_Y[train_words])
    loss = 1.0 * doc_loss + 1.0 * word_loss
    loss.backward()
    optim.step()
    loss_val = loss.item()
    with torch.no_grad():
        acc_val = ((doc_Y_hat.argmax(1)[~train_mask] == doc_Y.cuda()[~train_mask]).sum()/ (~train_mask).sum()).item() * 100.
    trainingProcess.set_postfix({"LOSS": loss_val, "Accuracy":acc_val})
    log['ACC'].append(acc_val)
    log["LOSS"].append(loss_val)

100%|██████████| 300/300 [00:14<00:00, 20.90it/s, LOSS=0.371, Accuracy=85.9]


In [10]:
max(log['ACC'])

90.65841436386108