In [1]:
import os, sys
sys.path.append('.')

from dataclasses import dataclass
from pathlib import Path

import draccus
import torch
import torch.distributed as dist
import tqdm
import wandb
from accelerate import PartialState
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from transformers.modeling_outputs import CausalLMOutputWithPast

from finetune.finetune_config import FinetuneConfig
from vla.base_prompter import PurePromptBuilder
from vla.utils import PaddedCollatorForActionPrediction
from vla.action_tokenizer import ActionTokenizer, RLbenchActionTokenizer
from vla.dataset import save_dataset_statistics, RLbenchDataset
import numpy as np
import torch.nn.functional as F

2024-07-13 19:11:14.780845: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-13 19:11:14.780954: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-13 19:11:14.824436: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-13 19:11:14.917050: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
cfg = FinetuneConfig()

In [4]:
print(f"Fine-tuning OpenVLA Model `{cfg.vla_path}` on `{cfg.dataset_name}`")

# [Validate] Ensure GPU Available & Set Device / Distributed Context
assert torch.cuda.is_available(), "Fine-tuning assumes at least one GPU is available!"
distributed_state = PartialState()
torch.cuda.set_device(device_id := distributed_state.local_process_index)
torch.cuda.empty_cache()

# Configure Unique Experiment ID & Log Directory
exp_id = (
    f"{cfg.vla_path.split('/')[-1]}+{cfg.dataset_name}"
    f"+b{cfg.batch_size * cfg.grad_accumulation_steps}"
    f"+lr-{cfg.learning_rate}"
)
if cfg.use_lora:
    exp_id += f"+lora-r{cfg.lora_rank}+dropout-{cfg.lora_dropout}"
if cfg.use_quantization:
    exp_id += "+q-4bit"

# Start =>> Build Directories
run_dir, adapter_dir = cfg.run_root_dir / exp_id, cfg.adapter_dir / exp_id
os.makedirs(run_dir, exist_ok=True)

# Quantization Config =>> only if LoRA fine-tuning
quantization_config = None
if cfg.use_quantization:
    assert cfg.use_lora, "Quantized training only supported for LoRA fine-tuning!"
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", #llm_int8_skip_modules = ['projector'],
    )

# Load OpenVLA Processor and Model using HF AutoClasses
processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True)
vla = AutoModelForVision2Seq.from_pretrained(
    cfg.vla_path_q,
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa",
    quantization_config=quantization_config,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    device_map = "auto"
)


Fine-tuning OpenVLA Model `/media/lawrence/Work/checkpoints/openvla-7b` on `test1`


Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


In [8]:
# Device Placement =>> note that BitsAndBytes automatically handles for quantized training
if cfg.use_quantization:
    vla = prepare_model_for_kbit_training(vla)
else:
    vla = vla.to(device_id)

# [LoRA] Wrap Model w/ PEFT `LoraConfig` =>> by default we set `target_modules=all-linear`
if cfg.use_lora:
    lora_config = LoraConfig(
        r=cfg.lora_rank,
        lora_alpha=min(cfg.lora_rank, 16),
        lora_dropout=cfg.lora_dropout,
        target_modules="all-linear",
        init_lora_weights="gaussian",
    )
    vla = get_peft_model(vla, lora_config)
    vla.print_trainable_parameters()

# # Wrap VLA in PyTorch DDP Wrapper for Multi-GPU Training
# vla = DDP(vla, device_ids=[device_id], find_unused_parameters=True, gradient_as_bucket_view=True)

# Create Optimizer =>> note that we default to a simple constant learning rate!
trainable_params = [param for param in vla.parameters() if param.requires_grad]
optimizer = AdamW(trainable_params, lr=cfg.learning_rate)


trainable params: 13,853,536 || all params: 7,555,090,720 || trainable%: 0.1834


In [9]:
# Create Action Tokenizer
action_tokenizer = RLbenchActionTokenizer(processor.tokenizer)

train_dataset = RLbenchDataset(
    cfg.train_data_path,
    action_tokenizer,
    processor.tokenizer,
    image_transform=processor.image_processor.apply_transform,
    prompt_builder_fn=PurePromptBuilder,
)

test_dataset = RLbenchDataset(
    cfg.test_data_path,
    action_tokenizer,
    processor.tokenizer,
    image_transform=processor.image_processor.apply_transform,
    prompt_builder_fn=PurePromptBuilder,
)

valid_dataset = RLbenchDataset(
    cfg.valid_data_path,
    action_tokenizer,
    processor.tokenizer,
    image_transform=processor.image_processor.apply_transform,
    prompt_builder_fn=PurePromptBuilder,
)

# [Important] Save Dataset Statistics =>> used to de-normalize actions for inference!
if distributed_state.is_main_process:
    save_dataset_statistics(train_dataset.dataset_statistics, run_dir)

# Create Collator and DataLoader
collator = PaddedCollatorForActionPrediction(
    processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=cfg.batch_size,
    shuffle=True,
    sampler=None,
    collate_fn=collator,
    num_workers=0,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=cfg.batch_size,
    shuffle=True,
    sampler=None,
    collate_fn=collator,
    num_workers=0,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)

valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=cfg.batch_size,
    shuffle=True,
    sampler=None,
    collate_fn=collator,
    num_workers=0,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)


In [10]:
step_idx, batch = next(enumerate(train_dataloader))

In [11]:
vla.train()
optimizer.zero_grad()

step_idx, batch = next(enumerate(train_dataloader))
with torch.autocast("cuda", dtype=torch.bfloat16):
    output: CausalLMOutputWithPast = vla(
        input_ids=batch["input_ids"].to(device_id),
        attention_mask=batch["attention_mask"].to(device_id),
        pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
        labels=batch["labels"],
    )
    loss = output.loss

# Backward!
# loss.backward()

# Compute Accuracy and L1 Loss for Logging
action_logits = output.logits[:, vla.vision_backbone.featurizer.patch_embed.num_patches : -1]#TODO: why num_patches:-1?
action_preds = action_logits.argmax(dim=2)
action_gt = batch["labels"][:, 1:].to(action_preds.device)
mask = action_gt > action_tokenizer.action_token_begin_idx

# masked_logits = action_logits[:,mask,action_tokenizer.action_token_begin_idx:]
masked_logits = action_logits[mask][:,action_tokenizer.action_token_begin_idx+1:processor.tokenizer.vocab_size].view(2,7,-1)



# Compute Accuracy
correct_preds = (action_preds == action_gt) & mask
action_accuracy = correct_preds.sum().float() / mask.sum().float()

# # Compute L1 Loss on Predicted (Continuous) Actions
# continuous_actions_pred = torch.tensor(
#     action_tokenizer.decode_token_ids_to_actions(action_preds[mask].cpu().numpy())
# )
# continuous_actions_gt = torch.tensor(
#     action_tokenizer.decode_token_ids_to_actions(action_gt[mask].cpu().numpy())
# )
# action_l1_loss = torch.nn.functional.l1_loss(continuous_actions_pred, continuous_actions_gt)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


In [39]:
action_tokenizer.tokenizer.vocab_size - action_preds[mask] - 1

tensor([ 86, 137, 177, 146, 161, 119, 255,  77, 173, 104, 198,  49, 127, 255],
       device='cuda:0')

In [43]:
(352 - action_logits[0][mask[0],action_tokenizer.action_token_begin_idx:processor.tokenizer.vocab_size-1]).shape

torch.Size([7, 352])

In [25]:
masked_logits

tensor([[[-11.8125, -11.7500, -12.9375,  ...,  -3.2500,  -4.4375,  -0.0311],
         [-10.6875,  -9.1875,  -9.4375,  ...,   0.0542,  -0.7891,   1.5078],
         [ -8.0625,  -7.2500,  -8.6875,  ...,  -4.7812,  -5.4375,  -2.1406],
         ...,
         [ -9.3125,  -5.8125,  -5.5000,  ...,   6.1562,   2.8594,  16.7500],
         [ -6.6875,  -6.0312,  -6.8438,  ...,   4.7812,   5.9688,   3.5312],
         [ -7.4062,  -7.6250,  -8.5625,  ...,   4.4062,   4.9375,   8.1875]],

        [[ -9.8750, -13.4375,  -9.0000,  ...,  -3.6719,  -1.3359,   6.1250],
         [-10.1875,  -9.3125, -11.0000,  ...,  -3.5156,  -5.2188,  -1.6172],
         [ -5.2812,  -3.1250,  -4.5938,  ...,  -2.8906,   0.8789,   2.7031],
         ...,
         [ -3.7969,  -4.0938,  -1.2188,  ...,   8.8750,   5.2188,  11.8750],
         [ -5.6250,  -4.2812,  -7.5000,  ...,   7.2500,   3.6719,   4.9062],
         [ -9.1250,  -9.0000,  -9.5000,  ...,   2.0312,   2.7031,   4.3438]]],
       device='cuda:0', grad_fn=<ViewBackwar

In [12]:
masked_logits.shape

torch.Size([2, 7, 352])

In [24]:
pred_action = action_tokenizer.output_logit_to_continous_action(masked_logits,batch['actions'])

In [27]:
F.l1_loss(pred_action, batch['actions'].to(pred_action.device))

tensor(0.7198, device='cuda:0', grad_fn=<MeanBackward0>)

In [9]:
masked_logits[:,1,:50].dtype

torch.float32

In [16]:
x = (F.softmax(masked_logits[:, 0,:50], dim = 1)@torch.tensor(action_tokenizer.x_bin_centers,dtype=torch.float32).to('cuda')).unsqueeze(1)

In [17]:
y = (F.softmax(masked_logits[:, 1, 50:150], dim = 1)@torch.tensor(action_tokenizer.y_bin_centers,dtype=torch.float32).to('cuda')).unsqueeze(1)

In [18]:
z = (F.softmax(masked_logits[:,2,150:250], dim = 1)@torch.tensor(action_tokenizer.z_bin_centers,dtype=torch.float32).to('cuda')).unsqueeze(1)

In [19]:
rot = F.softmax(masked_logits[:,3:-1,250:350], dim = 2)@torch.tensor(action_tokenizer.rot_bin_centers,dtype=torch.float32).to('cuda')

In [20]:
g = (F.softmax(masked_logits[:,-1,350:], dim = 1) @ torch.tensor([0,1],dtype=torch.float32).to('cuda')).unsqueeze(1)

In [21]:
torch.cat([x,y,z,rot,g],dim=1)

tensor([[ 0.1625, -0.0087,  1.3602, -0.5318,  0.0773, -0.3480,  0.9033],
        [ 0.1935,  0.3486,  0.6879, -1.8607,  1.6878, -1.1623,  0.8819]],
       device='cuda:0', grad_fn=<CatBackward0>)

In [22]:
batch['actions']

tensor([[ 2.3545e-01,  1.4843e-01,  9.3349e-01, -3.1400e+00,  1.8622e-03,
          7.7924e-01,  1.0000e+00],
        [ 2.3545e-01,  1.4843e-01,  9.3349e-01, -3.1400e+00,  1.8622e-03,
          7.7924e-01,  1.0000e+00]])

In [None]:
action_logits.shape

torch.Size([2, 28, 32064])

In [None]:
processor.tokenizer.vocab_size

32000

In [None]:
action_tokenizer.action_token_begin_idx

31647

In [None]:
action_tokenizer.action_token_begin_idx

31647

In [None]:
import torch.nn.functional as F
action_logits
action1 = F.softmax(action_logits[:,:,-256:], dim=-1)

In [None]:
action_gt

tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         31691, 31729, 31854, 31898, 31947, 31961, 31998,     2],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 31671,
         31712, 31835, 31898, 31948, 31992, 31999,     2,  -100]],
       device='cuda:0')

In [None]:
action1

tensor([[[7.8702e-11, 6.2277e-09, 1.8299e-10,  ..., 5.8241e-13,
          1.0021e-14, 1.8721e-14],
         [4.2377e-06, 8.5785e-05, 6.4618e-06,  ..., 4.0644e-11,
          4.1049e-10, 5.9137e-11],
         [7.2699e-04, 2.7969e-05, 1.5238e-04,  ..., 2.3909e-09,
          2.0654e-08, 4.1961e-09],
         ...,
         [1.7208e-08, 1.7372e-10, 1.1367e-12,  ..., 1.0394e-10,
          1.5449e-11, 7.8904e-12],
         [5.3094e-16, 9.9611e-14, 3.9790e-12,  ..., 2.4206e-18,
          5.8647e-17, 4.2907e-17],
         [1.9505e-06, 9.4124e-08, 1.5159e-07,  ..., 2.0914e-10,
          1.6805e-10, 4.8146e-11]],

        [[3.0287e-07, 3.1740e-07, 1.7730e-09,  ..., 3.4568e-11,
          1.1946e-11, 5.6431e-12],
         [3.2915e-07, 2.0081e-07, 1.5703e-06,  ..., 1.7571e-12,
          3.9596e-12, 7.8746e-12],
         [1.2685e-02, 6.3154e-04, 1.3927e-04,  ..., 1.1700e-07,
          2.1858e-07, 4.1717e-08],
         ...,
         [7.4718e-17, 7.0191e-17, 6.4411e-17,  ..., 7.2376e-21,
          8.461