In [387]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

# torch.set_printoptions(profile="full")

In [388]:
class Controller(nn.Module):
    def __init__(self, Imatrix, Rmatrix):
        super(Controller, self).__init__()
        self.bias = torch.ones(Rmatrix.shape[0])
        self.Imatrix = Imatrix
        self.Rmatrix = Rmatrix
        self.register_buffer('prev_x', torch.zeros(Rmatrix.shape[0]))

    def forward(self, x):
        self.prev_x = self.Imatrix @ x + self.Rmatrix @ self.prev_x + self.bias
        return torch.argmax(self.prev_x)

In [389]:
class Optimizer():
    def __init__(self, game, num_interactions, initial_variance):
        # TODO: organize this shit
        self.game = game
        self.env = gym.make(game, obs_type="rgb", frameskip=5)
        self.search_dimensionality = self.env.action_space.n + self.env.action_space.n ** 2
        self.search_u = torch.zeros(self.search_dimensionality) 
        self.search_Sigma = initial_variance * torch.eye(self.search_dimensionality)
        self.search_A = torch.linalg.cholesky(self.search_Sigma)
        self.sigma = torch.pow(torch.abs(torch.linalg.det(self.search_A)), 1 / self.search_dimensionality)
        self.B = self.search_A / self.sigma
        self.normal_distribution = torch.distributions.MultivariateNormal(torch.zeros(self.search_dimensionality), torch.eye(self.search_dimensionality))
        self.population_size = int(1.5 * (4 + int(3 * np.log(self.search_dimensionality))))
        self.num_interactions = num_interactions
        self.lr_u = 0.5
        self.lr_s = (3 / 5) * ((3 + np.log(self.search_dimensionality)) / self.search_dimensionality * np.sqrt(self.search_dimensionality))
        self.lr_b = self.lr_s
        self.env.close()

    def generate_controller_weights(self):
        controller_weights = []
        for k in range(self.population_size):
            sk = self.normal_distribution.sample()
            zk = self.search_u + self.sigma * self.B.T @ sk
            controller_weights.append((sk, zk))
        return controller_weights

    def initialise_controllers(self, controller_weights):
        controllers = []
        for weighttuple in controller_weights:
            weight = weighttuple[1]
            Rmatrix = torch.reshape(weight[:self.env.action_space.n ** 2], (self.env.action_space.n, self.env.action_space.n))
            Imatrix = torch.reshape(weight[self.env.action_space.n ** 2:], (self.env.action_space.n, -1))
            controllers.append(Controller(Imatrix, Rmatrix))
        return controllers

    def run_episodes(self, controllers, compressor):
        """
        TODO: split the function into smaller functions
        """
        fitness = []
        memory = torch.zeros((1, 80,70))
        for controller in controllers:
            env = gym.make(self.game, obs_type="rgb", frameskip=5)
            observation, info = env.reset()
            cummulative_reward = 0

            for i in range(self.num_interactions):
                comp_observation = compressor.downsize_image(torch.tensor(observation))
                encoded_observation = compressor.encode_observation(comp_observation)
                action = controller(encoded_observation)
                observation, reward, terminated, truncated, info = env.step(action)
                cummulative_reward += reward

                if not torch.any(memory):
                    memory[0] = comp_observation
                else:
                    memory = torch.cat((memory, comp_observation), 0)

                if terminated or truncated:
                    break
            env.close()
            fitness.append(cummulative_reward)
                    
        return fitness, memory
    
    def update_search_distribution(self, controller_weights, fitness):
        """
        TODO: split the function into smaller functions
        """
        sorted_weights = sorted(zip(fitness, controller_weights), key=lambda item: (item[0], item[1][0][0].item()), reverse=False)
        sorted_weights = [sk for _, (sk, _) in sorted_weights]
        divisor = sum(np.max((0.0, np.log(self.population_size / 2 + 1) - np.log(j))) for j in range(1, self.population_size + 1))
        utilities = [np.max((0.0, np.log(self.population_size / 2 + 1) - np.log(i))) / divisor for i in range(1, self.population_size + 1)]

        # Gradients calculation
        grad_d = torch.zeros(self.search_dimensionality)
        grad_m = torch.zeros(self.search_dimensionality, self.search_dimensionality)

        for i, sk in enumerate(sorted_weights):
            uk = utilities[i]
            grad_d += uk * sk
            grad_m += uk * (sk.unsqueeze(1) @ sk.unsqueeze(0) - torch.eye(self.search_dimensionality))
        
        grad_s = torch.trace(grad_m) / self.search_dimensionality
        grad_b = grad_m - grad_s * torch.eye(self.search_dimensionality)

        # Parameters update
        self.search_u += self.lr_u * self.sigma * self.B @ grad_d
        self.sigma = self.sigma * torch.exp(self.lr_s / 2 * grad_s)
        self.B = self.B @ torch.linalg.matrix_exp(self.lr_b / 2 * grad_b)
        self.search_A = self.sigma * self.B
        self.search_Sigma = self.search_A.T @ self.search_A

    def rescale_search_distribution(self, compressor, variance_new_weights=1): # worked semi-fine on 1
        """
        
        """
        no_RNNinputs = (self.search_dimensionality - self.env.action_space.n ** 2) / self.env.action_space.n
        additional_inputs = compressor.dictionary.shape[0] - no_RNNinputs
        if additional_inputs > 0:
            # Update mu
            self.search_u = torch.cat((self.search_u, torch.zeros(int(additional_inputs * self.env.action_space.n))))

            # Extend Sigma with new rows and columns
            new_dimensionality = self.search_u.shape[0]
            extended_Sigma = torch.eye(new_dimensionality) * variance_new_weights
            extended_Sigma[:self.search_Sigma.shape[0], :self.search_Sigma.shape[0]] = self.search_Sigma
            self.search_Sigma = extended_Sigma
            self.search_Sigma += 1e-4 * torch.eye(new_dimensionality)
            
            # Update consequent parameters
            # TODO: make seperate function that is called both in init and here
            self.search_dimensionality = new_dimensionality
            self.search_A = torch.linalg.cholesky(self.search_Sigma)
            self.sigma = torch.pow(torch.abs(torch.linalg.det(self.search_A)), 1 / self.search_dimensionality)
            self.B = self.search_A / self.sigma
            self.normal_distribution = torch.distributions.MultivariateNormal(torch.zeros(self.search_dimensionality), torch.eye(self.search_dimensionality))

            # Update learning parameters
            self.population_size = 4 + int(3 * np.log(self.search_dimensionality))
            self.lr_s = (3 / 5) * ((3 + np.log(self.search_dimensionality)) / self.search_dimensionality * np.sqrt(self.search_dimensionality))
            self.lr_b = self.lr_s
            

In [390]:
class Compressor():
    def __init__(self, epsilon, omega, residual_threshold):
        self.input_pixels = 70 * 80                                 # This could be done cleaner
        self.dictionary = torch.zeros((1, self.input_pixels))
        self.epsilon = epsilon
        self.omega = omega
        self.threshold = self.input_pixels * residual_threshold

    def downsize_image(self, observation):
        comp_observation = observation.permute(2, 0, 1)
        comp_observation = comp_observation.float().mean(0).unsqueeze(0).unsqueeze(0)
        comp_observation = F.interpolate(comp_observation, (80, 70), mode='bilinear', align_corners=False)
        return comp_observation.squeeze(0)

    def encode_observation(self, observation):
        """
        DRSC
        """
        P = torch.flatten(observation) / 255
        o = torch.zeros(self.dictionary.shape[0])
        w = 0

        while torch.sum(P) / 5600 > self.epsilon and w < self.omega:
            S = torch.norm(self.dictionary - P.unsqueeze(0), dim=1)
            msc = torch.argmin(S)
            
            o[msc] = 1
            w += 1
            P = P - self.dictionary[msc]
            F.relu(P, inplace=True)
        return o
    
    def update_dictionary(self, training_set):
        """
        IDVQ
        """
        for image in training_set:
            P = torch.flatten(image) / 255

            if not torch.any(self.dictionary):
                self.dictionary = P.unsqueeze(0)
            else:
                o = self.encode_observation(image)
                P_hat = self.dictionary.T @ o
                R = P - P_hat
                torch.nn.functional.relu(R, inplace=True)
                if torch.sum(R) > self.threshold:
                    self.dictionary = torch.cat((self.dictionary, R.unsqueeze(0)), dim=0)

In [391]:
class Trainer():
    def __init__(self, game, num_interactions,initial_variance, epsilon, omega, residual_threshold):
        # TODO: template system for different games pre-settings
        self.optimizer = Optimizer(game, num_interactions, initial_variance)
        self.compressor = Compressor(epsilon, omega, residual_threshold)

    def train_population(self):
        print("Population size is: ", self.optimizer.population_size)
        controller_weights = self.optimizer.generate_controller_weights()
        controllers = self.optimizer.initialise_controllers(controller_weights)
        fitness, memory = self.optimizer.run_episodes(controllers, self.compressor)
        print("The average reward is: ", sum(fitness) / len(fitness))
        self.optimizer.update_search_distribution(controller_weights, fitness)
        self.compressor.update_dictionary(memory)
        print("Dictionary size is: ", self.compressor.dictionary.shape[0])
        self.optimizer.rescale_search_distribution(self.compressor)
        print("")

In [392]:
trainer = Trainer(game="ALE/Qbert-v5",
                  num_interactions=200,
                  initial_variance=1,
                  epsilon=0.005,
                  omega=3,
                  residual_threshold=0.00499
                 )
for i in range(60):
    print(f"--------- RUN {i + 1} ---------")
    trainer.train_population()

--------- RUN 1 ---------
Population size is:  22
The average reward is:  73.86363636363636
Dictionary size is:  4

--------- RUN 2 ---------
Population size is:  16
The average reward is:  3.125
Dictionary size is:  4

--------- RUN 3 ---------
Population size is:  16
The average reward is:  0.0
Dictionary size is:  4

--------- RUN 4 ---------
Population size is:  16
The average reward is:  0.0
Dictionary size is:  4

--------- RUN 5 ---------
Population size is:  16
The average reward is:  0.0
Dictionary size is:  4

--------- RUN 6 ---------
Population size is:  16
The average reward is:  0.0
Dictionary size is:  4

--------- RUN 7 ---------
Population size is:  16
The average reward is:  0.0
Dictionary size is:  4

--------- RUN 8 ---------
Population size is:  16
The average reward is:  0.0
Dictionary size is:  4

--------- RUN 9 ---------
Population size is:  16
The average reward is:  0.0
Dictionary size is:  4

--------- RUN 10 ---------
Population size is:  16
The average rew