In [1]:
import sys
sys.path.insert(0, "..")
from testbed.models import Idefics
import torch
import config

device = torch.device("cuda:1")
model = Idefics(
    config.idefics_9b_path,
    dtype=torch.bfloat16,
    device=device,
)


Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [2]:
import pytorch_lightning as pl
import exp_settings as setting
import datasets
import os
import sys

sys.path.insert(0, "..")
from testbed.data import prepare_caption_input, prepare_dataloader, prepare_vqa_input
import config


class ICVDataModule(pl.LightningDataModule):

    def __init__(self, lmm) -> None:
        super().__init__()
        self.lmm = lmm
        tokenizer = self.lmm.processor.tokenizer
        if tokenizer.sep_token is None:
            tokenizer.add_special_tokens({"sep_token": "[SEP]"})
            self.lmm.model.resize_token_embeddings(len(tokenizer))
            self.lmm.model.tie_weights()

    def setup(self, stage: str) -> None:
        if stage == "fit" or stage is None:
            if setting.task == "vqa":
                self.dataset = datasets.load_dataset(
                    os.path.join(config.testbed_dir, "data", "vqav2"),
                    split="train",
                    data_dir=config.vqav2_dir,
                    images_dir=config.coco_dir,
                    trust_remote_code=True,
                )
            elif setting.task == "caption":
                self.dataset = datasets.load_dataset(
                    os.path.join(config.testbed_dir, "data", "coco"),
                    data_dir=config.karpathy_coco_caption_dir,
                    images_dir=config.coco_dir,
                    trust_remote_code=True,
                )

    def collate_fn(self, batch):
        if setting.task == "vqa":
            context, images = prepare_vqa_input(
                batch, instruction=setting.vqa_instruction
            )
            # we use the first answer as grounding truth
            answers = [item[-1]["answers"][0]["answer"] for item in batch]
        elif setting.task == "caption":
            context, images = prepare_caption_input(
                batch, instruction=setting.caption_instruction
            )
            answers = [item[-1]["sentences_raw"][0] for item in batch]

        # the last two items (take vqa as an example):
        # [
        #   { "role" : "question"
        #     "content" :  ... },
        #   { "role" : "short answer" }
        # ]
        query = [ctx[-2:] for ctx in context]
        context = [ctx[:-2] for ctx in context]

        context = self.lmm.apply_prompt_template(context)
        query = self.lmm.apply_prompt_template(query)
        # we use the sep token as delimiters to mark the boundary of examples, query and answer.
        sep_token, sep_token_id, eos_token = (
            self.lmm.processor.tokenizer.sep_token,
            self.lmm.processor.tokenizer.sep_token_id,
            self.lmm.processor.tokenizer.eos_token,
        )
        text_inputs = [
            ctx + sep_token + q + sep_token + ans + eos_token
            for ctx, q, ans in zip(context, query, answers)
        ]

        model_inputs = self.lmm.process_input(text_inputs, images)
        # sep_token is only used to extract corresponding examples, qeury and answer
        # it should be ignored in forward & backward process
        model_inputs["attention_mask"] = torch.where(
            model_inputs["input_ids"] == sep_token_id,
            False,
            model_inputs["attention_mask"],
        )
        return model_inputs

    def train_dataloader(self):
        return prepare_dataloader(
            self.dataset,
            setting.batch_size,
            setting.num_shot,
            collate_fn=self.collate_fn,
            num_workers=setting.num_workers,
            shuffle=True,
        )

In [None]:
from functools import partial
import hydra
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from deepspeed.ops.adam import DeepSpeedCPUAdam
from transformers import get_cosine_schedule_with_warmup
import exp_settings as setting
from testbed.models.model_base import HookType


class ICVModel(pl.LightningModule):
    def __init__(self, lmm, icv_encoder: torch.nn.Module) -> None:
        super().__init__()
        self.lmm = lmm

        self.icv_encoder = icv_encoder

    def forward(self, inputs):
        # inputs: ICE [SEP] query [SEP] answer [EOS]

        # step 1. ICV + query [SEP] answer [EOS]
        hooks = self.lmm.register_forward_hook(
            HookType.TEXT_MODEL_LAYER,
            partial(self.icv_encoder.hook, model_inputs=inputs),
        )
        sep_token_id = self.lmm.processor.tokenizer.sep_token_id
        sep_positions = (inputs['input_ids'] == sep_token_id).nonzero(as_tuple=True)[1]
        query_sep_answer = inputs['input_ids'][:, sep_positions[0] + 1:]
        query_outputs = self.lmm.model(**query_sep_answer, labels=query_sep_answer["input_ids"])
        ice_logits = query_outputs["logits"]
        for hook in hooks:
            hook.remove()

        # step 2. ICE [SEP] query [SEP] answer [EOS]
        with torch.no_grad():
            icv_logits = self.lmm.model(**inputs)["logits"]

        tea_probs = ice_logits[icl_context_mask].softmax(dim=-1)
        stu_log_probs = icv_logits[zero_shot_mask].log_softmax(dim=-1)
        
        kl_loss = F.kl_div(stu_log_probs, tea_probs, reduction='batchmean', log_target=False)

        loss_dict = {"kl_loss": kl_loss}
        loss += self.module_cfg.hard_loss_weight * query_outputs["loss"]
        loss_dict["ce_loss"] = query_outputs["loss"]
        loss_dict["loss"] = loss
        return loss_dict, icv_encoder_output

    def training_step(self, batch, batch_idx):
        loss_dict, icv_encoder_output = self(**batch)
        self.log_dict(loss_dict, sync_dist=True, prog_bar=True)
        alpha = icv_encoder_output.alpha.squeeze()
        for i in range(len(alpha)):
            self.log(f"alpha/alpha-{i}", alpha[i])
        return loss_dict["loss"]

    def configure_optimizers(self):
        param_dict = {pn: p for pn, p in self.icv_encoder.named_parameters()}
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}

        decay_params = [
            p for n, p in param_dict.items() if p.dim() >= 2 and "alpha" not in n
        ]
        nodecay_params = [
            p for n, p in param_dict.items() if p.dim() < 2 and "alpha" not in n
        ]

        alpha_params = [p for n, p in param_dict.items() if "alpha" in n]

        optim_groups = [
            {"params": decay_params, "weight_decay": setting.weight_decay},
            {"params": nodecay_params, "weight_decay": 0.0},
            {
                "params": alpha_params,
                "weight_decay": setting.weight_decay,
                "lr": setting.alpha_lr,
            },
        ]

        optimizer = DeepSpeedCPUAdam(
            optim_groups,
            lr=setting.icv_lr,
            weight_decay=setting.weight_decay,
        )

        step_batches = self.trainer.estimated_stepping_batches
        warmup_steps = setting.warmup_step
        if isinstance(warmup_steps, float):
            warm_steps = warmup_steps * step_batches
        elif isinstance(warmup_steps, int):
            warm_steps = warmup_steps
        else:
            raise ValueError(
                f"the warm_steps should be int or float, but got {type(warmup_steps)}"
            )
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, num_warmup_steps=warm_steps, num_training_steps=step_batches
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler, "interval": "step"},
        }

    def on_save_checkpoint(self, checkpoint):
        params_name = list(checkpoint["state_dict"].keys())
        for name in params_name:
            if name.startswith("model"):
                checkpoint["state_dict"].pop(name)

In [None]:
import os
import shutil
from pathlib import Path

import hydra
import pytorch_lightning as pl
import torch
from omegaconf import DictConfig
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    RichModelSummary,
    RichProgressBar,
)
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.deepspeed import (
    convert_zero_checkpoint_to_fp32_state_dict,
)
import sys

sys.path.insert(0, "..")
import config
from testbed.models import Idefics
from transformers import IdeficsModel


os.environ["TOKENIZERS_PARALLELISM"] = "false"


def main():
    pl.seed_everything(436)
    os.makedirs(config.result_dir, exist_ok=True)

    save_path = Path(save_path)

    wb_logger = WandbLogger(
        save_dir=config.result_dir,
        name="trial",
        project="VQAInContextVector",
        log_model=False,
    )
    trainer = pl.Trainer(
        logger=wb_logger,
        callbacks=[
            LearningRateMonitor(),
            RichModelSummary(max_depth=2),
            RichProgressBar(),
        ],
        max_epochs=10,
        strategy="deepspeed_stage_2_offload",
        precision="bf16-mixed",
        gradient_clip_val=1.0,
        log_every_n_steps=25,
        accumulate_grad_batches=8,
        enable_checkpointing=False,
    )

    lmm = Idefics(config.idefics_9b_path, model_class=IdeficsModel)
    data_module = ICVDataModule(
        data_config=config.data_config,
        prompt_manager=prompt_manager,
        prompt_processor=processor,
    )
    model = ICVModel(
        lmm=, icv_encoder=None
    )

    trainer.fit(
        model,
        data_module,
    )
    trainer.save_checkpoint(
        filepath=os.path.join(
            save_path,
            "last.ckpt",
        ),
        weights_only=True,
    )
    postprocess(config, save_path)


@rank_zero_only
def postprocess(config, save_path):
    # TODO: Save layer map
    save_path = Path(save_path)
    if "deepspeed" in config.trainer.strategy:
        cpk_save_path = save_path / "last.ckpt"
        output_file = save_path / "lightning_module.bin"
        convert_zero_checkpoint_to_fp32_state_dict(cpk_save_path, output_file)

        checkpoint = torch.load(output_file)
        params_name = list(checkpoint["state_dict"].keys())
        for name in params_name:
            if "lmm" in name or "interface.model" in name:
                checkpoint["state_dict"].pop(name)
        checkpoint["state_dict"]["use_sigmoid"] = getattr(
            config.icv_module.icv_encoder, "use_sigmoid", None
        )
        checkpoint["state_dict"]["lmm_args"] = checkpoint["hyper_parameters"][
            "lmm_config"
        ]
        torch.save(checkpoint["state_dict"], save_path / "icv_cpk.pth")
        os.remove(output_file)
        shutil.rmtree(
            cpk_save_path,
        )


if __name__ == "__main__":
    main()