## Imports

In [15]:
from collections import namedtuple, deque
import random
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from CellFreeNetwork import CellFreeNetwork
from SettingParams import mock_params
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Replay memory class

In [16]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

## Game Class

In [17]:
params = mock_params

class ClusteringGame:
    def __init__(self, num_users: int, num_aps: int, users_per_ap: int, network: CellFreeNetwork, alg: str):
        self.num_users = num_users
        self.num_aps = num_aps
        self.users_per_ap = users_per_ap
        self.num_actions = num_users * num_aps
        self.network = network
        self.all_actions = [i for i in range(num_actions)]
        self.alg = alg
        self.betas = None

    def generate_episode(self):
        self.network.generate_snapshot()
        betas = 10 ** (self.network.channel_model.path_loss_shadowing/10)
        self.betas = betas/np.max(betas)

    def get_init_state(self):
        """
        state description: ( user0_cluster, user1_cluster, ..., userK_cluster)
        :return:
        """
        return tuple(0 for _ in range(self.num_aps * self.num_users))

    def sample_action(self):
        return random.choice(self.all_actions)

    def result(self, state, action):
        mutable_state = list(state)
        mutable_state[action] = 1
        return tuple(mutable_state)

    def get_clusters_from_state(self, state):
        clusters = np.zeros((self.num_users, self.num_aps))
        for user in range(self.num_users):
            for ap in range(self.num_aps):
                clusters[user, ap] = state[user * num_aps + ap]

    def reward(self, state):
        self.network.set_clusters(self.get_clusters_from_state(state))

        num_frames = 50
        collective_channels, _, _, _ = self.network.generate_channel_realizations(num_frames)
        combiners = self.network.simulate_uplink_centralized(self.alg, collective_channels, collective_channels)
        precoders = self.network.simulate_downlink_centralized(self.alg, collective_channels, collective_channels)
        reward = \
            (self.network.compute_uplink_SE_centralized(collective_channels, combiners) +
             self.network.compute_downlink_SE_centralized(collective_channels, precoders)
             )/self.num_actions # average sum SE per UE
        return reward

    def terminal_test(self, state):
        clusters = self.get_clusters_from_state(state)
        return all(np.sum(clusters[:,ap]) == self.users_per_ap for ap in game.num_aps)


    def string_representation(self, state):
        return ''.join(state[i][j] for i in range(self.num_aps) for j in range(self.num_users))

### Initialize environment

In [18]:
num_users = params['num_users']
num_aps = params['num_aps']
num_actions = num_users * num_aps
users_per_ap = params['pilot_len']
alg = 'MMSE'
game = ClusteringGame(num_users, num_aps, users_per_ap, CellFreeNetwork(**params), alg)

## Neural Network

In [19]:
class DQN(nn.Module):

    def __init__(self, game: ClusteringGame):
        super(DQN, self).__init__()
        input_size = (game.num_aps * game.num_users) * 2
        self.layer1 = nn.Linear(input_size, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, game.num_actions)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

## Hyperparameters and helper functions

In [20]:
# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

policy_net = DQN(game).to(device)
target_net = DQN(game).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)

def select_action(game: ClusteringGame, state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return the largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[game.sample_action()]], device=device, dtype=torch.long)

def create_network_state(game: ClusteringGame, state):
    state_processed = [game.betas[i][j] for i in range(game.num_users) for j in range(num_aps)]
    state_processed.extend([state[i][j] for i in range(game.num_aps) for j in range(game.num_users)])
    return state_processed


episode_rewards = []


def plot_rewards(show_result=False):
    plt.figure(1)
    rewards_t = torch.tensor(episode_rewards, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.plot(rewards_t.numpy())
    # Take 100 episode averages and plot them too
    if len(rewards_t) >= 100:
        means = rewards_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

## Model Optimization

In [21]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    next_states = torch.cat([s for s in batch.next_state])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    with torch.no_grad():
        next_state_values = target_net(next_states, device=device).max(1)[0]
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

## Training Loop

In [22]:
num_episodes = 600

for i_episode in range(num_episodes):
    # Initialize the environment and get it's state
    game.generate_episode()
    state = game.get_init_state()
    state_t = torch.tensor(create_network_state(game,state), dtype=torch.float32, device=device).unsqueeze(0)
    while True:
        action = select_action(game, state)
        new_state = game.result(state, action)
        reward = game.reward(new_state) - game.reward(state)
        reward_t = torch.tensor([reward], device=device)
        terminated = game.terminal_test(new_state)
        next_state_t = torch.tensor(create_network_state(game,new_state), dtype=torch.float32, device=device).unsqueeze(0)

        # Store the transition in memory
        memory.push(state_t, action, next_state_t, reward_t)

        # Move to the next state
        state = new_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)

        if terminated:
            episode_rewards.append(reward)
            plot_rewards()
            break

print('Complete')
plot_rewards(show_result=True)
plt.ioff()
plt.show()

TypeError: 'int' object is not subscriptable