In [1]:
import os
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
index_to_label = {
    0: 'Catering',
    1: 'Education',
    2: 'Medical',
    3: 'Outdoor',
    4: 'Residential',
    5: 'Shopping',
    6: 'Tourism',
    7: 'Transport',
    8: 'Office'
}

## Model

In [3]:
class GCN_RES(nn.Module):
    def __init__(self, dim_in, dim_h, dim_out, num_layers=5):
        super(GCN_RES, self).__init__()
        # 网络层数
        self.num_layers = num_layers

        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()

        # 输入层
        self.convs.append(GCNConv(dim_in, dim_h))
        self.norms.append(nn.LayerNorm(dim_h))
        
        # 中间层
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(dim_h, dim_h))
            self.norms.append(nn.LayerNorm(dim_h))
        
        # 输出层
        self.convs.append(GCNConv(dim_h, dim_out))
        self.norms.append(nn.LayerNorm(dim_out))

        self.apply(self.weights_init)

    def weights_init(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, x, edge_index, edge_weight=None):
        for i in range(self.num_layers):
            if i == 0:
                h = self.convs[i](x, edge_index, edge_weight)
                h = self.norms[i](h)
                res = F.relu(h) + x
                h = F.dropout(res, p=0.6, training=self.training)
                
            elif i != (self.num_layers - 1):
                h = self.convs[i](h, edge_index, edge_weight)
                h = self.norms[i](h)
                res = F.relu(h) + res
                h = F.dropout(res, p=0.6, training=self.training)
                
            else:
                h = self.convs[i](h, edge_index, edge_weight)
                h = self.norms[i](h)

        return h

## Train

In [7]:
def train(model, data, optimizer, criterion, save_model, epochs=200):
    train_loss = []
    val_loss = []
    val_accs = []
    
    early_stopping_counter = 0
    max_acc = 0.0

    model.train()
    for epoch in range(epochs+1):
        data = data.to(device)
        optimizer.zero_grad()
        _ = model(data.x, data.edge_index, data.edge_attr)
        out = F.log_softmax(_, dim=1)

        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        train_loss.append(loss.item())

        loss.backward()
        optimizer.step()

        preds = out.argmax(dim=1)[data.train_mask].cpu()
        acc = accuracy_score(data.y[data.train_mask].cpu(), preds)

        f1 = f1_score(data.y[data.train_mask].cpu(), preds, average='macro')

        model.eval()
        with torch.no_grad():
            val_loss_ = criterion(out[data.val_mask], data.y[data.val_mask])
            val_loss.append(val_loss_.item())
            
            _ = model(data.x, data.edge_index, data.edge_attr)
            out = F.log_softmax(_, dim=1)
            
            val_acc = accuracy_score(data.y[data.val_mask].cpu(), out.argmax(dim=1)[data.val_mask].cpu())
            val_accs.append(val_acc)

            # 早停策略
            if val_acc > max_acc:
                max_acc = val_acc
                early_stopping_counter = 0
                torch.save(model, save_model)
            else:
                early_stopping_counter += 1
                if early_stopping_counter >= 20 and epoch >=100: break

        if epoch % 10 == 0:
            print(f'Epoch [{epoch:03d}/{epochs}], Train_Loss: {loss.item():0.3f}, Val_Loss: {val_loss[-1]:0.3f}, Val_acc: {max(val_accs):.3f}')
            val_accs = []

## Test

In [29]:
def test(model, data):
    model.eval()
    _ = model(data.x, data.edge_index, data.edge_attr)
    out = F.log_softmax(_, dim=1)
    
    y_pred = out.argmax(dim=1)[data.val_mask].cpu()
    y_true = data.y[data.val_mask].cpu()
    
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='macro')
    recall = recall_score(y_true, y_pred, average='macro')
    f1 = f1_score(y_true, y_pred, average='macro')
    
    fig, ax = plt.subplots(1,1, figsize=(18,6))
    cm = confusion_matrix(y_true, y_pred)
    cm_normal = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]  # 归一化
    ticks = list(index_to_label.keys())
    labels = list(index_to_label.values())
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.set_xticklabels(labels, rotation=90)
    ax.set_yticklabels(labels)
    im = ax.imshow(cm_normal, interpolation='nearest', cmap=plt.get_cmap('Blues'))
    im.set_clim(0.0, 1.0)
    for i in range(np.shape(cm_normal)[0]):
        for j in range(np.shape(cm_normal)[1]):
            if int(cm_normal[i][j] * 100 + 0.5) >= 0:
                ax.text(j, i, '('+str(cm[i][j])+')' + '\n' + str(round(cm_normal[i][j]*100,1))+'%',
                         ha="center", va="center", fontsize=10,
                         color="white" if cm_normal[i][j] > 0.8 else "black")  # 如果要更改颜色风格，需要同时更改此行
    plt.colorbar(im, ax=ax)
    plt.tight_layout()
    plt.show()

    
    result = {}
    result['accuracy'] = accuracy
    result['precision'] = precision
    result['recall'] = recall
    result['f1'] = f1
    return result

## Experience

In [6]:
path_dataset = 'data/final_data/dataset_torch/'
layers = [2, 2, 3]
path_model_out = 'data/final_data/model/'

### finetune bert

In [9]:
for i,dataset_ in enumerate([dataset for dataset in os.listdir(path_dataset) if dataset[-8:] != 'bert.pth' and dataset.endswith('.pth')]):
    print(dataset_)
    dataset = torch.load(path_dataset+dataset_)
    gcn = GCN_RES(dataset.num_node_features, dataset.num_node_features, dataset.num_classes, layers[i]).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(gcn.parameters(), lr=1e-3, weight_decay=5e-4)
    train(gcn, dataset, optimizer, criterion, path_model_out+dataset_, epochs=250)
    result = test(gcn, dataset)
    print(result)

### bert

In [1]:
layers = [3, 2, 2]
for i,dataset_ in enumerate([dataset for dataset in os.listdir(path_dataset) if dataset[-8:] == 'bert.pth' and dataset.endswith('.pth')]):
    print(dataset_)
    dataset = torch.load(path_dataset+dataset_).to(device)
    gcn = GCN_RES(dataset.num_node_features, dataset.num_node_features, dataset.num_classes, layers[i]).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(gcn.parameters(), lr=1e-3, weight_decay=5e-4)
    train(gcn, dataset, optimizer, criterion, path_model_out+dataset_, epochs=250)
    result = test(gcn, dataset)
    print(result)