In [1]:
import pickle
import math
import time
import argparse
import random
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn.functional import gumbel_softmax
from utils import gumbel_softmax_3d
%matplotlib inline

In [2]:
def load_data(path):
    with open(path, 'rb') as f:
        data = pickle.load(f)
    G = nx.from_dict_of_lists(data)
    return G

In [7]:
def train(args, G):
    bs = args.batch_size
    n = G.number_of_nodes()
    A = nx.to_numpy_matrix(G)
    A = torch.from_numpy(A)
    A = A.type(torch.float32)
    best_log = []
    A = A.cuda()

    # set parameters
    if torch.cuda.is_available():
        A = A.cuda()
        x = torch.randn(bs, n, 1, device='cuda')*1e-5
        x.requires_grad = True
    else:
        x = torch.randn(bs, n, 1, requires_grad=True)*1e-5

    # set optimizer
    optimizer = torch.optim.Adam([x], lr=args.lr)

    # training
    cost_arr = []
    for _ in range(args.iterations):
        optimizer.zero_grad()
        if torch.cuda.is_available():
            probs = torch.empty(bs, n, 2, device='cuda')
        else:
            probs = torch.empty(bs, n, 2)
        p = torch.sigmoid(x)
        probs[:, :, 0] = p.squeeze()
        probs[:, :, -1] = 1-probs[:, :, 0]
        logits = torch.log(probs+1e-10)
        s = gumbel_softmax_3d(logits, tau=args.tau, hard=args.hard)[:, :, 0]
        s = torch.unsqueeze(s, -1)  # size [bs, n, 1]
        cost = torch.sum(s)
        constraint = torch.sum((1-torch.transpose(s, 1, 2)) @ A @ (1-s))
        loss = cost + args.eta * constraint
        loss.backward()
        optimizer.step()

        with torch.no_grad():

            constraint = torch.squeeze((1-torch.transpose(s, 1, 2)) @ A @ (1-s))
            constraint = constraint.cpu().numpy()
            idx = np.argwhere(constraint == 0)  # select constraint=0
            if len(idx) != 0:

                cost = torch.sum(s, dim=1)[idx.reshape(-1,)]
                # from size [bs, 1] select constrain=0
                cost_arr.append(torch.min(cost.cpu()))
                #print('#',_,':',s_[torch.argmin(cost),:,0],'#',torch.min(cost))
                
            if _ % 100 == 0:
                print(_)
                if len(cost_arr) != 0:
                    print('# {}, cost: {}'.format(_, ((np.sort(cost_arr))[0:8])))
                    #print(x.data[torch.argmin(loss),:,0])
                else:
                    print('Failed!')
                    #print(x.data[torch.argmin(loss),:,0])

    return cost_arr

In [None]:
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='Solving MIS problems (with fixed tau in GS, parallel version)')
    parser.add_argument('--batch-size', type=int, default=128,
                        help='batch size (default: 128)')
    parser.add_argument('--data', type=str, default='pubmed',
                        help='data name (default: cora)')
    parser.add_argument('--tau', type=float, default=1.,
                        help='tau value in Gumbel-softmax (default: 1)')
    parser.add_argument('--hard', type=bool, default=True,
                        help='hard sampling in Gumbel-softmax (default: True)')
    parser.add_argument('--lr', type=float, default=1e-2,
                        help='learning rate (default: 1e-2)')
    parser.add_argument('--eta', type=float, default=3.,
                        help='constraint (default: 5)')
    parser.add_argument('--ensemble', type=int, default=10,
                        help='# experiments (default: 100)')
    parser.add_argument('--iterations', type=int, default=10000,
                        help='# iterations in gradient descent (default: 20000)')
    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
    args = parser.parse_args(args=[])

    # torch.manual_seed(args.seed)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print(device)

    # loading data
    G = load_data('./data/ind.' + args.data + '.graph')

    for i in range(args.ensemble):
        cost = train(args, G)
        if len(cost) != 0:
            print('# {}, cost: {}'.format(i, min(cost)))
        else:
            print('Failed!')


if __name__ == '__main__':
    main()

cuda
0
Failed!
100
Failed!
200
Failed!
300
Failed!
400
Failed!
500
Failed!
600
Failed!
700
Failed!
800
Failed!
900
Failed!
1000
Failed!
1100
Failed!
1200
Failed!
1300
Failed!
1400
Failed!
1500
Failed!
1600
Failed!
1700
Failed!
1800
Failed!
1900
Failed!
2000
Failed!
2100
Failed!
2200
Failed!
2300
Failed!
2400
Failed!
2500
# 2500, cost: [4823. 4877. 4883. 4890. 4902. 4903.]
2600
# 2600, cost: [4697. 4730. 4737. 4744. 4750. 4751. 4764. 4767.]
2700
# 2700, cost: [4609. 4616. 4617. 4620. 4620. 4623. 4628. 4630.]
2800
# 2800, cost: [4525. 4529. 4531. 4533. 4535. 4537. 4538. 4539.]
2900
# 2900, cost: [4433. 4436. 4439. 4442. 4445. 4447. 4449. 4449.]
3000
# 3000, cost: [4377. 4378. 4378. 4380. 4380. 4386. 4386. 4386.]
3100
# 3100, cost: [4304. 4310. 4310. 4313. 4314. 4314. 4315. 4316.]
3200
# 3200, cost: [4254. 4255. 4258. 4259. 4259. 4264. 4266. 4266.]
3300
# 3300, cost: [4205. 4206. 4206. 4207. 4208. 4212. 4212. 4212.]
3400
# 3400, cost: [4152. 4161. 4163. 4165. 4166. 4166. 4169. 4172.]
3500

6000
# 6000, cost: [3843. 3843. 3843. 3843. 3843. 3844. 3844. 3844.]
6100
# 6100, cost: [3839. 3841. 3842. 3842. 3842. 3842. 3842. 3843.]
6200
# 6200, cost: [3839. 3839. 3839. 3840. 3840. 3840. 3840. 3840.]
6300
# 6300, cost: [3838. 3839. 3839. 3839. 3839. 3839. 3839. 3839.]
6400
# 6400, cost: [3836. 3838. 3838. 3839. 3839. 3839. 3839. 3839.]
6500
# 6500, cost: [3836. 3837. 3837. 3837. 3838. 3838. 3838. 3838.]
6600
# 6600, cost: [3836. 3837. 3837. 3837. 3837. 3837. 3837. 3838.]
6700
# 6700, cost: [3835. 3835. 3836. 3836. 3836. 3836. 3836. 3837.]
6800
# 6800, cost: [3835. 3835. 3835. 3835. 3835. 3835. 3835. 3835.]
6900
# 6900, cost: [3833. 3834. 3834. 3835. 3835. 3835. 3835. 3835.]
7000
# 7000, cost: [3833. 3833. 3834. 3834. 3834. 3834. 3834. 3834.]
7100
# 7100, cost: [3833. 3833. 3833. 3834. 3834. 3834. 3834. 3834.]
7200
# 7200, cost: [3832. 3833. 3833. 3833. 3833. 3833. 3834. 3834.]
7300
# 7300, cost: [3832. 3832. 3833. 3833. 3833. 3833. 3833. 3833.]
7400
# 7400, cost: [3832. 3832. 38