In [2]:
from collections import deque
from enum import Enum
import torch
import random
import tqdm
from torch import nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Robot class

Given a maze matrix, the robot will start from the start state and go to the target state.

.train() method will train the robot.
.q_matrix to get the q matrix.

In [3]:

class Robot:
    """
    Given a maze matrix, the robot will start from the start state and go to the target state.
    """
    def __init__(self, start_state: tuple[int,int], target_state: tuple[int,int], maze_matrix: torch.Tensor, 
                 actions: torch.Tensor = torch.tensor([[0, -1], [0, 1], [-1, 0], [1, 0]], dtype=torch.int8),
                 reward_matrix: torch.Tensor = None, restrict_matrix: torch.Tensor = None, q_matrix: torch.Tensor = None, 
                 max_step = 100, greedy_rate = 0.7, discount = 0.9, lr = 0.1, 
                 wall_reward = -100, target_reward = 100):
        
        self.start_state = torch.tensor(start_state)
        self.target_state = torch.tensor(target_state)
        self.maze_matrix = maze_matrix
        self.actions = actions
        self.max_step = max_step
        self.greedy_rate = greedy_rate
        self.discount = discount
        self.lr = lr
        
        self.n_action = actions.size(0)
        
        if reward_matrix is None:
            reward_matrix = -torch.ones_like(maze_matrix, dtype=torch.int8)
            reward_matrix[maze_matrix] = wall_reward
            reward_matrix[target_state[0], target_state[1]] = target_reward
        self.reward_matrix = reward_matrix
            
        if restrict_matrix is None:
            restrict_matrix = torch.zeros_like(reward_matrix, dtype=torch.bool)
        self.restrict_matrix = restrict_matrix            
        
        if q_matrix is None:
            q_matrix = torch.zeros_like(reward_matrix, dtype=torch.float32).repeat(self.n_action,1,1)
        self.q_matrix = q_matrix
    
    def choose_action_id(self, state: torch.Tensor, greedy_rate: float = None) -> int:
        """
        Choose an action based on the given state and greedy rate.
        """
        
        if greedy_rate is None:
            greedy_rate = self.greedy_rate
            
        # ramdom choose action
        if random.random() > greedy_rate:
            return random.randint(0, self.n_action - 1)
        # greedy choose action
        else:
            q_values = self.q_matrix[:, state[0], state[1]] # get q values of all actions
            max_value = torch.max(q_values) # get max q value of all actions
            max_indices = torch.where(q_values == max_value)[0].tolist() # get indices of max q value
            return random.choice(max_indices) # choose one of the max q value
    
    def update(self, state:torch.Tensor, action_id: int, next_state: torch.Tensor, discount = None, lr = None):
        """
        Update the q matrix based on the given state, action, next state, discount and learning rate.
        """
        if discount is None:
            discount = self.discount
        if lr is None:
            lr = self.lr
        
        reward = self.reward_matrix[next_state[0], next_state[1]]
        q_old = self.q_matrix[action_id, state[0], state[1]]
        
        if self.maze_matrix[next_state[0], next_state[1]]:
            q_new = reward
        else:
            q_next_max = torch.max(self.q_matrix[:, next_state[0], next_state[1]])
            q_new = reward + discount * q_next_max
        
        self.q_matrix[action_id, state[0], state[1]] += lr * (q_new - q_old)
    
    def go(self, start_state: torch.Tensor, valid: bool = False, debug: bool = False) -> list:
        """
        Start from the start state and go to the target state.
        """
        states = [start_state]
        state = states[-1]
        # while not stop_matrix[state[0], state[1]] and len(states) <= MAX_STEP:
        while not torch.allclose(state, self.target_state) and len(states) <= self.max_step:
            if valid: # if in valid mode, choose the best action
                action_id = self.choose_action_id(state, 1)
            else:   # if not in valid mode, choose action based on greedy rate
                action_id = self.choose_action_id(state)
            # Get next state based on the action
            next_state = state + self.actions[action_id]
            if not valid: # if not in valid mode, update the q matrix
                self.update(state, action_id, next_state)

            if not self.maze_matrix[next_state[0], next_state[1]] + self.restrict_matrix[next_state[0], next_state[1]]: # if not hit the wall or restricted
                state = next_state
            else: # if hit the wall
                if debug:
                    print(f'{next_state}, hit the wall or restricted')
            states.append(next_state) # add the next state to the states list
        return states
    
    def train(self, n_epoch: int = 100, valid: bool = False, debug: bool = False) -> list:
        """
        Train the robot for n_epoch times.
        """
        paths = []
        for epoch in tqdm.trange(n_epoch):
            path = self.go(self.start_state, valid, debug)
            paths.append(path)
        return paths

    def eval(self, debug: bool = False) -> list:
        """
        Evaluate the robot.
        """
        return self.go(self.start_state, True, debug)


## ConvAutoEncoder

In [4]:

class BiasLayer(nn.Module):
    '''
    Bias Layer (add bias to individual network nodes/filter positions)
    '''
    def __init__(self, shape: tuple):
        '''
        Initialise parameters of bias layer
        ---
        INPUT
        shape: Requisite shape of bias layer
        '''
        super(BiasLayer, self).__init__()
        init_bias = torch.zeros(shape, device=device)
        self.bias = nn.Parameter(init_bias, requires_grad=True)

    def forward(self, x: torch.tensor)->torch.tensor:
        '''
        Forward pass
        ---
        INPUT
        x: Input features
        ---
        OUTPUT
        y: Output of bias layer
        '''
        y=x+self.bias
        return y

def conv2d_output_dims(x: 'tuple[int,int,int]', layer: nn.Conv2d)->'tuple[int,int,int]':
    """
    Unnecessarily complicated but complete way to
    calculate the output depth, height
    and width size for a Conv2D layer
    ---
    INPUT
    Args:
    x: Input size (depth, height, width)
    layer: The Conv2D layer
    ---
    OUTPUT:
    Tuple of out-depth/out-height and out-width
    Output shape as given in [Ref]
    Ref:
    https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
    """
    assert isinstance(layer, nn.Conv2d)
    p = layer.padding if isinstance(layer.padding, tuple) else (layer.padding,)
    k = layer.kernel_size if isinstance(layer.kernel_size, tuple) else (layer.kernel_size,)
    d = layer.dilation if isinstance(layer.dilation, tuple) else (layer.dilation,)
    s = layer.stride if isinstance(layer.stride, tuple) else (layer.stride,)
    in_depth, in_height, in_width = x
    out_depth = layer.out_channels
    out_height = 1 + (in_height + 2 * p[0] - (k[0] - 1) * d[0] - 1) // s[0]
    out_width = 1 + (in_width + 2 * p[-1] - (k[-1] - 1) * d[-1] - 1) // s[-1]
    return (out_depth, out_height, out_width)

class ConvAutoEncoder(nn.Module):
    '''
    A Convolutional AutoEncoder
    '''
    def __init__(self, x_dim: 'tuple[int,int,int]', K: int, nonlinear_ae: bool, nonlinear_std: bool, n_filters: int=10, filter_size: int=2, y_dim: 'tuple[int,int,int]'=None):
        '''
        Initialize parameters of ConvAutoEncoder
        ---
        INPUT
        x_dim: Input dimensions (channels, height, widths)
        K: message length/hidden dimension
        nonlinear_ae: are the activations in the autoencoder nonlinear?
        nonlinear_std: are the activations in the student nonlinear?
        n_filters: Number of filters (number of output channels)
        filter_size: Kernel size
        '''
        super().__init__()
        channels, height, widths = x_dim
        if y_dim is None:
            y_dim = x_dim
        y_channels, y_height, y_widths = y_dim

        # Encoder input bias layer
        self.enc_bias = BiasLayer(x_dim)
        # First encoder conv2d layer
        #32 different filters -> grid_dim x grid_dim x n_actions to grid_dim+1 x grid_dim+1 x 32
        self.enc_conv_1 = nn.Conv2d(channels, n_filters, filter_size, padding=filter_size-1, device=device)
        #32 different filters -> grid_dim+1 x grid_dim+1 x 32 to grid_dim+2 x grid_dim+2 x 32
        # Output shape of the first encoder conv2d layer given x_dim input
        conv_1_shape = conv2d_output_dims(x_dim, self.enc_conv_1)
        # Second encoder conv2d layer
        self.enc_conv_2 = nn.Conv2d(n_filters, n_filters, filter_size, padding=filter_size-1, device=device) #and here once again 32 different filters?!
        # Output shape of the second encoder conv2d layer given conv_1_shape input
        conv_2_shape = conv2d_output_dims(conv_1_shape, self.enc_conv_2)
        # The bottleneck is a dense layer, therefore we need a flattenning layer
        self.enc_flatten = nn.Flatten()
        # Conv output shape is (depth, height, width), so the flatten size is:
        flat_after_conv = conv_2_shape[0] * conv_2_shape[1] * conv_2_shape[2]
        # Encoder Linear layer
        self.enc_lin = nn.Linear(flat_after_conv, K, device=device)

        # Decoder Linear layer
        self.dec_lin = nn.Linear(K, flat_after_conv, device=device)
        # Unflatten data to (depth, height, width) shape
        self.dec_unflatten = nn.Unflatten(dim=-1, unflattened_size=conv_2_shape)
        # First "deconvolution" layer
        self.dec_deconv_1 = nn.ConvTranspose2d(n_filters, n_filters, filter_size, padding=filter_size-1, device=device)
        # Second "deconvolution" layer
        self.dec_deconv_2 = nn.ConvTranspose2d(n_filters, y_channels, filter_size, padding=filter_size-1, device=device)
        # Decoder output bias layer
        self.dec_bias = BiasLayer(y_dim)

        #booleans marking the nonlinearities
        self.nonlinear_ae=nonlinear_ae
        self.nonlinear_std=nonlinear_std

    def encode(self, q:torch.tensor)->torch.tensor:
        '''
        first half of autoencoder: encode q-matrix to create the message
        ---
        INPUT
        q: The Q-matrix
        ---
        OUTPUT
        m: The message, i.e. the encoded Q-matrix
        '''
        m = self.enc_bias(q)

        #nonlinear
        if self.nonlinear_ae:
            m = F.relu(self.enc_conv_1(m))
            m = F.relu(self.enc_conv_2(m))
        #linear
        else:
            m=self.enc_conv_1(m)
            m=self.enc_conv_2(m)

        m = self.enc_flatten(m)
        m = self.enc_lin(m)
        return m


    def decode_ae(self, m:torch.tensor)->torch.tensor:
        '''
        second half of autoencoder: reconstruct the original q-matrix from the message
        ---
        INPUT
        m: The message
        ---
        OUTPUT
        q: The decoded Q-matrix
        '''

        #nonlinear
        if self.nonlinear_ae:
            q = F.relu(self.dec_lin(m))
            q = self.dec_unflatten(q)
            q = F.relu(self.dec_deconv_1(q))
        #linear
        else:
            q=self.dec_lin(m)
            q = self.dec_unflatten(q)
            q=self.dec_deconv_1(q)

        q = self.dec_deconv_2(q)
        q = self.dec_bias(q)
        return q


    def forward(self, q: torch.tensor)->'tuple[torch.tensor, torch.tensor]':
        '''
        do a forward pass of the autoencoder, i.e. encoding and decoding, but without the student
        ---
        INPUT
        q: A number of messages combined in one tensor
        OUTPUT
        m: The student Q-matrices corresponding to the input messages
        q_rec: The reconstructed Q-matrix (by the second half of the autoencoder)
        '''
        m=self.encode(q)
        q_rec=self.decode_ae(m)
        return m, q_rec



## Train

In [None]:
# load data
batches = torch.load('batches.pt')
n_batches = len(batches)

In [6]:
# split data into train and test
train_test_ratio = 0.8
train_batches = batches[:int(n_batches*train_test_ratio)]
test_batches = batches[int(n_batches*train_test_ratio):]

In [7]:
# get dimensions
maze_tensor, d_tensor, q_matrices_tensor, q_matrix_combo_tensor, target_states, start_states = batches[0]
x_dim = tuple(q_matrices_tensor[0].shape) # (depth, height, width) = (4*2, 9, 9)
y_dim = tuple(q_matrix_combo_tensor[0].shape) # (depth, height, width) = (4, 9, 9)

In [None]:
# train autoencoder
n_epoches = 100

autoencoder = ConvAutoEncoder(x_dim, 30, True, True, y_dim=y_dim).to(device)
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.001)
criterion = nn.MSELoss()
for epoch in range(n_epoches):
    for batch in train_batches:
        maze_tensor, d_tensor, q_matrices_tensor, q_matrix_combo_tensor, target_states, start_states = batch
        optimizer.zero_grad()
        m, q_rec = autoencoder(q_matrices_tensor)
        loss = criterion(q_rec, q_matrix_combo_tensor)
        loss.backward()
        optimizer.step()
    print(f'epoch: {epoch}, loss: {loss.item()}')

# torch.save(autoencoder, 'autoencoder.pt')

## Test

In [None]:
# autoencoder = torch.load('autoencoder.pt')

from statistics import mean


with torch.no_grad():
    for batch in test_batches:
        maze_tensor, d_tensor, q_matrices_tensor, q_matrix_combo_tensor, target_states, start_states = batch
        m, q_rec = autoencoder(q_matrices_tensor)
        loss = criterion(q_rec, q_matrix_combo_tensor)
        min_len_deltas = []
        
        start_states = [tuple(start_state) for start_state in start_states]
        for q_matrix_combo, target_state in zip(q_rec, target_states):
            target_state = tuple(target_state)
            lens = []
            for start_state in start_states:
                robot = Robot(start_state, target_state, maze_tensor, q_matrix=q_matrix_combo)
                path = robot.eval()
                lens.append(len(path)-1)
                # print(path)
            min_len = min(lens)
            min_len_delta = int(min_len - d_tensor[target_state[0], target_state[1]])
            min_len_deltas.append(min_len_delta)
        print(f'loss: {loss.item()}, min_len_deltas_avg: {mean(min_len_deltas)}')