In [1]:

import torch
import torch.nn as nn
import tqdm
import wandb
from accelerate import PartialState
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training, PeftConfig
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import BertTokenizer, BertModel
import transformers
from transformers import AutoModel
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.optim.lr_scheduler import StepLR
from transformers import get_linear_schedule_with_warmup

from vla.base_prompter import PurePromptBuilder
from vla.utils import PaddedCollatorForPosePrediction, runningLoss
from vla.action_tokenizer import RLbenchPoseTokenizer
from vla.dataset import RLbenchCotDataset
import numpy as np
import torch.nn.functional as F
from typing import Callable, Dict, Sequence, Tuple
import numpy as np
from PIL import Image

from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import ArmActionMode, JointVelocity, JointPosition, EndEffectorPoseViaPlanning, EndEffectorPoseViaIK


from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.environment import Environment
from rlbench.observation_config import ObservationConfig, CameraConfig
# from rlbench.tasks.pick_described_object import PickDescribedObject
from rlbench.tasks import PutGroceriesInCupboard, PickAndLift, StackBlocks, PlaceHangerOnRack, PickDescribedObject, TakeLidOffSaucepan, SetTheTable, PutGroceriesInCupboard
from scipy.spatial.transform import Rotation as R
from matplotlib import pyplot as plt
from PIL import Image
from pyrep.const import RenderMode

from torch.utils.data import DataLoader, random_split


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ActorCriticModel(nn.Module):
    def __init__(self, state_dim, action_dim, dropout_rate, device='cuda'):
        super(ActorCriticModel, self).__init__()
        self.device = device

        # Define the Actor as a nested class
        class Actor(nn.Module):
            def __init__(self):
                super(Actor, self).__init__()
                self.net = nn.Sequential(
                    nn.Linear(state_dim, 4096),
                    nn.Linear(4096, 4096),
                    nn.LeakyReLU(),
                    nn.Dropout(dropout_rate),
                    nn.Linear(4096, 2048),
                    nn.LeakyReLU(),
                    nn.Dropout(dropout_rate),
                    nn.Linear(2048, 2048),
                    nn.LeakyReLU(),
                    nn.Dropout(dropout_rate),
                    nn.Linear(2048, 1024),
                    nn.LeakyReLU(),
                    nn.Dropout(dropout_rate),
                    nn.Linear(1024, action_dim),
                )

            def forward(self, state):
                action_probs = self.net(state)
                # last 2 logits for gripper open/close
                pos_logits = action_probs[:, :300]
                pos_logits = pos_logits.reshape(-1, 3, 100)
                pos_logprob = F.log_softmax(pos_logits, dim=-1)
                rot_logits = action_probs[:, 300:600]
                rot_logits = rot_logits.reshape(-1, 3, 100)
                rot_logprob = F.log_softmax(rot_logits, dim=-1)
                open_logits = action_probs[:, 600:].reshape(-1, 1, 2)
                open_logprob = F.log_softmax(open_logits, dim=-1)
                return pos_logprob, rot_logprob, open_logprob

        # Define the Critic as a nested class
        class Critic(nn.Module):
            def __init__(self):
                super(Critic, self).__init__()
                self.net = nn.Sequential(
                    nn.Linear(state_dim, 4096),
                    nn.Linear(4096, 2048),
                    nn.LeakyReLU(),
                    nn.Dropout(dropout_rate),
                    nn.Linear(2048, 1024),
                    nn.LeakyReLU(),
                    nn.Dropout(dropout_rate),
                    nn.Linear(1024, 512),
                    nn.LeakyReLU(),
                    nn.Dropout(dropout_rate),
                    nn.Linear(512, 256),
                    nn.LeakyReLU(),
                    nn.Dropout(dropout_rate),
                    nn.Linear(256, 1),
                    # nn.Tanh()
                )

            def forward(self, state):
                state_value = self.net(state)
                return state_value

        # Initialize actor and critic
        self.actor = Actor()
        self.critic = Critic()

        self.init_weight()
        self.to(device)

    def forward(self, state):
        # Actor forward pass
        pos_logprob, rot_logprob, open_logprob = self.actor(state)

        # Critic forward pass
        state_value = self.critic(state)

        return pos_logprob, rot_logprob, open_logprob, state_value

    def init_weight(self):
        for layer in self.actor.net.children():
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)
                nn.init.constant_(layer.bias, 0)

        for layer in self.critic.net.children():
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)
                nn.init.constant_(layer.bias, 0)

In [3]:
data_path = "datasets/pick_described_object_replay1/data.pt"
data = torch.load(data_path)

In [4]:
data.keys()

dict_keys(['images', 'instructions', 'grippers', 'items', 'objects', 'targets', 'stages', 'actions', 'rewards', 'dones', 'next_images', 'next_grippers'])

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased").to(device)
#  theaiinstitute/theia-base-patch16-224-cdiv
# theaiinstitute/theia-tiny-patch16-224-cdiv
image_encoder = AutoModel.from_pretrained("theaiinstitute/theia-tiny-patch16-224-cddsv", trust_remote_code=True,).to(device)


# Function to encode the descriptor using BERT
def encode_descriptor(descriptor):
    inputs = tokenizer(descriptor, return_tensors="pt", padding=True, truncation=True, max_length=10)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = bert_model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()

def encode_image(image):
    assert image.shape == (224, 224, 3)
    with torch.no_grad():
        image_encoding = image_encoder.forward_feature(image)
    return image_encoding.flatten().cpu().numpy()

def pose_processor(pose):
    #from quat to euler
    pos = pose[:3]
    pos_lower_bound = np.array([-0.2, -0.35, 0.752])
    pos_upper_bound = np.array([0.5, 0.35, 1.3])
    pos = np.clip(pos, pos_lower_bound, pos_upper_bound-1e-8)
    pos_bins = np.linspace(pos_lower_bound, pos_upper_bound, 101)
    x_idx = np.digitize(pos[0], pos_bins[:,0]) -1 
    y_idx = np.digitize(pos[1], pos_bins[:,1]) -1
    z_idx = np.digitize(pos[2], pos_bins[:,2]) -1
    x = np.zeros(100)
    x[x_idx] = 1
    y = np.zeros(100)
    y[y_idx] = 1
    z = np.zeros(100)
    z[z_idx] = 1
    pos = np.concatenate([x,y,z])

    euler = pose[3:6]
    euler[0] = euler[0] + np.pi if euler[0] < 0 else euler[0] - np.pi
    euler_lower_bound = np.array([-np.pi/4, -np.pi/4, -np.pi/2])
    euler_upper_bound = np.array([np.pi/4, np.pi/4, np.pi/2])
    euler = np.clip(euler, euler_lower_bound, euler_upper_bound-1e-8)
    euler_bins = np.linspace(euler_lower_bound, euler_upper_bound, 101)
    rx_idx = np.digitize(euler[0], euler_bins[:,0]) -1
    ry_idx = np.digitize(euler[1], euler_bins[:,1]) -1
    rz_idx = np.digitize(euler[2], euler_bins[:,2]) -1
    rx = np.zeros(100)
    rx[rx_idx] = 1
    ry = np.zeros(100)
    ry[ry_idx] = 1
    rz = np.zeros(100)
    rz[rz_idx] = 1
    
    rot = np.concatenate([rx,ry,rz])
    open_state = np.zeros(2)
    open_state[int(pose[-1])] = 1

    return np.concatenate([pos, rot, open_state])



Some weights of ViTModel were not initialized from the model checkpoint at facebook/deit-tiny-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
class ACDataLoader(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data['images'])

    def __getitem__(self, idx):
        img = np.array(Image.open(self.data['images'][idx]))
        img_encoding = encode_image(img)
        descriptor = self.data['instructions'][idx]
        descriptor_encoding = encode_descriptor(descriptor)
        pose = self.data['grippers'][idx]
        pose_encoding = pose_processor(pose)
        state = torch.tensor(np.concatenate([img_encoding, descriptor_encoding, pose_encoding]))
        action = torch.tensor(self.data['actions'][idx])
        return dict(states = state.float(), actions = action.float())

In [20]:
trainset = ACDataLoader(data)

In [21]:
train_size = int(0.9 * len(trainset))
test_size = len(trainset) - train_size
train_dataset, test_dataset = random_split(trainset, [train_size, test_size])


In [22]:
class Collator:
    def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        states, actions = tuple([instance[key] for instance in instances] for key in ("states", "actions"))
        states = torch.stack(states)
        actions = torch.stack(actions)
        output = dict(
            states = states,
            actions = actions
        )
        return output

In [23]:
collator = Collator()
train_dataloader = DataLoader(
        train_dataset,
        batch_size=16,
        shuffle=False,
        # sampler=sampler,
        collate_fn=collator,
        num_workers=0, 
    )

test_dataloader = DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=False,
    # sampler=sampler,
    collate_fn=collator,
    num_workers=0, 
)

In [24]:
batch = next(iter(train_dataloader))

In [25]:
bert_encoding_dim = 768  # BERT base model output dimension
image_encoding_dim = 37632
action_dim =  6*100 + 2 # 6 for endeffector pose + 1 for gripper

state_dim = image_encoding_dim + action_dim + bert_encoding_dim
dropout_rate = 0.2

acmodel = ActorCriticModel(state_dim, action_dim, dropout_rate)

In [26]:
learning_rate = 1e-4
weight_decay = 0.01

In [27]:
optimizer = AdamW(acmodel.actor.parameters(), lr=learning_rate,weight_decay=weight_decay)

In [28]:
batch['actions']

tensor([[ 3.1510e-01, -2.2684e-01,  9.0183e-01,  3.1356e+00,  3.5455e-03,
         -1.8505e+00,  0.0000e+00],
        [-2.0000e-01, -3.2500e-01,  1.2000e+00,  3.1416e+00, -1.3101e-14,
          3.1416e+00,  1.0000e+00],
        [ 6.0139e-03,  2.0162e-02,  1.1981e+00, -2.8774e+00,  5.0894e-01,
          1.1299e+00,  0.0000e+00],
        [ 9.3510e-02,  4.2717e-02,  9.8968e-01,  2.9928e+00,  2.7490e-01,
          1.3478e+00,  0.0000e+00],
        [-2.0000e-01, -3.2500e-01,  1.2000e+00,  3.1416e+00, -1.3101e-14,
          3.1416e+00,  1.0000e+00],
        [ 2.5556e-01,  3.2610e-01,  8.2499e-01,  3.1416e+00,  2.6340e-05,
          2.8547e+00,  0.0000e+00],
        [-1.8882e-02,  2.3190e-01,  8.6805e-01, -2.8813e+00,  7.9399e-01,
          2.8683e-01,  0.0000e+00],
        [-2.0000e-01, -3.2500e-01,  1.2000e+00,  3.1416e+00, -1.3101e-14,
          3.1416e+00,  1.0000e+00],
        [-2.0126e-03,  3.3283e-01,  1.2427e+00,  2.6184e+00,  4.5629e-01,
         -5.4830e-01,  0.0000e+00],
        [-

In [29]:
def get_action_loss(pos_logprob,rot_logprob,open_logprob, action_gt):
    pos_gt = action_gt[:,:3]
    pos_lower_bound = np.array([-0.2, -0.35, 0.752])
    pos_upper_bound = np.array([0.5, 0.35, 1.3])
    pos_bins = np.linspace(pos_lower_bound, pos_upper_bound, 101)
    pos_bin_centers = (pos_bins[:-1] + pos_bins[1:]) / 2
    pos_pred = (pos_logprob.exp()@torch.tensor(pos_bin_centers).to(torch.float32).to(device))[:,range(3),range(3)]
    assert pos_pred.shape == pos_gt.shape, f"{pos_pred.shape} != {pos_gt.shape}"
    pos_loss = F.mse_loss(pos_pred,pos_gt.to(device))

    euler_gt = action_gt[:,3:6]
    torch.where(euler_gt[:,0:1]>0, euler_gt[:,0:1] - torch.pi, euler_gt[:,0:1] + torch.pi) 
    euler_lower_bound = np.array([-np.pi/4, -np.pi/4, -np.pi/2])
    euler_upper_bound = np.array([np.pi/4, np.pi/4, np.pi/2])
    euler_bins = np.linspace(euler_lower_bound, euler_upper_bound, 101)
    euler_bin_centers = (euler_bins[:-1] + euler_bins[1:]) / 2
    euler_pred = (rot_logprob.exp()@torch.tensor(euler_bin_centers).to(torch.float32).to(device))[:,range(3),range(3)]
    euler_loss = F.mse_loss(euler_pred,euler_gt.to(device))

    open_gt = F.one_hot(action_gt[:,6:7].to(torch.int64).view(-1), num_classes=2).to(device)
    open_loss = F.cross_entropy(open_logprob.exp().squeeze(1).to(torch.float32),open_gt.to(device).to(torch.float32))
    
    return (pos_loss * 0.6 + euler_loss * 0.3 + open_loss * 0.1).to(torch.float16)

In [30]:
euler_lower_bound = np.array([-np.pi/4, -np.pi/4, -np.pi/2])
euler_upper_bound = np.array([np.pi/4, np.pi/4, np.pi/2])
euler_bins = np.linspace(euler_lower_bound, euler_upper_bound, 101)
euler_bin_centers = (euler_bins[:-1] + euler_bins[1:]) / 2

In [39]:
# scaler = torch.cuda.amp.GradScaler()
acmodel.actor.train()
epochs = 5
total_action_loss = []
with tqdm.tqdm(total=epochs*train_dataloader.__len__() , leave=False) as progress:
    for epoch in range(epochs):
        for step_idx, batch in enumerate(train_dataloader):
            # with torch.autocast("cuda"):
            inputs = batch['states'].to(torch.float32).to(device)
            outputs = acmodel.actor(inputs)
            pos_logprob,rot_logprob,open_logprob = outputs
            action_loss =get_action_loss(pos_logprob,rot_logprob,open_logprob, batch['actions'])
            optimizer.zero_grad()
            action_loss.backward()
            optimizer.step()
            print(action_loss)
            if step_idx % 20 == 0:
                action_losses = []
                for batch in test_dataloader:
                    with torch.no_grad():
                        inputs = batch['states'].to(torch.float32).to(device)
                        outputs = acmodel.actor(inputs)
                        pos_logprob,rot_logprob,open_logprob = outputs
                        action_loss =get_action_loss(pos_logprob,rot_logprob,open_logprob, batch['actions'])
                        action_losses.append(action_loss.cpu().numpy())
                avg_loss = np.mean(action_losses)
                total_action_loss.append(avg_loss)
                print(f"avg_loss {avg_loss}")
        progress.update()


  0%|          | 0/99 [00:00<?, ?it/s]

tensor(1.2314, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
avg_loss 1.1630859375
tensor(1.1963, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.2852, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.1416, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.3369, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.6396, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(0.9854, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.3174, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.2568, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.2686, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.3750, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.1191, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.1426, dev

  1%|          | 1/99 [00:50<1:22:25, 50.47s/it]

tensor(1.5107, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.2939, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.2314, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
avg_loss 1.162109375
tensor(1.1963, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.2852, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.1416, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.3369, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.6396, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(0.9854, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.3174, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.2568, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.2686, device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)
tensor(1.3750, devi

In [None]:
pos_logprob.dtype

torch.float32