In [2]:
!pip install datasets

Collecting datasets
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-4.0.0-py3-none-any.whl (494 kB)
Downloading dill-0.3.8-py3-none-any.whl (116 kB)
Downloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
Installing collected packages: xxhash, fsspec, dill, multiprocess, datasets
[2K  Attemptin

In [1]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from datasets import load_dataset
from PIL import Image, UnidentifiedImageError
import requests
from transformers import AutoProcessor
from utils import load_hf_model  # your local loader for PaliGemma models

# ========================
# CONFIG
# ========================
TARGET_PATH = "/home/jupyter/Paligemma/google/paligemma-3b-pt-896"  # teacher
DRAFT_PATH  = "/home/jupyter/Paligemma/google/draft/16"             # student (layer-pruned)
SAVE_PATH   = "/home/jupyter/Paligemma/google/draft_distilled"

DEVICE = "cuda"
BATCH_SIZE = 2
LR = 1e-5
EPOCHS = 1
MAX_LEN = 128
TEMPERATURE = 2.0
ALPHA = 0.5  # CE vs distillation balance

# ========================
# LOAD MODELS + PROCESSOR
# ========================
print("Loading teacher (target)...")
teacher, _ = load_hf_model(TARGET_PATH, DEVICE)
teacher = teacher.to(DEVICE).eval()
for p in teacher.parameters():
    p.requires_grad = False

print("Loading student (draft)...")
student, _ = load_hf_model(DRAFT_PATH, DEVICE)
student = student.to(DEVICE).train()

# Multimodal processor (supports images + text)
processor = AutoProcessor.from_pretrained(TARGET_PATH)

# ========================
# LOAD DATASET
# ========================
print("Loading image_captions_x dataset (1%)...")
full_dataset = load_dataset("kamruzzaman-asif/image_captions_x", split="laion[:1%]")
subset_size = max(1, len(full_dataset) // 100)  # 1% of dataset
dataset = full_dataset.select(range(subset_size))
print(f"Subset size: {len(dataset)}")

# ========================
# COLLATE FUNCTION
# ========================
def collate_fn(batch):
    images, captions = [], []

    for ex in batch:
        url, caption = ex["url"], ex["caption"]
        try:
            img = Image.open(requests.get(url, stream=True, timeout=5).raw).convert("RGB")
            images.append(img)
            captions.append(caption)
        except (UnidentifiedImageError, requests.RequestException):
            continue  # skip broken images

    if len(images) == 0:
        return None  # skip batch if all images failed

    enc = processor(
        images=images,
        text=captions,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=MAX_LEN
    )
    return {k: v.to(DEVICE) for k, v in enc.items()}

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

# ========================
# LOSS + OPTIMIZER
# ========================
optimizer = AdamW(student.parameters(), lr=LR)

def distillation_loss(student_logits, teacher_logits, labels, temperature=TEMPERATURE, alpha=ALPHA):
    # CE loss
    ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-100)(
        student_logits.view(-1, student_logits.size(-1)),
        labels.view(-1)
    )
    # KL divergence
    log_p = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1)
    q = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)
    kl_loss = torch.nn.functional.kl_div(log_p, q, reduction="batchmean") * (temperature ** 2)
    return alpha * ce_loss + (1 - alpha) * kl_loss

# ========================
# TRAIN LOOP
# ========================
for epoch in range(EPOCHS):
    total_loss = 0
    for step, batch in enumerate(tqdm(loader, desc=f"Epoch {epoch+1}")):
        if batch is None:
            continue  # skip batch if all images failed

        optimizer.zero_grad()

        # Teacher forward (frozen)
        with torch.no_grad():
            teacher_out = teacher(**batch)
            teacher_logits = teacher_out['logits']

        # Student forward
        student_out = student(**batch)
        student_logits = student_out['logits']

        # Compute distillation loss
        loss = distillation_loss(student_logits, teacher_logits, batch["input_ids"])
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if step % 10 == 0:
            print(f"Step {step}, Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1} - Avg Loss: {avg_loss:.4f}")

# ========================
# SAVE STUDENT
# ========================
torch.save(student.state_dict(), f"{SAVE_PATH}.pt")
print(f"✅ Distilled student saved at {SAVE_PATH}.pt")


FlashAttention is available and will be used.
Loading teacher (target)...
Loading student (draft)...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading image_captions_x dataset (1%)...


Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Subset size: 4000


Epoch 1:   0%|          | 0/2000 [00:00<?, ?it/s]You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images each text has and add special tokens.
Epoch 1:   0%|          | 0/2000 [00:02<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 258.00 MiB. GPU 

In [2]:
# train_ddp.py
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from datasets import load_dataset
from transformers import AutoProcessor
from PIL import Image, UnidentifiedImageError
import requests
from utils import load_hf_model  # your loader for PaliGemma

# ========================
# CONFIG
# ========================
TARGET_PATH = "/home/jupyter/Paligemma/google/paligemma-3b-pt-896"
DRAFT_PATH  = "/home/jupyter/Paligemma/google/draft/16"
SAVE_PATH   = "/home/jupyter/Paligemma/google/draft_distilled_ddp"

BATCH_SIZE = 2
LR = 1e-5
EPOCHS = 1
MAX_LEN = 128
TEMPERATURE = 2.0
ALPHA = 0.5

# ========================
# DISTILLATION LOSS
# ========================
def distillation_loss(student_logits, teacher_logits, labels, temperature=TEMPERATURE, alpha=ALPHA):
    ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-100)(
        student_logits.view(-1, student_logits.size(-1)),
        labels.view(-1)
    )
    log_p = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1)
    q = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)
    kl_loss = torch.nn.functional.kl_div(log_p, q, reduction="batchmean") * (temperature ** 2)
    return alpha * ce_loss + (1 - alpha) * kl_loss

# ========================
# DATASET & COLLATE
# ========================
class ImageCaptionDataset(Dataset):
    def __init__(self, split="laion[:1%]", max_len=MAX_LEN):
        full_dataset = load_dataset("kamruzzaman-asif/image_captions_x", split=split)
        subset_size = max(1, len(full_dataset) // 100)
        self.dataset = full_dataset.select(range(subset_size))
        self.processor = AutoProcessor.from_pretrained(TARGET_PATH)
        self.max_len = max_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        ex = self.dataset[idx]
        return ex

def collate_fn(batch):
    images, captions = [], []

    for ex in batch:
        url, caption = ex["url"], ex["caption"]
        try:
            img = Image.open(requests.get(url, stream=True, timeout=5).raw).convert("RGB")
            images.append(img)
            captions.append(caption)
        except (UnidentifiedImageError, requests.RequestException):
            continue

    if len(images) == 0:
        return None

    enc = processor(
        images=images,
        text=captions,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=MAX_LEN
    )
    return {k: v.to(torch.cuda.current_device()) for k, v in enc.items()}

# ========================
# TRAIN FUNCTION
# ========================
def train_ddp(rank, world_size):
    # Initialize process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    # Load teacher (frozen) on GPU 0 only
    if rank == 0:
        teacher, _ = load_hf_model(TARGET_PATH, device)
        teacher.eval()
        for p in teacher.parameters():
            p.requires_grad = False
    else:
        teacher = None

    # Load student
    student, _ = load_hf_model(DRAFT_PATH, device)
    student.train()
    student = DDP(student, device_ids=[rank], output_device=rank)

    # Dataset & Dataloader
    dataset = ImageCaptionDataset()
    sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler, collate_fn=collate_fn)

    optimizer = AdamW(student.parameters(), lr=LR)

    for epoch in range(EPOCHS):
        sampler.set_epoch(epoch)
        total_loss = 0
        for step, batch in enumerate(loader):
            if batch is None:
                continue

            optimizer.zero_grad()

            # Teacher forward (only rank 0)
            with torch.no_grad():
                teacher_logits = teacher(**batch)['logits'] if rank == 0 else None

            # Broadcast teacher logits to all ranks
            if rank == 0:
                dist.broadcast(teacher_logits, src=0)
            else:
                teacher_logits = torch.empty_like(student.module(**batch)['logits'])
                dist.broadcast(teacher_logits, src=0)

            # Student forward
            student_logits = student(**batch)['logits']
            loss = distillation_loss(student_logits, teacher_logits, batch["input_ids"])
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            if step % 10 == 0 and rank == 0:
                print(f"Rank {rank}, Step {step}, Loss: {loss.item():.4f}")

        avg_loss = total_loss / len(loader)
        if rank == 0:
            print(f"Epoch {epoch+1} - Avg Loss: {avg_loss:.4f}")

    if rank == 0:
        torch.save(student.module.state_dict(), f"{SAVE_PATH}.pt")
        print(f"✅ Distilled student saved at {SAVE_PATH}.pt")

    dist.destroy_process_group()

# ========================
# SPAWN MULTI-PROCESS
# ========================
if __name__ == "__main__":
    world_size = torch.cuda.device_count()  # 8 GPUs
    mp.spawn(train_ddp, args=(world_size,), nprocs=world_size, join=True)


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/conda/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/conda/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'train_ddp' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/conda/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
Traceback (most recent call last):
  File "<string>", line 1, in <module>
    exitcode = _main(fd, parent_sentinel)
  File "/opt/conda/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
  File "/opt/conda/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'train_ddp' on <module '__main__' (built-in)>
    exitcode = _main(fd

ProcessExitedException: process 6 terminated with exit code 1

In [2]:
from gemma_flash import PaliGemmaForConditionalGeneration
from transformers import AutoProcessor
import torch

pt_path = "/home/jupyter/Paligemma/google/draft_distilled_ddp.pt"
save_dir = "/home/jupyter/Paligemma/google/draft_distilled_ddp"
teacher_path = "/home/jupyter/Paligemma/google/paligemma-3b-pt-896"

# 1. Load base model from teacher config
model = PaliGemmaForConditionalGeneration.from_pretrained(teacher_path)

# 2. Load your trained weights (state dict)
state_dict = torch.load(pt_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)

# 3. Save in Hugging Face format
model.save_pretrained(save_dir)

# 4. Save processor (tokenizer + image processor)
processor = AutoProcessor.from_pretrained(teacher_path)
processor.save_pretrained(save_dir)

print(f"✅ Converted .pt checkpoint into Hugging Face directory: {save_dir}")


FlashAttention is available and will be used.


AttributeError: type object 'PaliGemmaForConditionalGeneration' has no attribute 'from_pretrained'

In [3]:
import torch
import shutil
from transformers import AutoProcessor

pt_path = "/home/jupyter/Paligemma/google/draft_distilled_ddp.pt"
save_dir = "/home/jupyter/Paligemma/google/draft_distilled_ddp"
teacher_path = "/home/jupyter/Paligemma/google/paligemma-3b-pt-896"

# 1. Load teacher config files into new dir
shutil.copy(f"{teacher_path}/config.json", save_dir)
shutil.copy(f"{teacher_path}/generation_config.json", save_dir)

# 2. Convert your .pt into Hugging Face naming
state_dict = torch.load(pt_path, map_location="cpu")
torch.save(state_dict, f"{save_dir}/pytorch_model.bin")

# 3. Copy tokenizer + processor from teacher
processor = AutoProcessor.from_pretrained(teacher_path)
processor.save_pretrained(save_dir)

print(f"✅ Draft model folder created at {save_dir}")


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


✅ Draft model folder created at /home/jupyter/Paligemma/google/draft_distilled_ddp
