In [523]:
#@title ##### License { display-mode: "form" }
# Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# OpenSpiel

* This Colab gets you started with installing OpenSpiel and its dependencies.
* OpenSpiel is a framework for reinforcement learning in games.
* The instructions are adapted from [here](https://github.com/deepmind/open_spiel/blob/master/docs/install.md).

## Install

Install OpenSpiel via pip:


In [524]:
!rm -r figure
!rm -r 0
!mkdir figure
!mkdir 0

In [525]:
! pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.10.0+cu102.html 
! pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.10.0+cu102.html 
! pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cu102.html 
! pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.10.0+cu102.html 
! pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.10.0+cu102.html 

Looking in links: https://pytorch-geometric.com/whl/torch-1.10.0+cu102.html
Looking in links: https://pytorch-geometric.com/whl/torch-1.10.0+cu102.html
Looking in links: https://pytorch-geometric.com/whl/torch-1.10.0+cu102.html
Looking in links: https://pytorch-geometric.com/whl/torch-1.10.0+cu102.html
Looking in links: https://pytorch-geometric.com/whl/torch-1.10.0+cu102.html


# Environment

In [526]:
# Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Reinforcement Learning (RL) Environment for Open Spiel.

This module wraps Open Spiel Python interface providing an RL-friendly API. It
covers both turn-based and simultaneous move games. Interactions between agents
and the underlying game occur mostly through the `reset` and `step` methods,
which return a `TimeStep` structure (see its docstrings for more info).

The following example illustrates the interaction dynamics. Consider a 2-player
Kuhn Poker (turn-based game). Agents have access to the `observations` (a dict)
field from `TimeSpec`, containing the following members:
 * `info_state`: list containing the game information state for each player. The
   size of the list always correspond to the number of players. E.g.:
   [[0, 1, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0]].
 * `legal_actions`: list containing legal action ID lists (one for each player).
   E.g.: [[0, 1], [0]], which corresponds to actions 0 and 1 being valid for
   player 0 (the 1st player) and action 0 being valid for player 1 (2nd player).
 * `current_player`: zero-based integer representing the player to make a move.

At each `step` call, the environment expects a singleton list with the action
(as it's a turn-based game), e.g.: [1]. This (zero-based) action must correspond
to the player specified at `current_player`. The game (which is at decision
node) will process the action and take as many steps necessary to cover chance
nodes, halting at a new decision or final node. Finally, a new `TimeStep`is
returned to the agent.

Simultaneous-move games follow analogous dynamics. The only differences is the
environment expects a list of actions, one per player. Note the `current_player`
field is "irrelevant" here, admitting a constant value defined in spiel.h, which
defaults to -2 (module level constant `SIMULTANEOUS_PLAYER_ID`).

See open_spiel/python/examples/rl_example.py for example usages.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

import enum
from absl import logging
import numpy as np

import pyspiel

SIMULTANEOUS_PLAYER_ID = pyspiel.PlayerId.SIMULTANEOUS

class TimeStep(
    collections.namedtuple(
        "TimeStep", ["observations", "rewards", "discounts", "step_type"])):
  """Returned with every call to `step` and `reset`.

  A `TimeStep` contains the data emitted by a game at each step of interaction.
  A `TimeStep` holds an `observation` (list of dicts, one per player),
  associated lists of `rewards`, `discounts` and a `step_type`.

  The first `TimeStep` in a sequence will have `StepType.FIRST`. The final
  `TimeStep` will have `StepType.LAST`. All other `TimeStep`s in a sequence will
  have `StepType.MID.

  Attributes:
    observations: a list of dicts containing observations per player.
    rewards: A list of scalars (one per player), or `None` if `step_type` is
      `StepType.FIRST`, i.e. at the start of a sequence.
    discounts: A list of discount values in the range `[0, 1]` (one per player),
      or `None` if `step_type` is `StepType.FIRST`.
    step_type: A `StepType` enum value.
  """
  __slots__ = ()

  def first(self):
    return self.step_type == StepType.FIRST

  def mid(self):
    return self.step_type == StepType.MID

  def last(self):
    return self.step_type == StepType.LAST

  def is_simultaneous_move(self):
    return self.observations["current_player"] == SIMULTANEOUS_PLAYER_ID

  def current_player(self):
    return self.observations["current_player"]


class StepType(enum.Enum):
  """Defines the status of a `TimeStep` within a sequence."""

  FIRST = 0  # Denotes the first `TimeStep` in a sequence.
  MID = 1  # Denotes any `TimeStep` in a sequence that is not FIRST or LAST.
  LAST = 2  # Denotes the last `TimeStep` in a sequence.

  def first(self):
    return self is StepType.FIRST

  def mid(self):
    return self is StepType.MID

  def last(self):
    return self is StepType.LAST


# Global pyspiel members
def registered_games():
  return pyspiel.registered_games()



class ChanceEventSampler(object):
  """Default sampler for external chance events."""

  def __init__(self, seed=None):
    self.seed(seed)

  def seed(self, seed=None):
    self._rng = np.random.RandomState(seed)

  def __call__(self, state):
    """Sample a chance event in the given state."""
    actions, probs = zip(*state.chance_outcomes())
    return self._rng.choice(actions, p=probs)


class ObservationType(enum.Enum):
  """Defines what kind of observation to use."""
  OBSERVATION = 0  # Use observation_tensor
  INFORMATION_STATE = 1  # Use information_state_tensor


class Environment(object):
  """Open Spiel reinforcement learning environment class."""

  def __init__(self,
               game,
               discount=1.0,
               chance_event_sampler=None,
               observation_type=None,
               include_full_state=False,
               distribution=None,
               mfg_population=None,
               enable_legality_check=False,
               **kwargs):
    """Constructor.

    Args:
      game: [string, pyspiel.Game] Open Spiel game name or game instance.
      discount: float, discount used in non-initial steps. Defaults to 1.0.
      chance_event_sampler: optional object with `sample_external_events` method
        to sample chance events.
      observation_type: what kind of observation to use. If not specified, will
        default to INFORMATION_STATE unless the game doesn't provide it.
      include_full_state: whether or not to include the full serialized
        OpenSpiel state in the observations (sometimes useful for debugging).
      distribution: the distribution over states if the game is a mean field
        game.
      mfg_population: The Mean Field Game population to consider.
      enable_legality_check: Check the legality of the move before stepping.
      **kwargs: dict, additional settings passed to the Open Spiel game.
    """
    self._chance_event_sampler = chance_event_sampler or ChanceEventSampler()
    self._include_full_state = include_full_state
    self._distribution = distribution
    self._mfg_population = mfg_population
    self._enable_legality_check = enable_legality_check

    if isinstance(game, str):
      if kwargs:
        game_settings = {key: val for (key, val) in kwargs.items()}
        logging.info("Using game settings: %s", game_settings)
        self._game = pyspiel.load_game(game, game_settings)
      else:
        logging.info("Using game string: %s", game)
        self._game = pyspiel.load_game(game)
    else:  # pyspiel.Game or API-compatible object.
      logging.info("Using game instance: %s", game.get_type().short_name)
      self._game = game

    self._num_players = self._game.num_players()
    self._state = None
    self._should_reset = True

    # Discount returned at non-initial  steps.
    self._discounts = [discount] * self._num_players

    # Determine what observation type to use.
    if observation_type is None:
      if self._game.get_type().provides_information_state_tensor:
        observation_type = ObservationType.INFORMATION_STATE
      else:
        observation_type = ObservationType.OBSERVATION

    # Check the requested observation type is supported.
    if observation_type == ObservationType.OBSERVATION:
      if not self._game.get_type().provides_observation_tensor:
        raise ValueError(f"observation_tensor not supported by {game}")
    elif observation_type == ObservationType.INFORMATION_STATE:
      if not self._game.get_type().provides_information_state_tensor:
        raise ValueError(f"information_state_tensor not supported by {game}")
    self._use_observation = (observation_type == ObservationType.OBSERVATION)

    if self._game.get_type().dynamics == pyspiel.GameType.Dynamics.MEAN_FIELD:
      assert distribution is not None
      assert mfg_population is not None
      assert 0 <= mfg_population < self._num_players

  def seed(self, seed=None):
    self._chance_event_sampler.seed(seed)

  def get_time_step(self):
    """Returns a `TimeStep` without updating the environment.

    Returns:
      A `TimeStep` namedtuple containing:
        observation: list of dicts containing one observations per player, each
          corresponding to `observation_spec()`.
        reward: list of rewards at this timestep, or None if step_type is
          `StepType.FIRST`.
        discount: list of discounts in the range [0, 1], or None if step_type is
          `StepType.FIRST`.
        step_type: A `StepType` value.
    """
    observations = {
        "info_state": [],
        "edge_index":[],
        "legal_actions": [],
        "current_player": [],
        "serialized_state": []
    }
    rewards = []
    step_type = StepType.LAST if self._state.is_terminal() else StepType.MID
    self._should_reset = step_type == StepType.LAST
    cur_rewards = self._state.rewards()
    player_id = 0
    rewards.append(cur_rewards[player_id])
    observations["info_state"].append(
          self._state.observation_tensor(player_id) if self._use_observation
          else self._state.information_state_tensor(player_id))
    observations["legal_actions"].append(self._state.legal_actions(player_id))
    observations["edge_index"]=self._state.edge_index
    observations["current_player"] = self._state.current_player()
    discounts = self._discounts
    if step_type == StepType.LAST:
      # When the game is in a terminal state set the discount to 0.
      discounts = [0. for _ in discounts]

    if self._include_full_state:
      observations["serialized_state"] = pyspiel.serialize_game_and_state(
          self._game, self._state)

    return TimeStep(
        observations=observations,
        rewards=rewards,
        discounts=discounts,
        step_type=step_type)

  def _check_legality(self, actions):
    if self.is_turn_based:
      legal_actions = self._state.legal_actions()
      if actions[0] not in legal_actions:
        raise RuntimeError(f"step() called on illegal action {actions[0]}")
    else:
      for p in range(len(actions)):
        legal_actions = self._state.legal_actions(p)
        if legal_actions and actions[p] not in legal_actions:
          raise RuntimeError(f"step() by player {p} called on illegal " +
                             f"action: {actions[p]}")

  def step(self, actions):
    """Updates the environment according to `actions` and returns a `TimeStep`.

    If the environment returned a `TimeStep` with `StepType.LAST` at the
    previous step, this call to `step` will start a new sequence and `actions`
    will be ignored.

    This method will also start a new sequence if called after the environment
    has been constructed and `reset` has not been called. Again, in this case
    `actions` will be ignored.

    Args:
      actions: a list containing one action per player, following specifications
        defined in `action_spec()`.

    Returns:
      A `TimeStep` namedtuple containing:
        observation: list of dicts containing one observations per player, each
          corresponding to `observation_spec()`.
        reward: list of rewards at this timestep, or None if step_type is
          `StepType.FIRST`.
        discount: list of discounts in the range [0, 1], or None if step_type is
          `StepType.FIRST`.
        step_type: A `StepType` value.
    """
    assert len(actions) == self.num_actions_per_step, (
        "Invalid number of actions! Expected {}".format(
            self.num_actions_per_step))
    if self._should_reset:
      return self.reset()

    if self._enable_legality_check:
      self._check_legality(actions)

    if self.is_turn_based:
      self._state.apply_action(actions[0])
    else:
      self._state.apply_actions(actions)


    return self.get_time_step()

  def reset(self):
    """Starts a new sequence and returns the first `TimeStep` of this sequence.

    Returns:
      A `TimeStep` namedtuple containing:
        observations: list of dicts containing one observations per player, each
          corresponding to `observation_spec()`.
        rewards: list of rewards at this timestep, or None if step_type is
          `StepType.FIRST`.
        discounts: list of discounts in the range [0, 1], or None if step_type
          is `StepType.FIRST`.
        step_type: A `StepType` value.
    """
    self._should_reset = False
    if self._game.get_type(
    ).dynamics == pyspiel.GameType.Dynamics.MEAN_FIELD and self._num_players > 1:
      self._state = self._game.new_initial_state_for_population(
          self._mfg_population)
    else:
      self._state = self._game.new_initial_state()

    observations = {
        "info_state": [],
        "edge_index":[],
        "legal_actions": [],
        "current_player": [],
        "serialized_state": []
    }
    player_id = 0
    observations["info_state"].append(
        self._state.observation_tensor(player_id) if self._use_observation
        else self._state.information_state_tensor(player_id))
    observations["legal_actions"].append(self._state.legal_actions(player_id))
    observations["edge_index"]=self._state.edge_index
    observations["current_player"] = self._state.current_player()


    if self._include_full_state:
      observations["serialized_state"] = pyspiel.serialize_game_and_state(
          self._game, self._state)

    return TimeStep(
        observations=observations,
        rewards=None,
        discounts=None,
        step_type=StepType.FIRST)


  def observation_spec(self):
    """Defines the observation per player provided by the environment.

    Each dict member will contain its expected structure and shape. E.g.: for
    Kuhn Poker {"info_state": (6,), "legal_actions": (2,), "current_player": (),
                "serialized_state": ()}

    Returns:
      A specification dict describing the observation fields and shapes.
    """
    return dict(
        info_state=tuple([
            self._game.observation_tensor if self._use_observation else
            self._game.information_state_tensor_size
        ]),
        edge_index=(),
        legal_actions=(self._game.num_distinct_actions(),),
        current_player=(),
        serialized_state=(),
    )

  def action_spec(self):
    """Defines per player action specifications.

    Specifications include action boundaries and their data type.
    E.g.: for Kuhn Poker {"num_actions": 2, "min": 0, "max":1, "dtype": int}

    Returns:
      A specification dict containing per player action properties.
    """
    return dict(
        num_actions=self._game.num_distinct_actions(),
        min=0,
        max=self._game.num_distinct_actions() - 1,
        dtype=int,
    )

  # Environment properties
  @property
  def use_observation(self):
    """Returns whether the environment is using the game's observation.

    If false, it is using the game's information state.
    """
    return self._use_observation

  # Game properties
  @property
  def name(self):
    return self._game.get_type().short_name

  @property
  def num_players(self):
    return self._game.num_players()

  @property
  def num_actions_per_step(self):
    return 10 #set as number of actions

  # New RL calls for more advanced use cases (e.g. search + RL).
  @property
  def is_turn_based(self):
    return ((self._game.get_type().dynamics
             == pyspiel.GameType.Dynamics.SEQUENTIAL) or
            (self._game.get_type().dynamics
             == pyspiel.GameType.Dynamics.MEAN_FIELD))

  @property
  def max_game_length(self):
    return self._game.max_game_length()

  @property
  def is_chance_node(self):
    return self._state.is_chance_node()

  @property
  def game(self):
    return self._game

  def set_state(self, new_state):
    """Updates the game state."""
    assert new_state.get_game() == self.game, (
        "State must have been created by the same game.")
    self._state = new_state

  @property
  def get_state(self):
    return self._state

# Game

In [532]:
def gen_graph(cur_n, g_type):
    if g_type == 'erdos_renyi':
        g = nx.erdos_renyi_graph(n=cur_n, p=0.15)
    elif g_type == 'powerlaw':
        g = nx.powerlaw_cluster_graph(n=cur_n, m=4, p=0.05)
    elif g_type == 'small-world':
        g = nx.connected_watts_strogatz_graph(n=cur_n, k=8, p=0.1)
    elif g_type == 'barabasi_albert':
        g = nx.barabasi_albert_graph(n=cur_n, m=4)
    return g

def gen_new_graphs(number_nodes):
    print('\ngenerating new training graphs...')
    graph_type = ['erdos_renyi', 'powerlaw','small-world', 'barabasi_albert']
    a = np.random.choice(graph_type)
    graph = gen_graph(number_nodes, a)
    #graph = gen_graph(number_nodes, 'small-world')
    active = 1
    nx.set_node_attributes(graph,active, "active")
    index = {n: {"index":n} for n in np.arange(len(graph))}
    nx.set_node_attributes(graph,index)
    return graph    

In [527]:
def gen_new_graphs(number_nodes):
    print('\ngenerating new training graphs...')
    graph_type = ['erdos_renyi', 'powerlaw','small-world', 'barabasi_albert']
    a = np.random.choice(graph_type)
    graph = gen_graph(number_nodes, a)
    #graph = gen_graph(number_nodes, 'small-world')
    active = 1
    nx.set_node_attributes(graph,active, "active")
    index = {n: {"index":n} for n in np.arange(len(graph))}
    nx.set_node_attributes(graph,index)
    return graph   
nx.write_edgelist(gen_new_graphs(50), "./real/test")


generating new training graphs...


In [528]:
#Re-register the Environment 

#GRAPH = nx.read_edgelist("./real/corruption.txt")
GRAPH = nx.read_edgelist("./real/test")
_NUM_PLAYERS = 2
_NUM_CELLS = len(GRAPH)  
print(_NUM_CELLS)
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import random
import copy
import networkx as nx
import torch
from torch_geometric.data import Data
from torch_geometric import utils
from open_spiel.python.observation import IIGObserverForPublicInfoGame
import pyspiel
from operator import itemgetter



def reset(graph):
    active = int(1)
    nx.set_node_attributes(graph,active, "active")
    index = {int(n): {"index":int(n)} for n in np.arange(len(graph))}
    nx.set_node_attributes(graph,index)
    return graph   


class GraphGame(pyspiel.Game):
  """A Python version of the Graph game."""

  def __init__(self, params=None):
    super().__init__(_GAME_TYPE, _GAME_INFO, params or dict())

  def new_initial_state(self):
    """Returns a state corresponding to the start of a game."""
    return GraphState(self)

  def make_py_observer(self, iig_obs_type=None, params=None):
    """Returns an object used for observing game state."""
    '''
    if ((iig_obs_type is None) or
        (iig_obs_type.public_info and not iig_obs_type.perfect_recall)):
      return BoardObserver(params)
    else:
      return IIGObserverForPublicInfoGame(iig_obs_type, params)
    '''
    return BoardObserver(params)


class GraphState(pyspiel.State):
  """A python version of the Tic-Tac-Toe state."""
  def __init__(self, game):
    """Constructor; should only be called by Game.new_initial_state."""
    super().__init__(game)
    self._is_terminal = False
    #self.board = reset(nx.read_edgelist("./real/corruption.txt"))
    self.board = reset(GRAPH)
    self.num_nodes = len(self.board)
    self.num_feature = 3
    self.edge_index = utils.convert.from_networkx(self.board).edge_index
    self._rewards = np.zeros(_NUM_PLAYERS)
    self._returns = np.zeros(_NUM_PLAYERS)
    self.lcc = [1]
    self.r = []
    self.timestep = 1

  # OpenSpiel (PySpiel) API functions are below. This is the standard set that
  # should be implemented by every perfect-information sequential-move game.

  def current_player(self):
    """Returns id of the next player to move, or TERMINAL if game is over."""
    #return pyspiel.PlayerId.TERMINAL if self._is_terminal else pyspiel.PlayerId.SIMULTANEOUS
    return pyspiel.PlayerId.TERMINAL if self._is_terminal else pyspiel.PlayerId.SIMULTANEOUS
  

  def _legal_actions(self, player):
    """Returns a list of legal actions, sorted in ascending order."""
    all_nodes = np.array(list(self.board.nodes(data="active")),dtype=int)[:,1]
    action_sequence = np.where(all_nodes == 1)[0]
    return action_sequence

  def _apply_actions(self, actions):
    """Applies the specified action to the state."""
    self.r.append(self._rewards[0])
    actions = np.array(actions,dtype=np.dtype('<U21'))
    for action in actions:
      self.board.nodes[action]["active"] = '0'
    _, cond, l = _network_dismantle(self.board)
    #self.data = features(subgraph)
    decrease_lcc = (self.lcc[-1] - l)
    self._rewards[0] = decrease_lcc*10
    self._rewards[1] = -self._rewards[0]
    self._returns += self._rewards
    self.lcc.append(l)
    if self.timestep==5:
        #cond = True 
        None
    else: 
        self.timestep+=1
    self._is_terminal = cond


  def _action_to_string(self, player, action):
    """Action -> string."""
    return "{}({})".format(0 if player == 0 else 1, action)

  def is_terminal(self):
    """Returns True if the game is over."""
    return self._is_terminal

  def returns(self):
    """Total reward for each player over the course of the game so far."""
    return self._returns
  def rewards(self):
    """Total reward for each player over the course of the game so far."""
    return self._rewards

  def __str__(self):
    """String for debug purposes. No particular semantics are required."""
    return _board_to_string(self.board)

  def new_initial_state(self):
    self.edge_index = utils.convert.from_networkx(self.board).edge_index
    self.board = reset(self.board)

class BoardObserver:
  """Observer, conforming to the PyObserver interface (see observation.py)."""

  def __init__(self,params):
    """Initializes an empty observation tensor."""
    if params:
      raise ValueError(f"Observation parameters not supported; passed {params}")
    # The observation should contain a 1-D tensor in `self.t ensor` and a
    # dictionary of views onto the tensor, which may be of any shape.
    # Here the observation is indexed `(cell state, row, column)
    _NUM_FEATURES = 3
    shape = ( _NUM_CELLS,_NUM_FEATURES)
    self.tensor = np.zeros(np.prod(shape), np.float32)
    self.dict = {"observation":np.reshape(self.tensor, shape)}
    #self.dict = {"observation":self.tensor}


  def set_from(self, state, player):
    """Updates `tensor` and `dict` to reflect `state` from PoV of `player`."""
    # We update the observation via the shaped tensor since indexing is more
    # convenient than with the 1-D tensor. Both are views onto the same memory.
    obs = self.dict["observation"]
    obs.fill(0)
    all_nodes = np.array(list(state.board.nodes(data="active")))[:,1]
    attack_nodes = np.array(np.where(all_nodes == '1')[0],dtype=str)
    subGraph = state.board.subgraph(np.array(attack_nodes,dtype=np.dtype('<U21')))
    data = np.array(features(subGraph))
    for i, x in enumerate(subGraph.nodes()):
        (i)
        obs[i,:] = data[i,:]
    self.tensor =obs.flatten()
    return self.tensor

  def string_from(self, state, player):
    """Observation of `state` from the PoV of `player`, as a string."""
    return _board_to_string(state.board)


# Helper functions for game details.
def get_index(g,index):
    node_location = nx.get_node_attributes(g, "index")
    return [node_location.get(key) for key in index]

def features(g):
    degree_centrality = list(nx.degree_centrality(g).values())
    #precolation_centrality = list(nx.percolation_centrality(g).values())
    #closeness_centrality = list(nx.closeness_centrality(g).values())
    #eigen_centrality = list(nx.eigenvector_centrality(g,tol=1e-04).values())
    clustering_coeff = list(nx.clustering(g).values())
    core_num = list(nx.core_number(g).values())
    #pagerank = list(nx.pagerank(g).values())
    #x = np.column_stack((degree_centrality,clustering_coeff,pagerank, core_num ))
    x = np.column_stack((degree_centrality,clustering_coeff, core_num ))
    return x

def _network_dismantle(board):
    """Checks if a line exists, returns "x" or "o" if so, and None otherwise."""
    all_nodes = np.array(list(board.nodes(data="active")))[:,1]
    alive_nodes = np.array(np.where(all_nodes == '1')[0],dtype=str)
    subGraph = board.subgraph(np.array(alive_nodes,dtype=np.dtype('<U21')))
    largest_cc = len(max(nx.connected_components(subGraph), key=len))/len(board)
    cond = True if largest_cc <= 0.1 else False
    cond = True if len(alive_nodes) <= 10 or len(alive_nodes) <= 10 else False
    return subGraph, cond, largest_cc


def _board_to_string(board):
    """Returns a string representation of the board."""
    value = np.array(list(board.nodes(data="active")))
    return " ".join(str(board.nodes[e]["index"])+"{"+str(f)+"}" for e, f in value)
_GAME_TYPE = pyspiel.GameType(
    short_name="graph_attack_defend",
    long_name="Python Attack Defend",
    dynamics=pyspiel.GameType.Dynamics.SIMULTANEOUS,
    #dynamics=pyspiel.GameType.Dynamics.SEQUENTIAL,
    chance_mode=pyspiel.GameType.ChanceMode.EXPLICIT_STOCHASTIC,
    information=pyspiel.GameType.Information.IMPERFECT_INFORMATION,
    utility=pyspiel.GameType.Utility.ZERO_SUM,
    reward_model=pyspiel.GameType.RewardModel.REWARDS,
    max_num_players=1,
    min_num_players=1,
    provides_information_state_string=True,
    provides_information_state_tensor=True,
    provides_observation_string=False,
    provides_observation_tensor=False,
    provides_factored_observation_string=True)

_GAME_INFO = pyspiel.GameInfo(
    num_distinct_actions=10,
    max_chance_outcomes=0,
    num_players=1,
    min_utility=-1.0,
    max_utility=1.0,
    utility_sum=0.0,
    max_game_length=_NUM_CELLS)
pyspiel.register_game(_GAME_TYPE, GraphGame)

50


# Evaluate

In [529]:
# Got it from DQN class to get the probabilities for the best action
from open_spiel.python import rl_agent
def step(agent,time_step):
    """Returns the action to be taken and updates the Q-network if needed.
    Args:
      time_step: an instance of rl_environment.TimeStep.
      is_evaluation: bool, whether this is a training or evaluation call.
      add_transition_record: Whether to add to the replay buffer on this step.
    Returns:
      A `rl_agent.StepOutput` containing the action probs and chosen action.
    """

    # Act step: don't act at terminal info states or if its not our turn.
    index = torch.Tensor(np.arange(10))
    if (not time_step.last()) and (
        time_step.is_simultaneous_move() or
        agent.player_id == time_step.current_player()):
      info_state = time_step.observations["info_state"][agent.player_id]
      legal_actions = np.array(time_step.observations["legal_actions"][agent.player_id])
      edge_index = time_step.observations["edge_index"]
      probs = np.zeros(agent._num_actions)
      #get values in terms of (Data infostate, Edgelist)
      shape = (int(len(info_state)/agent.num_feature),agent.num_feature)
      x = np.reshape(info_state, shape)
      data = torch.from_numpy(x.astype(np.float32))
      print(agent._q_network)
      q_values = agent._q_network(data.type(torch.FloatTensor),edge_index.type(torch.LongTensor)).detach()
      legal_q_values = q_values[legal_actions]
      sorted = torch.sort(legal_q_values,descending=True,stable=True)
      print(q_values)
      if torch.sum(torch.eq(sorted[0],torch.ones(agent._num_actions)))!= (agent._num_actions-10):
         action = np.random.randint(agent._num_actions,size=10)
         #action = torch.index_select(legal_actions,1,torch.rand(10))
      else:
        action = sorted[1][:10]
      probs[action] = 1.0
    return rl_agent.StepOutput(action=action, probs=probs)

In [530]:
def large(env, attacker):
    """Evaluates `attacker` against a new graph."""
    time_step = env.reset()
    episode_rewards = []
    i = 0
    while not time_step.last():
        agents_output = step(attacker,time_step)
        action_list = agents_output.action
        time_step = env.step(action_list)
        #drawNetwork(env.get_state.board,str(i))
        i+=1
        episode_rewards.append(env.get_state._rewards[0])
    episode_rewards,lcc = env.get_state.lcc
    # build gif
  


In [531]:
#Create new environment
game = "graph_attack_defend"
num_players = 1
env = Environment(game)
info_state_size = env.observation_spec()["info_state"][0]
feature_size = 3
num_actions = 20 #env.action_spec()["num_actions"]
agent = torch.load('dqn_test0')
agent._num_actions = _NUM_CELLS
rewards, lcc, action_lists = large(env, agent)
print(rewards)
plt.plot(lcc)
plt.title("(Real World)LCC/n vs iteration")
plt.savefig("(Real World)LCC/n vs iteration")
plt.show()
plt.clf()
plt.plot(rewards)
plt.title("Rewards for Attacker")
plt.savefig("(Real World)Rewards")
plt.show()

GraphNN(
  (conv1): GATv2Conv(3, 64, heads=1)
  (conv2): GATv2Conv(64, 1, heads=1)
)


AttributeError: 'GraphNN' object has no attribute 'conv3'