In [10]:
import pickle
import math
import time
import argparse
import random
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn.functional import gumbel_softmax
from utils import gumbel_softmax_3d
%matplotlib inline

In [11]:
def load_data(path):
    with open(path, 'rb') as f:
        data = pickle.load(f)
    G = nx.from_dict_of_lists(data)
    return G

In [12]:
def train(args, G):
    bs = args.batch_size
    n = G.number_of_nodes()
    A = nx.to_numpy_matrix(G)
    A = torch.from_numpy(A)
    A = A.type(torch.float32)
    best_log = []
    A = A.cuda()

    # set parameters
    if torch.cuda.is_available():
        A = A.cuda()
        x = torch.rand(bs, n, 1, device='cuda')
        x.requires_grad = True
    else:
        x = torch.randn(bs, n, 1, requires_grad=True)

    # set optimizer
    optimizer = torch.optim.Adam([x], lr=args.lr)

    # training
    cost_arr = []
    for _ in range(args.iterations):
        optimizer.zero_grad()
        if torch.cuda.is_available():
            probs = torch.empty(bs, n, 2, device='cuda')
        else:
            probs = torch.empty(bs, n, 2)
        #p = torch.sigmoid(x)
        p = x
        probs[:, :, 0] = p.squeeze()
        probs[:, :, -1] = 1-probs[:, :, 0]
        logits = torch.log(probs+1e-10)
        s = gumbel_softmax_3d(probs, tau=args.tau, hard=args.hard)[:, :, 0]
        s = torch.unsqueeze(s, -1)  # size [bs, n, 1]
        cost = torch.sum(s)
        constraint = torch.sum((1-torch.transpose(s, 1, 2)) @ A @ (1-s))
        loss = cost + args.eta * constraint
        loss.backward()
        optimizer.step()

        with torch.no_grad():

            constraint = torch.squeeze((1-torch.transpose(s, 1, 2)) @ A @ (1-s))
            constraint = constraint.cpu().numpy()
            idx = np.argwhere(constraint == 0)  # select constraint=0
            if len(idx) != 0:

                cost = torch.sum(s, dim=1)[idx.reshape(-1,)]
                # from size [bs, 1] select constrain=0
                cost_arr.append(torch.min(cost.cpu()))
                #print('#',_,':',s_[torch.argmin(cost),:,0],'#',torch.min(cost))
                
            if _ % 100 == 0:
                print(_)
                if len(cost_arr) != 0:
                    print('# {}, cost: {}'.format(_, ((np.sort(cost_arr))[0:8])))
                    #print(x.data[torch.argmin(loss),:,0])
                else:
                    print('Failed!')
                    #print(x.data[torch.argmin(loss),:,0])

    return cost_arr

In [15]:
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='Solving MVC problems (with fixed tau in GS, parallel version)')
    parser.add_argument('--batch-size', type=int, default=128,
                        help='batch size (default: 128)')
    parser.add_argument('--data', type=str, default='citeseer',
                        help='data name (default: cora)')
    parser.add_argument('--tau', type=float, default=1.,
                        help='tau value in Gumbel-softmax (default: 1)')
    parser.add_argument('--hard', type=bool, default=True,
                        help='hard sampling in Gumbel-softmax (default: True)')
    parser.add_argument('--lr', type=float, default=1e-2,
                        help='learning rate (default: 1e-2)')
    parser.add_argument('--eta', type=float, default=2.5,
                        help='constraint (default: 5)')
    parser.add_argument('--ensemble', type=int, default=10,
                        help='# experiments (default: 100)')
    parser.add_argument('--iterations', type=int, default=10000,
                        help='# iterations in gradient descent (default: 20000)')
    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
    args = parser.parse_args(args=[])

    # torch.manual_seed(args.seed)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print(device)

    # loading data
    G = load_data('./data/ind.' + args.data + '.graph')

    for i in range(args.ensemble):
        cost = train(args, G)
        if len(cost) != 0:
            print('# {}, cost: {}'.format(i, min(cost)))
        else:
            print('Failed!')


if __name__ == '__main__':
    main()

cuda
0
Failed!
100
Failed!
200
Failed!
300
Failed!
400
Failed!
500
Failed!
600
Failed!
700
Failed!
800
Failed!
900
Failed!
1000
Failed!
1100
Failed!
1200
Failed!
1300
Failed!
1400
# 1400, cost: [1651.]
1500
# 1500, cost: [1651.]
1600
# 1600, cost: [1651.]
1700
# 1700, cost: [1606. 1651.]
1800
# 1800, cost: [1606. 1651.]
1900
# 1900, cost: [1606. 1651.]
2000
# 2000, cost: [1606. 1651.]
2100
# 2100, cost: [1582. 1606. 1651.]
2200
# 2200, cost: [1582. 1589. 1606. 1651.]
2300
# 2300, cost: [1582. 1589. 1606. 1651.]
2400
# 2400, cost: [1580. 1580. 1582. 1589. 1591. 1593. 1606. 1651.]
2500
# 2500, cost: [1575. 1580. 1580. 1581. 1582. 1589. 1591. 1593.]
2600
# 2600, cost: [1575. 1580. 1580. 1581. 1582. 1589. 1591. 1593.]
2700
# 2700, cost: [1570. 1573. 1575. 1579. 1580. 1580. 1581. 1582.]
2800
# 2800, cost: [1570. 1573. 1575. 1579. 1580. 1580. 1581. 1582.]
2900
# 2900, cost: [1566. 1569. 1570. 1573. 1575. 1578. 1579. 1580.]
3000
# 3000, cost: [1565. 1566. 1569. 1570. 1573. 1574. 1575. 1576.]


5400
# 5400, cost: [1561. 1567. 1568. 1568. 1568. 1568. 1570. 1570.]
5500
# 5500, cost: [1561. 1567. 1568. 1568. 1568. 1568. 1570. 1570.]
5600
# 5600, cost: [1561. 1567. 1568. 1568. 1568. 1568. 1570. 1570.]
5700
# 5700, cost: [1561. 1567. 1568. 1568. 1568. 1568. 1570. 1570.]
5800
# 5800, cost: [1561. 1567. 1568. 1568. 1568. 1568. 1570. 1570.]
5900
# 5900, cost: [1558. 1561. 1567. 1568. 1568. 1568. 1568. 1568.]
6000
# 6000, cost: [1558. 1561. 1563. 1567. 1568. 1568. 1568. 1568.]
6100
# 6100, cost: [1558. 1561. 1563. 1564. 1567. 1568. 1568. 1568.]
6200
# 6200, cost: [1558. 1561. 1563. 1564. 1567. 1568. 1568. 1568.]
6300
# 6300, cost: [1558. 1561. 1561. 1563. 1564. 1567. 1568. 1568.]
6400
# 6400, cost: [1558. 1561. 1561. 1563. 1564. 1567. 1568. 1568.]
6500
# 6500, cost: [1558. 1561. 1561. 1563. 1564. 1567. 1568. 1568.]
6600
# 6600, cost: [1558. 1561. 1561. 1563. 1564. 1567. 1568. 1568.]
6700
# 6700, cost: [1558. 1561. 1561. 1563. 1564. 1567. 1568. 1568.]
6800
# 6800, cost: [1558. 1561. 15

KeyboardInterrupt: 