In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
from os.path import join, exists
from os import mkdir
import time
import os
import torch
import torch.utils.data
from torch import optim
from torch.nn import functional as F
from torchvision import transforms
from torch.utils.data import ConcatDataset
from omegaconf import OmegaConf
from dotenv import load_dotenv



from MAWM.core import get_cls, PRIMITIVE_TEMPLATES

from MAWM.data.utils import lejepa_train_tf, lejepa_test_tf, BufferAwareConcatDataset
from MAWM.data.loaders import RolloutObservationDataset

from MAWM.trainers.trainer_progjepa import ProgLejepaTrainer
from MAWM.writers.wandb_writer import WandbWriter
from MAWM.program.creator import create_specs_from_image, batchify_programs


{'CellEmpty': 0, 'CellObstacle': 1, 'CellItem': 2, 'CellGoal': 3, 'CellAgent': 4, 'GoalAt': 5, 'ItemAt': 6, 'Near': 7, 'SeeGoal': 8, 'CanMove': 9, 'OtherAgentAt': 10, 'OtherAgentNear': 11, 'OtherAgentDirection': 12}


In [3]:
try:
    cfg = OmegaConf.load("../cfgs/prog_lejepa/marlrid_cfg.yaml")
except:
    print("Invalid config file path")


In [4]:
cfg.data.data_dir = '../marl_grid_data'

In [5]:
num_primitives = len(PRIMITIVE_TEMPLATES)
grid_size = 7
void_transform = transforms.Lambda(lambda x: x)

def init_data(agent_id):
    train_ds = RolloutObservationDataset(agent= agent_id,
                                        root= cfg.data.data_dir,
                                        transform= lejepa_train_tf,
                                        buffer_size= cfg.data.buffer_size,
                                        train=True,
                                        obs_key= cfg.data.obs_key
                                        )
        
    test_ds = RolloutObservationDataset(agent= agent_id,
                                        root= cfg.data.data_dir,
                                        transform= lejepa_test_tf,
                                        buffer_size= cfg.data.buffer_size,
                                        train=False,
                                        obs_key= cfg.data.obs_key
                                        )
    
    return train_ds, test_ds

In [6]:
def init_models():
    model_cls = get_cls("MAWM.models.encoder", "ResNet18")
    v_encoder = model_cls()

    model_cls = get_cls("MAWM.models.program_encoder", "ProgramEncoder")
    p_encoder = model_cls(num_primitives, [grid_size, grid_size], 2, seq_len=49)
    return v_encoder, p_encoder

def init_opt(model):
    optimizer_cls = get_cls("torch.optim", cfg.optimizer.name)
    optimizer = optimizer_cls(model.parameters(), lr=cfg.optimizer.lr)
    return optimizer

In [7]:
dss_train, dss_test = [], []
for agent in cfg.env.agents:
    train_ds, test_ds = init_data(agent)
    dss_train.append(train_ds)
    dss_test.append(test_ds)


dataset_train = BufferAwareConcatDataset(datasets= dss_train)
dataset_test = BufferAwareConcatDataset(datasets= dss_test)

# train_loader = torch.utils.data.DataLoader(
#     dataset_train, batch_size=cfg.data.batch_size, shuffle=True, drop_last=True, num_workers=0)

val_loader = torch.utils.data.DataLoader(
        dataset_test, batch_size=cfg.data.batch_size, shuffle=False, num_workers=0)

Loading file buffer ...: 100%|██████████| 5/5 
Loading file buffer ...: 100%|██████████| 5/5 
Loading file buffer ...: 100%|██████████| 5/5 
Loading file buffer ...: 100%|██████████| 5/5 


In [8]:
v_encoder, p_encoder = init_models()

In [9]:
v_encoder.eval()
p_encoder.eval()

for p in v_encoder.parameters():
    p.requires_grad = False
for p in p_encoder.parameters():
    p.requires_grad = False


In [11]:
def denormalize(tensor):
    return tensor * 0.5 + 0.5

In [12]:
import torch.nn as nn
prog_dim = embed_dim = 256
model = nn.Linear(prog_dim, embed_dim, bias=False)

In [13]:
import torch
tau = 0.1
device = "cpu"
optimizer = init_opt(model)

In [14]:
losses = 0
for epoch in range(200)
for data in val_loader:
    obs, dones, agent_id = data
    mask = ~dones.bool()     # keep only where done is False

    if mask.sum() == 0:
        continue  # entire batch is terminals

    obs = obs[mask]          # filter observations
    programs = [create_specs_from_image(denormalize(img).permute(1, 2, 0).numpy()) for img in obs]
    batch_prim_ids, batch_param_tensor = batchify_programs(programs)

    batch_prim_ids = batch_prim_ids.to(device)
    batch_param_tensor = batch_param_tensor.to(device)
    obs = obs.to(device)

    optimizer.zero_grad()

    img_proj = v_encoder(obs)
    prog_proj = p_encoder(batch_prim_ids, batch_param_tensor)

    z_prog = model(prog_proj)

    z_prog = F.normalize(z_prog, dim=-1)
    z_img  = F.normalize(img_proj, dim=-1)

    logits = z_prog @ z_img.T                  # [B, B]
    B = len(obs)
    labels = torch.arange(B, device=logits.device)

    loss = F.cross_entropy(logits / tau, labels)
    losses += loss


KeyboardInterrupt: 

In [15]:
losses

tensor(103.4109, grad_fn=<AddBackward0>)