In [1]:
import json
import random
from PIL import Image
from torch.utils.data import Dataset, Sampler
import torch
from torchvision import transforms
from transformers import PreTrainedTokenizer
from typing import List
import torch
import torch.nn as nn
from transformers import ViTModel, GPT2Config, GPT2LMHeadModel
import torch
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from transformers import get_cosine_schedule_with_warmup
from tqdm import tqdm
from transformers import GPT2TokenizerFast

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE = 16
EPOCHS = 6
LR = 3e-4
WARMUP_STEPS = 500
MAX_STEPS = 30000
tokenizer = GPT2TokenizerFast.from_pretrained(
    "/content/drive/MyDrive/Colab Notebooks/RT-2/rt2_tokenizer"
)

print("Tokenizer size:", len(tokenizer))
print("EOS ID:", tokenizer.eos_token_id)

pad_token_id = tokenizer.pad_token_id

In [3]:
class RobotDataset(Dataset):
    def __init__(self, jsonl_path: str, tokenizer: PreTrainedTokenizer, image_size=224):
        self.samples = []
        self.tokenizer = tokenizer

        with open(jsonl_path, "r", encoding="utf-8") as f:
            for line in f:
                self.samples.append(json.loads(line))

        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

        self.act_start_id = tokenizer.convert_tokens_to_ids("<act_start>")
        self.eos_id = tokenizer.eos_token_id

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        image = Image.open(sample["image"]).convert("RGB")
        image = self.transform(image)

        instr_ids = self.tokenizer.encode(
            sample["instruction"],
            add_special_tokens=False
        )
        input_ids = (
            [self.tokenizer.bos_token_id]
            + instr_ids
            + [self.act_start_id]
        )

        target_ids = sample["action_tokens"] + [self.eos_id]

        return {
            "image": image,
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "labels": torch.tensor(target_ids, dtype=torch.long),
            "type": "robot"
        }


In [4]:
class VLMDataset(Dataset):
    def __init__(self, jsonl_path: str, tokenizer: PreTrainedTokenizer, image_size=224):
        self.samples = []
        self.tokenizer = tokenizer

        with open(jsonl_path, "r", encoding="utf-8") as f:
            for line in f:
                self.samples.append(json.loads(line))

        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        image = Image.open(sample["image"]).convert("RGB")
        image = self.transform(image)

        input_ids = [self.tokenizer.bos_token_id]
        target_ids = self.tokenizer.encode(
            sample["caption"],
            add_special_tokens=False
        ) + [self.tokenizer.eos_token_id]

        return {
            "image": image,
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "labels": torch.tensor(target_ids, dtype=torch.long),
            "type": "vlm"
        }


In [5]:
def get_robot_ratio(epoch):
    if epoch < 2:
        return 0.6
    elif epoch < 4:
        return 0.7
    else:
        return 0.8


In [6]:
class MixedBatchSampler(Sampler[List[int]]):
    def __init__(
        self,
        robot_len: int,
        vlm_len: int,
        batch_size: int,
        robot_ratio: float = 0.7
    ):
        self.robot_len = robot_len
        self.vlm_len = vlm_len
        self.batch_size = batch_size

        self.robot_bs = int(batch_size * robot_ratio)
        self.vlm_bs = batch_size - self.robot_bs

    def __iter__(self):
        robot_indices = list(range(self.robot_len))
        vlm_indices = list(range(self.vlm_len))

        random.shuffle(robot_indices)
        random.shuffle(vlm_indices)

        r_ptr, v_ptr = 0, 0

        while r_ptr < self.robot_len and v_ptr < self.vlm_len:
            batch = []

            batch.extend(robot_indices[r_ptr:r_ptr+self.robot_bs])
            batch.extend(
                [i + self.robot_len for i in vlm_indices[v_ptr:v_ptr+self.vlm_bs]]
            )

            random.shuffle(batch)

            yield batch

            r_ptr += self.robot_bs
            v_ptr += self.vlm_bs

    def __len__(self):
        return min(
            self.robot_len // self.robot_bs,
            self.vlm_len // self.vlm_bs
        )


In [7]:
def collate_fn(batch, pad_token_id, eos_token_id):
    images = torch.stack([x["image"] for x in batch])

    input_ids = [x["input_ids"] for x in batch]
    labels = [x["labels"] for x in batch]

    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids,
        batch_first=True,
        padding_value=eos_token_id
    )

    labels = torch.nn.utils.rnn.pad_sequence(
        labels,
        batch_first=True,
        padding_value=-100
    )

    attention_mask = (input_ids != eos_token_id).long()

    return {
        "images": images,
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }


In [20]:
import torch
import torch.nn as nn
from transformers import ViTModel, GPT2Config, GPT2LMHeadModel
from torch.cuda.amp import autocast


class MiniRT2(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        vision_model_name: str = "google/vit-base-patch16-224",
        d_model: int = 768,
        n_layer: int = 8,
        n_head: int = 8,
        max_seq_len: int = 128,
    ):
        super().__init__()

        self.vision_encoder = ViTModel.from_pretrained(vision_model_name)
        self.vision_encoder.eval()
        for p in self.vision_encoder.parameters():
            p.requires_grad = False

        vision_dim = self.vision_encoder.config.hidden_size
        self.vision_proj = nn.Linear(vision_dim, d_model)

        config = GPT2Config(
            vocab_size=vocab_size,
            n_embd=d_model,
            n_layer=n_layer,
            n_head=n_head,
            n_positions=max_seq_len + 256,
            bos_token_id=None,
            eos_token_id=None,
        )

        self.decoder = GPT2LMHeadModel(config)

        self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)

    def forward(
        self,
        images,
        input_ids,
        attention_mask=None,
        labels=None
    ):
        """
        images: (B, 3, 224, 224)
        input_ids: (B, T)
        attention_mask: (B, T)
        labels: (B, T)
        """

        B = images.size(0)

        with autocast(enabled=False):
            images = images.float()
            vision_outputs = self.vision_encoder(pixel_values=images)
            vision_embeds = self.vision_proj(
                vision_outputs.last_hidden_state
            )

        token_embeds = self.decoder.transformer.wte(input_ids)

        inputs_embeds = torch.cat(
            [vision_embeds, token_embeds], dim=1
        )

        if attention_mask is not None:
            vision_mask = torch.ones(
                (B, vision_embeds.size(1)),
                device=attention_mask.device,
                dtype=attention_mask.dtype
            )
            attention_mask = torch.cat(
                [vision_mask, attention_mask], dim=1
            )

        if labels is not None:
            vision_label_pad = torch.full(
                (B, vision_embeds.size(1)),
                -100,
                device=labels.device,
                dtype=labels.dtype
            )
            labels = torch.cat(
                [vision_label_pad, labels], dim=1
            )

        outputs = self.decoder(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            use_cache=False,
            return_dict=True
        )

        if labels is not None:
            shift_logits = outputs.logits[:, :-1, :]
            shift_labels = labels[:, 1:]

            min_len = min(
                shift_logits.size(1),
                shift_labels.size(1)
            )

            shift_logits = shift_logits[:, :min_len, :].contiguous()
            shift_labels = shift_labels[:, :min_len].contiguous()

            loss = self.loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )

            outputs.loss = loss

        return outputs


In [21]:
robot_ds = RobotDataset(
    jsonl_path="/content/drive/MyDrive/Colab Notebooks/RT-2/output/train_sequences_fixed.jsonl",
    tokenizer=tokenizer
)

vlm_ds = VLMDataset(
    jsonl_path="/content/drive/MyDrive/Colab Notebooks/RT-2/vlm_dataset.jsonl",
    tokenizer=tokenizer
)


In [22]:
class CombinedDataset(torch.utils.data.Dataset):
    def __init__(self, robot_ds, vlm_ds):
        self.robot_ds = robot_ds
        self.vlm_ds = vlm_ds
        self.robot_len = len(robot_ds)

    def __len__(self):
        return self.robot_len + len(self.vlm_ds)

    def __getitem__(self, idx):
        if idx < self.robot_len:
            return self.robot_ds[idx]
        else:
            return self.vlm_ds[idx - self.robot_len]


In [23]:
combined_ds = CombinedDataset(robot_ds, vlm_ds)


In [None]:
model = MiniRT2(
    vocab_size=len(tokenizer),
    n_layer=8,
    n_head=8,
    d_model=768
)

# Ensure vocab alignment
model.decoder.resize_token_embeddings(len(tokenizer))

model = model.to(DEVICE)
model.vision_encoder = model.vision_encoder.to(DEVICE)


In [25]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LR,
    weight_decay=0.01
)


In [26]:
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=WARMUP_STEPS,
    num_training_steps=MAX_STEPS
)


In [None]:
scaler = GradScaler()


In [None]:
model.train()
global_step = 0

for epoch in range(EPOCHS):

    robot_ratio = get_robot_ratio(epoch)
    print(f"Epoch {epoch+1}: robot_ratio={robot_ratio}")

    sampler = MixedBatchSampler(
        robot_len=len(robot_ds),
        vlm_len=len(vlm_ds),
        batch_size=BATCH_SIZE,
        robot_ratio=robot_ratio
    )

    loader = DataLoader(
    combined_ds,
    batch_sampler=sampler,
    collate_fn=lambda b: collate_fn(
        b,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id
    ),
    num_workers=0
)


    pbar = tqdm(loader, desc=f"Epoch {epoch+1}")

    for batch in pbar:
        if global_step >= MAX_STEPS:
            break

        images = batch["images"].to(DEVICE)
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        optimizer.zero_grad()

        with autocast():
            outputs = model(
                images=images,
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            loss = outputs.loss

        scaler.scale(loss).backward()

        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        global_step += 1

        pbar.set_postfix({
            "loss": f"{loss.item():.4f}",
            "step": global_step
        })

    torch.save({
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "step": global_step
    }, f"checkpoint_epoch_{epoch+1}.pt")

    if global_step >= MAX_STEPS:
        break
