In [1]:
from models import GAT, GCN
from utils import TSP
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

BATCH_SIZE = 100
EPOCH = 1000
TEST = 100
LR = 0.001

# Data

In [2]:
trainset = TSP(20, 'train_tsp20')
validset = TSP(20, 'valid_tsp20')

trainloader = DataLoader(trainset, BATCH_SIZE, shuffle=True, num_workers=6)
validloader = DataLoader(validset, BATCH_SIZE, shuffle=True, num_workers=6)

# Model

In [3]:
model = GCN(trainset[0][2].shape[1], 1024, trainset[0][0].shape[0], 0.3, True).cuda()
# model = GAT(trainset[0][1].shape[1], 16, trainset[0][0].shape[0], 0.3, True).cuda()
loss_fn = nn.CrossEntropyLoss()
optim = Adam(model.parameters(), lr=LR)

# Train

In [4]:
for epoch in range(EPOCH):
    for i, (adjs, labels, features) in enumerate(trainloader):
        features = features.type(torch.float).cuda()
        adjs = adjs.type(torch.float).cuda()
        labels = labels.type(torch.long).cuda()
        
        optim.zero_grad()
        outputs = model(features, adjs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optim.step()

        predicts = outputs.argmax(dim=2)
        precision = accuracy_score(labels.flatten().cpu(), predicts.flatten().cpu())
    print("[{:3d}/{:3d}]  loss: {:.4f}   precision: {:5.2%}".format(epoch+1, EPOCH, loss, precision))

#     if(epoch % TEST == TEST-1):
#         model.eval()
        
#         outputs = model(features)
#         loss = loss_fn(outputs, adjs)
        
#         mask = outputs.cpu() > 0.9
#         predicts = torch.zeros(outputs.shape).masked_fill_(mask, 1.0)
#         precision = 0.0
#         for adj, predict in zip(adjs, predicts):
#             precision += accuracy_score(adj.cpu(), predict) / adjs.shape[0]
        
#         print("=============================================")
#         print("[Testing]  loss: {:.4f}   precision: {:5.2%}".format(loss, precision))
#         print("=============================================")
        
#         model.train()

[  1/1000]  loss: 3.0008   precision: 4.20%
[  2/1000]  loss: 2.9959   precision: 5.45%
[  3/1000]  loss: 2.9971   precision: 5.35%
[  4/1000]  loss: 2.9978   precision: 4.50%
[  5/1000]  loss: 2.9959   precision: 5.25%
[  6/1000]  loss: 2.9962   precision: 5.05%
[  7/1000]  loss: 2.9978   precision: 5.05%
[  8/1000]  loss: 2.9951   precision: 4.85%
[  9/1000]  loss: 2.9955   precision: 4.90%
[ 10/1000]  loss: 2.9943   precision: 4.60%
[ 11/1000]  loss: 2.9948   precision: 5.30%
[ 12/1000]  loss: 2.9972   precision: 5.05%
[ 13/1000]  loss: 2.9961   precision: 5.50%
[ 14/1000]  loss: 2.9964   precision: 4.95%
[ 15/1000]  loss: 2.9963   precision: 4.60%
[ 16/1000]  loss: 2.9953   precision: 4.80%
[ 17/1000]  loss: 2.9973   precision: 4.85%
[ 18/1000]  loss: 2.9967   precision: 4.80%
[ 19/1000]  loss: 2.9950   precision: 5.45%
[ 20/1000]  loss: 2.9945   precision: 5.05%
[ 21/1000]  loss: 2.9959   precision: 5.00%
[ 22/1000]  loss: 2.9952   precision: 4.90%
[ 23/1000]  loss: 2.9967   preci

[188/1000]  loss: 2.9938   precision: 5.20%
[189/1000]  loss: 2.9947   precision: 5.10%
[190/1000]  loss: 2.9941   precision: 4.90%
[191/1000]  loss: 2.9952   precision: 4.35%
[192/1000]  loss: 2.9947   precision: 5.15%
[193/1000]  loss: 2.9937   precision: 5.45%
[194/1000]  loss: 2.9950   precision: 5.15%
[195/1000]  loss: 2.9960   precision: 4.85%
[196/1000]  loss: 2.9947   precision: 4.55%
[197/1000]  loss: 2.9923   precision: 5.15%
[198/1000]  loss: 2.9960   precision: 5.00%
[199/1000]  loss: 2.9961   precision: 4.95%
[200/1000]  loss: 2.9975   precision: 5.35%
[201/1000]  loss: 2.9944   precision: 5.75%
[202/1000]  loss: 2.9957   precision: 5.25%
[203/1000]  loss: 2.9934   precision: 4.90%
[204/1000]  loss: 2.9934   precision: 5.35%
[205/1000]  loss: 2.9970   precision: 5.25%
[206/1000]  loss: 2.9952   precision: 4.75%
[207/1000]  loss: 2.9954   precision: 5.30%
[208/1000]  loss: 2.9912   precision: 5.00%
[209/1000]  loss: 2.9987   precision: 5.15%
[210/1000]  loss: 2.9940   preci

[375/1000]  loss: 2.9927   precision: 5.70%
[376/1000]  loss: 2.9952   precision: 4.95%
[377/1000]  loss: 2.9981   precision: 4.80%
[378/1000]  loss: 2.9933   precision: 5.40%
[379/1000]  loss: 2.9931   precision: 4.95%
[380/1000]  loss: 2.9968   precision: 4.45%
[381/1000]  loss: 2.9933   precision: 4.95%
[382/1000]  loss: 2.9947   precision: 5.00%
[383/1000]  loss: 2.9931   precision: 4.65%
[384/1000]  loss: 2.9945   precision: 4.90%
[385/1000]  loss: 2.9913   precision: 4.95%
[386/1000]  loss: 2.9960   precision: 5.20%
[387/1000]  loss: 2.9935   precision: 5.35%
[388/1000]  loss: 2.9937   precision: 4.95%
[389/1000]  loss: 2.9959   precision: 4.95%
[390/1000]  loss: 2.9949   precision: 4.65%
[391/1000]  loss: 2.9955   precision: 5.05%
[392/1000]  loss: 2.9944   precision: 4.70%
[393/1000]  loss: 2.9921   precision: 5.40%
[394/1000]  loss: 2.9942   precision: 5.40%
[395/1000]  loss: 2.9930   precision: 4.95%
[396/1000]  loss: 2.9936   precision: 4.95%
[397/1000]  loss: 2.9942   preci

[562/1000]  loss: 2.9959   precision: 5.20%
[563/1000]  loss: 2.9937   precision: 5.25%
[564/1000]  loss: 2.9937   precision: 4.65%
[565/1000]  loss: 2.9935   precision: 5.00%
[566/1000]  loss: 2.9936   precision: 5.15%
[567/1000]  loss: 2.9950   precision: 5.10%
[568/1000]  loss: 2.9940   precision: 5.10%
[569/1000]  loss: 2.9961   precision: 5.95%
[570/1000]  loss: 2.9977   precision: 4.95%
[571/1000]  loss: 2.9945   precision: 4.75%
[572/1000]  loss: 2.9931   precision: 5.25%
[573/1000]  loss: 2.9921   precision: 4.70%
[574/1000]  loss: 2.9942   precision: 5.30%
[575/1000]  loss: 2.9956   precision: 4.75%
[576/1000]  loss: 2.9934   precision: 5.35%
[577/1000]  loss: 2.9920   precision: 5.15%
[578/1000]  loss: 2.9933   precision: 4.90%
[579/1000]  loss: 2.9937   precision: 5.20%
[580/1000]  loss: 2.9939   precision: 5.05%
[581/1000]  loss: 2.9932   precision: 4.75%
[582/1000]  loss: 2.9926   precision: 5.15%
[583/1000]  loss: 2.9921   precision: 5.30%
[584/1000]  loss: 2.9951   preci

[749/1000]  loss: 2.9944   precision: 4.70%
[750/1000]  loss: 2.9938   precision: 5.15%
[751/1000]  loss: 2.9917   precision: 5.25%
[752/1000]  loss: 2.9900   precision: 4.70%
[753/1000]  loss: 2.9915   precision: 5.40%
[754/1000]  loss: 2.9940   precision: 4.90%
[755/1000]  loss: 2.9983   precision: 4.80%
[756/1000]  loss: 2.9952   precision: 5.10%
[757/1000]  loss: 2.9942   precision: 4.15%
[758/1000]  loss: 2.9951   precision: 5.15%
[759/1000]  loss: 2.9906   precision: 5.05%
[760/1000]  loss: 2.9924   precision: 5.20%
[761/1000]  loss: 2.9974   precision: 5.35%
[762/1000]  loss: 2.9936   precision: 4.80%
[763/1000]  loss: 2.9926   precision: 5.05%
[764/1000]  loss: 2.9901   precision: 5.25%
[765/1000]  loss: 2.9976   precision: 4.75%
[766/1000]  loss: 2.9911   precision: 5.10%
[767/1000]  loss: 2.9924   precision: 4.40%
[768/1000]  loss: 2.9918   precision: 4.55%
[769/1000]  loss: 2.9919   precision: 5.00%
[770/1000]  loss: 2.9907   precision: 5.25%
[771/1000]  loss: 2.9910   preci

[936/1000]  loss: 2.9926   precision: 5.05%
[937/1000]  loss: 2.9938   precision: 4.85%
[938/1000]  loss: 2.9895   precision: 4.45%
[939/1000]  loss: 2.9930   precision: 5.35%
[940/1000]  loss: 2.9901   precision: 5.80%
[941/1000]  loss: 2.9918   precision: 4.50%
[942/1000]  loss: 2.9910   precision: 5.70%
[943/1000]  loss: 2.9903   precision: 5.80%
[944/1000]  loss: 2.9925   precision: 5.10%
[945/1000]  loss: 2.9936   precision: 4.70%
[946/1000]  loss: 2.9944   precision: 5.05%
[947/1000]  loss: 2.9920   precision: 5.30%
[948/1000]  loss: 2.9941   precision: 4.20%
[949/1000]  loss: 2.9926   precision: 5.35%
[950/1000]  loss: 2.9934   precision: 5.90%
[951/1000]  loss: 2.9920   precision: 5.20%
[952/1000]  loss: 2.9963   precision: 5.30%
[953/1000]  loss: 2.9919   precision: 5.25%
[954/1000]  loss: 2.9921   precision: 4.80%
[955/1000]  loss: 2.9919   precision: 4.90%
[956/1000]  loss: 2.9929   precision: 4.75%
[957/1000]  loss: 2.9941   precision: 5.05%
[958/1000]  loss: 2.9925   preci