In [1]:
import os
import wandb
import argparse
import numpy as np
import yaml
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, ConcatDataset
from torch.optim import Adam, AdamW
from torchvision import transforms
import math
import torch.nn.functional as F

import torch.backends.cudnn as cudnn
from gnm_train.training.logger import Logger
from gnm_train.models.gnm import GNM
from gnm_train.models.gnm_modified import GNMModified

from gnm_train.models.siamese import SiameseModel
from gnm_train.models.stacked import StackedModel
from gnm_train.data.gnm_dataset import GNM_Dataset
from gnm_train.data.gnm_dataset_modified import GNM_Dataset_Modified

from gnm_train.data.pairwise_distance_dataset import PairwiseDistanceDataset
from gnm_train.training.train_utils import (
    train_eval_loop,
    load_model,
    get_saved_optimizer,
)

with open("config/gnm/gnm_carla.yaml", "r") as f:
    default_config = yaml.safe_load(f)

config = default_config

config["run_name"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")+"_distance_two_datasets"
config["project_folder"] = os.path.join(
    "logs", config["project_name"], config["run_name"]
)
os.makedirs(
    config[
        "project_folder"
    ],  # should error if dir already exists to avoid overwriting and old project
)

if config["use_wandb"]:
    wandb.login()
    wandb.init(
        project=config["project_name"], settings=wandb.Settings(start_method="fork")
    )
    wandb.run.name = config["run_name"]
    # update the wandb args with the training configurations
    if wandb.run:
        wandb.config.update(config)

print(config)

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkojogyaase[0m. Use [1m`wandb login --relogin`[0m to force relogin


{'project_name': 'gnm', 'run_name': 'gnm_public_2024_02_06_09_26_23_distance_two_datasets', 'use_wandb': True, 'train': True, 'batch_size': 10, 'eval_batch_size': 5, 'epochs': 50, 'gpu_ids': [0], 'num_workers': 4, 'lr': '1e-5', 'optimizer': 'adam', 'seed': 0, 'model_type': 'gnm-modified', 'obs_encoding_size': 1024, 'goal_encoding_size': 1024, 'normalize': True, 'context_type': 'temporal', 'context_size': 5, 'alpha': 0.6, 'distance': {'min_dist_cat': 0, 'max_dist_cat': 5}, 'action': {'min_dist_cat': 2, 'max_dist_cat': 5}, 'close_far_threshold': 5, 'len_traj_pred': 5, 'learn_angle': True, 'image_size': [85, 64], 'datasets': {'carla': {'data_folder': '../carla_trajectories', 'train': 'gnm_train/data/data_splits/carla_trajectories/train/', 'test': 'gnm_train/data/data_splits/carla_trajectories/test/', 'end_slack': 0, 'goals_per_obs': 2, 'negative_mining': True}, 'go_stanford': {'data_folder': '../go_stanford', 'train': 'gnm_train/data/data_splits/go_stanford/train/', 'test': 'gnm_train/dat

In [2]:

assert config["distance"]["min_dist_cat"] < config["distance"]["max_dist_cat"]
assert config["action"]["min_dist_cat"] < config["action"]["max_dist_cat"]

if torch.cuda.is_available():
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    if "gpu_ids" not in config:
        config["gpu_ids"] = [0]
    elif type(config["gpu_ids"]) == int:
        config["gpu_ids"] = [config["gpu_ids"]]
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
        [str(x) for x in config["gpu_ids"]]
    )
    print("Using cuda devices:", os.environ["CUDA_VISIBLE_DEVICES"])
else:
    print("Using cpu")

first_gpu_id = config["gpu_ids"][0]
device = torch.device(
    f"cuda:{first_gpu_id}" if torch.cuda.is_available() else "cpu"
)

if "seed" in config:
    np.random.seed(config["seed"])
    torch.manual_seed(config["seed"])
    cudnn.deterministic = True

cudnn.benchmark = True  # good if input sizes don't vary
transform = [
    transforms.ToTensor(),
    transforms.Resize(
        (config["image_size"][1], config["image_size"][0])
    ),  # torch does (h, w)
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
transform = transforms.Compose(transform)
aspect_ratio = config["image_size"][0] / config["image_size"][1]

# Load the data
train_dist_dataset = []
# train_action_dataset = []
test_dataloaders = {}


Using cuda devices: 0


In [3]:
from gnm_train.training.logger import Logger

if "context_type" not in config:
    config["context_type"] = "temporal"

for dataset_name in config["datasets"]:
    data_config = config["datasets"][dataset_name]
    if "negative_mining" not in data_config:
        data_config["negative_mining"] = True
    if "goals_per_obs" not in data_config:
        data_config["goals_per_obs"] = 1
    if "end_slack" not in data_config:
        data_config["end_slack"] = 0
    if "waypoint_spacing" not in data_config:
        data_config["waypoint_spacing"] = 1
        
    for data_split_type in ["train", "test"]:
        if data_split_type in data_config:
            for output_type in ["distance"]:
                if output_type == "pairwise":
                    dataset = PairwiseDistanceDataset(
                        data_folder=data_config["data_folder"],
                        data_split_folder=data_config[data_split_type],
                        dataset_name=dataset_name,
                        transform=transform,
                        aspect_ratio=aspect_ratio,
                        waypoint_spacing=data_config["waypoint_spacing"],
                        min_dist_cat=config["distance"]["min_dist_cat"],
                        max_dist_cat=config["distance"]["max_dist_cat"],
                        close_far_threshold=config["close_far_threshold"],
                        negative_mining=data_config["negative_mining"],
                        context_size=config["context_size"],
                        context_type=config["context_type"],
                        end_slack=data_config["end_slack"],
                    )
                else:
                    dataset = GNM_Dataset_Modified(
                        data_folder=data_config["data_folder"],
                        data_split_folder=data_config[data_split_type],
                        dataset_name=dataset_name,
                        is_action=(output_type == "action"),
                        transform=transform,
                        aspect_ratio=aspect_ratio,
                        waypoint_spacing=data_config["waypoint_spacing"],
                        min_dist_cat=config[output_type]["min_dist_cat"],
                        max_dist_cat=config[output_type]["max_dist_cat"],
                        negative_mining=data_config["negative_mining"],
                        len_traj_pred=config["len_traj_pred"],
                        learn_angle=config["learn_angle"],
                        context_size=config["context_size"],
                        context_type=config["context_type"],
                        end_slack=data_config["end_slack"],
                        goals_per_obs=data_config["goals_per_obs"],
                        normalize=config["normalize"],
                    )
                if data_split_type == "train":
                    if output_type == "distance":
                        train_dist_dataset.append(dataset)
                        print(
                            f"Loaded {len(dataset)} {dataset_name} training points"
                        )
                    elif output_type == "action":
                        train_action_dataset.append(dataset)
                else:
                    dataset_type = f"{dataset_name}_{data_split_type}"
                    if dataset_type not in test_dataloaders:
                        test_dataloaders[dataset_type] = {}
                    test_dataloaders[dataset_type][output_type] = dataset

# combine all the datasets from different robots
train_dist_dataset = ConcatDataset(train_dist_dataset)
# train_action_dataset = ConcatDataset(train_action_dataset)

train_dist_loader = DataLoader(
    train_dist_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=config["num_workers"],
    drop_last=True,
)
# train_action_loader = DataLoader(
#     train_action_dataset,
#     batch_size=config["batch_size"],
#     shuffle=True,
#     num_workers=config["num_workers"],
#     drop_last=True,
# )

if "eval_batch_size" not in config:
    config["eval_batch_size"] = config["batch_size"]

for dataset_type in test_dataloaders:
    for loader_type in test_dataloaders[dataset_type]:
        test_dataloaders[dataset_type][loader_type] = DataLoader(
            test_dataloaders[dataset_type][loader_type],
            batch_size=config["eval_batch_size"],
            shuffle=True,
            num_workers=config["num_workers"],
            drop_last=True,
        )

# Create the model
# if config["model_type"] == "gnm":
model = GNMModified(
    config["context_size"],
    config["len_traj_pred"],
    config["learn_angle"],
    config["obs_encoding_size"],
    config["goal_encoding_size"],
)
# elif config["model"] == "siamese":
#     model = SiameseModel(
#         config["context_size"],
#         config["len_traj_pred"],
#         config["learn_angle"],
#         config["obs_encoding_size"],
#         config["goal_encoding_size"],
#     )
# elif config["model"] == "stacked":
#     model = StackedModel(
#         config["context_size"],
#         config["len_traj_pred"],
#         config["learn_angle"],
#         config["obsgoal_encoding_size"],
#     )
# else:
#     raise ValueError(f"Model {config['model']} not supported")


if len(config["gpu_ids"]) > 1:
    model = nn.DataParallel(model, device_ids=config["gpu_ids"])
model = model.to(device)
lr = float(config["lr"])

config["optimizer"] = config["optimizer"].lower()
if config["optimizer"] == "adam":
    optimizer = Adam(model.parameters(), lr=lr)
elif config["optimizer"] == "adamw":
    optimizer = AdamW(model.parameters(), lr=lr)
elif config["optimizer"] == "sgd":
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
else:
    raise ValueError(f"Optimizer {config['optimizer']} not supported")

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',eps=1e-12)

current_epoch = 0
# if "load_run" in config:
#     load_project_folder = os.path.join("logs", config["load_run"])
#     print("Loading model from ", load_project_folder)
#     latest_path = os.path.join(load_project_folder, "latest.pth")
#     latest_checkpoint = torch.load(latest_path, map_location=device)
#     load_model(model, latest_checkpoint)
#     optimizer = get_saved_optimizer(latest_checkpoint, device)
#     current_epoch = latest_checkpoint["epoch"] + 1

torch.autograd.set_detect_anomaly(True)






Loaded 668 carla training points
Loaded 286756 go_stanford training points


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7fcdc76200a0>

In [4]:

def pairwise_acc(
    model: nn.Module,
    eval_loader: DataLoader,
    device: torch.device,
    save_folder: str,
    epoch: int,
    eval_type: str,
    print_log_freq: int = 100,
    image_log_freq: int = 1000,
    num_images_log: int = 8,
    use_wandb: bool = True,
    display: bool = False,
):
    """
    Evaluate the model on the pairwise distance accuracy metric. Given 1 observation and 2 subgoals, the model should determine which goal is closer.

    Args:
        model (nn.Module): The model to evaluate.
        eval_loader (DataLoader): The dataloader for the evaluation dataset.
        device (torch.device): The device to use for evaluation.
        save_folder (str): The folder to save the evaluation results.
        epoch (int): The current epoch.
        eval_type (str): The type of evaluation. Can be "train" or "val".
        print_log_freq (int, optional): The frequency at which to print the evaluation results. Defaults to 100.
        image_log_freq (int, optional): The frequency at which to log the evaluation results. Defaults to 1000.
        num_images_log (int, optional): The number of images to log. Defaults to 32.
        use_wandb (bool, optional): Whether to use wandb for logging. Defaults to True.
        display (bool, optional): Whether to display the evaluation results. Defaults to False.
    """
    correct_list = []
    model.eval()
    num_batches = len(eval_loader)

    with torch.no_grad():
        for i, vals in enumerate(eval_loader):
            
            (
                obs_image,
                close_image,
                far_image,
                transf_obs_image,
                transf_close_image,
                transf_far_image,
                close_dist_label,
                far_dist_label,
            ) = vals
            transf_obs_image = transf_obs_image.to(device)
            transf_close_image = transf_close_image.to(device)
            transf_far_image = transf_far_image.to(device)

            close_pred, _ = model(transf_obs_image, transf_close_image)
            far_pred, _ = model(transf_obs_image, transf_far_image)

            close_pred_flat = close_pred.reshape(close_pred.shape[0])
            far_pred_flat = far_pred.reshape(far_pred.shape[0])

            close_pred_flat = to_numpy(close_pred_flat)
            far_pred_flat = to_numpy(far_pred_flat)

            correct = np.where(far_pred_flat > close_pred_flat, 1, 0)
            correct_list.append(correct)
            if i % print_log_freq == 0:
                print(f"({i}/{num_batches}) batch of points processed")

            if i % image_log_freq == 0:
                visualize_dist_pairwise_pred(
                    to_numpy(obs_image),
                    to_numpy(close_image),
                    to_numpy(far_image),
                    to_numpy(close_pred),
                    to_numpy(far_pred),
                    to_numpy(close_dist_label),
                    to_numpy(far_dist_label),
                    eval_type,
                    save_folder,
                    epoch,
                    num_images_log,
                    use_wandb,
                    display,
                )
        if len(correct_list) == 0:
            return 0
        return np.concatenate(correct_list).mean()


def get_total_loss(dist_loss, action_loss, alpha):
    """Get total loss from distance and action loss."""
    # return alpha * (1e-2 * dist_loss) + (1 - alpha) * action_loss
    return dist_loss


def load_model(model, checkpoint: dict) -> None:
    """Load model from checkpoint."""
    loaded_model = checkpoint["model"]
    try:  # for DataParallel
        state_dict = loaded_model.module.state_dict()
        model.load_state_dict(state_dict)
    except (RuntimeError, AttributeError) as e:
        state_dict = loaded_model.state_dict()
        model.load_state_dict(state_dict)


def get_saved_optimizer(
    checkpoint: dict, device: torch.device
) -> torch.optim.Optimizer:
    optimizer = checkpoint["optimizer"]
    optimizer_to(optimizer, device)
    return optimizer


def optimizer_to(optim, device):
    """Move optimizer state to device."""
    for param in optim.state.values():
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)


In [5]:
def evaluate(
    eval_type: str,
    model: nn.Module,
    eval_dist_loader: DataLoader,
    # eval_action_loader: DataLoader,
    device: torch.device,
    project_folder: str,
    normalized: bool,
    epoch: int = 0,
    alpha: float = 0.5,
    learn_angle: bool = True,
    print_log_freq: int = 100,
    image_log_freq: int = 1000,
    num_images_log: int = 8,
    use_wandb: bool = True,
):
    """
    Evaluate the model on the given evaluation dataset.

    Args:
        eval_type (string): f"{data_type}_{eval_type}" (e.g. "recon_train", "gs_test", etc.)
        model (nn.Module): model to evaluate
        eval_dist_loader (DataLoader): dataloader for distance prediction
        eval_action_loader (DataLoader): dataloader for action prediction
        device (torch.device): device to use for evaluation
        project_folder (string): path to project folder
        epoch (int): current epoch
        alpha (float): weight for action loss
        learn_angle (bool): whether to learn the angle of the action
        print_log_freq (int): frequency of printing loss
        image_log_freq (int): frequency of logging images
        num_images_log (int): number of images to log
        use_wandb (bool): whether to use wandb for logging
    """
    model.eval()
    dist_loss_logger = Logger("dist_loss", eval_type, window_size=print_log_freq)
    action_loss_logger = Logger("action_loss", eval_type, window_size=print_log_freq)
    action_waypts_cos_sim_logger = Logger(
        "action_waypts_cos_sim", eval_type, window_size=print_log_freq
    )
    multi_action_waypts_cos_sim_logger = Logger(
        "multi_action_waypts_cos_sim", eval_type, window_size=print_log_freq
    )
    total_loss_logger = Logger(
        "total_loss_logger", eval_type, window_size=print_log_freq
    )

    variables = [
        dist_loss_logger,
        action_loss_logger,
        action_waypts_cos_sim_logger,
        multi_action_waypts_cos_sim_logger,
        total_loss_logger,
    ]
    if learn_angle:
        action_orien_cos_sim_logger = Logger(
            "action_orien_cos_sim", eval_type, window_size=print_log_freq
        )
        multi_action_orien_cos_sim_logger = Logger(
            "multi_action_orien_cos_sim", eval_type, window_size=print_log_freq
        )
        variables.extend(
            [action_orien_cos_sim_logger, multi_action_orien_cos_sim_logger]
        )

    num_batches = len(eval_dist_loader)

    with torch.no_grad():
        for i, val in enumerate(eval_dist_loader):
              # for i ,val in zip(range(4),train_dist_dataset):
        #   if i < 1:
            dist_vals = val
            (
                dist_obs_image,
                dist_goal_image,
                dist_trans_obs_image,
                dist_trans_goal_image,
                dist_label,
                dist_dataset_index,
            ) = dist_vals
            # (
            #     action_obs_image,
            #     action_goal_image,
            #     action_trans_obs_image,
            #     action_trans_goal_image,
            #     action_goal_pos,
            #     action_label,
            #     action_dataset_index,
            # ) = action_vals
            dist_obs_data = dist_trans_obs_image.to(device)
            dist_goal_data = dist_trans_goal_image.to(device)
            dist_label = dist_label.to(device)

            dist_pred, _ = model(dist_obs_data, dist_goal_data)

            dist_loss = F.binary_cross_entropy(dist_pred.unsqueeze(dim=0), dist_label.unsqueeze(dim=0))
            # print("dist",dist_loss,dist_pred.unsqueeze(dim=0),dist_label.unsqueeze(dim=0))
            # print("dist",dist_pred,dist_label,dist_loss)
            assert not math.isnan(dist_loss)
            # action_obs_data = action_trans_obs_image.to(device)
            # action_goal_data = action_trans_goal_image.to(device)
            # action_label = action_label.to(device)

            # _, action_pred = model(action_obs_data, action_goal_data)
            # action_loss = F.mse_loss(action_pred, action_label)
            # action_waypts_cos_sim = F.cosine_similarity(
            #     action_pred[:, :, :2], action_label[:, :, :2], dim=-1
            # ).mean()
            # multi_action_waypts_cos_sim = F.cosine_similarity(
            #     torch.flatten(action_pred[:, :, :2], start_dim=1),
            #     torch.flatten(action_label[:, :, :2], start_dim=1),
            #     dim=-1,
            # ).mean()
            # if learn_angle:
            #     action_orien_cos_sim = F.cosine_similarity(
            #         action_pred[:, :, 2:], action_label[:, :, 2:], dim=-1
            #     ).mean()
            #     multi_action_orien_cos_sim = F.cosine_similarity(
            #         torch.flatten(action_pred[:, :, 2:], start_dim=1),
            #         torch.flatten(action_label[:, :, 2:], start_dim=1),
            #         dim=-1,
            #     ).mean()
            #     action_orien_cos_sim_logger.log_data(action_orien_cos_sim.item())
            #     multi_action_orien_cos_sim_logger.log_data(
            #         multi_action_orien_cos_sim.item()
            #     )
            total_loss =  dist_loss
            # print("loss",total_loss)
            # scheduler.step(dist_loss)


            dist_loss_logger.log_data(dist_loss.item())
            # action_loss_logger.log_data(action_loss.item())
            # action_waypts_cos_sim_logger.log_data(action_waypts_cos_sim.item())
            # multi_action_waypts_cos_sim_logger.log_data(
            #     multi_action_waypts_cos_sim.item()
            # )

            total_loss_logger.log_data(total_loss.item())

            if i % print_log_freq == 0:
                log_display = f"(epoch {epoch}) (batch {i}/{num_batches - 1}) "
                for var in variables:
                    print(log_display + var.display())
                print()

            if i % image_log_freq == 0:
                visualize_dist_pred(
                    to_numpy(dist_obs_image),
                    to_numpy(dist_goal_image),
                    to_numpy(dist_pred),
                    to_numpy(dist_label),
                    eval_type,
                    project_folder,
                    epoch,
                    num_images_log,
                    use_wandb=use_wandb,
                )
                # visualize_traj_pred(
                #     to_numpy(action_obs_image),
                #     to_numpy(action_goal_image),
                #     to_numpy(action_dataset_index),
                #     to_numpy(action_goal_pos),
                #     to_numpy(action_pred),
                #     to_numpy(action_label),
                #     eval_type,
                #     normalized,
                #     project_folder,
                #     epoch,
                #     num_images_log,
                #     use_wandb=use_wandb,
                # )
    data_log = {}
    for var in variables:
        log_display = f"(epoch {epoch}) "
        data_log[var.full_name()] = var.average()
        print(log_display + var.display())
    # print()
    if use_wandb:
        wandb.log(data_log)
    return dist_loss_logger.average(), 0

In [6]:

def train(
    model: nn.Module,
    optimizer: Adam,
    train_dist_loader: DataLoader,
    # train_action_loader: DataLoader,
    device: torch.device,
    project_folder: str,
    normalized: bool,
    epoch: int,
    alpha: float = 0.5,
    learn_angle: bool = True,
    print_log_freq: int = 100,
    image_log_freq: int = 1000,
    num_images_log: int = 8,
    use_wandb: bool = True,
):
    """
    Train the model for one epoch.

    Args:
        model: model to train
        optimizer: optimizer to use
        train_dist_loader: dataloader for distance training
        train_action_loader: dataloader for action training
        device: device to use
        project_folder: folder to save images to
        epoch: current epoch
        alpha: weight of action loss
        learn_angle: whether to learn the angle of the action
        print_log_freq: how often to print loss
        image_log_freq: how often to log images
        num_images_log: number of images to log
        use_wandb: whether to use wandb
    """
    model.train()
    dist_loss_logger = Logger("dist_loss", "train", window_size=print_log_freq)
    lr_logger = Logger("learning_rate", "train", window_size=print_log_freq)
    action_loss_logger = Logger("action_loss", "train", window_size=print_log_freq)
    action_waypts_cos_sim_logger = Logger(
        "action_waypts_cos_sim", "train", window_size=print_log_freq
    )
    multi_action_waypts_cos_sim_logger = Logger(
        "multi_action_waypts_cos_sim", "train", window_size=print_log_freq
    )
    total_loss_logger = Logger("total_loss", "train", window_size=print_log_freq)

    variables = [
        dist_loss_logger,
        # action_loss_logger,
        # action_waypts_cos_sim_logger,
        # multi_action_waypts_cos_sim_logger,
        total_loss_logger,
        lr_logger
    ]

    # if learn_angle:
    #     action_orien_cos_sim_logger = Logger(
    #         "action_orien_cos_sim", "train", window_size=print_log_freq
    #     )
    #     multi_action_orien_cos_sim_logger = Logger(
    #         "multi_action_orien_cos_sim", "train", window_size=print_log_freq
    #     )
    #     variables.extend(
    #         [action_orien_cos_sim_logger, multi_action_orien_cos_sim_logger]
    #     )

    num_batches = len(train_dist_loader)
    for i, val in enumerate(train_dist_loader):
    # for i ,val in zip(range(4),train_dist_dataset):
    #  if i < 1:
        dist_vals = val
        (
            dist_obs_image,
            dist_goal_image,
            dist_trans_obs_image,
            dist_trans_goal_image,
            dist_label,
            dist_dataset_index,
        ) = dist_vals
        dist_obs_data = dist_trans_obs_image.to(device)
        dist_goal_data = dist_trans_goal_image.to(device)
        dist_label = dist_label.to(device)

        optimizer.zero_grad()
        dist_pred, _ = model(dist_obs_data, dist_goal_data)
        # print("dist",dist_pred,dist_label)
        dist_loss = F.binary_cross_entropy(dist_pred.unsqueeze(dim=0), dist_label.unsqueeze(dim=0))
        # print("dist",dist_loss,dist_pred.unsqueeze(dim=0),dist_label.unsqueeze(dim=0))
        # print("dist",dist_pred,dist_loss)
        assert not math.isnan(dist_loss.item())
        # action_obs_data = action_trans_obs_image.to(device)
        # action_goal_data = action_trans_goal_image.to(device)
        # action_label = action_label.to(device)

        # _, action_pred = model(action_obs_data, action_goal_data)
        # action_loss = F.mse_loss(action_pred, action_label)
        # action_waypts_cos_similairity = F.cosine_similarity(
        #     action_pred[:, :, :2], action_label[:, :, :2], dim=-1
        # ).mean()
        # multi_action_waypts_cos_sim = F.cosine_similarity(
        #     torch.flatten(action_pred[:, :, :2], start_dim=1),
        #     torch.flatten(action_label[:, :, :2], start_dim=1),
        #     dim=-1,
        # ).mean()
        # if learn_angle:
        #     action_orien_cos_sim = F.cosine_similarity(
        #         action_pred[:, :, 2:], action_label[:, :, 2:], dim=-1
        #     ).mean()
        #     multi_action_orien_cos_sim = F.cosine_similarity(
        #         torch.flatten(action_pred[:, :, 2:], start_dim=1),
        #         torch.flatten(action_label[:, :, 2:], start_dim=1),
        #         dim=-1,
        #     ).mean()
        #     action_orien_cos_sim_logger.log_data(action_orien_cos_sim.item())
        #     multi_action_orien_cos_sim_logger.log_data(
        #         multi_action_orien_cos_sim.item()
        #     )
        total_loss = dist_loss
        # print("loss",total_loss)

        total_loss.backward()
        optimizer.step()


            # visualize_traj_pred(
            #     to_numpy(action_obs_image),
            #     to_numpy(action_goal_image),
            #     to_numpy(action_dataset_index),
            #     to_numpy(action_goal_pos),
            #     to_numpy(action_pred),
            #     to_numpy(action_label),
            #     "train",
            #     normalized,
            #     project_folder,
            #     epoch,
            #     num_images_log,
            #     use_wandb=use_wandb,
            # )

    dist_loss_logger.log_data(dist_loss.item())
    lr_logger.log_data(optimizer.param_groups[0]["lr"])
    # action_loss_logger.log_data(action_loss.item())
    # action_waypts_cos_sim_logger.log_data(action_waypts_cos_similairity.item())
    # multi_action_waypts_cos_sim_logger.log_data(multi_action_waypts_cos_sim.item())
    total_loss_logger.log_data(total_loss.item())

    if use_wandb:
        data_log = {}
        for var in variables:
            data_log[var.full_name()] = var.latest()
        wandb.log(data_log)

    if i % print_log_freq == 0:
        log_display = f"(epoch {epoch}) (batch {i}/{num_batches - 1}) "
        for var in variables:
            print(log_display + var.display())
        print()

    if i % image_log_freq == 0:
        visualize_dist_pred(
            to_numpy(dist_obs_image),
            to_numpy(dist_goal_image),
            to_numpy(dist_pred),
            to_numpy(dist_label),
            "train",
            project_folder,
            epoch,
            num_images_log,
            use_wandb=use_wandb,
        )
    return

In [7]:
import wandb
import os
import numpy as np
from typing import List, Optional, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Adam
from gnm_train.visualizing.action_utils import visualize_traj_pred
from gnm_train.visualizing.distance_utils import visualize_dist_pred, visualize_dist_pairwise_pred
from gnm_train.visualizing.visualize_utils import to_numpy
from gnm_train.training.logger import Logger



def train_eval_loop(
    model: nn.Module,
    optimizer: Adam,
    train_dist_loader: DataLoader,
    # train_action_loader: DataLoader,
    test_dataloaders: Dict[str, DataLoader],
    epochs: int,
    device: torch.device,
    project_folder: str,
    normalized: bool,
    print_log_freq: int = 100,
    image_log_freq: int = 1000,
    num_images_log: int = 8,
    pairwise_test_freq: int = 5,
    current_epoch: int = 0,
    alpha: float = 0.5,
    learn_angle: bool = True,
    use_wandb: bool = True,
):
    """
    Train and evaluate the model for several epochs.

    Args:
        model: model to train
        optimizer: optimizer to use
        train_dist_loader: dataloader for training distance predictions
        train_action_loader: dataloader for training action predictions
        test_dataloaders: dict of dataloaders for testing
        epochs: number of epochs to train
        device: device to train on
        project_folder: folder to save checkpoints and logs
        log_freq: frequency of logging to wandb
        image_log_freq: frequency of logging images to wandb
        num_images_log: number of images to log to wandb
        pairwise_test_freq: frequency of testing pairwise distance accuracy
        current_epoch: epoch to start training from
        alpha: tradeoff between distance and action loss
        learn_angle: whether to learn the angle or not
        use_wandb: whether to log to wandb or not
        load_best: whether to load the best model or not
    """
    assert 0 <= alpha <= 1
    latest_path = os.path.join(project_folder, f"latest.pth")

    for epoch in range(current_epoch, current_epoch + epochs):
        print(
            f"Start GNM Training Epoch {epoch}/{current_epoch + epochs - 1}"
        )
        train(
            model,
            optimizer,
            train_dist_loader,
            # train_action_loader,
            device,
            project_folder,
            normalized,
            epoch,
            alpha,
            learn_angle,
            print_log_freq,
            image_log_freq,
            num_images_log,
            use_wandb,
        )

        eval_total_losses = []
        for dataset_type in test_dataloaders:
            print(
                f"Start {dataset_type} GNM Testing Epoch {epoch}/{current_epoch + epochs - 1}"
            )
            dist_loader = test_dataloaders[dataset_type]["distance"]
            # action_loader = test_dataloaders[dataset_type]["action"]
            test_dist_loss, test_action_loss = evaluate(
                dataset_type,
                model,
                dist_loader,
                # action_loader,
                device,
                project_folder,
                normalized,
                epoch,
                alpha,
                learn_angle,
                print_log_freq,
                image_log_freq,
                num_images_log,
                use_wandb,
            )

            total_eval_loss = get_total_loss(test_dist_loss, test_action_loss, alpha)
            eval_total_losses.append(total_eval_loss)
            wandb.log({f"{dataset_type}_total_loss": total_eval_loss})
            print(f"{dataset_type}_total_loss: {total_eval_loss}")
            wandb.log({f"{dataset_type}_dist_loss": test_dist_loss})
            print(f"{dataset_type}_dist_loss: {test_dist_loss}")
            wandb.log({f"{dataset_type}_action_loss": test_action_loss})
            print(f"{dataset_type}_action_loss: {test_action_loss}")

        checkpoint = {
            "epoch": epoch,
            "model": model,
            "optimizer": optimizer,
            "avg_eval_loss": np.mean(eval_total_losses),
        }

        numbered_path = os.path.join(project_folder, f"{epoch}.pth")
        torch.save(checkpoint, latest_path)
        torch.save(checkpoint, numbered_path)  # keep track of model at every epoch

        if (epoch - current_epoch) % pairwise_test_freq == 0:
            print(f"Start Pairwise Testing Epoch {epoch}/{current_epoch + epochs - 1}")
            for dataset_type in test_dataloaders:
                if "pairwise" in test_dataloaders[dataset_type]:
                    pairwise_dist_loader = test_dataloaders[dataset_type]["pairwise"]
                    pairwise_accuracy = pairwise_acc(
                        model,
                        pairwise_dist_loader,
                        device,
                        project_folder,
                        epoch,
                        dataset_type,
                        print_log_freq,
                        image_log_freq,
                        num_images_log,
                        use_wandb=use_wandb,
                    )
                    wandb.log({f"{dataset_type}_pairwise_acc": pairwise_accuracy})
                    print(f"{dataset_type}_pairwise_acc: {pairwise_accuracy}")



In [8]:
if config["train"]:
    train_eval_loop(
        model=model,
        optimizer=optimizer,
        train_dist_loader=train_dist_loader,
        # train_action_loader=train_action_loader,
        test_dataloaders=test_dataloaders,
        epochs=config["epochs"],
        device=device,
        project_folder=config["project_folder"],
        normalized=config["normalize"],
        print_log_freq=config["print_log_freq"],
        image_log_freq=config["image_log_freq"],
        num_images_log=config["num_images_log"],
        pairwise_test_freq=config["pairwise_test_freq"],
        current_epoch=current_epoch,
        learn_angle=config["learn_angle"],
        alpha=config["alpha"],
        use_wandb=config["use_wandb"],
    )
print("FINISHED TRAINING")



Start GNM Training Epoch 0/49




Start carla_test GNM Testing Epoch 0/49
(epoch 0) (batch 0/41) dist_loss (carla_test): 1.0843 (100pt moving_avg: 1.0843) (avg: 1.0843)
(epoch 0) (batch 0/41) action_loss (carla_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) (batch 0/41) action_waypts_cos_sim (carla_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) (batch 0/41) multi_action_waypts_cos_sim (carla_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) (batch 0/41) total_loss_logger (carla_test): 1.0843 (100pt moving_avg: 1.0843) (avg: 1.0843)
(epoch 0) (batch 0/41) action_orien_cos_sim (carla_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) (batch 0/41) multi_action_orien_cos_sim (carla_test): nan (100pt moving_avg: nan) (avg: nan)

(epoch 0) dist_loss (carla_test): 0.7145 (100pt moving_avg: 0.4948) (avg: 0.4948)
(epoch 0) action_loss (carla_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) action_waypts_cos_sim (carla_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) multi_action_waypts

  fig, ax = plt.subplots(1, len(imgs))


(epoch 0) (batch 2100/15847) dist_loss (go_stanford_test): 0.3493 (100pt moving_avg: 0.4076) (avg: 0.376)
(epoch 0) (batch 2100/15847) action_loss (go_stanford_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) (batch 2100/15847) action_waypts_cos_sim (go_stanford_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) (batch 2100/15847) multi_action_waypts_cos_sim (go_stanford_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) (batch 2100/15847) total_loss_logger (go_stanford_test): 0.3493 (100pt moving_avg: 0.4076) (avg: 0.376)
(epoch 0) (batch 2100/15847) action_orien_cos_sim (go_stanford_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) (batch 2100/15847) multi_action_orien_cos_sim (go_stanford_test): nan (100pt moving_avg: nan) (avg: nan)

(epoch 0) (batch 2200/15847) dist_loss (go_stanford_test): 0.5207 (100pt moving_avg: 0.3819) (avg: 0.3763)
(epoch 0) (batch 2200/15847) action_loss (go_stanford_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) (batch 2200/

  plt.figure()


(epoch 0) (batch 3100/15847) dist_loss (go_stanford_test): 0.7479 (100pt moving_avg: 0.3717) (avg: 0.378)
(epoch 0) (batch 3100/15847) action_loss (go_stanford_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) (batch 3100/15847) action_waypts_cos_sim (go_stanford_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) (batch 3100/15847) multi_action_waypts_cos_sim (go_stanford_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) (batch 3100/15847) total_loss_logger (go_stanford_test): 0.7479 (100pt moving_avg: 0.3717) (avg: 0.378)
(epoch 0) (batch 3100/15847) action_orien_cos_sim (go_stanford_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) (batch 3100/15847) multi_action_orien_cos_sim (go_stanford_test): nan (100pt moving_avg: nan) (avg: nan)

(epoch 0) (batch 3200/15847) dist_loss (go_stanford_test): 0.1352 (100pt moving_avg: 0.3936) (avg: 0.3785)
(epoch 0) (batch 3200/15847) action_loss (go_stanford_test): nan (100pt moving_avg: nan) (avg: nan)
(epoch 0) (batch 3200/