# Import modules

In [None]:
import train
import environment as environment
import neuralnet
#from neuralnet import convolutional

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from util import flatten_state, flatten_state_not_first_board

# Set training and network parameters

In [None]:
# training settings
trainingKwargs = {    
    'num_episodes' : 100,
    #rollout_limit = env.spec.timestep_limit # max rollout length
    'discount_factor' : 0.9, # reward discount factor (gamma), 1.0 = no discount
    'val_freq' : 25 # validation frequency
}

# training network settings
netKwargs = {
    'n_inputs' : 614,
    'n_hidden' : 500,
    #'n_outputs' : env.action_space.n, 
    'n_outputs' : 6, # This method is updated in the training class
    'learning_rate': 0.001,
    'batch_norm' : False,
    'conv1_in_channels' : 1,
    'conv1_out_channels' : 3,
    'conv2_out_channels' : 3,
    'conv3_out_channels' : 3,
    'kernel_size' : 5
}


In [None]:
class PolicyNet(nn.Module):
    """Policy network"""

    def __init__(self, n_inputs, n_hidden, n_outputs, learning_rate, batch_norm, conv1_in_channels, conv1_out_channels, conv2_out_channels, conv3_out_channels, kernel_size):
        super(PolicyNet, self).__init__()
        # Network Parameters
        # network
        self.other_shape = [3]
        
        #Input for conv2d is (batch_size, num_channels, width, height)
        self.conv1 = nn.Conv2d(in_channels = conv1_in_channels, out_channels=conv1_out_channels,
                               kernel_size=kernel_size, stride=1, padding=2)
        
        self.conv2 = nn.Conv2d(in_channels = conv1_out_channels, out_channels=conv2_out_channels,
                               kernel_size=kernel_size, stride=1, padding=2)
        
        self.conv3 = nn.Conv2d(in_channels = conv2_out_channels, out_channels=conv3_out_channels,
                               kernel_size=kernel_size, stride=1, padding=2)
        
        self.convolution_out_size = 11*11*3
        
        self.ffn_input_size = n_inputs
        
        self.ffn = nn.Sequential(
            nn.Linear(n_inputs, n_hidden),
            nn.ReLU(),
            #
            nn.Dropout(0.25),
            #nn.BatchNorm1d(n_hidden),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            nn.Dropout(0.25),
            #nn.BatchNorm1d(n_hidden),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            #nn.BatchNorm1d(n_hidden),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            nn.Dropout(0.25),
            #nn.BatchNorm1d(n_hidden),
            nn.Linear(n_hidden, n_outputs),
        )
        
        self.activation = F.relu
        
        if batch_norm:
            self.bn1 = nn.BatchNorm2d(11)
            #self.bn2 = nn.BatchNorm2d(num_channels)
            #self.bn3 = nn.BatchNorm2d(num_channels)
        else:
            self.bn1 = lambda x: x
            self.bn2 = lambda x: x
            self.bn3 = lambda x: x
        
        self.ffn.apply(self.init_weights)
        
        #self.hidden = nn.Linear(n_inputs, n_hidden)
        #self.hidden2 = nn.Linear(n_hidden, n_hidden)
        #self.out = nn.Linear(n_hidden, n_outputs)
        # training
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    def forward(self, x):
       #Setup data for board
        #print(x)
        #print(x.keys())
        board = x[0]['board']
        
        board = torch.tensor(board)
        board = board.unsqueeze(0)
        board = board.unsqueeze(0)
        board = board.float()
        for i in range(1,len(x)):
            completeBoard = torch.tensor(x[i]['board'])
            completeBoard = completeBoard.unsqueeze(0)
            completeBoard = completeBoard.unsqueeze(0)
            completeBoard = completeBoard.float()
            board = torch.cat([board, completeBoard], dim=0)
        
        #print(board.size())
        board = torch.autograd.Variable(board)
        board = self.conv1(board)
        board = self.bn1(board)
        board = self.activation(board)
        board = self.conv2(board)
        board = self.bn1(board)
        board = self.activation(board)
        board = self.conv3(board)
        board = self.bn1(board)
        board = self.activation(board)
        #print(board.size())
        
        #x = board.view(-1, self.l1_in_features)
        x2 = board.view(-1, self.convolution_out_size)

        x = flatten_state_not_first_board(x)
        x = torch.cat([x2, x], dim=1)
        
        x = self.ffn(x)
        return F.softmax(x, dim=1)
    
    def loss(self, action_probabilities, returns):
        return -torch.mean(torch.mul(torch.log(action_probabilities), returns))
    
    def init_weights(m, *args):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform(m.weight)
            m.bias.data.fill_(0.01)

# Instantiate environment, network, and trainer

In [None]:
neuralNet = PolicyNet(**netKwargs)
env = environment.randomEnv() # could be stopEnv, simpleEnv All agents are same type
#env = environment.manualEnv(["si", "ra", "ra"]) si = simple, st = stop, ra = random
valueTrainer = train.PolicyTraining(env, neuralNet, **trainingKwargs)

# Start training

In [None]:
valueTrainer.train()

# Visualize 