In [None]:
%cd /content/drive/MyDrive
!git clone https://github.com/Toan-it-mta/GraphConvolutionNetwork.git

In [None]:
!python -c "import torch; print(torch.__version__)"
!python -c "import torch; print(torch.version.cuda)"

1.10.0+cu111
11.1


#Cài thư viện

In [None]:
!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
!pip install torch-geometric
!pip install bpemb
!pip install sentence_transformers

In [None]:
%cd /content/drive/MyDrive/GraphConvolutionNetwork
!python dataset.py

In [None]:
%cd /content/drive/MyDrive/GraphConvolutionNetwork
!python train.py

#Tiến hành train mô hình

In [None]:
import os
import torch
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from invoiceGCN import InvoiceGCN
import torch.nn.functional as F

def load_train_test_split(save_fd):
    train_data = torch.load(os.path.join(save_fd, 'train_data.dataset'))
    test_data = torch.load(os.path.join(save_fd, 'test_data.dataset'))
    return train_data, test_data

#Thay đổi đường dẫn tới thư mục data
train_data, test_data = load_train_test_split(save_fd="/content/drive/MyDrive/GraphConvolutionNetwork/dataset/Vietnamese")

model = InvoiceGCN(input_dim=train_data.x.shape[1], chebnet=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = torch.optim.AdamW(
    model.parameters(), lr=0.001, weight_decay=0.9
)
train_data = train_data.to(device)
test_data = test_data.to(device)

# class weights for imbalanced data
_class_weights = compute_class_weight(class_weight="balanced", classes=train_data.y.unique().cpu().numpy(), y=train_data.y.cpu().numpy())
print(_class_weights)

no_epochs = 2000
for epoch in range(1, no_epochs + 1):
    print(f'epoch: {epoch} \n')
    model.train()
    optimizer.zero_grad()

    # NOTE: just use boolean indexing to filter out test data, and backward after that!
    # the same holds true with test data :D
    # https://github.com/rusty1s/pytorch_geometric/issues/1928
    loss = F.nll_loss(
        model(train_data), train_data.y - 1, weight=torch.FloatTensor(_class_weights).to(device)
    )
    loss.backward()
    optimizer.step()

    # calculate acc on 5 classes
    with torch.no_grad():

        if epoch % 200 == 0:
            model.eval()

            # forward model
            for index, name in enumerate(['train', 'test']):
                _data = eval("{}_data".format(name))
                y_pred = model(_data).max(dim=1)[1]
                y_true = (_data.y - 1)
                acc = y_pred.eq(y_true).sum().item() / y_pred.shape[0]

                y_pred = y_pred.cpu().numpy()
                y_true = y_true.cpu().numpy()
                print("\t{} acc: {}".format(name, acc))
                # confusion matrix
                if name == 'test':
                    cm = confusion_matrix(y_true, y_pred)
                    class_accs = cm.diagonal() / cm.sum(axis=1)
                    print(classification_report(y_true, y_pred))

            loss_val = F.nll_loss(model(test_data), test_data.y - 1
            )
            fmt_log = "Epoch: {:03d}, train_loss:{:.4f}, val_loss:{:.4f}"
            print(fmt_log.format(epoch, loss, loss_val))
            print(">" * 50)
    

#Chạy tường minh kết quả

In [None]:
import shutil
from tqdm.notebook import tqdm as tqdm_nb
from graph import Grapher
from torch_geometric.utils.convert import from_networkx
import time
import numpy as np
import cv2
import matplotlib.pyplot as plt



test_output_fd = "/content/drive/MyDrive/GraphConvolutionNetwork/dataset/Vietnamese/outputs"
shutil.rmtree(test_output_fd)
if not os.path.exists(test_output_fd):
    os.mkdir(test_output_fd)

def make_info(img_id='584'):
    connect = Grapher(img_id, data_fd='/content/drive/MyDrive/GraphConvolutionNetwork/dataset/Vietnamese')
    G, _, _ = connect.graph_formation()
    df = connect.relative_distance()
    individual_data = from_networkx(G)
    img_fd = '/content/drive/MyDrive/GraphConvolutionNetwork/dataset/Vietnamese/raw/img'
    img = cv2.imread(os.path.join(img_fd, "{}.jpg".format(img_id)))[:, :, ::-1]

    return G, df, individual_data, img

y_preds = model(test_data).max(dim=1)[1].cpu().numpy()
LABELS = ["company", "address", "date", "total", "other"]
test_batch = test_data.batch.cpu().numpy()
indexes = range(len(test_data.img_id))
# print(indexes)
for index in tqdm_nb(indexes):
    start = time.time()
    img_id = test_data.img_id[index]  # not ordering by img_id
    sample_indexes = np.where(test_batch == index)[0]
    y_pred = y_preds[sample_indexes]

    # print("Img index: {}".format(index))
    # print("Img id: {}".format(img_id))
    # print('y_pred: {}'.format(y_pred))
    # print("y_pred: {}".format(Counter(y_pred)))
    G, df, individual_data, img = make_info(img_id)
    try:
        assert len(y_pred) == df.shape[0]

        img2 = np.copy(img)
        for row_index, row in df.iterrows():
            x1, y1, x2, y2 = row[['xmin', 'ymin', 'xmax', 'ymax']]
            true_label = row["labels"]

            if isinstance(true_label, str) and true_label != "invoice":
                cv2.rectangle(img2, (x1, y1), (x2, y2), (0, 255, 0), 2)

            _y_pred = y_pred[row_index]
            if _y_pred != 4:
                cv2.rectangle(img2, (x1, y1), (x2, y2), (255, 0, 0), 3)
                _label = LABELS[_y_pred]
                cv2.putText(
                    img2, "{}".format(_label), (x1, y1),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2
                )

        end = time.time()
        # print("\tImage {}: {}".format(img_id, end - start))
        # plt.imshow(img2)
        plt.savefig(os.path.join(test_output_fd, '{}_result.png'.format(img_id)), bbox_inches='tight')
        # plt.savefig('{}_result.png'.format(img_id), bbox_inches='tight')
    except:
        continue

  0%|          | 0/389 [00:00<?, ?it/s]

<Figure size 432x288 with 0 Axes>