In [1]:
import os, sys
sys.path.append('.')

from dataclasses import dataclass
from pathlib import Path

import draccus
import torch
import torch.distributed as dist
import tqdm
import wandb
from accelerate import PartialState
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.optim.lr_scheduler import StepLR
from transformers import get_linear_schedule_with_warmup

from vla.base_prompter import PurePromptBuilder
from vla.utils import PaddedCollatorForPosePrediction, runningLoss
from vla.action_tokenizer import RLbenchPoseTokenizer
from vla.dataset import RLbenchCotDataset
import numpy as np
import torch.nn.functional as F


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
save_path = "/media/lawrence/Work/checkpoints/vla-rl-ecot"
base_model_path = "/media/lawrence/Work/checkpoints/ecot-openvla-7b-bridge"
adapter_path = "adapter-tmp/nll_loss_cot+pick_described_object+e5+b8+lr-5e-05+lora-r16+dropout-0.0+q-4bit"

In [3]:
dataset_statistics: tuple = (np.array([-0.20173775, -0.36754665,  0.81396234, -3.14153998, -0.38798628, -3.14158631,  0. ]), np.array([0.41802976, 0.45118147, 1.47966564, 3.14159215, 0.30391057, 3.14157801, 1.])) # Min-Max normalization statistics

In [4]:
processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
action_tokenizer = RLbenchPoseTokenizer(processor.tokenizer,dataset_statistics)
test_data_path: Path = Path(f"./datasets/pick_described_object/test_data1.pt")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
quantization_config = BitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", #llm_int8_skip_modules = ['projector'],
        )
vla = AutoModelForVision2Seq.from_pretrained(
        base_model_path,
        torch_dtype=torch.bfloat16,
        attn_implementation="sdpa",
        quantization_config=quantization_config,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
        device_map = "cuda"
    )
vla.load_adapter(adapter_path)

Loading checkpoint shards: 100%|██████████| 3/3 [00:12<00:00,  4.27s/it]


In [6]:
testset = RLbenchCotDataset(
    test_data_path,
    action_tokenizer,
    processor.tokenizer,
    image_transform=processor.image_processor.apply_transform,
    prompt_builder_fn=PurePromptBuilder,
)
collator = PaddedCollatorForPosePrediction(
    processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
)
test_dataloader = DataLoader(
    testset,
    batch_size=2,
    sampler=None,
    collate_fn=collator,
    num_workers=1,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)


In [7]:
torch.cuda.empty_cache()

In [8]:
device_id = vla.device.type
test_nll_loss = []
vla.eval()
for test_idx, batch in enumerate(test_dataloader):
    with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
        output: CausalLMOutputWithPast = vla(
            input_ids=batch["input_ids"].to(device_id),
            attention_mask=batch["attention_mask"].to(device_id),
            pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
            labels=batch["labels"],
        )
        test_nll_loss_ = output.loss
        test_nll_loss.append(test_nll_loss_)
    if test_idx == 30:
        break

In [10]:
test_nll_loss

[tensor(0.7865, device='cuda:0'),
 tensor(0.8690, device='cuda:0'),
 tensor(0.7530, device='cuda:0'),
 tensor(0.6961, device='cuda:0'),
 tensor(0.3100, device='cuda:0'),
 tensor(0.2857, device='cuda:0'),
 tensor(0.3724, device='cuda:0'),
 tensor(0.3880, device='cuda:0'),
 tensor(0.3730, device='cuda:0'),
 tensor(0.2115, device='cuda:0'),
 tensor(0.6261, device='cuda:0'),
 tensor(0.6178, device='cuda:0'),
 tensor(0.6179, device='cuda:0'),
 tensor(0.6286, device='cuda:0'),
 tensor(0.4834, device='cuda:0'),
 tensor(0.3462, device='cuda:0'),
 tensor(0.3569, device='cuda:0'),
 tensor(0.3540, device='cuda:0'),
 tensor(0.2164, device='cuda:0'),
 tensor(0.2830, device='cuda:0'),
 tensor(0.5385, device='cuda:0'),
 tensor(0.5433, device='cuda:0'),
 tensor(0.4642, device='cuda:0'),
 tensor(0.4272, device='cuda:0'),
 tensor(0.3361, device='cuda:0'),
 tensor(0.2896, device='cuda:0'),
 tensor(0.3263, device='cuda:0'),
 tensor(0.3334, device='cuda:0'),
 tensor(0.0704, device='cuda:0'),
 tensor(0.5370