# 1. INIT - Imports

In [1]:
import os
import math as mt
import numpy as np
import matplotlib.pyplot as plt
import random
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from pympler import asizeof
from visdom import Visdom
from copy import deepcopy
import time
import networkx as nx
#this is set for the printing of Q-matrices via console
torch.set_printoptions(precision=3, sci_mode=False, linewidth=100)

from tqdm.auto import tqdm, trange

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#suppress scientific notation in printouts
np.set_printoptions(suppress=True)

In [2]:
def set_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

# 2. INIT - Parameters

In [3]:
init_color_hex: str = "#2f78b6" # hex color of init point (left top)
goal_color_hex: str = '#f0b226' # hex color of goal point (right bottom)
wall_color: list[float] = [0.7, 0.7, 0.7] # RGB color of walls

n_cell: int = 4 # how many original cells
grid_dim = n_cell * 2 - 1 # side length of the square gridworld, in the form of 2n-1
n_actions: int = 4 # how many actions are possible in each state
lava: bool = False #do we use lava states - i.e. accessible wall states - (True) or wall states (False)? (we now always use "False")
starting_state: int = 0
goal_state = grid_dim ** 2 - 1

# generation of maze
maze_gen: bool = False # generate wall_states? wall_state_dict {0: [], 1: [1], 2: [2], 3: [3], ... } consists of 
n_mazes: int = 3000

# rewards in the gridworld (-0.1, 100, -5)
step_reward: float = -0.1 # for taking a step
goal_reward: float = 2. # for reaching the goal
wall_reward: float = -1. # for bumping into a wall

# CA-TS Net settings
input_neurons: int = 2 # for network init
output_neurons = n_actions # modeling the Q(s, a) value by the output neurons w.r.t. number of action
concept_size: int = 64 # the concept vector size of CA-TS DQN
using_concept_eps: bool = True
concept_eps: float = 0.1
concept_eps_str = f"_cpteps{concept_eps}" if using_concept_eps else ""
using_res: bool = True
res_str = "_res" if using_res else ""
hidden_dims: list = [768] * 16
dim_str = "-".join(str(d) for d in hidden_dims)
q_s2a: bool = False # whether using Q(s) -> a: True or Q(s, a): False
q_str = "" if q_s2a else "_sa2q"
n_agents: int = 5 # how many anegts used in communication games
n_unique: int = 20 # unique mazes learned by each agent
n_group: int = 7

# CA-TS Net training settings
batch_size: int = 512 # 0 indicates using all data in buffer as a batch
epsilon: float = 0.1 # greedy action policy
lr: float = 1e-4 # learning rate
gamma_bellman: float = 0.9999 # bellman equation
target_replace_steps: int = 0 # renew target_net by eval_net after how many iter times, 0 indicates directly using eval_net as target_net
memory_capacity: int = 0 # number transitions stored, 0 indicates pre-store all transitions in memory (change training mode as epoch manner)
cap_str = "" if memory_capacity != 0 else "_prestore"
n_episode: int = 100000

# AE Net settings
n_recons_test: int = 20
n_align_test: int = 20
ae_dims: list = [128, 192, 256, 192, 128]
ae_dim_str = "-".join(str(d) for d in ae_dims)

# AE Net training settings
ae_batch_size: int = 128
ae_weight_decay: float = 1e-5
ae_align_lr: float = 1e-4
align_weight: float = 1.0
using_ae_eps: bool = False
n_epoch: int = 10000
print_frequency: int = 100
train_ae: bool = True
load_ae: bool = False
loss_type: str = 'L1Loss' # options['MSELoss', 'L1Loss', 'SmoothL1Loss']


# 3. FUNCTIONS - CA-TS Net class

In [4]:
class Net_sa2q(nn.Module):
    def __init__(self):
        super(Net_sa2q, self).__init__()
        self.using_res = using_res

        self.ts_fc_layers = nn.ModuleDict()
        self.ts_norm_layers = nn.ModuleDict()
        if self.using_res:
            self.ts_skip_layers = nn.ModuleDict()

        prev_dim = input_neurons + n_actions # dim 0, 1 is xy coordinates, dim 2 to 5 is action from 0 to 3 (right, uo, left, down)
        for i, dim in enumerate(hidden_dims):
            self.ts_fc_layers[f'fc{i}'] = nn.Linear(prev_dim, dim, bias=True)
            if self.using_res and prev_dim != dim:
                self.ts_skip_layers[f'skip{i}'] = nn.Linear(prev_dim, dim, bias=False)
            self.ts_norm_layers[f'norm{i}'] = nn.LayerNorm(dim)
            prev_dim = dim
            

        self.ts_fc_layers[f'fc{len(hidden_dims)}'] = nn.Linear(prev_dim, 1, bias=True)

        self.cdp_fc_layers = nn.ModuleDict()
        self.cdp_norm_layers = nn.ModuleDict()
        if self.using_res:
            self.cdp_skip_layers = nn.ModuleDict()

        prev_dim = concept_size
        for i, dim in enumerate(hidden_dims):
            self.cdp_fc_layers[f'fc{i}'] = nn.Linear(prev_dim, dim, bias=True)
            if self.using_res and prev_dim != dim:
                self.cdp_skip_layers[f'skip{i}'] = nn.Linear(prev_dim, dim, bias=False)
            self.cdp_norm_layers[f'norm{i}'] = nn.LayerNorm(dim)
            prev_dim = dim

        self.ts_afun = nn.ReLU()
        self.cdp_afun = nn.Sigmoid()

        self.concept_embedding_layer = nn.Embedding(num_embeddings=n_mazes, embedding_dim=concept_size)

    def forward(self, x, concept_idx=None):
        if concept_idx is not None:
            concept = self.concept_embedding_layer(concept_idx)
            cdp_activations = []
            c = concept
            for i in range(len(self.cdp_fc_layers)):
                if self.using_res:
                    identity = c
                    out = self.cdp_fc_layers[f'fc{i}'](c)
                    out = self.cdp_norm_layers[f'norm{i}'](out)
                    if f'skip{i}' in self.cdp_skip_layers:
                        identity = self.cdp_skip_layers[f'skip{i}'](identity)
                    c = out + identity
                else:
                    c = self.cdp_fc_layers[f'fc{i}'](c)
                    c = self.cdp_norm_layers[f'norm{i}'](c)

                c = self.cdp_afun(c)
                cdp_activations.append(c)

        for i in range(len(self.ts_fc_layers) - 1):
            if self.using_res:
                identity = x
                out = self.ts_fc_layers[f'fc{i}'](x)
                out = self.ts_norm_layers[f'norm{i}'](out)
                if f'skip{i}' in self.ts_skip_layers:
                    identity = self.ts_skip_layers[f'skip{i}'](identity)
                x = out + identity
            else:
                x = self.ts_fc_layers[f'fc{i}'](x)
                x = self.ts_norm_layers[f'norm{i}'](x)
                
            x = self.ts_afun(x)

            if concept_idx is not None:
                x = torch.mul(x, cdp_activations[i])

        x = self.ts_fc_layers[f'fc{len(self.cdp_fc_layers)}'](x)

        return x

# 4. FUNCTIONS - Recons datasets

In [5]:
class ReconsDataset(Dataset):
    def __init__(self, concept_data):
        self.concept_data = concept_data
    
    def __len__(self):
        return len(self.concept_data)

    def __getitem__(self, idx):
        return self.concept_data[idx]

# 5. FUNCTIONS - Alignment datasets

In [6]:
class AlignDataset(Dataset):
    def __init__(self, concept_data_list):
        self.n_agents = len(concept_data_list)
        self.concept_data_list = concept_data_list
    
    def __len__(self):
        return len(self.concept_data_list[0])

    def __getitem__(self, idx):
        grouped_item = []
        for i in range(self.n_agents):
            grouped_item.append(self.concept_data_list[i][idx])
        return grouped_item

# 6. FUNCTIONS - Processing exclude mazes

In [7]:
def get_recons_indices(n_mazes, n_agents, i_agent, n_unique, n_recons_test):
    common_end = n_mazes - n_agents * n_unique
    indices_list1 = list(range(0, common_end))
    
    unique_start = common_end + (i_agent - 1) * n_unique
    unique_end = unique_start + n_unique
    indices_list2 = list(range(unique_start, unique_end))

    combined_list = indices_list1 + indices_list2
    test_indices = random.sample(combined_list, n_recons_test)
    train_indices = [item for item in combined_list if item not in test_indices]

    return train_indices, test_indices

In [8]:
def get_align_indices(n_mazes, n_agents, n_unique, n_align_test):
    common_end = n_mazes - n_agents * n_unique
    indices_list = list(range(0, common_end))

    test_indices = random.sample(indices_list, n_align_test)
    train_indices = [item for item in indices_list if item not in test_indices]

    return train_indices, test_indices

In [9]:
def get_unique_indices(n_mazes, n_agents, i_agent, n_unique):
    common_end = n_mazes - n_agents * n_unique
    unique_start = common_end + (i_agent - 1) * n_unique
    unique_end = unique_start + n_unique

    return list(range(unique_start, unique_end))

In [10]:
def get_common_indices(n_mazes, n_agents, n_unique):
    common_end = n_mazes - n_agents * n_unique

    return list(range(0, common_end))

# 7. FUNCTIONS - Autoencoder

In [11]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        mid_index = len(ae_dims) // 2
        # encoder from concept_size to bottleneck
        encoder_layers = []
        prev_dim = concept_size
        for i in range(mid_index + 1):  # include bottleneck
            encoder_layers.append(nn.Linear(prev_dim, ae_dims[i]))
            # add ReLU except bottleneck
            if i != mid_index:
                encoder_layers.append(nn.LayerNorm(ae_dims[i]))
                encoder_layers.append(nn.ReLU())
            prev_dim = ae_dims[i]
        self.encoder = nn.Sequential(*encoder_layers)
        
        # decoder from bottleneck to concept_size
        decoder_layers = []
        prev_dim = ae_dims[mid_index]  # latent dim
        for i in range(mid_index - 1, -1, -1):
            decoder_layers.append(nn.Linear(prev_dim, ae_dims[i]))
            decoder_layers.append(nn.LayerNorm(ae_dims[i]))
            decoder_layers.append(nn.ReLU())
            prev_dim = ae_dims[i]
        # last layer, no ReLU
        decoder_layers.append(nn.Linear(prev_dim, concept_size))
        self.decoder = nn.Sequential(*decoder_layers)

        self.t = nn.Parameter(torch.tensor(1.0))
    
    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed
    
    def encode(self, x):
        if using_ae_eps:
            x = x + concept_eps * torch.rand_like(x).to(device)
        return self.encoder(x)
    
    def decode(self, latent):
        return self.decoder(latent)

# 8. FUNCTIONS - Test cross-agent concept reconstruction

In [12]:
def test_alignment(agents_ae, align_testloader, recons_criterion, agent_i = None, agent_j = None):
    # Get a batch of aligned concepts from the test set
    test_batch = next(iter(align_testloader))
    
    if agent_i is None or agent_j is None:
        agent_indices = list(range(1, n_agents + 1))
        agent_i, agent_j = random.sample(agent_indices, 2)

    with torch.no_grad():
        # Get agent1's concept vectors
        agenti_concepts = test_batch[agent_i - 1].to(device)
        agentj_concepts = test_batch[agent_j - 1].to(device)

        agents_ae[agent_i].eval()
        agents_ae[agent_j].eval()
        # Encode using agent1's autoencoder
        latent_agenti = agents_ae[agent_i].encode(agenti_concepts)
        latent_agentj = agents_ae[agent_j].encode(agentj_concepts)

        # Decode using agent2's autoencoder
        recons_agentjbylatenti = agents_ae[agent_j].decode(latent_agenti)
        recons_agentjbylatentj = agents_ae[agent_j].decode(latent_agentj)
        recons_agentibylatentj = agents_ae[agent_i].decode(latent_agentj)
        recons_agentibylatenti = agents_ae[agent_i].decode(latent_agenti)

        agents_ae[agent_i].train()
        agents_ae[agent_j].train()

        # Calculate MSE between reconstructed and original
        recons_loss1 = recons_criterion(recons_agentibylatentj, agenti_concepts)
        recons_loss2 = recons_criterion(recons_agentjbylatentj, agentj_concepts)
        recons_loss3 = recons_criterion(recons_agentjbylatenti, agentj_concepts)
        recons_loss4 = recons_criterion(recons_agentibylatenti, agenti_concepts)

        latent_loss1 = recons_criterion(latent_agenti, latent_agentj)
        latent_loss2 = recons_criterion(latent_agenti[0:-2], latent_agentj[1:-1])

    print(f"Loss between same latents of agent{agent_i}'s and agent{agent_j}'s: {latent_loss1.item():.4f}")
    print(f"Loss between different latents of agent{agent_i}'s and agent{agent_j}'s: {latent_loss2.item():.4f}")
    print(f"Loss between agent{agent_i}'s original concepts and agent{agent_j}'s latent reconstructed in agent{agent_i}'s space: {recons_loss1.item():.4f}")
    print(f"Loss between agent{agent_j}'s original concepts and agent{agent_j}'s latent reconstructed in agent{agent_j}'s space: {recons_loss2.item():.4f}")
    print(f"Loss between agent{agent_j}'s original concepts and agent{agent_i}'s latent reconstructed in agent{agent_j}'s space: {recons_loss3.item():.4f}")
    print(f"Loss between agent{agent_i}'s original concepts and agent{agent_i}'s latent reconstructed in agent{agent_i}'s space: {recons_loss4.item():.4f}")


# 9. FUNCTIONS - Alignment contrast loss

In [13]:
def align_clip_loss(latent_i, latent_j, t = torch.tensor(0.0)):
    normalized_latent_i = torch.nn.functional.normalize(latent_i, p = 2, dim = 1)
    normalized_latent_j = torch.nn.functional.normalize(latent_j, p = 2, dim = 1)
    logits = torch.matmul(normalized_latent_i, normalized_latent_j.T) * torch.exp(t)
    labels = torch.arange(latent_i.size(0)).to(device)
    loss_i = torch.nn.functional.cross_entropy(logits, labels)
    loss_j = torch.nn.functional.cross_entropy(logits.T, labels)
    return (loss_i + loss_j) * 0.5

In [14]:
def align_clip_l1loss(latent_i, latent_j):
    batch_size = latent_i.size(0)
    hidden_dim = latent_i.size(1)
    
    distance_matrix = torch.cdist(latent_i, latent_j, p=1) / hidden_dim
    
    # Create mask for diagonal elements
    diag_mask = torch.eye(batch_size, dtype=torch.bool, device=device)
    
    # Calculate diagonal loss
    diag_loss = torch.sum(distance_matrix[diag_mask])
    
    # Calculate off-diagonal loss with margin
    off_diag_mask = ~diag_mask
    margin = torch.tensor(0.3, device=device)
    off_diag_loss = torch.sum(torch.clamp_min(margin - distance_matrix[off_diag_mask], 0))
    
    total_loss = diag_loss + off_diag_loss
    return total_loss / (batch_size * batch_size)

In [15]:
def align_balanced_l1loss(latent_i, latent_j):
    batch_size = latent_i.size(0)
    hidden_dim = latent_i.size(1)
    
    # Part 1: Direct L1 loss between latent_i and latent_j
    direct_loss = torch.nn.functional.l1_loss(latent_i, latent_j)
    
    # Part 2: Shift latent_j indices and compute L1 loss
    # Create shifted indices using modulo operation
    shifted_indices = (torch.arange(batch_size, device=device) + torch.randint(low=1, high=batch_size, size=(), device=device)) % batch_size
    latent_j_shifted = latent_j[shifted_indices]

    # Compute L1 loss between latent_i and shifted latent_j
    margin = torch.tensor(0.3, device=device)
    shifted_loss = torch.clamp_min(margin - torch.nn.functional.l1_loss(latent_i, latent_j_shifted), 0)
    
    # Combine both losses with equal weight
    total_loss = (direct_loss + shifted_loss) * 0.5
    
    return total_loss

# 10. EXECUTION - Constructing datasets

In [16]:
# set_seed(0)

In [None]:
agents = {}
for i in range(n_agents):
    i_agent = i + 1
    agents[i_agent] = Net_sa2q()
    ckpt_file: str = f"agent_ckpt/group{n_group}/ckpt{i_agent}of{n_agents}_unique{n_unique}_{grid_dim}x{grid_dim}_n{n_mazes}{q_str}{cap_str}_dim{dim_str}{res_str}_cptsz{concept_size}{concept_eps_str}_lr{lr}_epsi{n_episode}_gamma{gamma_bellman}_bs{batch_size}_tr{target_replace_steps}.pt"
    agents[i_agent].load_state_dict(torch.load(ckpt_file))

In [18]:
recons_trainsets = {}
recons_trainloaders = {}
recons_testsets = {}
recons_testloaders = {}

for i in range(n_agents):
    i_agent = i + 1
    recons_train_indices, recons_test_indices = get_recons_indices(n_mazes, n_agents, i_agent, n_unique, n_recons_test)
    recons_trainsets[i_agent] = ReconsDataset(agents[i_agent].concept_embedding_layer.weight.data[recons_train_indices])
    recons_testsets[i_agent] = ReconsDataset(agents[i_agent].concept_embedding_layer.weight.data[recons_test_indices])

    recons_trainloaders[i_agent] = DataLoader(recons_trainsets[i_agent], batch_size = ae_batch_size, shuffle = True)
    recons_testloaders[i_agent] = DataLoader(recons_testsets[i_agent], batch_size = ae_batch_size, shuffle = False)

align_train_indices, align_test_indices = get_align_indices(n_mazes, n_agents, n_unique, n_align_test)
align_train_concept = []
align_test_concept = []
for i in range(n_agents):
    i_agent = i + 1
    align_train_concept.append(agents[i_agent].concept_embedding_layer.weight.data[align_train_indices])
    align_test_concept.append(agents[i_agent].concept_embedding_layer.weight.data[align_test_indices])

align_trainset = AlignDataset(align_train_concept)
align_testset = AlignDataset(align_test_concept)

align_trainloader = DataLoader(align_trainset, batch_size = ae_batch_size, shuffle = True)
align_testloader = DataLoader(align_testset, batch_size = ae_batch_size, shuffle = False)


# 12. EXECUTION - Two training stages

## method 1: odd-numbered epochs, self-reconstruction; even-numbered epochs, alignment

In [19]:
# agents_ae = {}
# for i in range(n_agents):
#     i_agent = i + 1
#     agents_ae[i_agent] = Autoencoder().to(device)
#     if load_ae:
#         ae_ckpt_file: str = f"ae_ckpt/group{n_group}_dim{ae_dim_str}_ReconsTest{n_recons_test}_AlignTest{n_align_test}_bs{ae_batch_size}_AlignLr{ae_align_lr}_epo{n_epoch}_{loss_type}/ae{i_agent}of{n_agents}.pt"
#         agents_ae[i_agent].load_state_dict(torch.load(ae_ckpt_file))
#         print(f'Autoencoder of agent {i_agent} loaded!')

# global_optim = torch.optim.Adam([p for model in agents_ae.values() for p in model.parameters()], lr = ae_align_lr, weight_decay = ae_weight_decay)

# recons_criterion = nn.__dict__[loss_type]()

# if train_ae:
#     try:
#         for epoch in range(n_epoch):
#             if epoch % 2 == 0:
#                 for i in range(n_agents):
#                     i_agent = i + 1
#                     train_l_sum = 0.
#                     train_batch_count = 0
#                     test_l_sum = 0.
#                     test_batch_count = 0
#                     for x in recons_trainloaders[i_agent]:
#                         x = x.to(device)
#                         x_hat = agents_ae[i_agent](x)
#                         recons_loss = recons_criterion(x_hat, x)

#                         global_optim.zero_grad()
#                         recons_loss.backward()
#                         global_optim.step()

#                         train_l_sum += recons_loss.cpu().item()
#                         train_batch_count += 1
                    
#                     for x in recons_testloaders[i_agent]:
#                         x = x.to(device)
#                         with torch.no_grad():
#                             agents_ae[i_agent].eval()
#                             x_hat = agents_ae[i_agent](x)
#                             agents_ae[i_agent].train()
#                             recons_loss = recons_criterion(x_hat, x)

#                         test_l_sum += recons_loss.cpu().item()
#                         test_batch_count += 1
                    
#                     if (epoch) % print_frequency == 0:
#                         print(f'epoch {epoch + 1}, agent {i_agent}, recons training loss {train_l_sum / train_batch_count:.4f}, recons testing loss {test_l_sum / test_batch_count:.4f}')
                
#                 # pass

#             else:
#                 train_l_sum = 0.
#                 train_batch_count = 0
#                 test_l_sum = 0.
#                 test_l_sum_recons = 0.
#                 test_l_sum_align = 0.
#                 test_batch_count = 0
#                 for x in align_trainloader:
#                     latent = []
#                     for i in range(n_agents):
#                         i_agent = i + 1
#                         x_in = x[i].to(device)
#                         x_hat = agents_ae[i_agent].encode(x_in)
#                         latent.append(x_hat)

#                     # # method 1: mean of all latent vectors
#                     # stacked_latent = torch.stack(latent)  # shape (n_agents, batch_size, dim_latent)
#                     # mean_latent = stacked_latent.mean(dim = 0, keepdim = True).repeat(n_agents, 1, 1)
#                     # align_loss = F.mse_loss(stacked_latent, mean_latent)

#                     # method 2: align between each latent vector plus mutual reconstruction loss
#                     align_loss = 0.0
#                     for i in range(n_agents):
#                         for j in range(i+1, n_agents):
#                             align_loss += align_weight * recons_criterion(latent[i], latent[j])

#                     recon_loss = 0.0
#                     for i in range(n_agents):
#                         x_in = x[i].to(device)
#                         for j in range(n_agents):
#                             # if i != j:
#                                 x_tar = x[j].to(device)
#                                 j_agent = j + 1
#                                 recon_loss += recons_criterion(agents_ae[j_agent].decode(latent[i]), x_tar)

#                     align_loss = align_loss + recon_loss

#                     global_optim.zero_grad()
#                     align_loss.backward()
#                     global_optim.step()

#                     train_l_sum += align_loss.cpu().item()
#                     train_batch_count += 1

#                 with torch.no_grad():
#                     for x in align_testloader:
#                         latent = []
#                         for i in range(n_agents):
#                             i_agent = i + 1
#                             agents_ae[i_agent].eval()
#                             x_in = x[i].to(device)
#                             x_hat = agents_ae[i_agent].encode(x_in)
#                             latent.append(x_hat)
#                             agents_ae[i_agent].train()

#                         # # method 1: mean of all latent vectors
#                         # stacked_latent = torch.stack(latent)  # shape (n_agents, batch_size, dim_latent)
#                         # mean_latent = stacked_latent.mean(dim = 0, keepdim = True).repeat(n_agents, 1, 1)
#                         # align_loss = F.mse_loss(stacked_latent, mean_latent)

#                         # method 2: align between each latent vector plus mutual reconstruction loss
#                         align_loss = 0.0
#                         for i in range(n_agents):
#                             for j in range(i+1, n_agents):
#                                 align_loss += align_weight * recons_criterion(latent[i], latent[j])
#                         test_l_sum_align += align_loss.cpu().item()

#                         recon_loss = 0.0
#                         for i in range(n_agents):
#                             x_in = x[i].to(device)
#                             for j in range(n_agents):
#                                 # if i != j:
#                                     x_tar = x[j].to(device)
#                                     j_agent = j + 1
#                                     agents_ae[j_agent].eval()
#                                     recon_loss += recons_criterion(agents_ae[j_agent].decode(latent[i]), x_tar) # was set to x_in before
#                                     agents_ae[j_agent].train()
#                         test_l_sum_recons += recon_loss.cpu().item()
                        
#                         align_loss = align_loss + recon_loss

#                         test_l_sum += align_loss.cpu().item()
#                         test_batch_count += 1

#                 if (epoch + 1) % print_frequency == 0:
#                     print(
#                         f'epoch {epoch + 1},' 
#                         f'align training loss {train_l_sum / train_batch_count:.4f},' 
#                         f'align sub-recons testing loss {test_l_sum_recons / test_batch_count:.4f},'
#                         f'align sub-align testing loss {test_l_sum_align / test_batch_count:.4f},'
#                         f'align testing loss {test_l_sum / test_batch_count:.4f}'
#                     )
#                     print('On training set alignment...')
#                     test_alignment(agents_ae, align_trainloader, recons_criterion)
#                     print('On testing set alignment...')
#                     test_alignment(agents_ae, align_testloader, recons_criterion)
    
#     except KeyboardInterrupt:
#         pass

#     for i in range(n_agents):
#         i_agent = i + 1
#         ae_ckpt_file: str = f"ae_ckpt/group{n_group}_dim{ae_dim_str}_ReconsTest{n_recons_test}_AlignTest{n_align_test}_bs{ae_batch_size}_AlignLr{ae_align_lr}_epo{epoch+1}_{loss_type}/ae{i_agent}of{n_agents}.pt"
        
#         os.makedirs(os.path.dirname(ae_ckpt_file), exist_ok=True)

#         torch.save(agents_ae[i_agent].state_dict(), ae_ckpt_file)
#         print(f'Autoencoder of agent {i_agent} saved!')

## method 2: encoder alignment until convergence, then training decoder reconstruction using fixed encoder

In [20]:
# agents_ae = {}
# for i in range(n_agents):
#     i_agent = i + 1
#     agents_ae[i_agent] = Autoencoder().to(device)
#     if load_ae:
#         ae_ckpt_file: str = f"ae_ckpt/group{n_group}_dim{ae_dim_str}_ReconsTest{n_recons_test}_AlignTest{n_align_test}_bs{ae_batch_size}_AlignLr{ae_align_lr}_epo{n_epoch}_{loss_type}/ae{i_agent}of{n_agents}.pt"
#         agents_ae[i_agent].load_state_dict(torch.load(ae_ckpt_file))
#         print(f'Autoencoder of agent {i_agent} loaded!')

# global_optim = torch.optim.Adam([p for model in agents_ae.values() for p in model.parameters()], lr = ae_align_lr, weight_decay = ae_weight_decay)

# recons_criterion = nn.__dict__[loss_type]()

# if train_ae:
#     try:
#         for epoch in range(n_epoch):
#             if epoch < 1000:
#                 train_l_sum = 0.
#                 train_batch_count = 0
#                 test_l_sum = 0.
#                 test_l_sum_recons = 0.
#                 test_l_sum_align = 0.
#                 test_batch_count = 0
#                 for x in align_trainloader:
#                     latent = []
#                     for i in range(n_agents):
#                         i_agent = i + 1
#                         x_in = x[i].to(device)
#                         x_hat = agents_ae[i_agent].encode(x_in)
#                         latent.append(x_hat)

#                     # # method 1: mean of all latent vectors
#                     # stacked_latent = torch.stack(latent)  # shape (n_agents, batch_size, dim_latent)
#                     # mean_latent = stacked_latent.mean(dim = 0, keepdim = True).repeat(n_agents, 1, 1)
#                     # align_loss = F.mse_loss(stacked_latent, mean_latent)

#                     # method 2: align between each latent vector plus mutual reconstruction loss
#                     align_loss = 0.0
#                     for i in range(n_agents):
#                         for j in range(i+1, n_agents):
#                             align_loss += align_weight * recons_criterion(latent[i], latent[j])

#                     global_optim.zero_grad()
#                     align_loss.backward()
#                     global_optim.step()

#                     train_l_sum += align_loss.cpu().item()
#                     train_batch_count += 1

#                 if (epoch + 1) % print_frequency == 0:
#                     print(
#                         f'epoch {epoch + 1},' 
#                         f'align training loss {train_l_sum / train_batch_count:.4f},' 
#                     )
#                     print('On training set alignment...')
#                     test_alignment(agents_ae, align_trainloader, recons_criterion)
#                     print('On testing set alignment...')
#                     test_alignment(agents_ae, align_testloader, recons_criterion)

#             else:
#                 train_l_sum = 0.
#                 train_batch_count = 0
#                 test_l_sum = 0.
#                 test_l_sum_recons = 0.
#                 test_l_sum_align = 0.
#                 test_batch_count = 0
#                 for x in align_trainloader:
#                     latent = []
#                     with torch.no_grad():
#                         for i in range(n_agents):
#                             i_agent = i + 1
#                             x_in = x[i].to(device)
#                             x_hat = agents_ae[i_agent].encode(x_in)
#                             latent.append(x_hat)

#                     # # method 1: mean of all latent vectors
#                     # stacked_latent = torch.stack(latent)  # shape (n_agents, batch_size, dim_latent)
#                     # mean_latent = stacked_latent.mean(dim = 0, keepdim = True).repeat(n_agents, 1, 1)
#                     # align_loss = F.mse_loss(stacked_latent, mean_latent)

#                     # method 2: align between each latent vector plus mutual reconstruction loss
#                     align_loss = 0.0
#                     for i in range(n_agents):
#                         for j in range(i+1, n_agents):
#                             align_loss += align_weight * recons_criterion(latent[i], latent[j])

#                     recon_loss = 0.0
#                     for i in range(n_agents):
#                         x_in = x[i].to(device)
#                         for j in range(n_agents):
#                             # if i != j:
#                                 x_tar = x[j].to(device)
#                                 j_agent = j + 1
#                                 recon_loss += recons_criterion(agents_ae[j_agent].decode(latent[i]), x_tar)

#                     align_loss = align_loss + recon_loss

#                     global_optim.zero_grad()
#                     align_loss.backward()
#                     global_optim.step()

#                     train_l_sum += align_loss.cpu().item()
#                     train_batch_count += 1
                    
#                 if (epoch + 1) % print_frequency == 0:
#                     with torch.no_grad():
#                         for x in align_testloader:
#                             latent = []
#                             for i in range(n_agents):
#                                 i_agent = i + 1
#                                 agents_ae[i_agent].eval()
#                                 x_in = x[i].to(device)
#                                 x_hat = agents_ae[i_agent].encode(x_in)
#                                 latent.append(x_hat)
#                                 agents_ae[i_agent].train()

#                             # # method 1: mean of all latent vectors
#                             # stacked_latent = torch.stack(latent)  # shape (n_agents, batch_size, dim_latent)
#                             # mean_latent = stacked_latent.mean(dim = 0, keepdim = True).repeat(n_agents, 1, 1)
#                             # align_loss = F.mse_loss(stacked_latent, mean_latent)

#                             # method 2: align between each latent vector plus mutual reconstruction loss
#                             align_loss = 0.0
#                             for i in range(n_agents):
#                                 for j in range(i+1, n_agents):
#                                     align_loss += align_weight * recons_criterion(latent[i], latent[j])
#                             test_l_sum_align += align_loss.cpu().item()

#                             recon_loss = 0.0
#                             for i in range(n_agents):
#                                 x_in = x[i].to(device)
#                                 for j in range(n_agents):
#                                     # if i != j:
#                                         x_tar = x[j].to(device)
#                                         j_agent = j + 1
#                                         agents_ae[j_agent].eval()
#                                         recon_loss += recons_criterion(agents_ae[j_agent].decode(latent[i]), x_tar) # was set to x_in before
#                                         agents_ae[j_agent].train()
#                             test_l_sum_recons += recon_loss.cpu().item()
                            
#                             align_loss = align_loss + recon_loss

#                             test_l_sum += align_loss.cpu().item()
#                             test_batch_count += 1

#                     print(
#                         f'epoch {epoch + 1},' 
#                         f'align training loss {train_l_sum / train_batch_count:.4f},' 
#                         f'align sub-recons testing loss {test_l_sum_recons / test_batch_count:.4f},'
#                         f'align sub-align testing loss {test_l_sum_align / test_batch_count:.4f},'
#                         f'align testing loss {test_l_sum / test_batch_count:.4f}'
#                     )
#                     print('On training set alignment...')
#                     test_alignment(agents_ae, align_trainloader, recons_criterion)
#                     print('On testing set alignment...')
#                     test_alignment(agents_ae, align_testloader, recons_criterion)
    
#     except KeyboardInterrupt:
#         pass

#     for i in range(n_agents):
#         i_agent = i + 1
#         ae_ckpt_file: str = f"ae_ckpt/group{n_group}_dim{ae_dim_str}_ReconsTest{n_recons_test}_AlignTest{n_align_test}_bs{ae_batch_size}_AlignLr{ae_align_lr}_epo{epoch+1}_{loss_type}/ae{i_agent}of{n_agents}.pt"
        
#         os.makedirs(os.path.dirname(ae_ckpt_file), exist_ok=True)

#         torch.save(agents_ae[i_agent].state_dict(), ae_ckpt_file)
#         print(f'Autoencoder of agent {i_agent} saved!')

## method 3: contrastive loss

In [21]:
# agents_ae = {}
# for i in range(n_agents):
#     i_agent = i + 1
#     agents_ae[i_agent] = Autoencoder().to(device)
#     if load_ae:
#         ae_ckpt_file: str = f"ae_ckpt/group{n_group}_dim{ae_dim_str}_ReconsTest{n_recons_test}_AlignTest{n_align_test}_bs{ae_batch_size}_AlignLr{ae_align_lr}_epo{n_epoch}_{loss_type}/ae{i_agent}of{n_agents}.pt"
#         agents_ae[i_agent].load_state_dict(torch.load(ae_ckpt_file))
#         print(f'Autoencoder of agent {i_agent} loaded!')

# global_optim = torch.optim.Adam([p for model in agents_ae.values() for p in model.parameters()], lr = ae_align_lr, weight_decay = ae_weight_decay)

# recons_criterion = nn.__dict__[loss_type]()

# if train_ae:
#     try:
#         for epoch in range(n_epoch):
#             if epoch % 2 == 0:
#                 for i in range(n_agents):
#                     i_agent = i + 1
#                     train_l_sum = 0.
#                     train_batch_count = 0
#                     test_l_sum = 0.
#                     test_batch_count = 0
#                     for x in recons_trainloaders[i_agent]:
#                         x = x.to(device)
#                         x_hat = agents_ae[i_agent](x)
#                         recons_loss = recons_criterion(x_hat, x)

#                         global_optim.zero_grad()
#                         recons_loss.backward()
#                         global_optim.step()

#                         train_l_sum += recons_loss.cpu().item()
#                         train_batch_count += 1
                    
#                     for x in recons_testloaders[i_agent]:
#                         x = x.to(device)
#                         with torch.no_grad():
#                             agents_ae[i_agent].eval()
#                             x_hat = agents_ae[i_agent](x)
#                             agents_ae[i_agent].train()
#                             recons_loss = recons_criterion(x_hat, x)

#                         test_l_sum += recons_loss.cpu().item()
#                         test_batch_count += 1
                    
#                     if (epoch) % print_frequency == 0:
#                         print(f'epoch {epoch + 1}, agent {i_agent}, recons training loss {train_l_sum / train_batch_count:.4f}, recons testing loss {test_l_sum / test_batch_count:.4f}')
                
#                 if (epoch) % print_frequency == 0:
#                     print('On training set alignment...')
#                     test_alignment(agents_ae, align_trainloader, recons_criterion)
#                     print('On testing set alignment...')
#                     test_alignment(agents_ae, align_testloader, recons_criterion)

#                 pass

#             else:
#                 train_l_sum = 0.
#                 train_batch_count = 0
#                 test_l_sum = 0.
#                 test_l_sum_recons = 0.
#                 test_l_sum_align = 0.
#                 test_batch_count = 0
#                 for x in align_trainloader:
#                     latent = []
#                     for i in range(n_agents):
#                         i_agent = i + 1
#                         x_in = x[i].to(device)
#                         x_hat = agents_ae[i_agent].encode(x_in)
#                         latent.append(x_hat)

#                     # # method 1: mean of all latent vectors
#                     # stacked_latent = torch.stack(latent)  # shape (n_agents, batch_size, dim_latent)
#                     # mean_latent = stacked_latent.mean(dim = 0, keepdim = True).repeat(n_agents, 1, 1)
#                     # align_loss = F.mse_loss(stacked_latent, mean_latent)

#                     # method 2: align between each latent vector plus mutual reconstruction loss
#                     align_loss = 0.0
#                     for i in range(n_agents):
#                         for j in range(i+1, n_agents):
#                             align_loss += align_weight * align_balanced_l1loss(latent[i], latent[j])

#                     recon_loss = 0.0
#                     for i in range(n_agents):
#                         x_in = x[i].to(device)
#                         for j in range(n_agents):
#                             # if i != j:
#                                 x_tar = x[j].to(device)
#                                 j_agent = j + 1
#                                 recon_loss += recons_criterion(agents_ae[j_agent].decode(latent[i]), x_tar)

#                     align_loss = align_loss + recon_loss

#                     global_optim.zero_grad()
#                     align_loss.backward()
#                     global_optim.step()

#                     train_l_sum += align_loss.cpu().item()
#                     train_batch_count += 1

#                 with torch.no_grad():
#                     for x in align_testloader:
#                         latent = []
#                         for i in range(n_agents):
#                             i_agent = i + 1
#                             agents_ae[i_agent].eval()
#                             x_in = x[i].to(device)
#                             x_hat = agents_ae[i_agent].encode(x_in)
#                             latent.append(x_hat)
#                             agents_ae[i_agent].train()

#                         # # method 1: mean of all latent vectors
#                         # stacked_latent = torch.stack(latent)  # shape (n_agents, batch_size, dim_latent)
#                         # mean_latent = stacked_latent.mean(dim = 0, keepdim = True).repeat(n_agents, 1, 1)
#                         # align_loss = F.mse_loss(stacked_latent, mean_latent)

#                         # method 2: align between each latent vector plus mutual reconstruction loss
#                         align_loss = 0.0
#                         for i in range(n_agents):
#                             for j in range(i+1, n_agents):
#                                 align_loss += align_weight * align_balanced_l1loss(latent[i], latent[j])
#                         test_l_sum_align += align_loss.cpu().item()

#                         recon_loss = 0.0
#                         for i in range(n_agents):
#                             x_in = x[i].to(device)
#                             for j in range(n_agents):
#                                 # if i != j:
#                                     x_tar = x[j].to(device)
#                                     j_agent = j + 1
#                                     agents_ae[j_agent].eval()
#                                     recon_loss += recons_criterion(agents_ae[j_agent].decode(latent[i]), x_tar) # was set to x_in before
#                                     agents_ae[j_agent].train()
#                         test_l_sum_recons += recon_loss.cpu().item()
                        
#                         align_loss = align_loss + recon_loss

#                         test_l_sum += align_loss.cpu().item()
#                         test_batch_count += 1

#                 if (epoch + 1) % print_frequency == 0:
#                     print(
#                         f'epoch {epoch + 1},' 
#                         f'align training loss {train_l_sum / train_batch_count:.4f},' 
#                         f'align sub-recons testing loss {test_l_sum_recons / test_batch_count:.4f},'
#                         f'align sub-align testing loss {test_l_sum_align / test_batch_count:.4f},'
#                         f'align testing loss {test_l_sum / test_batch_count:.4f}'
#                     )
#                     print('On training set alignment...')
#                     test_alignment(agents_ae, align_trainloader, recons_criterion)
#                     print('On testing set alignment...')
#                     test_alignment(agents_ae, align_testloader, recons_criterion)
                
#                 pass

#     except KeyboardInterrupt:
#         pass

#     for i in range(n_agents):
#         i_agent = i + 1
#         ae_ckpt_file: str = f"ae_ckpt/group{n_group}_dim{ae_dim_str}_ReconsTest{n_recons_test}_AlignTest{n_align_test}_bs{ae_batch_size}_AlignLr{ae_align_lr}_epo{epoch+1}_{loss_type}/ae{i_agent}of{n_agents}.pt"
        
#         os.makedirs(os.path.dirname(ae_ckpt_file), exist_ok=True)

#         torch.save(agents_ae[i_agent].state_dict(), ae_ckpt_file)
#         print(f'Autoencoder of agent {i_agent} saved!')

## method 4: balanced align + balanced reconstruction

In [None]:
agents_ae = {}
for i in range(n_agents):
    i_agent = i + 1
    agents_ae[i_agent] = Autoencoder().to(device)
    if load_ae:
        ae_ckpt_file: str = f"ae_ckpt/group{n_group}_dim{ae_dim_str}_ReconsTest{n_recons_test}_AlignTest{n_align_test}_bs{ae_batch_size}_AlignLr{ae_align_lr}_epo{n_epoch}_{loss_type}/ae{i_agent}of{n_agents}.pt"
        agents_ae[i_agent].load_state_dict(torch.load(ae_ckpt_file))
        print(f'Autoencoder of agent {i_agent} loaded!')

global_optim = torch.optim.Adam([p for model in agents_ae.values() for p in model.parameters()], lr = ae_align_lr, weight_decay = ae_weight_decay)

recons_criterion = nn.__dict__[loss_type]()

if train_ae:
    try:
        for epoch in range(n_epoch):
            train_l_sum = 0.
            train_batch_count = 0
            test_l_sum = 0.
            test_l_sum_recons = 0.
            test_l_sum_align = 0.
            test_batch_count = 0
            for x in align_trainloader:
                latent = []
                for i in range(n_agents):
                    i_agent = i + 1
                    x_in = x[i].to(device)
                    x_hat = agents_ae[i_agent].encode(x_in)
                    latent.append(x_hat)

                latent_shifted = []
                shifted_num = torch.randint(1, n_agents, ()).item()
                for i in range(n_agents):
                    latent_shifted.append(latent[(i + shifted_num) % n_agents])

                # method 2: align between each latent vector plus mutual reconstruction loss
                align_loss = 0.0
                for i in range(n_agents):
                    align_loss += align_weight * align_balanced_l1loss(latent[i], latent_shifted[i])

                recon_loss = 0.0
                for i in range(n_agents):
                    i_agent = i + 1
                    x_in = x[i].to(device)
                    recon_loss += recons_criterion(
                        agents_ae[i_agent].decode(
                            torch.cat([latent[i], latent_shifted[i]], dim = 0)
                        ), 
                        torch.cat([x_in, x_in], dim = 0)
                    )

                align_loss = align_loss + recon_loss

                global_optim.zero_grad()
                align_loss.backward()
                global_optim.step()

                train_l_sum += align_loss.cpu().item()
                train_batch_count += 1

            with torch.no_grad():
                for x in align_testloader:
                    latent = []
                    for i in range(n_agents):
                        i_agent = i + 1
                        agents_ae[i_agent].eval()
                        x_in = x[i].to(device)
                        x_hat = agents_ae[i_agent].encode(x_in)
                        latent.append(x_hat)
                        agents_ae[i_agent].train()

                    latent_shifted = []
                    shifted_num = torch.randint(1, n_agents, ()).item()
                    for i in range(n_agents):
                        latent_shifted.append(latent[(i + shifted_num) % n_agents])

                    align_loss = 0.0
                    for i in range(n_agents):
                        align_loss += align_weight * align_balanced_l1loss(latent[i], latent_shifted[i])
                    test_l_sum_align += align_loss.cpu().item()

                    recon_loss = 0.0
                    for i in range(n_agents):
                        i_agent = i + 1
                        x_in = x[i].to(device)
                        agents_ae[i_agent].eval()
                        recon_loss += recons_criterion(
                            agents_ae[i_agent].decode(
                                torch.cat([latent[i], latent_shifted[i]], dim = 0)
                            ), 
                            torch.cat([x_in, x_in], dim = 0)
                        )
                        agents_ae[i_agent].train()

                    test_l_sum_recons += recon_loss.cpu().item()
                    
                    align_loss = align_loss + recon_loss

                    test_l_sum += align_loss.cpu().item()
                    test_batch_count += 1

            if (epoch + 1) % print_frequency == 0:
                print(
                    f'epoch {epoch + 1},' 
                    f'align training loss {train_l_sum / train_batch_count:.4f},' 
                    f'align sub-recons testing loss {test_l_sum_recons / test_batch_count:.4f},'
                    f'align sub-align testing loss {test_l_sum_align / test_batch_count:.4f},'
                    f'align testing loss {test_l_sum / test_batch_count:.4f}'
                )
                print('On training set alignment...')
                test_alignment(agents_ae, align_trainloader, recons_criterion)
                print('On testing set alignment...')
                test_alignment(agents_ae, align_testloader, recons_criterion)

    except KeyboardInterrupt:
        pass

    for i in range(n_agents):
        i_agent = i + 1
        ae_ckpt_file: str = f"ae_ckpt/group{n_group}_dim{ae_dim_str}_ReconsTest{n_recons_test}_AlignTest{n_align_test}_bs{ae_batch_size}_AlignLr{ae_align_lr}_epo{epoch+1}_{loss_type}/ae{i_agent}of{n_agents}.pt"
        
        os.makedirs(os.path.dirname(ae_ckpt_file), exist_ok=True)

        torch.save(agents_ae[i_agent].state_dict(), ae_ckpt_file)
        print(f'Autoencoder of agent {i_agent} saved!')

# 13. EXECUTION - Concept reconstruction

get unique concept of each agent

get latent representation of each agent's unique concept

get all recons concepts of each agent

In [23]:
unique_conceptsets = {}
unique_conceptloaders = {}

for i in range(n_agents):
    i_agent = i + 1
    unique_indices = get_unique_indices(n_mazes, n_agents, i_agent, n_unique)
    unique_conceptsets[i_agent] = ReconsDataset(agents[i_agent].concept_embedding_layer.weight.data[unique_indices])
    unique_conceptloaders[i_agent] = DataLoader(unique_conceptsets[i_agent], batch_size = ae_batch_size, shuffle = False)

latent_unique = {}

for i in range(n_agents):
    i_agent = i + 1
    latent_i = None
    for x in unique_conceptloaders[i_agent]:
        x = x.to(device)
        latent_hat = agents_ae[i_agent].encode(x)
        if latent_i is None:
            latent_i = latent_hat
        else:
            latent_i = torch.cat([latent_i, latent_hat], dim = 0)
    
    latent_unique[i_agent] = latent_i

common_indices = get_common_indices(n_mazes, n_agents, n_unique)
latent_common = {}
recons_concepts = {}

for i in range(n_agents):
    latent_common_i = None
    recons_concept_i = None
    i_agent = i + 1
    common_conceptset = ReconsDataset(agents[i_agent].concept_embedding_layer.weight.data[common_indices])
    common_conceptloaders = DataLoader(common_conceptset, batch_size = ae_batch_size, shuffle = False)
    for x in common_conceptloaders:
        x = x.to(device)
        latent_hat = agents_ae[i_agent].encode(x)
        x_hat = agents_ae[i_agent].decode(latent_hat)
        
        if latent_common_i is None:
            latent_common_i = latent_hat
        else:
            latent_common_i = torch.cat([latent_common_i, latent_hat], dim = 0)

        if recons_concept_i is None:
            recons_concept_i = x_hat
        else:
            recons_concept_i = torch.cat([recons_concept_i, x_hat], dim = 0)
    
    # cat is wrong, should be concated by indices
    for j in range(n_agents):
        j_agent = j + 1
        latent_j = agents_ae[i_agent].decode(latent_unique[j_agent])
        recons_concept_i = torch.cat([recons_concept_i, latent_j], dim = 0)
    
    latent_common[i_agent] = latent_common_i.detach().cpu()
    recons_concepts[i_agent] = recons_concept_i.detach().cpu()

epoch = n_epoch - 1 if 'epoch' not in locals() or epoch is None else epoch
torch.save(recons_concepts, f"recons_concepts/concepts_group{n_group}_dim{ae_dim_str}_ReconsTest{n_recons_test}_AlignTest{n_align_test}_bs{ae_batch_size}_AlignLr{ae_align_lr}_epo{epoch+1}_{loss_type}.pt")

# 14. TEST - Cross-agent concept reconstruction with aligndata