# Create environment

In [1]:
from vlnce_baselines.common.env_utils import construct_envs
from habitat import Config
from vlnce_baselines.config.default import get_config
from habitat_baselines.common.environments import get_env_class

#Import config and perform some manipulation
exp_config ='vlnce_baselines/config/paper_configs/transformer_semantics.yaml'
config = get_config(exp_config, None)
split = config.TASK_CONFIG.DATASET.SPLIT
config.defrost()
config.TASK_CONFIG.TASK.NDTW.SPLIT = split
config.TASK_CONFIG.TASK.SDTW.SPLIT = split

# if doing teacher forcing, don't switch the scene until it is complete
if config.DAGGER.P == 1.0:
    config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = (
        -1
    )
config.freeze()
envs = construct_envs(config, get_env_class(config.ENV_NAME))
obs=envs.reset()

2020-08-11 01:00:47,596 Initializing dataset VLN-CE-v1
2020-08-11 01:00:48,422 [construct_envs] Using GPU ID 0
2020-08-11 01:00:48,423 [construct_envs] Using GPU ID 0
2020-08-11 01:00:48,424 [construct_envs] Using GPU ID 0
2020-08-11 01:00:48,424 [construct_envs] Using GPU ID 0
2020-08-11 01:00:48,425 [construct_envs] Using GPU ID 0
2020-08-11 01:00:48,426 [construct_envs] Using GPU ID 0


# Reset and get observations

## 1. Get Observations from env reset

In [None]:
import torch
from vlnce_baselines.common.utils import transform_obs
from habitat_baselines.common.utils import batch_obs

device = (
    torch.device("cuda", config.TORCH_GPU_ID)
    if torch.cuda.is_available()
    else torch.device("cpu")
)
print(envs.num_envs)
print(envs.current_episodes()[0].episode_id)
observations = envs.reset()
# observations[0]['instruction_text'] = 'Go to kitchen'

for k,v in observations[0].items():
    print(k)
# print(observations[0])

observations = transform_obs(
    observations, config.TASK_CONFIG.TASK.INSTRUCTION_SENSOR_UUID, is_bert=True
)

# print(observations)
batch = batch_obs(observations, device=device)

# print(batch['instruction_batch'])


# print (batch["instruction_batch"].shape)
# lengths = (batch['instruction_batch'] != 0.0).long().sum(dim=1)
# print(lengths)

# print(batch['instruction'])
instruction = batch["instruction"].long()
from transformers import BertTokenizer, BertModel

embedding_layer = BertModel.from_pretrained('bert-base-uncased').to(device)
embedding_layer.eval()

with torch.no_grad():
    embedded = embedding_layer(instruction)
    embedded = embedded[0]
     
print(embedded.shape)

lengths = (batch['instruction'] != 0.0).long().sum(dim=1)

# def get_trasnformer_mask(instr_embedding, instr_len, device):
#     mask = torch.ones((instr_embedding.shape[0], instr_embedding.shape[1]), dtype=torch.bool).to(device)
#     attention_mask = torch.ones((instr_embedding.shape[0], instr_embedding.shape[1], instr_embedding.shape[1]), dtype=torch.bool).to(device)
#     for i, _len in enumerate(instr_len):
#         mask[i, :_len] = 0
#         attention_mask[i, :_len, :_len] = 0
#     pe_mask = mask.unsqueeze(dim=-1)
#     value = (instr_embedding, pe_mask)
#     return value, attention_mask.unsqueeze(1), mask.unsqueeze(dim=1).unsqueeze(dim=1)


# value, attn_mask, mask = get_trasnformer_mask(embedded, lengths, device)

# print("attn_mask: ",attn_mask.shape)
# print("mask: ",mask.shape)

# print(value[1].shape)

def get_instruction_mask(instr_embedding, instr_len, device):
    mask = torch.ones((instr_embedding.shape[0], instr_embedding.shape[1]), dtype=torch.bool).to(device)
    for i, _len in enumerate(instr_len):
        mask[i, :_len] = 0
    return mask.unsqueeze(dim=1).unsqueeze(dim=1)
enc_mask = get_instruction_mask(embedded, lengths, device)


In [None]:
print(enc_mask.shape)

## 2. Get observations from file

In [None]:
print(batch['progress'].shape)

print(batch['instruction'].shape)

In [2]:
#Import config and perform some manipulation
from vlnce_baselines.config.default import get_config
exp_config ='vlnce_baselines/config/paper_configs/transformer_semantics.yaml'
config = get_config(exp_config, None)
split = config.TASK_CONFIG.DATASET.SPLIT
config.defrost()
config.TASK_CONFIG.TASK.NDTW.SPLIT = split
config.TASK_CONFIG.TASK.SDTW.SPLIT = split

# if doing teacher forcing, don't switch the scene until it is complete
if config.DAGGER.P == 1.0:
    config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = (
        -1
    )
config.freeze()

import torch
import copy
import gc
import json
import os
import random
import time
import warnings
from collections import defaultdict
from typing import Dict

import lmdb
import msgpack_numpy
import numpy as np
import torch
import torch.nn.functional as F
import tqdm

class ObservationsDict(dict):
    def pin_memory(self):
        for k, v in self.items():
            self[k] = v.pin_memory()

        return self

def collate_fn(batch):
    """Each sample in batch: (
            obs,
            prev_actions,
            oracle_actions,
            inflec_weight,
        )
    """

    def _pad_helper(t, max_len, fill_val=0):
        pad_amount = max_len - t.size(0)
        if pad_amount == 0:
            return t
        
        pad = torch.full_like(t[0:1], fill_val).expand(pad_amount, *t.size()[1:])
        return torch.cat([t, pad], dim=0)

    transposed = list(zip(*batch))

    observations_batch = list(transposed[0])
    prev_actions_batch = list(transposed[1])
    corrected_actions_batch = list(transposed[2])
    N = len(corrected_actions_batch)
    weights_batch = list(transposed[3])
    B = len(prev_actions_batch)
    new_observations_batch = defaultdict(list)
    for sensor in observations_batch[0]:
        if sensor == 'instruction':
            for bid in range(N):
                new_observations_batch[sensor].append(observations_batch[bid][sensor])
        else: 
            for bid in range(B):
                new_observations_batch[sensor].append(observations_batch[bid][sensor])

    observations_batch = new_observations_batch

    max_traj_len = max(ele.size(0) for ele in prev_actions_batch)
#     obs_lengths=[]
    for bid in range(B):
        for sensor in observations_batch:
            if sensor == 'instruction':
                continue
            observations_batch[sensor][bid] = _pad_helper(
                observations_batch[sensor][bid], max_traj_len, fill_val=0.0
            )
#         obs_lengths.append(prev_actions_batch[bid].shape[0])
        prev_actions_batch[bid] = _pad_helper(prev_actions_batch[bid], max_traj_len)
        corrected_actions_batch[bid] = _pad_helper(
            corrected_actions_batch[bid], max_traj_len, fill_val=-1.0
        )
        weights_batch[bid] = _pad_helper(weights_batch[bid], max_traj_len)

    for sensor in observations_batch:
        observations_batch[sensor] = torch.stack(observations_batch[sensor], dim=1)
        observations_batch[sensor] = observations_batch[sensor].transpose(1,0)
        observations_batch[sensor] = observations_batch[sensor].contiguous().view(
            -1, *observations_batch[sensor].size()[2:]
        )

    prev_actions_batch = torch.stack(prev_actions_batch, dim=1)
    corrected_actions_batch = torch.stack(corrected_actions_batch, dim=1)
    weights_batch = torch.stack(weights_batch, dim=1)
    not_done_masks = torch.ones_like(corrected_actions_batch, dtype=torch.float)
    not_done_masks[0] = 0
    
    prev_actions_batch = prev_actions_batch.transpose(1,0)
    not_done_masks = not_done_masks.transpose(1,0)
    corrected_actions_batch = corrected_actions_batch.transpose(1,0)
    weights_batch = weights_batch.transpose(1,0)
    
    observations_batch = ObservationsDict(observations_batch)

    return (
        observations_batch,
        prev_actions_batch.contiguous().view(-1, 1),
        not_done_masks.contiguous().view(-1, 1),
        corrected_actions_batch,
        weights_batch,
    )

def _block_shuffle(lst, block_size):
    blocks = [lst[i : i + block_size] for i in range(0, len(lst), block_size)]
    random.shuffle(blocks)

    return [ele for block in blocks for ele in block]

class IWTrajectoryDataset(torch.utils.data.IterableDataset):
    def __init__(
        self,
        lmdb_features_dir,
        use_iw,
        inflection_weight_coef=1.0,
        lmdb_map_size=1e9,
        batch_size=1,
    ):
        super().__init__()
        self.lmdb_features_dir = lmdb_features_dir
        self.lmdb_map_size = lmdb_map_size
        self.preload_size = batch_size * 100
        self._preload = []
        self.batch_size = batch_size

        if use_iw:
            self.inflec_weights = torch.tensor([1.0, inflection_weight_coef])
        else:
            self.inflec_weights = torch.tensor([1.0, 1.0])

        with lmdb.open(
            self.lmdb_features_dir,
            map_size=int(self.lmdb_map_size),
            readonly=True,
            lock=False,
        ) as lmdb_env:
            self.length = lmdb_env.stat()["entries"]

    def _load_next(self):
        if len(self._preload) == 0:
            if len(self.load_ordering) == 0:
                raise StopIteration

            new_preload = []
            lengths = []
            with lmdb.open(
                self.lmdb_features_dir,
                map_size=int(self.lmdb_map_size),
                readonly=True,
                lock=False,
            ) as lmdb_env, lmdb_env.begin(buffers=True) as txn:
                for _ in range(self.preload_size):
                    if len(self.load_ordering) == 0:
                        break

                    new_preload.append(
                        msgpack_numpy.unpackb(
                            txn.get(str(self.load_ordering.pop()).encode()), raw=False
                        )
                    )

                    lengths.append(len(new_preload[-1][0]))

            sort_priority = list(range(len(lengths)))
            random.shuffle(sort_priority)

            sorted_ordering = list(range(len(lengths)))
            sorted_ordering.sort(key=lambda k: (lengths[k], sort_priority[k]))

            for idx in _block_shuffle(sorted_ordering, self.batch_size):
                self._preload.append(new_preload[idx])

        return self._preload.pop()

    def __next__(self):
        obs, prev_actions, oracle_actions= self._load_next()
        
        instruction_batch = obs['instruction'][0]
        instruction_batch = np.expand_dims(instruction_batch, axis=0)
        obs['instruction'] = instruction_batch
        for k, v in obs.items():
            obs[k] = torch.from_numpy(v)

        prev_actions = torch.from_numpy(prev_actions)
        
        
        oracle_actions = torch.from_numpy(oracle_actions)

        inflections = torch.cat(
            [
                torch.tensor([1], dtype=torch.long),
                (oracle_actions[1:] != oracle_actions[:-1]).long(),
            ]
        )
        return (obs, prev_actions, oracle_actions, self.inflec_weights[inflections])

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            start = 0
            end = self.length
        else:
            per_worker = int(np.ceil(self.length / worker_info.num_workers))

            start = per_worker * worker_info.id
            end = min(start + per_worker, self.length)

        # Reverse so we can use .pop()
        self.load_ordering = list(
            reversed(_block_shuffle(list(range(start, end)), self.preload_size))
        )

        return self

In [3]:
# lmdb_features_dir = 'data/trajectories_dirs/seq2seq/trajectories.lmdb'
lmdb_features_dir = '/data/zirshad/VLNCE-data/data/trajectory_dir/transformer_semantic/trajectories.lmdb'

USE_IW= False

dataset = IWTrajectoryDataset(
    lmdb_features_dir,
    USE_IW,
    inflection_weight_coef=1.0,
    lmdb_map_size=1e9,
    batch_size=5,
)
diter = torch.utils.data.DataLoader(
    dataset,
    batch_size=5,
    shuffle=False,
    collate_fn=collate_fn,
    pin_memory=False,
    drop_last=True,  # drop last batch if smaller
    num_workers=3,
)

device = (
    torch.device("cuda", config.TORCH_GPU_ID)
    if torch.cuda.is_available()
    else torch.device("cpu")
)

print(dataset.length)

# iterloader = iter(diter)
observations_batch,prev_actions_batch,not_done_masks,corrected_actions_batch,weights_batch = next(iter(diter))
# try:
#     observations_batch,prev_actions_batch,not_done_masks,corrected_actions_batch,weights_batch = next(iterloader)
# except StopIteration:
#     batch_iter = iter(diter)
#     observations_batch,prev_actions_batch,not_done_masks,corrected_actions_batch,weights_batch = next(batch_iter)
    
# print(observations_batch['depth_features'].shape)
# print(prev_actions_batch.shape)
observations_batch = {
    k: v.to(device=device, non_blocking=True)
    for k, v in observations_batch.items()
}

for k, v in observations_batch.items():
    print(k)



10819




instruction
progress
heading
ego_sem_map
rgb_features
depth_features


In [6]:
print(observations_batch['ego_sem_map'].shape)

T,N = corrected_actions_batch.size()

esm= observations_batch['ego_sem_map']
sem_map          = esm.view(T, N, *esm.size()[1:])

print(sem_map.shape)

print(sem_map[:,10].shape)

torch.Size([490, 40, 40])
torch.Size([5, 98, 40, 40])
torch.Size([5, 40, 40])


In [10]:
pab= prev_actions_batch.view(T,N)


print(not_done_masks.shape)
ma = not_done_masks.view(T,N)

print(ma.shape)
# print(prev_actions_batch.view(T,N))

prev_actions_single = pab[:,3]


print(prev_actions_single.shape)

print(ma[:,1].shape)
print(prev_actions_single)


# print(corrected_actions_batch)

corrected_actions = corrected_actions_batch.contiguous().view(T*N)

# print(corrected_actions)

prev_action_embedding = nn.Embedding(num_actions+1, 32, padding_idx= 1)

torch.Size([490, 1])
torch.Size([5, 98])
torch.Size([5])
torch.Size([5])
tensor([2, 2, 2, 3, 3])


In [8]:
prev_actions_single = pab[:,1]

print(prev_actions_single)
masks_single = ma[:,:80+1]

# print(prev_actions_single)
# print(masks_single)

pad_mask = (((prev_actions_single+1)*masks_single)!= 1).unsqueeze(1).unsqueeze(2)

print(pad_mask)

seq_len = prev_actions_single.shape[1]

# prev_action_pad = ((prev_actions_batch+1)*not_done_masks).view(T,N)

# pad= prev_action_pad[:,0:80+1]

# prev_action_pad_mask = (pad != 1).unsqueeze(1).unsqueeze(2)

print(prev_action_pad_mask)

# print(prev_action_pad_mask)
mask_self_attention = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1).bool()
# print(torch.triu(torch.ones((seq_len, seq_len)), difagonal=1))
mask_self_attention = mask_self_attention  # (1, 1, seq_len, seq_len)
mask_self_attention = mask_self_attention.gt(0)  # (b_s, 1, seq_len, seq_len)

mask = prev_action_pad_mask & ~mask_self_attention
mask = ~mask
print("----------------------")
print(mask)

tensor([2, 2, 2, 3, 3])


RuntimeError: The size of tensor a (5) must match the size of tensor b (81) at non-singleton dimension 1

In [None]:
print(observations_batch['progress'].shape)

print(observations_batch['instruction'].shape)

print(prev_actions_batch.shape)

print(corrected_actions_batch.shape)

T,N = corrected_actions_batch.size()
# df = observations_batch['depth_features'].view(T,N,-1,-1)

In [None]:
batch=observations_batch
instruction = batch["instruction"].long()
from transformers import BertTokenizer, BertModel

embedding_layer = BertModel.from_pretrained('bert-base-uncased').to(device)
embedding_layer.eval()

with torch.no_grad():
    embedded = embedding_layer(instruction)
    embedded = embedded[0]
     
print(embedded.shape)

lengths = (batch['instruction'] != 0.0).long().sum(dim=1)

def get_trasnformer_mask(instr_embedding, instr_len, device):
    mask = torch.ones((instr_embedding.shape[0], instr_embedding.shape[1]), dtype=torch.bool).to(device)
    attention_mask = torch.ones((instr_embedding.shape[0], instr_embedding.shape[1], instr_embedding.shape[1]), dtype=torch.bool).to(device)
    for i, _len in enumerate(instr_len):
        mask[i, :_len] = 0
        attention_mask[i, :_len, :_len] = 0
    pe_mask = mask.unsqueeze(dim=-1)
    value = (instr_embedding, pe_mask)
    return value, attention_mask.unsqueeze(1), mask.unsqueeze(dim=1).unsqueeze(dim=1)


value, attn_mask, mask = get_trasnformer_mask(embedded, lengths, device)

print("attn_mask: ",attn_mask.shape)
print("mask: ",mask.shape)

print(value[1].shape)


In [None]:
import torch.nn as nn
# print(prev_actions_batch)

# print(prev_actions_batch.float() + 1)

# print(not_done_masks)

# print(prev_actions_batch)


num_embeddings = envs.action_spaces[0].n

print(num_embeddings)
embedding = nn.Embedding(num_actions+1, 32, padding_idx= 1)

prev_actions_embedding = embedding(((prev_actions_batch.float()+1) * not_done_masks).long())

print(prev_actions_embedding.shape)
T,N = corrected_actions_batch.size()
print(prev_actions_embedding[74])
print(T,N)

### Test Language encoder

In [None]:
from vlnce_baselines.models.transformer.transformer import TransformerLanguageEncoder
d_model = 512
dropout =0.1
h=8
d_att   = int(d_model / h)

model_config = config.MODEL
encoder = TransformerLanguageEncoder(model_config.TRANSFORMER_INSTRUCTION_ENCODER).to(device=device)

w_t = encoder(value, attention_mask=attn_mask, attention_weights=None, device=device)

print(w_t.shape)

In [None]:
model_config = config.MODEL
ins_fc =  torch.nn.Linear(model_config.TRANSFORMER_INSTRUCTION_ENCODER.d_in, model_config.TRANSFORMER_INSTRUCTION_ENCODER.d_model).to(device)
w_t = ins_fc(embedded)

print(w_t.shape)

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from vlnce_baselines.common.utils import sinusoid_encoding_table

d_in = 768
dropout =0.1
fc = nn.Linear(d_in, d_model, bias=True).to(device)
data, mask = value
dropout = nn.Dropout(p=dropout)
layer_norm = nn.LayerNorm(d_model)

out = F.relu(fc(data))
out = dropout(out)
out = layer_norm(out)
pe = sinusoid_encoding_table(out.shape[1], out.shape[2])
pe = pe.expand(out.shape[0], pe.shape[0], pe.shape[1]).to(device)

print(pe.shape)
out = out + pe.masked_fill(mask, 0)

print(out.shape)

In [None]:
from vlnce_baselines.common.utils import sinusoid_encoding_table

pe = sinusoid_encoding_table(out.shape[1], out.shape[2])

print(pe.shape)

pos = torch.arange(out.shape[1], dtype=torch.float32)

dim = torch.arange(out.shape[2] // 2, dtype=torch.float32).view(1, -1)



pos = pos.view(-1, 1)
print(pos.shape)
print(dim.shape)

sin = torch.sin(pos / 10000 ** (2 * dim / out.shape[2]))

print(sin.shape)

In [None]:
import torchvision
model1 = torchvision.models.resnet18(pretrained=True)
model1

### DETR Backbone and Positional Encodings for images

In [None]:
print(batch['rgb'].shape)
from typing import Optional, List
from torch import Tensor
import torch.distributed as dist

class NestedTensor(object):
    def __init__(self, tensors, mask: Optional[Tensor]):
        self.tensors = tensors
        self.mask = mask

    def to(self, device):
        # type: (Device) -> NestedTensor # noqa
        cast_tensor = self.tensors.to(device)
        mask = self.mask
        if mask is not None:
            assert mask is not None
            cast_mask = mask.to(device)
        else:
            cast_mask = None
        return NestedTensor(cast_tensor, cast_mask)

    def decompose(self):
        return self.tensors, self.mask

    def __repr__(self):
        return str(self.tensors)
    
def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True

def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0

import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter

def _max_by_axis(the_list):
    # type: (List[List[int]]) -> List[int]
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)
    return maxes

def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
    # TODO make this more general
    if tensor_list[0].ndim == 3:
        # TODO make it support different-sized images
        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
        for img, pad_img, m in zip(tensor_list, tensor, mask):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
            m[: img.shape[1], :img.shape[2]] = False
    else:
        raise ValueError('not supported')
    return NestedTensor(tensor, mask)

class FrozenBatchNorm2d(torch.nn.Module):
    """
    BatchNorm2d where the batch statistics and the affine parameters are fixed.
    Copy-paste from torchvision.misc.ops with added eps before rqsrt,
    without which any other models than torchvision.models.resnet[18,34,50,101]
    produce nans.
    """

    def __init__(self, n):
        super(FrozenBatchNorm2d, self).__init__()
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        num_batches_tracked_key = prefix + 'num_batches_tracked'
        if num_batches_tracked_key in state_dict:
            del state_dict[num_batches_tracked_key]

        super(FrozenBatchNorm2d, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

    def forward(self, x):
        # move reshapes to the beginning
        # to make it fuser-friendly
        w = self.weight.reshape(1, -1, 1, 1)
        b = self.bias.reshape(1, -1, 1, 1)
        rv = self.running_var.reshape(1, -1, 1, 1)
        rm = self.running_mean.reshape(1, -1, 1, 1)
        eps = 1e-5
        scale = w * (rv + eps).rsqrt()
        bias = b - rm * scale
        return x * scale + bias

class BackboneBase(nn.Module):

    def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
        super().__init__()
        for name, parameter in backbone.named_parameters():
            if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                parameter.requires_grad_(False)
        if return_interm_layers:
            return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
        else:
            return_layers = {'layer4': "0"}
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.num_channels = num_channels

    def forward(self, x):
        xs = self.body(x)
#         out: Dict[str, NestedTensor] = {}
#         for name, x in xs.items():
#             m = tensor_list.mask
#             assert m is not None
#             mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
#             out[name] = NestedTensor(x, mask)
        return xs

class Backbone(BackboneBase):
    """ResNet backbone with frozen BatchNorm."""
    def __init__(self, name: str,
                 train_backbone: bool,
                 return_interm_layers: bool,
                 dilation: bool):
        backbone = getattr(torchvision.models, name)(
            replace_stride_with_dilation=[False, False, dilation],
            pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
        num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
        super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
        

lr_backbone = 1e-5

backbone ='resnet50'
train_backbone = lr_backbone > 0

return_interm_layers = False
dilation = False

backbone = Backbone(backbone, train_backbone, return_interm_layers, dilation).to(device)

# if isinstance(samples, (list, torch.Tensor)):
#     samples = nested_tensor_from_tensor_list(samples)

        
samples = batch['rgb'].permute(0,3,1,2)

features = backbone(samples)

print(features['0'].shape)


num_channels = 2048

hidden_dim = 512//2
input_proj = nn.Conv2d(num_channels, hidden_dim, kernel_size=1).to(device)

out = input_proj(features['0'])

print(out.shape)


src = out.flatten(2).permute(2, 0, 1)
print(src.shape)


In [None]:
hidden_dim =256
N_steps = hidden_dim // 2

class PositionEmbeddingLearned(nn.Module):
    """
    Absolute pos embedding, learned.
    """
    def __init__(self, num_pos_feats=256):
        super().__init__()
        self.row_embed = nn.Embedding(50, num_pos_feats)
        self.col_embed = nn.Embedding(50, num_pos_feats)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.row_embed.weight)
        nn.init.uniform_(self.col_embed.weight)

    def forward(self, tensor):
        x = tensor
        h, w = x.shape[-2:]
        i = torch.arange(w, device=x.device)
        
        print("i",i.shape)
        j = torch.arange(h, device=x.device)
        print("j", j.shape)
        x_emb = self.col_embed(i)
        y_emb = self.row_embed(j)
        
        print(x_emb.shape)
        print(y_emb.shape)
        pos = torch.cat([
            x_emb.unsqueeze(0).repeat(h, 1, 1),
            y_emb.unsqueeze(1).repeat(1, w, 1),
        ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
        return pos

position_embedding = PositionEmbeddingLearned(N_steps).to(device)

pos = position_embedding(out)

print(pos.shape)


pos_embed = pos.flatten(2).permute(0, 2, 1)

print(pos_embed.shape)

# o = pos_embed +src

In [None]:
embedding = nn.Embedding(10, 3)
# a batch of 2 samples of 4 indices each
input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])

print(input.shape)
embed= embedding(input)
print(embed.shape)

In [None]:
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
observation_space=envs.observation_spaces[0]
action_space=envs.action_spaces[0]

model_config = config.MODEL

class TorchVisionResNet50(nn.Module):
    r"""
    Takes in observations and produces an embedding of the rgb component.

    Args:
        observation_space: The observation_space of the agent
        output_size: The size of the embedding vector
        device: torch.device
    """

    def __init__(
        self, observation_space, output_size, device, spatial_output: bool = False
    ):
        super().__init__()
        self.device = device
        self.resnet_layer_size = 2048
        linear_layer_input_size = 0
        if "rgb" in observation_space.spaces:
            self._n_input_rgb = observation_space.spaces["rgb"].shape[2]
            obs_size_0 = observation_space.spaces["rgb"].shape[0]
            obs_size_1 = observation_space.spaces["rgb"].shape[1]
            if obs_size_0 != 224 or obs_size_1 != 224:
                logger.warn(
                    f"WARNING: TorchVisionResNet50: observation size {obs_size} is not conformant to expected ResNet input size [3x224x224]"
                )
            linear_layer_input_size += self.resnet_layer_size
        else:
            self._n_input_rgb = 0

        if self.is_blind:
            self.cnn = nn.Sequential()
            return

        self.cnn = models.resnet50(pretrained=True)

        # disable gradients for resnet, params frozen
        for param in self.cnn.parameters():
            param.requires_grad = False
        self.cnn.eval()

        self.spatial_output = spatial_output

        if not self.spatial_output:
            self.output_shape = (output_size,)
            self.fc = nn.Linear(linear_layer_input_size, output_size).cuda(2)
            self.activation = nn.ReLU()
        else:

            class SpatialAvgPool(nn.Module):
                def forward(self, x):
                    x = F.adaptive_avg_pool2d(x, (4, 4))

                    return x

            self.cnn.avgpool = SpatialAvgPool()
            self.cnn.fc = nn.Sequential()

            self.spatial_embeddings = nn.Embedding(4 * 4, 64)

            self.output_shape = (
                self.resnet_layer_size + self.spatial_embeddings.embedding_dim,
                4,
                4,
            )

        self.layer_extract = self.cnn._modules.get("layer4")

    @property
    def is_blind(self):
        return self._n_input_rgb == 0

    def forward(self, observations):
        r"""Sends RGB observation through the TorchVision ResNet50 pre-trained
        on ImageNet. Sends through fully connected layer, activates, and
        returns final embedding.
        """

        def resnet_forward(observation):
            resnet_output = torch.zeros(1, dtype=torch.float32, device=self.device)

            def hook(m, i, o):
                resnet_output.set_(o)

            # output: [BATCH x RESNET_DIM]
            h = self.layer_extract.register_forward_hook(hook)
            self.cnn(observation)
            h.remove()
            return resnet_output

        if "rgb_features" in observations:
            resnet_output = observations["rgb_features"]
        else:
            # permute tensor to dimension [BATCH x CHANNEL x HEIGHT x WIDTH]
            rgb_observations = observations["rgb"].permute(0, 3, 1, 2)
            rgb_observations = rgb_observations / 255.0  # normalize RGB
            resnet_output = resnet_forward(rgb_observations.contiguous())

        if self.spatial_output:
            b, c, h, w = resnet_output.size()

            spatial_features = (
                self.spatial_embeddings(
                    torch.arange(
                        0,
                        self.spatial_embeddings.num_embeddings,
                        device=resnet_output.device,
                        dtype=torch.long,
                    )
                )
                .view(1, -1, h, w)
                .expand(b, self.spatial_embeddings.embedding_dim, h, w)
            )

            return torch.cat([resnet_output, spatial_features], dim=1)
        else:
            return resnet_output
        
        
rgb_encoder = TorchVisionResNet50(
                observation_space, model_config.RGB_ENCODER.output_size, device, spatial_output=False,
            )

rgb_encoder = rgb_encoder.to(device)
rgb_embedding = rgb_encoder(batch)

num_channels = 2048

hidden_dim = 512

input_proj = nn.Conv2d(num_channels, hidden_dim, kernel_size=1).to(device)

print(rgb_embedding.shape)

out = input_proj(rgb_embedding)

out = out.flatten(2).permute(0, 2, 1)


print(out.shape)

In [None]:
hidden_dim =512
N_steps = hidden_dim // 2

class PositionEmbeddingLearned(nn.Module):
    """
    Absolute pos embedding, learned.
    """
    def __init__(self, num_pos_feats=256):
        super().__init__()
        self.row_embed = nn.Embedding(50, num_pos_feats)
        self.col_embed = nn.Embedding(50, num_pos_feats)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.row_embed.weight)
        nn.init.uniform_(self.col_embed.weight)

    def forward(self, tensor):
        x = tensor
        h, w = x.shape[-2:]
        i = torch.arange(w, device=x.device)
        j = torch.arange(h, device=x.device)
        print("i: ",i.shape)
        print("j: ",j.shape)
        x_emb = self.col_embed(i)
        y_emb = self.row_embed(j)
        
        print("x_embed",x_emb.shape)
        print("y_embed",x_emb.shape)
        
        print("x_repeat: ", x_emb.unsqueeze(0).repeat(h, 1, 1).shape)
        print("y_repeat: ", y_emb.unsqueeze(1).repeat(1, w, 1).shape)
        pos = torch.cat([
            x_emb.unsqueeze(0).repeat(h, 1, 1),
            y_emb.unsqueeze(1).repeat(1, w, 1),
        ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
        return pos

position_embedding = PositionEmbeddingLearned(N_steps).to(device)

pos = position_embedding(rgb_embedding)

print(pos.shape)


pos_embed = pos.flatten(2).permute(0, 2, 1)

print(pos_embed.shape)

# combined = pos_embed +out

# print(combined.shape)

In [None]:
from habitat_baselines.rl.ddppo.policy.resnet_policy import ResNetEncoder
from gym import spaces
from habitat_baselines.rl.ddppo.policy import resnet
from habitat_baselines.common.utils import Flatten
import numpy as np
class VlnResnetDepthEncoder(nn.Module):
    def __init__(
        self,
        observation_space,
        output_size=128,
        checkpoint="NONE",
        backbone="resnet50",
        resnet_baseplanes=32,
        normalize_visual_inputs=False,
        trainable=False,
        spatial_output: bool = False,
    ):
        super().__init__()
        self.visual_encoder = ResNetEncoder(
            spaces.Dict({"depth": observation_space.spaces["depth"]}),
            baseplanes=resnet_baseplanes,
            ngroups=resnet_baseplanes // 2,
            make_backbone=getattr(resnet, backbone),
            normalize_visual_inputs=normalize_visual_inputs,
        )

        for param in self.visual_encoder.parameters():
            param.requires_grad_(trainable)

        if checkpoint != "NONE":
            ddppo_weights = torch.load(checkpoint)

            weights_dict = {}
            for k, v in ddppo_weights["state_dict"].items():
                split_layer_name = k.split(".")[2:]
                if split_layer_name[0] != "visual_encoder":
                    continue

                layer_name = ".".join(split_layer_name[1:])
                weights_dict[layer_name] = v

            del ddppo_weights
            self.visual_encoder.load_state_dict(weights_dict, strict=True)

        self.spatial_output = spatial_output

        if not self.spatial_output:
            self.output_shape = (output_size,)
            self.visual_fc = nn.Sequential(
                Flatten(),
                nn.Linear(np.prod(self.visual_encoder.output_shape), output_size),
                nn.ReLU(True),
            ).cuda(2)
        else:
            self.spatial_embeddings = nn.Embedding(
                self.visual_encoder.output_shape[1]
                * self.visual_encoder.output_shape[2],
                64,
            )

            self.output_shape = list(self.visual_encoder.output_shape)
            self.output_shape[0] += self.spatial_embeddings.embedding_dim
            self.output_shape = tuple(self.output_shape)

    def forward(self, observations):
        """
        Args:
            observations: [BATCH, HEIGHT, WIDTH, CHANNEL]
        Returns:
            [BATCH, OUTPUT_SIZE]
        """
        if "depth_features" in observations:
            x = observations["depth_features"]
        else:
            x = self.visual_encoder(observations)

        if self.spatial_output:
            b, c, h, w = x.size()

            spatial_features = (
                self.spatial_embeddings(
                    torch.arange(
                        0,
                        self.spatial_embeddings.num_embeddings,
                        device=x.device,
                        dtype=torch.long,
                    )
                )
                .view(1, -1, h, w)
                .expand(b, self.spatial_embeddings.embedding_dim, h, w)
            )

            return torch.cat([x, spatial_features], dim=1)
        else:
            return x
        
depth_encoder = VlnResnetDepthEncoder(
                observation_space,
                output_size=model_config.DEPTH_ENCODER.output_size,
                checkpoint=model_config.DEPTH_ENCODER.ddppo_checkpoint,
                backbone=model_config.DEPTH_ENCODER.backbone,
                spatial_output=False,
            ).to(device)
        
depth_embedding = depth_encoder(batch)

print(depth_embedding.shape)

In [None]:
from torchvision.models import resnet50

model_1 = resnet50(pretrained=True)

print(model_1)

## Image Encoder

### RGB Depth Resnet Feature extractor

In [None]:
observations_batch = batch
model_config = config.MODEL

from vlnce_baselines.models.encoders.resnet_encoders import (
    TorchVisionResNet50,
    VlnResnetDepthEncoder,
)
observation_space=envs.observation_spaces[0]
action_space=envs.action_spaces[0]
depth_encoder = VlnResnetDepthEncoder(
                observation_space,
                output_size=model_config.DEPTH_ENCODER.output_size,
                checkpoint=model_config.DEPTH_ENCODER.ddppo_checkpoint,
                backbone=model_config.DEPTH_ENCODER.backbone,
                resnet_output=True,
            )
rgb_encoder = TorchVisionResNet50(
                observation_space, model_config.RGB_ENCODER.output_size,model_config.RGB_ENCODER.resnet_output_size, device, resnet_output=True,
            )

rgb_encoder = rgb_encoder.to(device)
depth_encoder = depth_encoder.to(device)

depth_embedding = depth_encoder(observations_batch)
print(depth_embedding.shape)
rgb_embedding = rgb_encoder(observations_batch)
print(rgb_embedding.shape)
# x = torch.cat([depth_embedding, rgb_embedding], dim=1)

In [None]:
print(observations_batch['rgb_features'].shape)
print(observations_batch['depth_features'].shape)

In [None]:
import torch.nn.functional as F
rgb_d = torch.cat((rgb_embedding, depth_embedding), dim=1)

print(rgb_d.shape)

fc = torch.nn.Linear(rgb_embedding.shape[1]+depth_embedding.shape[1], 512).to(device)

rgb_d = F.relu(fc(rgb_d.permute(0,2,3,1)))

rgb_d = rgb_d.permute(0,3,1,2)


print(rgb_d.shape)

# rgb_embedding = rgb_embedding.view(batch_size, max_len, -1, rgb_out_dim, rgb_out_dim)

In [None]:

print(rgb_embedding.shape)

print(depth_embedding.shape)

pooler = torch.nn.AdaptiveAvgPool2d(4).to(device)

rgb_up = pooler(rgb_embedding)

print(rgb_up.shape)

# rgb_embedding = rgb_embedding[:,1]

### 2D POS Encoding for Image

In [None]:
import torch.nn as nn
print(rgb_embedding.shape)

from vlnce_baselines.models.transformer.transformer import PositionEmbedding2DLearned

N_steps = model_config.IMAGE_CROSS_MODAL_ENCODER.d_model // 2
position_embedding_2d = PositionEmbedding2DLearned(N_steps).to(device)
rgbd_pos_embed =  position_embedding_2d(rgb_d)

print(rgbd_pos_embed.shape)

rgbd_out = rgb_d.flatten(2).permute(0, 2, 1)
pos_embed_out = rgbd_pos_embed.flatten(2).permute(0, 2, 1)

print(rgbd_out.shape)
print(pos_embed_out.shape)
# num_channels = 2048

# hidden_dim = 512

# input_proj = nn.Conv2d(num_channels, hidden_dim, kernel_size=1).to(device)

# rgb = input_proj(rgb_embedding)

### Image Encoder with Self and Cross Attention

In [None]:
import torch.nn as nn
from vlnce_baselines.models.transformer.transformer import ImageCrossModalEncoder, ImageEncoder_with_PosEncodings, SemMapEncoder_with_PosEncodings

image_encoder_trans = ImageEncoder_with_PosEncodings(model_config.IMAGE_CROSS_MODAL_ENCODER).to(device=device)

i_t = image_encoder_trans(rgbd_out, w_t, None, enc_mask, pos_embed_out)

print(i_t.shape)

list =[]

visual_pooler = nn.Sequential(
    nn.AdaptiveAvgPool1d((1)),
    nn.Flatten()
).to(device)


# i_t = visual_pooler(i_t.permute(0,2,1))

# print(i_t.shape)


In [None]:


print(batch['ego_sem_map'].shape)
print(w_t.shape)

### Semantic Map Self and Cross Attention

In [None]:
observation_space=envs.observation_spaces[0]
action_space=envs.action_spaces[0]
import torch.nn.functional as F
from vlnce_baselines.models.transformer.transformer import SemMapEncoder_with_PosEncodings
model_config.defrost()
model_config.SEM_MAP_TRANSFORMER.map_size = int(observation_space.spaces["ego_sem_map"].high.max() + 1)
model_config.SEM_MAP_TRANSFORMER.max_position_embeddings=observation_space.spaces["ego_sem_map"].shape[0] // 2
model_config.freeze()

sem_map_encoder = SemMapEncoder_with_PosEncodings(model_config.SEM_MAP_TRANSFORMER).to(device=device)
# sem = observations_batch['ego_sem_map']

print()
sem = batch['ego_sem_map'].unsqueeze(1)
sem = F.interpolate(sem, size=(20, 20), mode='nearest')

print(sem.shape)
s_t = sem_map_encoder(sem.squeeze(0), w_t, None, enc_mask)
print(s_t.shape)

sem_pooler = nn.Sequential(
    nn.AdaptiveAvgPool1d((16))
).to(device)


s_t = sem_pooler(s_t.permute(0,2,1)).permute(0,2,1)

# print(s_t.shape)

rgb_ds = torch.cat((i_t, s_t), dim=-1)

print(rgb_ds.shape)

In [None]:
pooler = nn.AdaptiveAvgPool1d(16).to(device)
sem_out = pooler(s_t.permute(0,2,1)).permute(0,2,1)

print(sem_out.shape)
rgb_ds = torch.cat((i_t, sem_out), dim=-1)

print(rgb_ds.shape)

fc = torch.nn.Linear(i_t.shape[2]+sem_out.shape[2], 512).to(device)

rgb_ds = F.relu(fc(rgb_ds))


print(rgb_ds.shape)


In [None]:
from PIL import Image
from habitat_sim.utils.common import d3_40_colors_rgb
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

batch = observations_batch

def semantic_to_image(semantic_obs, num_colors):
    from habitat_sim.utils.common import d3_40_colors_rgb
    semantic_img = Image.new("P", (semantic_obs.shape[1], semantic_obs.shape[0]))
    semantic_img.putpalette(d3_40_colors_rgb[:num_colors].flatten())
    semantic_img.putdata((semantic_obs.flatten() % num_colors).astype(np.uint8))
    semantic_img = semantic_img.convert("RGB")
    return np.array(semantic_img)


# def downsampling(x, size=None, scale_factor=None, mode='bilinear'):
#     # define size if user has specified scale_factor
#     if size is None: size = (int(scale_factor*x.size(2)), int(scale_factor*x.size(3)))
#     # create coordinates
#     h = torch.arange(0,size[0]) / (size[0]-1) * 2 - 1
#     w = torch.arange(0,size[1]) / (size[1]-1) * 2 - 1
#     # create grid
#     print(h.shape)
#     print(w.shape)
#     grid = torch.zeros(size[0],size[1],2)
#     grid[:,:,0] = w.unsqueeze(0).repeat(size[0],1)
#     grid[:,:,1] = h.unsqueeze(0).repeat(size[1],1).transpose(0,1)
#     # expand to match batch size
#     grid = grid.unsqueeze(0).repeat(x.size(0),1,1,1)
#     if x.is_cuda: grid = grid.cuda()
#     # do sampling
#     return F.grid_sample(x, grid, mode=mode)

import torch.nn.functional as F

image = observations_batch['ego_sem_map'][140].unsqueeze(0).unsqueeze(0)
# image = batch['ego_sem_map'].unsqueeze(0)

print(batch['ego_sem_map'].shape)
print(batch['ego_sem_map'].unsqueeze(0).shape)
print(batch['ego_sem_map'].unsqueeze(1).shape)

# image_small = downsampling(image, size=[20,20])
# print(image_small.shape)
# Create grid
out_size = 20
x = torch.linspace(-1, 1, out_size).view(-1, 1).repeat(1, out_size)
y = torch.linspace(-1, 1, out_size).repeat(out_size, 1)
grid = torch.cat((x.unsqueeze(2), y.unsqueeze(2)), 2).to(device)

# plt.figure()
# plt.matshow(grid[:,:,0].cpu().numpy())

# plt.figure()
# plt.matshow(grid[:,:,1].cpu().numpy())

# grid.unsqueeze_(0)



grid = grid.unsqueeze_(0).repeat(2,1,1,1)

print("image", image.shape)

# image_small = F.grid_sample(image, grid, mode='nearest', align_corners=True)
image_small = F.interpolate(image, size=(15, 15), mode='nearest')

import numpy as np
np.set_printoptions(threshold=np.inf)

ego_map= image_small.squeeze(0).squeeze(0)
semantic= image.squeeze(0).squeeze(0)

# print(ego_map.shape)
# print(semantic.shape)

# print(semantic.cpu().numpy())
# print(ego_map)

# print(image_small)

object2idx = envs.call_at(0,'get_object2idx')
semantic_obs = object2idx[semantic.cpu().numpy().astype(np.int)] 

ego_map = object2idx[ego_map.cpu().numpy().astype(np.int)]

semantic_img = semantic_to_image(semantic_obs, 40)
ego_img= semantic_to_image(ego_map, 40)

plt.figure()
plt.imshow(semantic_img)
plt.figure()
plt.imshow(ego_img)

print("ego image", ego_img.shape)

# plt.imshow(image.squeeze(0).squeeze(0).cpu().numpy())
# plt.imtshow(image_small.squeeze(0).squeeze(0).cpu().numpy())


In [None]:
import torch.nn as nn

## Single prev action
prev_action_list=[]
prev_actions = torch.zeros(
    config.NUM_PROCESSES, 1, device=device, dtype=torch.long
)
not_done_masks = torch.zeros(config.NUM_PROCESSES, 1, device=device)
if config.NUM_PROCESSES==1:
    prev_actions = prev_actions.view(-1)
    not_done_masks = not_done_masks.view(-1)


## Batch prev actions
# prev_actions = prev_actions_batch
# batch_size = 5
# N = int(prev_actions_batch.size(0)/batch_size)

# prev_actions = prev_actions.view(batch_size, N).to(device)
# not_done_masks = not_done_masks.view(batch_size, N).to(device)

# prev_actions = prev_actions[:,:3]
# not_done_masks = not_done_masks[:,:3]

num_embeddings = envs.action_spaces[0].n
embedding = nn.Embedding(num_embeddings+1, 100, padding_idx= 1).to(device)

prev_actions_embedding = embedding(((prev_actions.float()+1)*not_done_masks).long())

print(prev_actions_embedding.shape)

prev_action_list=[]

In [None]:
x = torch.cat([rgb_ds, prev_actions_embedding], dim=1)

print(x.shape)

state_compress = nn.Sequential(
    nn.Linear(
        model_config.TRANSFORMER.output_size*2 + embedding.embedding_dim,
        model_config.TRANSFORMER.output_size,
    ),
    nn.ReLU(True),
).to(device)


x = state_compress(x)

print(x.shape)

## Hybrid RNN-Transformer Decoder

In [22]:
# x=rgb_ds
# print(x.shape)
from vlnce_baselines.models.decoder.hybrid_rnn_decoder import HybridRNNDecoder
model_config = config.MODEL
state_decoder = HybridRNNDecoder(
    model_config.TRANSFORMER.output_size,
    hidden_size=512, 
    num_layers=1,
    rnn_type=model_config.HYBRID_STATE_DECODER.rnn_type,
)

recurrent_hidden_states = torch.zeros(
    state_decoder.num_recurrent_layers,
    T,
    512,
    device=device,
)


masked= recurrent_hidden_states[0,:]*ma[:,0].unsqueeze(1).unsqueeze(0).to(device)

print(masked)

print(recurrent_hidden_states.shape)
print(state_decoder.num_recurrent_layers)


# not_done_masks = torch.zeros(config.NUM_PROCESSES, 1, device=device)

# out, rnn_hidden_states = state_decoder(x, recurrent_hidden_states, not_done_masks)

# print(out.shape)
# print(rnn_hidden_states.shape)

# # num_recurrent_layers = state_decoder.num_recurrent_layers

# # def unpack_hidden(hidden_states):
# #         hidden_states = (
# #             hidden_states[0 : num_recurrent_layers],
# #             hidden_states[num_recurrent_layers :],
# #         )

# #     return hidden_states

# # hidden_states = unpack_hidden(recurrent_hidden_states)
# # x, hidden_states = rnn(x,mask_hidden(hidden_states, masks.unsqueeze(0))
# # )

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0')
torch.Size([2, 5, 512])
2


In [None]:
import numpy as np
spatial_dim = int(np.sqrt(out.shape[1]))
out_1 = out.view(out.shape[0], spatial_dim, spatial_dim, -1).permute(0,3,1,2)

print(out_1.shape)
out_pos_embed =  position_embedding_2d(out.view(out.shape[0], spatial_dim, spatial_dim, -1).permute(0,3,1,2))
out_pos_embed = out_pos_embed.flatten(2).permute(0, 2, 1)
print(out_pos_embed.flatten(2).permute(0, 2, 1).shape)
print(out.shape)

In [None]:
import torch.nn as nn
from vlnce_baselines.models.transformer.transformer import ImageCrossModalEncoder, ImageEncoder_with_PosEncodings, SemMapEncoder_with_PosEncodings

image_encoder_trans = ImageEncoder_with_PosEncodings(model_config.IMAGE_CROSS_MODAL_ENCODER).to(device=device)

a_t = image_encoder_trans(out, w_t, None, enc_mask, out_pos_embed)

print(a_t.shape)

In [None]:
pooler = nn.Sequential(
    nn.AdaptiveAvgPool1d((1)),
    nn.Flatten()
)

pooled_out = pooler(a_t.permute(0,2,1))

print(pooled_out.shape)

print(out[-1].shape)

out = torch.cat((pooled_out, out[:,-1,:]), dim=1)

print(out.shape)

In [None]:
# print(not_done_masks)
# print(prev_actions_embedding[90])

# prev_action_seq = prev_actions_embedding

# prev_action_list = []

prev_action_list.append(prev_actions_embedding)
prev_action_seq = torch.stack(prev_action_list).transpose(0,1)

print(prev_action_seq.shape)


In [None]:
print(prev_action_seq.shape)

print(len(prev_action_list))

print(prev_actions.shape)

In [None]:
from vlnce_baselines.models.trasnformer.trasnformer import ActionDecoderTrasnformer

seq_len = prev_action_seq.shape[1]

prev_action_pad = ((prev_actions+1)*not_done_masks.to(device)).view(T,N)

pad= prev_action_pad[:,0:76]

prev_action_pad_mask = (pad != 1).unsqueeze(1).unsqueeze(2)

# print(prev_action_pad_mask)
mask_self_attention = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1).bool().to(device)
# print(torch.triu(torch.ones((seq_len, seq_len)), difagonal=1))
mask_self_attention = mask_self_attention  # (1, 1, seq_len, seq_len)
mask_self_attention = mask_self_attention.gt(0)  # (b_s, 1, seq_len, seq_len)

mask = prev_action_pad_mask & ~mask_self_attention

print("--------------")
print(~mask[1,:,75])

In [None]:
from vlnce_baselines.models.transformer.transformer import ActionDecoderTransformer
action_decoder = ActionDecoderTransformer(model_config.ACTION_DECODER_TRANFORMER).to(device=device)

In [None]:
out = action_decoder(prev_action_seq, w_t, rgb_ds, enc_att_mask_w=enc_mask, 
                                    enc_att_mask_i=None, device=device, pos_embed= pos_embed_out)

In [None]:
print(out.shape)

In [None]:
print(probs.shape)

In [None]:
from vlnce_baselines.models.transformer.transformer import ActionDecoder_SemMap
action_decoder_sem = ActionDecoder_SemMap(model_config.ACTION_DECODER_TRANFORMER).to(device=device)

In [None]:
out_sem = action_decoder_sem(prev_action_seq, w_t, i_t,s_t, enc_att_mask_w=mask, 
                                    enc_att_mask_i=None,enc_att_mask_s=sem_att_mask, device=device, pos_embed= pos_embed_out)

In [None]:
print(out_sem.shape)