# Installations, Imports, and Drive Setup

In [None]:
    !pip install PyDrive

In [None]:
import os
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
from google.colab import files

In [None]:
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [None]:
download = drive.CreateFile({'id': '1Bwd_Ma8mqCho0diPLoqPrLhsrAA77oEi'})
download.GetContentFile('games.zip')

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
!unzip "/content/gdrive/My Drive/AlphaHitler/games.zip"

# Network Training

In [None]:
import json

def populate_inputs(game_number_start, game_number_end, max_num_govs):
    """
    Populates the input nodes with game information.

    Parameters
    ----------
    game_number_start : int
        Start game parsing at game_number_start.json
    game_number_end : int
        End game parsing at game_number_end.json
    max_num_govs : int
        Truncate elements of games greater than length max_num_govs

    Returns
    -------
    games : list
        List of size num_valid_games X max_num_govs X num_input_nodes containing the input layers created for each valid game
    results : list
        List of size num_valid_games X 7 containing fascist results for each valid game (1 = fascist, 0 = liberal)
    game_numbers : list
        List of length num_valid_games containing game numbers for all valid games
    game_lengths : list
        List of length num_valid_games containing game lengths for all valid games for use by torch.nn.utils.rnn.pack_padded_sequence to keep forward() parallelizable
    outed : list
        List of size num_valid_games X 7 storing who is "outed" fascist for each valid game (1 = outed, 0 = not outed) for use in calculation of "conflict confidence" metric
    """

    games = []
    num_valid_games = 0
    game_lengths = []
    results = []
    game_numbers = []
    outed = []
    num_input_nodes = 89

    # Treat each lib seat POV as a different game to augment data
    for lib_number in range(0, 4):
        for game_number in range(game_number_start, game_number_end):
            # Is the game valid
            game_valid = True
            # Contains all the data for this game
            game_data = []
            # Who is "outed" as a fascist
            game_outed = [0] * 7
            # Who is "confirmed" as a liberal
            confirmed = [0] * 7
            # Who is "outed" as hitler
            game_hitler = [0] * 7
            # Who is confirmed not hitler (cnh)
            game_cnh = [0] * 7
            # Number of played (majority vote "Ja") governments
            num_played_govs = 0
            # Roles of the players in the game (1 = fascist, 0 = liberal)
            roles = []

            file_name = "games/" + str(game_number) + ".json"

            # Number of lib and fas cards on the board
            lib_cards_played = 0
            fas_cards_played = 0

            with open(file_name) as f:
                data = json.load(f)

                # Check not custom
                if data["customGameSettings"]["enabled"]:
                    continue
                # Check if the game is 7 players
                if len(data["players"]) != 7:
                    continue
                # Check not rebalanced 7p
                if data["gameSetting"]["rebalance7p"]:
                    continue

                # Fill roles
                for seat in range(0, 7):
                    roles.append(1 if (data["players"][seat]["role"] == "fascist" or data["players"][seat]["role"] == "hitler") else 0)

                # Pick one of the lib seats to play as
                lib_count = 0
                confirmed_seat = 0
                for seat in range(7):
                    if data["players"][seat]["role"] == "liberal":
                        if lib_count == lib_number:
                            confirmed_seat = seat
                        lib_count = lib_count + 1

                # One-hot encode which seat the network is playing
                confirmed[confirmed_seat] = 1

                if len(data["logs"]) <= 0:
                    game_valid = False

                # Length 8 (1-7 - chancellor seat number) (8 - result)
                investigation_data = [0] * 8
                # Length 7 (1-7 - chosen seat number)
                special_election_data = [0] * 7
                # Length 14 (1-7 - shot seat number)
                bullet_data_1 = [0] * 7
                # Length 14 (1-7 - shot seat number)
                bullet_data_2 = [0] * 7

                # For each government
                for gov in range(0, len(data["logs"])):
                    # Is this government valid
                    gov_valid = False
                    # Did a veto force a td
                    veto_and_td = False
                    # Lenth 31 (1-7 - pres, 8-14 chanc, 15-18 pres claim, 19-21 chanc claim, 22 veto, 23 blue, 24 red, 25-31 vote data)
                    gov_data = []
                    # Contains topdeck data
                    topdeck = []

                    # If the government was played
                    if len(data["logs"][gov]) >= 7:
                        num_played_govs += 1
                        gov_valid = True

                        # President seat number
                        for pres in range(0, 7):
                            gov_data.append(1 if data["logs"][gov]["presidentId"] == pres else 0)

                        # Chancellor seat number
                        for chan in range(0, 7):
                            gov_data.append(1 if data["logs"][gov]["chancellorId"] == chan else 0)

                        pres_claim = (data["logs"][gov]["presidentClaim"]["reds"] if "presidentClaim" in data["logs"][gov] else -1) if data["logs"][gov]["presidentId"] != confirmed_seat else data["logs"][gov]["presidentHand"]["reds"]
                        chanc_claim = (data["logs"][gov]["chancellorClaim"]["reds"] if "chancellorClaim" in data["logs"][gov] else -1) if data["logs"][gov]["chancellorId"] != confirmed_seat else data["logs"][gov]["chancellorHand"]["reds"]          

                        # President number of reds claimed
                        if "presidentClaim" in data["logs"][gov]:
                            gov_data.append(1 if pres_claim == 0 else 0)
                            gov_data.append(1 if pres_claim == 1 else 0)
                            gov_data.append(1 if pres_claim == 2 else 0)
                            gov_data.append(1 if pres_claim == 3 else 0)
                        elif "chancellorClaim" in data["logs"][gov]:
                            pres_claim = chanc_claim + 1
                            gov_data.append(0)
                            gov_data.append(1 if chanc_claim == 0 else 0)
                            gov_data.append(1 if chanc_claim == 1 else 0)
                            gov_data.append(1 if chanc_claim == 2 else 0)
                        elif "enactedPolicy" in data["logs"][gov]:
                            gov_data.append(0)
                            gov_data.append(0)
                            gov_data.append(0 if data["logs"][gov]["enactedPolicy"] == "fascist" else 1)
                            gov_data.append(1 if data["logs"][gov]["enactedPolicy"] == "fascist" else 0)
                        else:
                            game_valid = False

                        # Chancellor number of reds claimed
                        if "chancellorClaim" in data["logs"][gov]:
                            gov_data.append(1 if chanc_claim == 0 else 0)
                            gov_data.append(1 if chanc_claim == 1 else 0)
                            gov_data.append(1 if chanc_claim == 2 else 0)
                        elif "presidentClaim" in data["logs"][gov]:
                            if "enactedPolicy" in data["logs"][gov] and data["logs"][gov]["enactedPolicy"] == "fascist":
                                chanc_claim = 2
                                gov_data.append(0)
                                gov_data.append(0)
                                gov_data.append(1)
                            else:
                                chanc_claim = min(0, pres_claim - 1)
                                gov_data.append(1 if chanc_claim == 0 else 0)
                                gov_data.append(1 if chanc_claim == 1 else 0)
                                gov_data.append(0)
                        elif "enactedPolicy" in data["logs"][gov]:
                            gov_data.append(0)
                            gov_data.append(0 if data["logs"][gov]["enactedPolicy"] == "fascist" else 1)
                            gov_data.append(1 if data["logs"][gov]["enactedPolicy"] == "fascist" else 0)
                        else:
                            game_valid = False

                        # Encode card outed
                        if (confirmed_seat == data["logs"][gov]["presidentId"] or confirmed_seat == data["logs"][gov]["chancellorId"]) and (pres_claim - chanc_claim != 1 and pres_claim != 0) and "enactedPolicy" in data["logs"][gov] and data["logs"][gov]["enactedPolicy"] == "fascist":
                            game_outed[data["logs"][gov]["chancellorId" if confirmed_seat == data["logs"][gov]["presidentId"] else "presidentId"]] = 1
                            
                        # Veto
                        gov_data.append(1 if ("presidentVeto" in data["logs"][gov] and "chancellorVeto" in data["logs"][gov] and data["logs"][gov]["presidentVeto"] and data["logs"][gov]["chancellorVeto"]) else 0)

                        # Enacted policy
                        if "enactedPolicy" in data["logs"][gov]:
                            gov_data.append(0 if data["logs"][gov]["enactedPolicy"] == "fascist" else 1)
                            gov_data.append(1 if data["logs"][gov]["enactedPolicy"] == "fascist" else 0)
                        else:
                            gov_data.append(0)
                            gov_data.append(0)

                        # Vote data
                        for seat in range(0, 7):
                          gov_data.append(1 if data["logs"][gov]["votes"][seat] else 0)
                        
                        # If investigation
                        if "investigationId" in data["logs"][gov]:
                            investigation_data[data["logs"][gov]["investigationId"]] = 1
                            if not "investigationClaim" in data["logs"][gov]:
                                game_valid = False
                            else:
                                investigation_data[7] = 1 if data["logs"][gov]["investigationClaim"] == "fascist" else 0

                        # Encode inv outed
                        if "investigationId" in data["logs"][gov] and (confirmed_seat == data["logs"][gov]["presidentId"] or confirmed_seat == data["logs"][gov]["investigationId"]) and "investigationClaim" in data["logs"][gov] and data["logs"][gov]["investigationClaim"] == "fascist":
                            game_outed[data["logs"][gov]["investigationId" if confirmed_seat == data["logs"][gov]["presidentId"] else "presidentId"]] = 1

                        # Encode inv confirmed
                        if "investigationId" in data["logs"][gov] and confirmed_seat == data["logs"][gov]["presidentId"] and roles[data["logs"][gov]["investigationId"]] == 0:
                            confirmed[data["logs"][gov]["investigationId"]] = 1

                        # Did a veto force a topdeck
                        if "presidentVeto" in data["logs"][gov] and "chancellorVeto" in data["logs"][gov] and data["logs"][gov]["presidentVeto"] and data["logs"][gov]["chancellorVeto"] and "enactedPolicy" in data["logs"][gov]:
                            veto_and_td = True
  
                        # If Special Election
                        if "specialElection" in data["logs"][gov]:
                            special_election_data[data["logs"][gov]["specialElection"]] = 1
                            
                        # If bullet
                        if "execution" in data["logs"][gov]:
                            # If first bullet
                            if not 1 in bullet_data_1:
                                bullet_data_1[data["logs"][gov]["execution"]] = 1
                            # If second bullet
                            else:
                                bullet_data_2[data["logs"][gov]["execution"]] = 1

                    # Neined government
                    elif "presidentId" in data["logs"][gov] and "chancellorId" in data["logs"][gov] and "votes" in data["logs"][gov] and len(data["logs"][gov]["votes"]) == 7:
                        gov_valid = True
                        # President seat number
                        for pres in range(0, 7):
                            gov_data.append(1 if data["logs"][gov]["presidentId"] == pres else 0)

                        # Chancellor seat number
                        for chan in range(0, 7):
                            gov_data.append(1 if data["logs"][gov]["chancellorId"] == chan else 0)

                        # Empty data
                        for fill in range(8):
                            gov_data.append(0)

                        # Encode enacted policy if topdecked
                        if "enactedPolicy" in data["logs"][gov]:
                            gov_data.append(0 if data["logs"][gov]["enactedPolicy"] == "fascist" else 1)
                            gov_data.append(1 if data["logs"][gov]["enactedPolicy"] == "fascist" else 0)
                        else:
                            gov_data.append(0)
                            gov_data.append(0)

                        # Vote data
                        for seat in range(0, 7):
                            gov_data.append(1 if data["logs"][gov]["votes"][seat] else 0)

                    else:
                        game_valid = False

                    # If the government was a topdeck
                    if (len(data["logs"][gov]) == 4 and ("enactedPolicy" in data["logs"][gov])) or veto_and_td:
                        gov_valid = True
                        topdeck.append(1)
                    else:
                        topdeck.append(0)

                    # If Hitler was elected in HZ
                    if "chancellorId" in data["logs"][gov] and len(data["logs"][gov]) == 3 and gov == len(data["logs"]) - 1 and data["players"][data["logs"][gov]["chancellorId"]]["role"] == "hitler" and fas_cards_played >= 3:
                        game_outed[data["logs"][gov]["chancellorId"]] = 1
                        game_hitler[data["logs"][gov]["chancellorId"]] = 1

                    # If Hitler was shot at any point
                    if "execution" in data["logs"][gov] and gov == len(data["logs"]) - 1 and data["players"][data["logs"][gov]["execution"]]["role"] == "hitler":
                        game_outed[data["logs"][gov]["execution"]] = 1
                        game_hitler[data["logs"][gov]["execution"]] = 1

                    # Encode chancellors in HZ being CNH
                    if "chancellorId" in data["logs"][gov] and len(data["logs"][gov]) >= 7 and fas_cards_played >= 3:
                        game_cnh[data["logs"][gov]["chancellorId"]] = 1

                    # All the people confirmed lib are also CNH, for example after inv
                    for seat in range(7):
                        game_cnh[seat] = game_cnh[seat] | confirmed[seat]

                    # If a policy was enacted, updated board fas and lib card counts
                    if "enactedPolicy" in data["logs"][gov]:
                        if data["logs"][gov]["enactedPolicy"] == "fascist":
                            fas_cards_played += 1
                        else:
                            lib_cards_played += 1

                    # Fill empty data with 0s
                    for i in range(len(gov_data), 31):
                        gov_data.append(0)
                    for i in range(len(investigation_data), 8):
                        investigation_data.append(0)
                    for i in range(len(special_election_data), 7):
                        special_election_data.append(0)
                    for i in range(len(bullet_data_1), 7):
                        bullet_data_1.append(0)
                    for i in range(len(bullet_data_2), 7):
                        bullet_data_2.append(0)

                    game_data.append(gov_data + investigation_data + special_election_data + bullet_data_1 + bullet_data_2 + topdeck + confirmed + game_outed + game_hitler + game_cnh)
                    
                    # If some misclaim / troll play will mess up training
                    for seat in range(7):
                        if confirmed[seat] == 1 and roles[seat] == 1:
                            game_valid = False

                if num_played_govs < 3:
                    game_valid = False

                if game_valid:
                    game_lengths.append(min(max_num_govs, len(game_data)))

                for i in range(len(game_data), max_num_govs):
                    game_data.append([0] * num_input_nodes)

                if game_valid:
                    num_valid_games += 1
                    games.append(game_data[:max_num_govs])
                    results.append(roles)
                    game_numbers.append(game_number)
                    outed.append(game_outed)


    # Remove games with fake conflicts
    for t in range(len(games) - 1, -1, -1):
        for seat in range(7):
            if outed[t][seat] == 1 and results[t][seat] == 0:
                del games[t]
                del results[t]
                del game_numbers[t]
                del game_lengths[t]
                del outed[t]
                break

    return games, results, game_numbers, game_lengths, outed

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
import gc
import sys

# Prints the entire numpy array
np.set_printoptions(threshold=sys.maxsize)

# Deterministic randomness
torch.manual_seed(0)

# Garbage collect
gc.collect()

# Enables device agnostic tensor creation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Empty the CUDA cache
with torch.no_grad():
    torch.cuda.empty_cache()
torch.cuda.empty_cache()

# Recurrent Neural Network
class RNN(nn.Module):
    def __init__(self, n_inputs, n_hidden, X_in, seq_lengths):
        super(RNN, self).__init__()
        self.rnn = nn.LSTM(n_inputs, n_hidden, batch_first=True, num_layers=1, dropout=0.0)
        self.n_hidden = n_hidden
        self.X = X_in
        self.seq_lengths = seq_lengths
        self.FC = nn.Linear(self.n_hidden, 7)
    def forward(self):
        # Initialize the hidden state with all zeroes
        self.init_hidden()
        # Pack X
        self.X_packed = torch.nn.utils.rnn.pack_padded_sequence(Variable(self.X), self.seq_lengths, batch_first=True, enforce_sorted=False)
        # Calculate values of hidden states
        _, self.hx = self.rnn(self.X_packed)
        # Run the states through the sigmoid function
        sigmoid = nn.Sigmoid()
        out = sigmoid(self.FC(self.hx[0][0])).to(device)

        return out
    def init_hidden(self):
        self.hx = Variable(torch.zeros(1, len(self.X), self.n_hidden).to(device))

# Hyperparameters
# N_EPOCHS determined at runtime based on training_length
LEARNING_RATE = 0.0015
LAMBDA_ = .0004
N_HIDDEN = 17

# Training lengths
training_lengths = []
for length in range(1, 30):
    training_lengths.append(length)

# Number of nodes in input layer
N_INPUT = 89

# Number of games
N_GAMES = 25000

# Train a model for each game_length
for training_length in training_lengths:

    # Training set
    X, Y, train_game_numbers, train_seq_lengths, train_outed = populate_inputs(2, int(N_GAMES * 0.8), training_length)

    # Validation set
    validation_X, validation_Y, validation_game_numbers, validation_seq_lengths, validation_outed = populate_inputs(int(N_GAMES * 0.8), int(N_GAMES * 1), training_length)

    # Convert to tensors
    X = torch.as_tensor(X, dtype=torch.float32).to(device)
    Y = torch.as_tensor(Y).to(device)
    train_game_numbers = torch.as_tensor(train_game_numbers).to(device)
    train_seq_lengths = torch.as_tensor(train_seq_lengths).to(device)
    train_outed = torch.as_tensor(train_outed).to(device)

    validation_X = torch.as_tensor(validation_X, dtype=torch.float32).to(device)
    validation_Y = torch.as_tensor(validation_Y).to(device)
    validation_game_numbers = torch.as_tensor(validation_game_numbers).to(device)
    validation_seq_lengths = torch.as_tensor(validation_seq_lengths).to(device)
    validation_outed = torch.as_tensor(validation_outed).to(device)

    # Training model
    train_model = RNN(N_INPUT, N_HIDDEN, X, train_seq_lengths).to(device)

    # Validation model
    validation_model = RNN(N_INPUT, N_HIDDEN, validation_X, validation_seq_lengths).to(device)

    # Set the loss function to binary cross entropy and use Adam optimizer
    criterion = nn.BCELoss().to(device)
    optimizer = optim.Adam(train_model.parameters(), lr=LEARNING_RATE, weight_decay=LAMBDA_)

    # Training metrics
    train_seat_errors = []

    # Validation metrics
    validation_seat_errors = []
    rounded_validation_seat_errors = []
    sorted_validation_seat_errors = []
    threshold_validation_seat_errors = []

    # The longer game lengths need more epochs to be fit
    N_EPOCHS = 5000 if training_length < 15 else 10000

    for epoch in range(N_EPOCHS):
        # Print the epoch number every 100 epochs
        if epoch % 100 == 0:
            print("Epoch " + str(epoch))

        train_model.train()
        validation_model.train()

        # zero the parameter gradients
        optimizer.zero_grad()

        # reset hidden states
        train_model.init_hidden()
        validation_model.init_hidden()

        # Run the model, collect states (store them)
        last_states = train_model()

        # Calculate loss and backpropogate
        loss = criterion(last_states.float(), Y.float())

        loss.backward()
        optimizer.step()

        # Evaluate the training set errors and store them
        train_seat_error = torch.sum(torch.abs(Y - last_states)) / len(Y)
        train_seat_errors.append(train_seat_error.detach().item())
        
        train_model.eval()

        # Transfer weights of training model to the validation model for validation
        torch.save(train_model.state_dict(), 'parameters-transfer')
        validation_model.load_state_dict(torch.load('parameters-transfer'))
        
        # Run the validation model on the validation data set and store predictions
        prediction = validation_model()

        # Store the errors for the validation set
        validation_seat_error = torch.sum(torch.abs(validation_Y - prediction)) / len(validation_Y)
        validation_seat_errors.append(validation_seat_error.detach().item())

        # Store the errors for the validation set, but round to either 0 or 1
        rounded_validation_seat_error = torch.sum(torch.abs(validation_Y - torch.round(prediction))) / len(validation_Y)
        rounded_validation_seat_errors.append(rounded_validation_seat_error.detach().item())

        # If this is the middle or last epoch
        if (epoch + 1) % (N_EPOCHS / 2) == 0:

            # Assume the 3 highest values are fascist and the rest are lib
            _, sorted_validation_seat_indices = torch.sort(prediction)
            sorted_prediction = torch.zeros(prediction.size()).to(device)
            for i in range(4, 7):
                for game in range(len(validation_X)):
                    sorted_prediction[game][sorted_validation_seat_indices[game][i]] = 1
            sorted_validation_seat_error = torch.sum(torch.abs(validation_Y - sorted_prediction)) / len(validation_Y)
            sorted_validation_seat_errors.append(sorted_validation_seat_error.detach().item())

        validation_model.eval()

    # Save the parameters for this model
    save_name = "parameters-" + str(training_length) + ".pt"
    torch.save(train_model.state_dict(), save_name)

    # Print validation set results
    print("\nValidation Set Data")
    print("Game #" + str(validation_game_numbers[-1]))
    print("Pred: " + str(prediction[-1]) + "\nReal: " + str(validation_Y[-1]))

    # Data after all epochs
    print("\n" + '\033[1m' + "End")
    print('\033[21m' + "Train error: " + "                      %.4f" % train_seat_errors[-1])
    print("Validation error: " + "                 %.4f" % validation_seat_errors[-1])
    print("Validation error with threshold: " + "  %.4f" % rounded_validation_seat_errors[-1])
    print("Validation error with sort: " + "       %.4f" % sorted_validation_seat_errors[-1])

    # Graph Train Error
    plt.subplot(3, 1, 1)
    plt.title("Raw Seat Errors")
    plt.plot(train_seat_errors)
    plt.plot(validation_seat_errors)

    # Graph validation set error
    plt.subplot(3, 1, 2)
    plt.title("Threshold Validation Error")
    plt.plot(rounded_validation_seat_errors)

    # Pad graphs so the titles and graphs don't overlap
    plt.tight_layout()

    # Show plot
    plt.show()

# Live Game Prediction

In [None]:
import json

def load_game(file_name):
    """
    Loads game information from a json file and populates the input layer of the LSTM

    Parameters
    ----------
    file_name : str
        Parse json file file_name

    Returns
    -------
    list
        List of containing one element: the data for this game
    """

    game_data = []
    game_outed = [0] * 7
    game_hitler = [0] * 7
    game_cnh = [0] * 7
    confirmed = [0] * 7

    lib_cards_played = 0
    fas_cards_played = 0

    with open(file_name) as file_in:
        data = json.load(file_in)

        # Encode network's seat
        for seat in range(0, 7):
            if data["players"][seat]["role"] == "liberal":
                confirmed[seat] = 1
                confirmed_seat = seat

        # Length 8 (1-7 - investigated seat number) (8 - result)
        investigation_data = [0] * 8
        # Length 7 (1-7 - chosen seat number)
        special_election_data = [0] * 7
        # Length 7 (1-7 - shot seat number)
        bullet_data_1 = [0] * 7
        # Length 7 (1-7 - shot seat number)
        bullet_data_2 = [0] * 7

        # For each government
        for gov in range(0, len(data["logs"])):

            # If a veto forces a td
            veto_and_td = False
            # Lenth 31 (1-7 - pres, 8-14 chanc, 15-18 pres claim, 19-21 chanc claim, 22 veto, 23 blue, 24 red, 25-31 vote data)
            gov_data = []
            # Topdeck data
            topdeck = []

            # If the government was played
            if len(data["logs"][gov]) >= 5:

                # President seat number
                for pres in range(0, 7):
                    gov_data.append(1 if data["logs"][gov]["presidentId"] == pres else 0)

                # Chancellor seat number
                for chan in range(0, 7):
                    gov_data.append(1 if data["logs"][gov]["chancellorId"] == chan else 0)

                pres_claim = data["logs"][gov]["presidentClaim"]["reds"]
                chanc_claim = data["logs"][gov]["chancellorClaim"]["reds"]

                # President number of reds claimed
                if "presidentClaim" in data["logs"][gov]:
                    gov_data.append(1 if pres_claim == 0 else 0)
                    gov_data.append(1 if pres_claim == 1 else 0)
                    gov_data.append(1 if pres_claim == 2 else 0)
                    gov_data.append(1 if pres_claim == 3 else 0)

                # Chancellor number of reds claimed
                if "chancellorClaim" in data["logs"][gov]:
                    gov_data.append(1 if chanc_claim == 0 else 0)
                    gov_data.append(1 if chanc_claim == 1 else 0)
                    gov_data.append(1 if chanc_claim == 2 else 0)

                # Encode card outed
                if (confirmed_seat == data["logs"][gov]["presidentId"] or confirmed_seat == data["logs"][gov]["chancellorId"]) and (pres_claim - chanc_claim != 1 and pres_claim != 0) and "enactedPolicy" in data["logs"][gov] and data["logs"][gov]["enactedPolicy"] == "fascist":
                    game_outed[data["logs"][gov]["chancellorId" if confirmed_seat == data["logs"][gov]["presidentId"] else "presidentId"]] = 1
                    
                # Veto
                gov_data.append(1 if ("presidentVeto" in data["logs"][gov] and "chancellorVeto" in data["logs"][gov] and data["logs"][gov]["presidentVeto"] and data["logs"][gov]["chancellorVeto"]) else 0)

                # Enacted policy
                if "enactedPolicy" in data["logs"][gov]:
                    gov_data.append(0 if data["logs"][gov]["enactedPolicy"] == "fascist" else 1)
                    gov_data.append(1 if data["logs"][gov]["enactedPolicy"] == "fascist" else 0)
                else:
                    gov_data.append(0)
                    gov_data.append(0)

                # Vote data
                for seat in range(0, 7):
                    gov_data.append(1 if data["logs"][gov]["votes"][seat] else 0)
                
                # If investigation
                if "investigationId" in data["logs"][gov]:
                    investigation_data[data["logs"][gov]["investigationId"]] = 1
                    investigation_data[7] = 1 if data["logs"][gov]["investigationClaim"] == "fascist" else 0

                # Encode inv outed
                if "investigationId" in data["logs"][gov] and (confirmed_seat == data["logs"][gov]["presidentId"] or confirmed_seat == data["logs"][gov]["investigationId"]) and data["logs"][gov]["investigationClaim"] == "fascist":
                    game_outed[data["logs"][gov]["investigationId" if confirmed_seat == data["logs"][gov]["presidentId"] else "presidentId"]] = 1

                # Encode inv confirmed
                if "investigationId" in data["logs"][gov] and confirmed_seat == data["logs"][gov]["presidentId"] and data["logs"][gov]["investigationClaim"] == "liberal": 
                    confirmed[data["logs"][gov]["investigationId"]] = 1

                # Did a veto force a topdeck
                if "presidentVeto" in data["logs"][gov] and "chancellorVeto" in data["logs"][gov] and data["logs"][gov]["presidentVeto"] and data["logs"][gov]["chancellorVeto"] and "enactedPolicy" in data["logs"][gov]:
                    veto_and_td = True

                # If Special Election
                if "specialElection" in data["logs"][gov]:
                    special_election_data[data["logs"][gov]["specialElection"]] = 1

                # If bullet
                if "execution" in data["logs"][gov]:
                    # If first bullet
                    if not 1 in bullet_data_1:
                        bullet_data_1[data["logs"][gov]["execution"]] = 1
                    # If second bullet
                    else:
                        bullet_data_2[data["logs"][gov]["execution"]] = 1

            # Neined government
            else:
                # President seat number
                for pres in range(0, 7):
                    gov_data.append(1 if data["logs"][gov]["presidentId"] == pres else 0)

                # Chancellor seat number
                for chan in range(0, 7):
                    gov_data.append(1 if data["logs"][gov]["chancellorId"] == chan else 0)

                # Empty data
                for fill in range(8):
                    gov_data.append(0)

                # Encode enacted policy if topdecked
                if "enactedPolicy" in data["logs"][gov]:
                    gov_data.append(0 if data["logs"][gov]["enactedPolicy"] == "fascist" else 1)
                    gov_data.append(1 if data["logs"][gov]["enactedPolicy"] == "fascist" else 0)
                else:
                    gov_data.append(0)
                    gov_data.append(0)

                # Vote data
                for seat in range(0, 7):
                    gov_data.append(1 if data["logs"][gov]["votes"][seat] else 0)

            # If the government was a topdeck
            if (len(data["logs"][gov]) == 4 and ("enactedPolicy" in data["logs"][gov])) or veto_and_td:
                topdeck.append(1)
            else:
                topdeck.append(0)

            # If Hitler was elected in HZ
            if len(data["logs"][gov]) == 4 and "hitler" in data["logs"][gov] and fas_cards_played >= 3:
                game_outed[data["logs"][gov]["chancellorId"]] = 1
                game_hitler[data["logs"][gov]["chancellorId"]] = 1

            # If Hitler was shot at any point
            if "execution" in data["logs"][gov] and "hitler" in data["logs"][gov]:
                game_outed[data["logs"][gov]["execution"]] = 1
                game_hitler[data["logs"][gov]["execution"]] = 1

            # Encode chancellors in HZ being CNH
            if len(data["logs"][gov]) >= 5 and fas_cards_played >= 3:
                game_cnh[data["logs"][gov]["chancellorId"]] = 1

            # All the people confirmed lib are also CNH, for example after inv
            for seat in range(7):
                game_cnh[seat] = game_cnh[seat] | confirmed[seat]

            if "enactedPolicy" in data["logs"][gov]:
                if data["logs"][gov]["enactedPolicy"] == "fascist":
                    fas_cards_played += 1
                else:
                    lib_cards_played += 1

            # Fill empty data with 0s
            for i in range(len(gov_data), 31):
                gov_data.append(0)

            game_data.append(gov_data + investigation_data + special_election_data + bullet_data_1 + bullet_data_2 + topdeck + confirmed + game_outed + game_hitler + game_cnh)

    return [game_data]

In [None]:
import json
import re

def convert_to_json(file_name):
    """
    Takes a file containing plaintext game notation and converts it to a json file that can then be interpreted by load_game()
     
    Example: file_name is the name of a file with the following plaintext game notation:

    SEAT 1
    1111111 - 15 RRB RB B

    This method will take that game notation and convert it to a json file containing the following, which can be interpreted by load_game()

    {"logs": [{"votes": [true, true, true, true, true, true, true], "presidentId": 0, "chancellorId": 4, "presidentClaim": {"reds": 2, "blues": 1}, "chancellorClaim": {"reds": 1, "blues": 1}, "enactedPolicy": "liberal"}], 
    "players": [{"role": "liberal"}, {"role": "not_me"}, {"role": "not_me"}, {"role": "not_me"}, {"role": "not_me"}, {"role": "not_me"}, {"role": "not_me"}]}
    
    Parameters
    ----------
    file_name : str
        Parse plaintext file file_name

    Returns
    -------
    new_file_name : str
        String containing the name of a new file containing game information encoded in json format
    """
    
    data = {}
    data["logs"] = []

    with open(file_name) as file_in:
        lines = file_in.readlines()

    # Create players list that stores my_seat as liberal
    my_seat_index = int(re.split('SEAT', lines[0])[1]) - 1
    data["players"] = []
    for seat in range(7):
        seat_data = {}
        seat_data["role"] = "liberal" if seat == my_seat_index else "not_me"
        data["players"].append(seat_data)

    # Create government logs
    for line in lines[1:]:

        # Create dict to represent this government
        gov = {}
        # Votes list
        gov["votes"] = (list(re.split(' - ', line)[0]))
        # Convert to boolean
        for seat in range(7):
            gov["votes"][seat] = gov["votes"][seat] == "1"

        # President and chancellor Ids
        gov["presidentId"] = int(re.split(' - ', line)[1][0]) - 1
        gov["chancellorId"] = int(re.split(' - ', line)[1][1]) - 1

        # If the gov was played
        if len(line) > 15:

            # President claim
            reds = re.split(' - ', line)[1].split()[1].count('R')
            # Create dict
            gov["presidentClaim"] = {}
            # Number of reds and blues claimed
            gov["presidentClaim"]["reds"] = reds
            gov["presidentClaim"]["blues"] = 3 - reds

            # Chancellor claim
            reds = re.split(' - ', line)[1].split()[2].count('R')
            # Create dict
            gov["chancellorClaim"] = {}
            # Number of reds and blues claimed
            gov["chancellorClaim"]["reds"] = reds
            gov["chancellorClaim"]["blues"] = 2 - reds

            # Enacted policy
            policy = re.split(' - ', line)[1].split()[3]
            gov["enactedPolicy"] = "fascist" if policy == 'R' else "liberal"

            # If a special action (inv, se, bullet) was taken or veto took place
            if line.count('-') == 2:
                # If investigation
                if "INV" in line:
                    gov["investigationId"] = int(re.split(' - ', line)[2].split()[1]) - 1
                    gov["investigationClaim"] = "liberal" if re.split(' - ', line)[2].split()[2] == "LIB" else "fascist"
                
                # If special election
                if "SE" in line:
                    gov["specialElection"] = int(re.split(' - ', line)[2].split()[1]) - 1
                
                # If bullet
                if "KILL" in line:
                    gov["execution"] = int(re.split(' - ', line)[2].split()[1]) - 1

                if "VETO" in line:
                    gov["presidentVeto"] = True
                    gov["chancellorVeto"] = True
        
        # Check if this gov was a topdeck
        elif len(line) == 15:
            policy = re.split(' - ', line)[1].split()[1]
            gov["enactedPolicy"] = "fascist" if policy == 'R' else "liberal"

        if "H" in line:
            gov["hitler"] = True

        # Append this gov to the data
        data["logs"].append(gov)

    new_file_name = "game_out.json"
    with open(new_file_name, 'w') as file_out:
        json.dump(data, file_out)

    return new_file_name

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
import gc
import sys

# Prints the entire numpy array
np.set_printoptions(threshold=sys.maxsize)

# Deterministic randomness
torch.manual_seed(0)

# Garbage collect
gc.collect()

# Enables device agnostic tensor creation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Empty the CUDA cache
with torch.no_grad():
    torch.cuda.empty_cache()
torch.cuda.empty_cache()

# Recurrent Neural Network
class RNN(nn.Module):
    def __init__(self, n_inputs, n_hidden, X_in, seq_lengths):
        super(RNN, self).__init__()
        self.rnn = nn.LSTM(n_inputs, n_hidden, batch_first=True, num_layers=1, dropout=0.0)
        self.n_hidden = n_hidden
        self.X = X_in
        self.seq_lengths = seq_lengths
        self.FC = nn.Linear(self.n_hidden, 7)
    def forward(self):
        # Initialize the hidden state with all zeroes
        self.init_hidden()
        # Pack X
        self.X_packed = torch.nn.utils.rnn.pack_padded_sequence(Variable(self.X), self.seq_lengths, batch_first=True, enforce_sorted=False)
        # Calculate values of hidden states
        _, self.hx = self.rnn(self.X_packed)
        # Run the states through the sigmoid function
        sigmoid = nn.Sigmoid()
        out = sigmoid(self.FC(self.hx[0][0])).to(device)

        return out
    def init_hidden(self):
        self.hx = Variable(torch.zeros(1, len(self.X), self.n_hidden).to(device))

# Size of input layer
N_INPUT = 89
# Size of hidden layer
N_HIDDEN = 17

# Name of game file
file_name = "game_in.txt"

# Training set
X = load_game(convert_to_json(file_name))

game_length = len(X[0])

# Convert to tensor
X = torch.as_tensor(X, dtype=torch.float32).to(device)

# Create model
model = RNN(N_INPUT, N_HIDDEN, X, [game_length]).to(device)

# Load the correct model parameters
parameter_file_name = "parameters-" + str(game_length) + ".pt"
model.load_state_dict(torch.load(parameter_file_name))

# Get prediction
prediction = model()

print(prediction)