<a href="https://colab.research.google.com/github/Rfrowein/armAIF/blob/main/ArmAIF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import os
import sys
import numpy as np
import random
import math

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import Rectangle, Circle
from matplotlib.transforms import Bbox

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import imageio

from statistics import mean
from IPython.display import clear_output

from tqdm import tqdm

import torch
from torch import nn
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.optim as optim

from torchvision.utils import save_image

import cv2
from skimage.util import random_noise
from PIL import Image

#create folder /data in current working directory, if it does not exist yet
if not os.path.exists(os.getcwd()+'/data'):
    os.mkdir(os.getcwd()+'/data')


#Create folder /networks in current working directory, if it does not exist yet
if not os.path.exists(os.getcwd()+'/networks'):
        os.mkdir(os.getcwd()+'/networks')

In [4]:
'''
Info: Create an environment containing an arm

Input:
  -position: Location of the arm in the environment [-1,1]
  -name: Name of the newly created figure (.png,.jpg, etc. is not needed)

Output: -- save figure --
'''
def create_image_v2(position, name):

    image = mpimg.imread('arm.png')
    
    #frameon = false, removes the outer axis --> needed as when turning axis('off') will also remove background
    fig = plt.figure(figsize=(2,1.5),frameon=False)
    
    #Create environment, y:[0, 1.5] x:[0, 2.7] (the 2.7 is the -1,1 environment + 0.7 of arm width preventing out of bounds)
    ax1 = fig.add_axes([0, 0, 2.7, 1.5])
    
    #Create arm, which is plotted inside the environment
    ax2 = fig.add_axes([(position+1), 0, 0.7, 1]) #as the plot runs from 0 to 2 instead of -1 to 1, +1 is added
    ax2.axis('off')
    ax2.imshow(image, aspect='auto', zorder=-1) #add image
    
    #Variable extent is needed to prevent padding from forming when saving. 
      # -as the saved image is used for the neural network padding breaks it (when using a black background)
    #Bbox_inches = 'tight' will leave padding
    #Bbox_inches = 0 will leave no usable image
    #Pad_inches = 0 will leave padding, setting it to -0.32 will remove padding, but deform image (altough still usable)
    extent = ax1.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) 
    plt.savefig(name + '.png',bbox_inches=extent)



In [13]:
'''
Info: Create image, read image (add noise) and transform it to a tensor

Input:
  -name: String containing the name of the to be saved image (exclude .png)
  -location (optional): Single x-axis positional input
  -noise: True/False for adding Gaussian noise to the image (0.01)

Outout: img_tensor , location
  -img_tensor: Tensor of the image
  -location: The x location of the image
'''

def create_tensor_v2(name, location = None, noise = True):
    
    if location is None:
        location = round(random.uniform(-1, 1),2) # Create random location between -1 and 1 (2 decimal)
        
    create_image_v2(location,name) # Create image
    
    #Create and read image
    img = cv2.imread(name + '.png',0)
    img = cv2.resize(img, (40,40))
    if noise:
      img = random_noise(img,mode='gaussian')
    
    #Convert image to tensor
    img_tensor = torch.from_numpy(img)
    return img_tensor, location  #converts the nparray image to tensor (I,X)



In [6]:
'''
Info: Create randomly generated data (grayscale) and save it in folder /data (folder is created if not present)

Input:
  -nr_data (optional): number of randomly generated data

Output: tuple(list[location], list[img_tensor])
  -data_X: List of he x locations of the images
  -data_I: Images as tensor

'''

def create_data_v2(nr_data = 100, noise = True):
    plt.style.use('dark_background') #Change the style of ALL plots to black background [1]
    data_X = []
    data_I = []
    
    #Create random data
    for i in range(0,nr_data,1):
        I,X = create_tensor_v2(os.getcwd()+'/data/true_image'+str(i), noise = noise) #location is excluded to get random locations
        data_X.append(torch.FloatTensor([X]))
        data_I.append(I/255) # The I/255 is a conversion from RGB to grayscale 
        plt.close('all') #Close all plots and/or images (precaution for memory build up)
    return (data_X, data_I)


In [6]:
'''
Info: Neural network using 4 transposed convolutional layers to generate an image from a single horizontal positional location

Input: Single x locational variable
Output: 40 x 40 image

'''

class Net(nn.Module):
    def __init__(self):
        self.epoch_losses = []
        self.test_losses = []
        self.saved_lr = []
        self.saved_batch_size = []
        
        # The decoder uses 4 layeres, where 2 have rectangular kernels 
        super(Net,self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1 , 16, (1,10)),  #input (*batch_size*,1,1,1) output (*batch_size*,16,1,10)
            nn.ReLU(True),
            nn.ConvTranspose2d(16 , 8, (10,1)), #input (*batch_size*,16,1,10) output (*batch_size*,8,10,10)
            nn.ReLU(True),
            nn.ConvTranspose2d(8 , 4 , 9,stride=2), #input (*batch_size*,8,10,10) output (*batch_size*,4,27,27)
            nn.ReLU(True),
            #nn.Dropout(p=0.2),
            nn.ConvTranspose2d(4 , 1 , 14), #input (*batch_size*,4,27,27) output (*batch_size*,1,40,40)
            nn.Tanh()
        )
        
    def forward(self,x):
        x = self.decoder(x)
        return x
    '''
    Info: Evaluate the current model (uses entire dataset as batch)
    Input:
      -test_data (optional): tuple(list[x_location], list[image])
    Output:
      -test_loss: Mean squared error loss of the test_data
    '''
    def eval_model(self,test_data=None):
        model.eval()#set model to evaluation mode
        loss_func = nn.MSELoss()

        if test_data==None:
            X,I = create_data_v2(self.batch_size)
            test_data = (X,I)   

        with torch.no_grad():
            test_input = Variable(torch.stack(test_data[0]).view(len(test_data[0]),1,1,1)) 
            
            test_output = model.decoder(test_input)
            test_original = torch.stack(test_data[1]).view(len(test_data[1]),1,40,40)
                       
            loss = loss_func(test_output, test_original.type(torch.FloatTensor))
            test_loss = loss.item()
            return test_loss
    
    # Main function, call to run the model, trains and tests the current model
        # Input:
        # - data (optional): Training data as tuple (tensorlist locations, tensorlist images)
        # - epoch (optional): Number of cycles to run over the training (and test) data
        # - plot (optional): True/False, if True will plot information and prrogress every 10 epochs
        # - batch_size (optional): Set the batch size of the data --> if batch_size incompatible with data_size, data_size will reduce to fit in full batches
        # Output:
        # - epoch_losses: List of train losses from each epoch
        # - test_losses: List of test losses from each epoch

    '''
    Info: Train and test the neural network (testing data is created on top of the training data, 30%)
      -The neural network uses a schedular decreasing learning rate over time (based on ReduceLROnPlateau)
    Input:
      -data (optional): training data as tuple(list[x-location],list[image])
      -epochs (optional): number of cycles to run over the data
      -plot: True/False for plotting progress during runtime (increases runtime)
      -batch_size (optional): number of data to run trough before updating internal parameters
      -hybrid: True/False for increasing batch_size during runtime (increase every 10 epochs)
    Output:
      -epoch_losses: Mean squared error loss of each cycle (mean of MSE of the batches)
      -test_losses: Mean squared error loss of each cycle running oveer the test data
    '''
    def train_model(self, data=None, epochs=20 , plot=True, batch_size = 16, hybrid = True):
        # Initialization
        self.batch_size = batch_size
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        scheduler = ReduceLROnPlateau(optimizer, 'min', factor = .8)
        loss_func = nn.MSELoss()        

        # Create random data if none is given
        if data is None:
            X,I = create_data_v2(500)
            data = (X,I)

        # Create random evaluation/test data (with a ratio of about 30% of the total data)
        X_eval, I_eval = create_data_v2(math.floor((len(data)/0.7) * 0.3),noise = False)
        data_eval = (X_eval, I_eval)

        
        #train and test model
        for epoch in range(epochs):
            for param_group in optimizer.param_groups:
                self.saved_lr.append(param_group['lr'])

            if epoch%10 == 0 and hybrid:
                self.batch_size+=1
            self.saved_batch_size.append(self.batch_size) #save batch size

            batch_losses = []

            #shuffle train and test data at start of each epoch
            comb_data = list(zip(data[0], data[1])) # Combine X,I (keep the X with its attached I when shuffeling)
            comb_data_eval = list(zip(data_eval[0], data_eval[1]))
            random.shuffle(comb_data)
            random.shuffle(comb_data_eval)
            X, I = zip(*comb_data)
            X_eval, I_eval = zip(*comb_data)

            model.train() #Set model to train mode

            

            for i in (range(math.floor(len(X)/self.batch_size))): # turn the data in to batches of batch_size (rounded downwards)
                input = Variable(torch.stack(X[i * self.batch_size:(i + 1) * self.batch_size]).view(self.batch_size,1,1,1), requires_grad=True)
                original = torch.stack(I[i * self.batch_size:(i + 1) * self.batch_size]).view(self.batch_size,1,40,40)
                output = model.forward(input)
              
                #Forward pass
                optimizer.zero_grad() #zero the gradient buffers
                loss = loss_func(output, original.type(torch.FloatTensor)) # Mean Squared Error (MSE) loss
               
                #Backward pass
                loss.backward()
                optimizer.step()  #update optimizer
            
                batch_losses.append(loss.item()) 
                output = output.detach() #This is done for plotting purposes
            test_loss = self.eval_model((X_eval, I_eval)) #Set model to evaluation mode
            
            self.test_losses.append(test_loss)
            epoch_loss = mean(batch_losses)
            self.epoch_losses.append(epoch_loss)
        
            scheduler.step(epoch_loss) #update schedular
        
            # Plot information about the training and testing of the model during runtim if plot is true, otherwise a single plot at the end is shown
            if (epoch % 10 == 0 and plot==True and epoch !=0) or epoch == epochs-1:
                clear_output(wait=True) #Clear output field
                
                #Indicate losses of current cycle
                print('epoch [{}/{}]\nepoch loss: {}\ntest loss: {}\n'.format(epoch+1,epochs,epoch_loss,test_loss))        

                #Show all losses
                print('Loss plot (excluding first 5 epochs)')
                x1 = np.linspace(5,len(self.epoch_losses),len(self.epoch_losses)-5,endpoint=True)
                x2 = np.linspace(5,len(self.test_losses),len(self.test_losses)-5,endpoint=True)
                plt.plot(x1,self.epoch_losses[5::],'w',label='train')
                plt.plot(x2,self.test_losses[5::],'r--',label='test', alpha = 0.5)
                plt.legend(loc='upper right')
                plt.xlabel('epoch')
                plt.ylabel('MSE')
                plt.xlim(xmin=5)
                plt.show()
                plt.clf()

                #Show first/a random data generation from the neural network (shows progress)
                print('\nVisualization\n epoch: {}\n batch_size: {}'.format(epoch,self.batch_size))
                fig, ax = plt.subplots(nrows=2, sharex=True, figsize=(3, 5))
                ax[0].imshow(output[0][0].view(40,40), origin='upper', cmap='gray')
                ax[0].set_title('predicted')
                ax[0].axis('off')
                ax[1].imshow(original[0][0].view(40,40), origin='upper', cmap='gray')
                ax[1].set_title('original')
                ax[1].axis('off')
                plt.show()
                plt.clf()
        return self.epoch_losses, self.test_losses



In [7]:
'''
Info: Plot information (losses, learning rate, batch size) about trained neural network
'''

'''
Input:
  -model: trained neural network
  -reduce_start (optional): Set starting epoch for the plot
  -reduce_end (optional): Remove N amount of epochs from the end
Output: --Plot of train and test losses --
'''
def visualize_learning(model, reduce_start = 0, reduce_end = 0):
  #calculate amount of epochs
  len_data = np.linspace(reduce_start,len(model.epoch_losses)-reduce_end,len(model.epoch_losses)-reduce_start-reduce_end,endpoint = True)
  #plot train losses
  plt.plot(len_data,model.epoch_losses[reduce_start:len(model.epoch_losses)-reduce_end], 'b', label = 'train')
  #plot test losses
  plt.plot(len_data,model.test_losses[reduce_start:len(model.test_losses)-reduce_end], 'r', label = 'test')

  plt.legend(loc='upper right')
  plt.xlabel('epoch')
  plt.ylabel('MSE loss')
  plt.xlim(xmin=reduce_start)
  plt.title('epoch losses: '+str(reduce_start) + ' - '+str(len(model.epoch_losses) - reduce_end))
  plt.show()

'''
Input:
  -model: trained neural network
Output: --Plot of learning rate and a plot of batch sizes--
'''
def visualize_params(model):
    x = np.linspace(0,len(model.saved_lr),len(model.saved_lr),endpoint=True)
    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5))
    ax1.plot(x, model.saved_lr)
    ax1.set_title('learning rate')
    ax1.set_xlabel('epoch')
    ax1.set_ylabel('lr')
    ax1.set_ylim([0,0.00012])
    ax2.plot(x, model.saved_batch_size)
    ax2.set_title('batch size')
    ax2.set_xlabel('epoch')
    ax2.set_ylabel('size')


Only run the next 2/3 cells if you want to create your own neural network

In [None]:
'''
Info: Create training data for the neural network (large dataset recommended)
  -Only run once
'''
train_data_large = create_data_v2(1600, noise=True)
#train_data_medium = create_data_v2(800, noise = True)
#train_data_small = create_data_v2(200, noise = True)

In [None]:
'''
Info: Create, run and save neural network
  -There is a small possibility the network gets stuck
'''
model = Net() #create network

#In the case of hybrid it is advisary to start with a small batch size
train_loss, test_loss = model.train_model(data=train_data_large, batch_size=1, epochs=800, hybrid = True) #run network

#Save trained neural network,
torch.save(model, os.getcwd()+'/networks/trained_network_DataLarge_Hybrid.pth')


In [None]:
'''
Info: Show additional/more specific results
'''
model = torch.load(os.getcwd()+'/networks/trained_network_DataLarge_Hybrid.pth')  
visualize_learning(model,50,0)
visualize_params(model)


In [8]:
'''
Info: Save the states created by armAI (active inference)

Input:
  -armAI: The armAI containing the values needed for plotting
  -itt: The current itteration/state
  -visual: True/False for plotting sensory visual state
  -proprioception: True/False for plotting sensory proprioception state
  -goal: True/False for plotting goal state
Output: -- image of the armAI state --
'''

def save_state_v4(armAI, itt, visual, proprioception, goal):

    legend = []
    #get image to be plotted as indicators
    image = mpimg.imread('arm.png')

    #Create figure
    fig = plt.figure(figsize=(2,1.5),frameon=False)

    # Create environment
    env = fig.add_axes([0, 0, 2.7, 1.5],alpha=0.5,facecolor='white') # The 2.7 is the -1,1 environment + 0.7 of arm width (that the arm at outer location -1 and 1 are still within the environment)
    env.get_yaxis().set_visible(False)
    env.get_xaxis().set_visible(False)

    #Add belief/mental (blue)
    mental_x_clipped = np.clip(armAI.mental_x,-1,1)
    belief = fig.add_axes([(mental_x_clipped+1), 0, 0.7, 1],facecolor = 'b')
    belief.patch.set_alpha(0.2)
    belief.get_yaxis().set_visible(False)
    belief.get_xaxis().set_visible(False)
    belief.spines['bottom'].set_color('b')
    belief.spines['top'].set_color('b') 
    belief.spines['right'].set_color('b')
    belief.spines['left'].set_color('b')
    belief.patch.set_alpha(0.5)
    belief_leg = mpatches.Patch(color='b', label='Mental | ' + str(round(armAI.mental_x.item(), 3)))
    legend.append(belief_leg)
    
    #Use the generated mental state as plot for belief
    belief_img = armAI.network.decoder(Variable(armAI.mental_x.view(-1,1,1,1), requires_grad=False)).detach()
    env.imshow(belief_img.view(40,40),cmap = 'gray',aspect='auto')

    #Use the arm.png as plot for belief
    #belief.imshow(image, aspect='auto', alpha=0.5)

    # Add visual arm (red)
    if visual:
       visual_x_clipped = np.clip(armAI.visual_x,-1,1) #prevent plotting out of bounds
       vis = fig.add_axes([(visual_x_clipped+1), 0, 0.7, 1],facecolor='r') # the +1 as the environment runs from 0 - 2
       vis.patch.set_alpha(0.2) #give 'faint' background collor
       vis.imshow(image, aspect='auto',alpha=0.8)
       vis.get_yaxis().set_visible(False)
       vis.get_xaxis().set_visible(False)
       vis.spines['bottom'].set_color('red')
       vis.spines['top'].set_color('red') 
       vis.spines['right'].set_color('red')
       vis.spines['left'].set_color('red')
       vis.patch.set_alpha(0.5)
       vis_leg = mpatches.Patch(color='red', label='Visual | ' + str(round(armAI.visual_x.item(),3)))
       legend.append(vis_leg)

    #Add proprioception circle (yellow)
    if proprioception: 
       prop_x_clipped = np.clip(armAI.proprioception_loc,-1,1) #prevent plotting out of bounds
       prop = fig.add_axes([(prop_x_clipped+1), 0, 0.7, 1]) # the +1 as the environment runs from 0 - 2
       circ = Circle((0.35,0.5),0.1, color = 'yellow')
       prop.add_patch(circ)
       prop.patch.set_alpha(0.5)
       prop.axis('off')
       prop_leg = mpatches.Patch(color='yellow', label='Prop | ' + str(round(armAI.proprioception_loc.item(),3)))
       legend.append(prop_leg)
    
    #Add goal (green)
    if goal: 
        goal_x_clipped = np.clip(armAI.goal_x,-1,1) #prevent plotting out of bounds
        attr = fig.add_axes([(goal_x_clipped+1), 0, 0.7, 1],facecolor='g' )
        attr.get_yaxis().set_visible(False)
        attr.get_xaxis().set_visible(False)
        attr.spines['bottom'].set_color('g')
        attr.spines['top'].set_color('g') 
        attr.spines['right'].set_color('g')
        attr.spines['left'].set_color('g')
        attr.patch.set_alpha(0.5) 
        attr.imshow(image, aspect='auto', alpha=0.5)
        attr.patch.set_alpha(0.2)
        attr_leg = mpatches.Patch(color='g', label='Goal | ' + str(round(armAI.goal_x.item(),3)))
        legend.append(attr_leg)

    #Set title and legend
    env.text(.5,.9, 'State: ' + str(itt) , horizontalalignment='center', transform=env.transAxes, color = 'white')
    env.legend(handles=legend, prop={'size': 10})

    plt.savefig(os.getcwd()+'/data/state: '+ str(itt) + '.png', bbox_inches='tight')
    return fig


In [9]:
'''
Info: Helper functions for armAI (preventing excessive large cells)
'''

class AI_helper():
    def __init__(self, armAI):
      self.armAI = armAI

    #Info: check for neural network, load 'trained_network_DataLarge_Hybrid' if none given
    def check_network(self):
      if self.armAI.network == None:
          print('No neural network was given. Searching in folder networks for \'trained_network_DataLarge_Hybrid.pth\'')
          if os.path.exists(os.getcwd()+'/networks/trained_network_DataLarge_Hybrid.pth') == False:
              print('No network found, please run the \'Train, Test and Save model\' section ')
              sys.exit()
          return torch.load(os.getcwd()+'/networks/trained_network_DataLarge_Hybrid.pth')    
      else:
          return self.armAI.network
    
    #Info: Check if the initialization is correct
    def check_initialization(self):
        if len(self.armAI.visual) == 0 or self.armAI.mental_x == None or (self.armAI.induce_movement and len(self.armAI.attractor_img) ==  0):
            print('\nINITIALIZATION INCOMPLETE: check set_mental(), set_visual() and/or set_attractor()')
            print('vis: {}\nmental: {}\nattr: {}'.format(len(self.armAI.visual), self.armAI.mental_x , len(self.attractor_img)))
            return False
        return True

    #Info: Create a new visual image based on new_location
    def update_visual_state(self, new_location):
        self.armAI.visual_prop = new_location
        img,loc = create_tensor_v2('visual_state', new_location, noise = False)
        self.armAI.set_visual(torch.FloatTensor([loc]),img/255)
        plt.clf()

    #Info: Add random normal Gaussian noise to value    
    def add_noise(self, value, noise):
        return value + np.random.normal(0, noise)

    #Info: Reform the mu data and visual states as .gif
    def save_gif(self, name, action):
        img_list = []
        for i in range(len(self.armAI.mu_dot_data)):
            plt.style.use('classic')
            fig, ax = plt.subplots(nrows=2, figsize = (10,9))
            ax[1].plot(range(i), self.armAI.mu_dot_data[0:i], 'r', label='mu_dot | ' + str(round(self.armAI.mu_dot_data[i].item(),3)))
            if action:
                ax[1].plot(range(i), self.armAI.a_dot_data[0:i], 'g--', label = 'a_dot | ' + str(round(self.armAI.a_dot_data[i].item(),3)))
            ax[1].set_xlabel('itteration')
            ax[1].set_title('mu')
            ax[1].legend(loc='lower center', fontsize='x-large')
            ax[1].plot(range(i),np.zeros(i),'b--')

            state = Image.open(os.getcwd()+'/data/state: '+ str(i) + '.png')
            ax[0].imshow(state)
            ax[0].set_xlabel('location')
            ax[0].set_title('Environment')
            ax[0].axis('off')

            fig.canvas.draw()
            image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
            image  = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            plt.close(fig)

            img_list.append(image)


        kwargs_write = {'fps':4.0, 'quantizer':'nq'}
        imageio.mimsave(str(name) + '.gif', img_list, fps=4)
        #img_list[0].save('temp.gif', save_all = True, optimize = False, duration = 100, loop = 0)
        #img_list[0].save('States_all_v2.gif', save_all=True, append_images=img_list[1:], optimize=False, duration=2000, loop=0)

          

      
    

In [41]:
'''
Info: armAI, active inference | Using 2 sensory inputs, internal state (and goal state) to minimize prediction errors.
*Note* somehow the vision is reversed, unable to find the cause..
  -perceptual minimization
  -action minimization
  --goal minimization

Input:
  -Starting values
  -limitations

Output: -- .gif containing |all states| process of reaching equilibrium
'''
class ArmAI():
    def __init__(self,max_itt=100, network = None, dt = 0.1):
        
      #INITIALIZE CLASS VARIABLES
        self.dt = dt
        self.max_itt = max_itt
        self.network = network
        self.AI_helper = AI_helper(self)
        self.finished = False
        
        #Sigma's
        self.sigma_vis= 0.1
        self.sigma_prop = 0.95
        self.sigma_dyn = 0.95
        self.sigma_dyn_prop = 1
        
        #adap sigma
        self.adapt_sigma_vis = True 
        self.sigma_v_gamma = 0.1

        #Gain's
        self.mu_gain = 0.1
        self.a_gain = 0.2

        self.action_gain = 0.1
        self.perception_gain = 0.1
        self.goal_gain = 0.05

        self.visual_gain = 10
        
        #Error data (plotting purposes)
        self.error_vis = [] 
        self.error_prop = []
        self.error_mental = []
        self.error_mental_prop = []
        self.mu_dot_data = []
        self.a_dot_data = []

        #Visual (perception)
        self.visual_img = torch.FloatTensor() #visual_state
        self.visual_x = 0 #visual_location (plotting purposes)
        self.pred_err_vis = torch.FloatTensor()
        self.pred_err_vis_prior = torch.FloatTensor()

        #Proprioception (perception)
        self.proprioception_loc = 0

        #Mental/belief
        self.mental_img = torch.FloatTensor() #mental state
        self.mental_x = 0 #mental location/proprioceptive location

        #Attractor/goal
        self.goal_img = torch.FloatTensor()
        self.goal_x = 0 #attractor location
      
      #CHECK FOR NETWORK
        #initialize network
        self.network = self.AI_helper.check_network() 
        #set network to evaluation mode
        self.network.eval()
        
    #SET CLASS VARIABLES
    def set_mental(self,belief_loc): #Mental/Belief position
        self.mental_x = belief_loc
        
    def set_goal(self, goal_loc, goal_vis):#Goal position
        self.goal_vis = goal_vis
        self.goal_x = goal_loc 

    def set_visual(self, visual_x, visual_img):#Sensory visual
        self.visual_img = visual_img

        self.visual_x = visual_x #(plotting/visualizing purposes)

    def set_proprioception(self, prop_loc): #Sensory proprioception
        self.proprioception_loc = prop_loc
    
    #Visual error (error based on vision)
    def visual_error(self, input_image):
        #Generate mental/belief image
        input = Variable(self.mental_x.view(-1,1,1,1), requires_grad=True)           
        self.mental_img = self.network.decoder(input)
    
        err_vis = (input_image - self.mental_img.detach()) #prediction error
        err_vis_var = torch.var(err_vis) #Variance, Note that precision is the inversed variance
        #Error visual
        #err_vis_var = np.clip(err_vis_var,0.0014,1.5)
        error_vis = 1/err_vis_var * (input_image - self.mental_img) 
 
        #Backward pass
        input.grad = torch.zeros(input.size())
        self.mental_img.backward(0.1 * error_vis, retain_graph=True)

        return input.grad

    #Locational/proprioception error (error based on horizontal locations)
    def prop_error(self,input_prop):
        error_prop = (1/0.05) * (input_prop - self.mental_x) #Precision * prediciton error
        return error_prop   

    '''
    Info: Active inference with the capability of moving towards a goal
    Input: based on initialization
      -perception (optional):True/False for using perception to minimize variational free energy
      -action (optional): True/False for using action to minimize variational free energy
      -sense_vis (optional): True/False for using sensory vision
      -sense_prop (optional): True/False for using sensory proprioception 
      -goal (optional): True/False for moving towards a goal state
      -name (optional): Name of the resulting .gif
    Output: -- a name.gif containing the progress --
    '''
    def active_inference(self, perception = True, action = True, sense_vis = True, sense_prop = True, goal = False, name = 'armAI'):
      
      for i in range(self.max_itt):
          #Reset variables and save current state
          save_state_v4(self, i, sense_vis, sense_prop, goal)
          mu_vis = 0
          mu_prop = 0
          mu_action_vis = 0
          mu_action_prop = 0
          mu_goal_vis = 0 #preferred state
          mu_goal_prop = 0 #preferred state

          #Perception part of minimizing surprise
          if perception:
            if sense_vis:
              mu_vis =  -1 * self.visual_error(self.visual_img)
            if sense_prop:
              self.proprioception_loc_noise = self.AI_helper.add_noise(self.proprioception_loc, 0)
              mu_prop = self.prop_error(self.proprioception_loc_noise)
          #Action part of minimizing surprise
          if action:
            if goal: #This would be an inferred action that would reduce future free energy (expected free energy [https://arxiv.org/pdf/2004.08128.pdf])
              #goal is not a sensory state, thus needs to be generated trough the neural network to compare with mental state
              mu_goal_vis = -1 * self.visual_error(self.goal_vis)
              mu_goal_prop =  self.prop_error(self.goal_x)

            #Variational free energy
            if sense_vis:
                #error sensory to mental (-1 * error mental to sensory)
                mu_action_vis =  self.visual_error(self.visual_img)
                 
            if sense_prop:
                self.proprioception_loc_noise = self.AI_helper.add_noise(self.proprioception_loc, 0)
                #error proprioception to mental (-1 * error mental to sensory)
                mu_action_prop = (-1) *self.prop_error(self.proprioception_loc_noise) 

          #Sum the prediction errors
          a_dot =    mu_action_vis +  mu_action_prop #action 'velocity'
          if goal:
            mu_dot =    mu_goal_vis + mu_goal_prop #mental 'velocity' (mu_goal_vis/prop = 0 if goal=False)
          else:
            mu_dot =    mu_vis+  mu_prop 

          #print('mu_vis: {}\n mu_prop: {}\n mu_action_vis: {}\n mu_action_prop: {}\n mu_goal_vis: {}\nmu_goal_prop: {}\n\n mu_dot: {}\na_dot: {}\n\n'.format(mu_vis,mu_prop,mu_action_vis,mu_action_prop,mu_goal_vis,mu_goal_prop,mu_dot,a_dot))

          #save data (plotting purposes)
          a_dot = a_dot * self.a_gain
          mu_dot = mu_dot * self.mu_gain
          if action:
            self.a_dot_data.append(a_dot)
          self.mu_dot_data.append(torch.FloatTensor([mu_dot]))

          #Update states
          
          new_vis_loc = torch.add(self.visual_x, a_dot, alpha=self.dt)
          self.AI_helper.update_visual_state(new_vis_loc) #Update visual arm
          self.proprioception_loc = torch.add(self.proprioception_loc,  a_dot, alpha=self.dt) #Update proprioceptive location
          self.mental_x = torch.add(self.mental_x,  mu_dot, alpha=self.dt) #Update mental location
          plt.close('all')


          
      #save all states to a gif
      self.AI_helper.save_gif(name, action)
    
              
            


The cell below contains test setups

In [None]:
'''
## Base Case
# -All variables are set to 0. Too much movement indicating variable adjustments need to be made
visual_img, visual_loc = create_tensor_v2('init vision', location = 0, noise = False)  #Create visual state
mental_loc = 0 # Does not require an image
proprioception_loc = 0
goal_img,goal_loc = create_tensor_v2('init goal', location = 0, noise = False) 

armAI = ArmAI(max_itt = 50,dt = 0.05)

armAI.set_visual(torch.FloatTensor([visual_loc]),visual_img/255)
armAI.set_mental(torch.FloatTensor([mental_loc]))
armAI.set_proprioception(torch.FloatTensor([proprioception_loc]))
armAI.set_goal(torch.FloatTensor([goal_loc]),goal_img/255)

armAI.active_inference(perception = True, action = True, sense_vis = True, sense_prop = True, goal = True, name = 'Base_case')


## Perception test
#   -Check the model while only based on perception can perform as expected
#     -Both visual and proprioception states are initialized at the same location. Mental state at a random location with the expectation to correct itself

visual_img, visual_loc = create_tensor_v2('init vision', location = 0.7, noise = False)  #Create visual state
mental_loc = -0.7 # Does not require an image
proprioception_loc = 0.7
goal_img,goal_loc = create_tensor_v2('init goal', location = 0, noise = False) 

armAI = ArmAI(max_itt = 80,dt = 0.05)

armAI.set_visual(torch.FloatTensor([visual_loc]),visual_img/255)
armAI.set_mental(torch.FloatTensor([mental_loc]))
armAI.set_proprioception(torch.FloatTensor([proprioception_loc]))
armAI.set_goal(torch.FloatTensor([goal_loc]),goal_img/255)

armAI.active_inference(perception = True, action = False, sense_vis = True, sense_prop = True, goal = False, name = 'Perception_test')


## Action test
#   -Check the model while only based on action can perform as expected
#     -Both visual and proprioception states are initilized at the same location. Mental state at a random location with the expectation of the sensory states to correct themselved

visual_img, visual_loc = create_tensor_v2('init vision', location = 0.7, noise = False)  #Create visual state
mental_loc = -0.7 # Does not require an image
proprioception_loc = 0.7
goal_img,goal_loc = create_tensor_v2('init goal', location = 0, noise = False) 

armAI = ArmAI(max_itt = 80,dt = 0.05)

armAI.set_visual(torch.FloatTensor([visual_loc]),visual_img/255)
armAI.set_mental(torch.FloatTensor([mental_loc]))
armAI.set_proprioception(torch.FloatTensor([proprioception_loc]))

armAI.active_inference(perception = False, action = True, sense_vis = True, sense_prop = True, goal = False, name = 'Action_test')


## Goal test
#   -Check if the model is capable of moving towards a goal state
#     -Visual, proprioception and mental states are initialized at the same location. A goal state is set with the expectation that all finish at the goal state, with the mental state leading the way.

visual_img, visual_loc = create_tensor_v2('init vision', location = 0.7, noise = False)  #Create visual state
mental_loc = 0.7 # Does not require an image
proprioception_loc = 0.7
goal_img,goal_loc = create_tensor_v2('init goal', location = -0.4, noise = False) 


armAI = ArmAI(max_itt = 80,dt = 0.05)

armAI.set_visual(torch.FloatTensor([visual_loc]),visual_img/255)
armAI.set_mental(torch.FloatTensor([mental_loc]))
armAI.set_proprioception(torch.FloatTensor([proprioception_loc]))
armAI.set_goal(torch.FloatTensor([goal_loc]),goal_img/255)

armAI.active_inference(perception = True, action = True, sense_vis = True, sense_prop = True, goal = True, name = 'Goal_test')



#--- Special cases --

## Rubbed hand illusion
#   -The proprioception and mental state are initialized at the same location. The visual state is the rubber hand used in the RHI expriments, expecting that some mental discplacement will come fort

visual_img, visual_loc = create_tensor_v2('init vision', location = 0.1, noise = False)  #Create visual state
mental_loc = 0.5 # Does not require an image
proprioception_loc = 0.5


armAI = ArmAI(max_itt = 80,dt = 0.05)

armAI.set_visual(torch.FloatTensor([visual_loc]),visual_img/255)
armAI.set_mental(torch.FloatTensor([mental_loc]))
armAI.set_proprioception(torch.FloatTensor([proprioception_loc]))

armAI.active_inference(perception = True, action = False, sense_vis = True, sense_prop = True, goal = False, name = 'Rubber_hand_illusion')

## Full random placement
#   -The mental state and the proprioception state are expected to move towards eachother. 
#     However, in the case that the mental state and visual state find overlapment a visual error can be produced and taken into account.

visual_img, visual_loc = create_tensor_v2('init vision', location = 0, noise = False)  #Create visual state
mental_loc = -0.8 # Does not require an image
proprioception_loc = 0.8


armAI = ArmAI(max_itt = 80,dt = 0.05)

armAI.set_visual(torch.FloatTensor([visual_loc]),visual_img/255)
armAI.set_mental(torch.FloatTensor([mental_loc]))
armAI.set_proprioception(torch.FloatTensor([proprioception_loc]))

armAI.active_inference(perception = True, action = True, sense_vis = True, sense_prop = True, goal = False, name = 'Full_random_visMid')

#----
visual_img, visual_loc = create_tensor_v2('init vision', location = -0.5, noise = False)  #Create visual state
mental_loc = 0.3 # Does not require an image
proprioception_loc = 0.7


armAI = ArmAI(max_itt = 80,dt = 0.05)

armAI.set_visual(torch.FloatTensor([visual_loc]),visual_img/255)
armAI.set_mental(torch.FloatTensor([mental_loc]))
armAI.set_proprioception(torch.FloatTensor([proprioception_loc]))

armAI.active_inference(perception = True, action = True, sense_vis = True, sense_prop = True, goal = False, name = 'Full_random_mentalMid')


## Proprioception only
#   - Active inference performed without Vision

visual_img, visual_loc = create_tensor_v2('init vision', location = -0.7, noise = False)  #Create visual state
mental_loc = -0.7 # Does not require an image
proprioception_loc = 0.7
#goal_img,goal_loc = create_tensor_v2('init goal', location = 0.4, noise = False) 


armAI = ArmAI(max_itt = 80,dt = 0.05)

armAI.set_visual(torch.FloatTensor([visual_loc]),visual_img/255)
armAI.set_mental(torch.FloatTensor([mental_loc]))
armAI.set_proprioception(torch.FloatTensor([proprioception_loc]))
armAI.set_goal(torch.FloatTensor([goal_loc]),goal_img/255)

armAI.active_inference(perception = True, action = True, sense_vis = False, sense_prop = True, goal = False, name = 'Prop_only')


## Vision only
#   - Active inference performed without proprioception (Losing Touch: A man without his body)
#       -No overlap (with and wihout goal)
#       -Overlap (with and without goal)

visual_img, visual_loc = create_tensor_v2('init vision', location = 0.2, noise = False)  #Create visual state
mental_loc = -0.7 # Does not require an image
proprioception_loc = 0.7


armAI = ArmAI(max_itt = 80,dt = 0.05)

armAI.set_visual(torch.FloatTensor([visual_loc]),visual_img/255)
armAI.set_mental(torch.FloatTensor([mental_loc]))

armAI.active_inference(perception = True, action = True, sense_vis = True, sense_prop = False, goal = False, name = 'VisOnly_NoOverlap')

#-----
'''
visual_img, visual_loc = create_tensor_v2('init vision', location = 0.6, noise = False)  #Create visual state
mental_loc = 0.8 # Does not require an image
armAI.set_proprioception(torch.FloatTensor([proprioception_loc]))


armAI = ArmAI(max_itt = 100,dt = 0.05)

armAI.set_visual(torch.FloatTensor([visual_loc]),visual_img/255)
armAI.set_mental(torch.FloatTensor([mental_loc]))
armAI.set_proprioception(torch.FloatTensor([proprioception_loc]))


armAI.active_inference(perception = True, action = True, sense_vis = True, sense_prop = False, goal = False, name = 'VisOnly_Overlap')

