# 1. INIT - Imports

In [1]:
import os
import math as mt
import numpy as np
import matplotlib.pyplot as plt
import random
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from pympler import asizeof
from visdom import Visdom
from copy import deepcopy
import time
import networkx as nx
#this is set for the printing of Q-matrices via console
torch.set_printoptions(precision=3, sci_mode=False, linewidth=100)

from tqdm.auto import tqdm, trange

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#suppress scientific notation in printouts
np.set_printoptions(suppress=True)

In [2]:
def set_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

# 2. INIT - Parameters

In [3]:
init_color_hex: str = "#2f78b6" # hex color of init point (left top)
goal_color_hex: str = '#f0b226' # hex color of goal point (right bottom)
wall_color: list[float] = [0.7, 0.7, 0.7] # RGB color of walls

n_cell: int = 4 # how many original cells
grid_dim = n_cell * 2 - 1 # side length of the square gridworld, in the form of 2n-1
n_actions: int = 4 # how many actions are possible in each state
lava: bool = False #do we use lava states - i.e. accessible wall states - (True) or wall states (False)? (we now always use "False")
starting_state: int = 0
goal_state = grid_dim ** 2 - 1

# generation of maze
maze_gen: bool = False # generate wall_states? wall_state_dict {0: [], 1: [1], 2: [2], 3: [3], ... } consists of 
n_mazes: int = 3800

# rewards in the gridworld (-0.1, 100, -5)
step_reward: float = -0.1 # for taking a step
goal_reward: float = 2. # for reaching the goal
wall_reward: float = -1. # for bumping into a wall

# CA-TS Net settings
input_neurons: int = 2 # for network init
output_neurons = n_actions # modeling the Q(s, a) value by the output neurons w.r.t. number of action
using_mazearray_as_concept: bool = False
using_mazearray_as_concept_str = "_fixedconcept" if using_mazearray_as_concept else ""
concept_size: int = 64 if using_mazearray_as_concept is False else grid_dim * grid_dim # the concept vector size of CA-TS DQN
using_concept_eps: bool = True
concept_eps: float = 1.0
concept_eps_str = f"_cpteps{concept_eps}" if using_concept_eps else ""
using_res: bool = True
res_str = "_res" if using_res else ""
hidden_dims: list = [768] * 18
dim_str = "-".join(str(d) for d in hidden_dims)
q_s2a: bool = False # whether using Q(s) -> a: True or Q(s, a): False
q_str = "" if q_s2a else "_sa2q"
n_agents: int = 5 # how many anegts used in communication games
n_unique: int = 20 # unique mazes learned by each agent
n_group: int = 8

# CA-TS Net training settings
batch_size: int = 512 # 0 indicates using all data in buffer as a batch
epsilon: float = 0.1 # greedy action policy
lr: float = 1e-4 # learning rate
gamma_bellman: float = 0.9999 # bellman equation
target_replace_steps: int = 0 # renew target_net by eval_net after how many iter times, 0 indicates directly using eval_net as target_net
memory_capacity: int = 0 # number transitions stored, 0 indicates pre-store all transitions in memory (change training mode as epoch manner)
cap_str = "" if memory_capacity != 0 else "_prestore"
n_episode: int = 100000

# AE Net settings
n_recons_test: int = 20
n_align_test: int = 20
translater_dims: list = [1024] * 5
translater_dim_str = "-".join(str(d) for d in translater_dims)

# AE Net training settings
translater_batch_size: int = 128
translater_weight_decay: float = 1e-5
translater_align_lr: float = 1e-5
using_translater_eps: bool = False
n_epoch: int = 10000
print_frequency: int = 20
train_translater: bool = True
load_translater: bool = False
loss_type: str = 'L1Loss' # options['MSELoss', 'L1Loss', 'SmoothL1Loss']


# 3. FUNCTIONS - CA-TS Net class

In [4]:
class Net_sa2q(nn.Module):
    def __init__(self):
        super(Net_sa2q, self).__init__()
        self.using_res = using_res

        self.ts_fc_layers = nn.ModuleDict()
        self.ts_norm_layers = nn.ModuleDict()
        if self.using_res:
            self.ts_skip_layers = nn.ModuleDict()

        prev_dim = input_neurons + n_actions # dim 0, 1 is xy coordinates, dim 2 to 5 is action from 0 to 3 (right, uo, left, down)
        for i, dim in enumerate(hidden_dims):
            self.ts_fc_layers[f'fc{i}'] = nn.Linear(prev_dim, dim, bias=True)
            if self.using_res and prev_dim != dim:
                self.ts_skip_layers[f'skip{i}'] = nn.Linear(prev_dim, dim, bias=False)
            self.ts_norm_layers[f'norm{i}'] = nn.LayerNorm(dim)
            prev_dim = dim
            

        self.ts_fc_layers[f'fc{len(hidden_dims)}'] = nn.Linear(prev_dim, 1, bias=True)

        self.cdp_fc_layers = nn.ModuleDict()
        self.cdp_norm_layers = nn.ModuleDict()
        if self.using_res:
            self.cdp_skip_layers = nn.ModuleDict()

        prev_dim = concept_size
        for i, dim in enumerate(hidden_dims):
            self.cdp_fc_layers[f'fc{i}'] = nn.Linear(prev_dim, dim, bias=True)
            if self.using_res and prev_dim != dim:
                self.cdp_skip_layers[f'skip{i}'] = nn.Linear(prev_dim, dim, bias=False)
            self.cdp_norm_layers[f'norm{i}'] = nn.LayerNorm(dim)
            prev_dim = dim

        self.ts_afun = nn.ReLU()
        self.cdp_afun = nn.Sigmoid()

        self.concept_embedding_layer = nn.Embedding(num_embeddings=n_mazes, embedding_dim=concept_size)

    def forward(self, x, concept_idx=None):
        if concept_idx is not None:
            concept = self.concept_embedding_layer(concept_idx)
            cdp_activations = []
            c = concept
            for i in range(len(self.cdp_fc_layers)):
                if self.using_res:
                    identity = c
                    out = self.cdp_fc_layers[f'fc{i}'](c)
                    out = self.cdp_norm_layers[f'norm{i}'](out)
                    if f'skip{i}' in self.cdp_skip_layers:
                        identity = self.cdp_skip_layers[f'skip{i}'](identity)
                    c = out + identity
                else:
                    c = self.cdp_fc_layers[f'fc{i}'](c)
                    c = self.cdp_norm_layers[f'norm{i}'](c)

                c = self.cdp_afun(c)
                cdp_activations.append(c)

        for i in range(len(self.ts_fc_layers) - 1):
            if self.using_res:
                identity = x
                out = self.ts_fc_layers[f'fc{i}'](x)
                out = self.ts_norm_layers[f'norm{i}'](out)
                if f'skip{i}' in self.ts_skip_layers:
                    identity = self.ts_skip_layers[f'skip{i}'](identity)
                x = out + identity
            else:
                x = self.ts_fc_layers[f'fc{i}'](x)
                x = self.ts_norm_layers[f'norm{i}'](x)
                
            x = self.ts_afun(x)

            if concept_idx is not None:
                x = torch.mul(x, cdp_activations[i])

        x = self.ts_fc_layers[f'fc{len(self.cdp_fc_layers)}'](x)

        return x

# 4. FUNCTIONS - Recons datasets

In [5]:
class ReconsDataset(Dataset):
    def __init__(self, concept_data):
        self.concept_data = concept_data
    
    def __len__(self):
        return len(self.concept_data)

    def __getitem__(self, idx):
        return self.concept_data[idx]

# 5. FUNCTIONS - Alignment datasets

In [6]:
class AlignDataset(Dataset):
    def __init__(self, concept_data_list):
        self.n_agents = len(concept_data_list)
        self.concept_data_list = concept_data_list
    
    def __len__(self):
        return len(self.concept_data_list[0])

    def __getitem__(self, idx):
        grouped_item = []
        for i in range(self.n_agents):
            grouped_item.append(self.concept_data_list[i][idx])
        return grouped_item

# 6. FUNCTIONS - Processing exclude mazes

In [7]:
def get_recons_indices(n_mazes, n_agents, i_agent, n_unique, n_recons_test):
    common_end = n_mazes - n_agents * n_unique
    indices_list1 = list(range(0, common_end))
    
    unique_start = common_end + (i_agent - 1) * n_unique
    unique_end = unique_start + n_unique
    indices_list2 = list(range(unique_start, unique_end))

    combined_list = indices_list1 + indices_list2
    test_indices = random.sample(combined_list, n_recons_test)
    train_indices = [item for item in combined_list if item not in test_indices]

    return train_indices, test_indices

In [8]:
def get_align_indices(n_mazes, n_agents, n_unique, n_align_test):
    common_end = n_mazes - n_agents * n_unique
    indices_list = list(range(0, common_end))

    test_indices = random.sample(indices_list, n_align_test)
    train_indices = [item for item in indices_list if item not in test_indices]

    return train_indices, test_indices

In [9]:
def get_unique_indices(n_mazes, n_agents, i_agent, n_unique):
    common_end = n_mazes - n_agents * n_unique
    unique_start = common_end + (i_agent - 1) * n_unique
    unique_end = unique_start + n_unique

    return list(range(unique_start, unique_end))

In [10]:
def get_common_indices(n_mazes, n_agents, n_unique):
    common_end = n_mazes - n_agents * n_unique

    return list(range(0, common_end))

# 7. FUNCTIONS - Translater

In [11]:
class Translater(nn.Module):
    def __init__(self):
        super(Translater, self).__init__()

        self.trans_fc_layers = nn.ModuleDict()
        self.trans_norm_layers = nn.ModuleDict()
        self.trans_skip_layers = nn.ModuleDict()

        prev_dim = concept_size
        for i, dim in enumerate(translater_dims):
            self.trans_fc_layers[f'fc{i}'] = nn.Linear(prev_dim, dim, bias=True)
            if prev_dim != dim:
                self.trans_skip_layers[f'skip{i}'] = nn.Linear(prev_dim, dim, bias=False)
            self.trans_norm_layers[f'norm{i}'] = nn.LayerNorm(dim)
            prev_dim = dim
            
        self.trans_fc_layers[f'fc{len(translater_dims)}'] = nn.Linear(prev_dim, concept_size, bias=True)
        self.trans_afun = nn.ReLU()
    
    def forward(self, x):
        for i in range(len(self.trans_fc_layers) - 1):
            identity = x
            out = self.trans_fc_layers[f'fc{i}'](x)
            out = self.trans_norm_layers[f'norm{i}'](out)
            if f'skip{i}' in self.trans_skip_layers:
                identity = self.trans_skip_layers[f'skip{i}'](identity)
            x = out + identity
            x = self.trans_afun(x)

        x = self.trans_fc_layers[f'fc{len(translater_dims)}'](x)

        return x

# 10. EXECUTION - Constructing datasets

In [12]:
# set_seed(0)

In [13]:
agents = {}
for i in range(n_agents):
    i_agent = i + 1
    agents[i_agent] = Net_sa2q()
    ckpt_file: str = f"agent_ckpt/group{n_group}/ckpt{i_agent}of{n_agents}_unique{n_unique}_{grid_dim}x{grid_dim}_n{n_mazes}{q_str}{cap_str}_dim{dim_str}{res_str}_cptsz{concept_size}{concept_eps_str}{using_mazearray_as_concept_str}_lr{lr}_epsi{n_episode}_gamma{gamma_bellman}_bs{batch_size}_tr{target_replace_steps}.pt"
    agents[i_agent].load_state_dict(torch.load(ckpt_file))

  agents[i_agent].load_state_dict(torch.load(ckpt_file))


In [14]:
recons_trainsets = {}
recons_trainloaders = {}
recons_testsets = {}
recons_testloaders = {}

for i in range(n_agents):
    i_agent = i + 1
    recons_train_indices, recons_test_indices = get_recons_indices(n_mazes, n_agents, i_agent, n_unique, n_recons_test)
    recons_trainsets[i_agent] = ReconsDataset(agents[i_agent].concept_embedding_layer.weight.data[recons_train_indices])
    recons_testsets[i_agent] = ReconsDataset(agents[i_agent].concept_embedding_layer.weight.data[recons_test_indices])

    recons_trainloaders[i_agent] = DataLoader(recons_trainsets[i_agent], batch_size = translater_batch_size, shuffle = True)
    recons_testloaders[i_agent] = DataLoader(recons_testsets[i_agent], batch_size = translater_batch_size, shuffle = False)

align_train_indices, align_test_indices = get_align_indices(n_mazes, n_agents, n_unique, n_align_test)
align_train_concept = []
align_test_concept = []
for i in range(n_agents):
    i_agent = i + 1
    align_train_concept.append(agents[i_agent].concept_embedding_layer.weight.data[align_train_indices])
    align_test_concept.append(agents[i_agent].concept_embedding_layer.weight.data[align_test_indices])

align_trainset = AlignDataset(align_train_concept)
align_testset = AlignDataset(align_test_concept)

align_trainloader = DataLoader(align_trainset, batch_size = translater_batch_size, shuffle = True)
align_testloader = DataLoader(align_testset, batch_size = translater_batch_size, shuffle = False)


# 12. EXECUTION - Translation process

In [15]:
translater_2to1 = Translater().to(device)

global_optim = torch.optim.Adam(translater_2to1.parameters(), lr = translater_align_lr, weight_decay = translater_weight_decay)
recons_criterion = nn.__dict__[loss_type]()

if train_translater:
    try:
        for epoch in range(n_epoch):
            train_l_sum = 0.
            train_batch_count = 0
            test_l_sum = 0.
            test_batch_count = 0
            for x in align_trainloader:
                concept_1 = x[0].to(device)
                concept_1 += concept_eps * (torch.rand_like(concept_1).to(device) - 0.5)

                concept_2 = x[1].to(device)
                concept_2 += concept_eps * (torch.rand_like(concept_2).to(device) - 0.5)

                concept_1_recons = translater_2to1(concept_2)
                
                l = recons_criterion(concept_1_recons, concept_1)
                global_optim.zero_grad()
                l.backward()
                global_optim.step()

                train_l_sum += l.item()
                train_batch_count += 1
            
            if (epoch + 1) % print_frequency == 0:
                with torch.no_grad():
                    for x in align_testloader:
                        concept_1 = x[0].to(device)
                        # concept_1 += concept_eps * torch.rand_like(concept_1).to(device)
                        concept_2 = x[1].to(device)
                        # concept_2 += concept_eps * torch.rand_like(concept_2).to(device)
                        translater_2to1.eval()
                        concept_1_recons = translater_2to1(concept_2)
                        translater_2to1.train()
                        l = recons_criterion(concept_1_recons, concept_1)
                        test_l_sum += l.item()
                        test_batch_count += 1

                print(
                    f'epoch {epoch + 1},' 
                    f'align training loss {train_l_sum / train_batch_count:.4f},' 
                    f'align testing loss {test_l_sum / test_batch_count:.4f}'
                )
    
    except KeyboardInterrupt:
        pass


epoch 20,align training loss 0.9694,align testing loss 0.9920
epoch 40,align training loss 0.9267,align testing loss 0.9965
epoch 60,align training loss 0.8802,align testing loss 1.0079
epoch 80,align training loss 0.8341,align testing loss 1.0217
epoch 100,align training loss 0.7947,align testing loss 1.0332
epoch 120,align training loss 0.7602,align testing loss 1.0373
epoch 140,align training loss 0.7287,align testing loss 1.0483
epoch 160,align training loss 0.7010,align testing loss 1.0573
epoch 180,align training loss 0.6761,align testing loss 1.0666
epoch 200,align training loss 0.6542,align testing loss 1.0734
epoch 220,align training loss 0.6335,align testing loss 1.0808
epoch 240,align training loss 0.6158,align testing loss 1.0879
epoch 260,align training loss 0.5976,align testing loss 1.0951
epoch 280,align training loss 0.5811,align testing loss 1.0976
epoch 300,align training loss 0.5675,align testing loss 1.1042
epoch 320,align training loss 0.5528,align testing loss 1.1