# Imports

In [42]:
import os
import copy
import json
import logging
import pathlib
import torch
import transformers
import tokenizers

from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, List
from torch.nn import functional as F
from torch.nn import ModuleList

from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from torch.utils.data import Dataset
from llava.train.llava_trainer import LLaVATrainer
from llava.train.train import DataArguments, TrainingArguments, ModelArguments, LazySupervisedDataset, DataCollatorForSupervisedDataset

from llava import conversation as conversation_lib
from llava.model import *
from llava.mm_utils import tokenizer_image_token
from peft import LoraConfig, get_peft_model

from PIL import Image

# Utils

In [43]:
'''
Return all of the Linear Layers in the neural network
'''
def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']

    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            # print(f"{name} is Linear")
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == -1 else names[-1])

    if 'lm_head' in lora_module_names:
        lora_module_names.remove('lm_head')
    
    return list(lora_module_names)

def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
                                data_path=data_args.data_path,
                                data_args=data_args)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset,
                eval_dataset=None,
                data_collator=data_collator)

# Define Arguments

In [44]:
# Model Args
model_name = 'liuhaotian/llava-v1.5-7b'
attn_implementation = 'flash_attention_2'
cache_dir = None
torch_dtype = torch.bfloat16
model_max_length=2048

In [45]:
# Model Arguments
model_args = ModelArguments(
    model_name_or_path='liuhaotian/llava-v1.5-7b',
    version='v1',
    freeze_backbone=False,
    tune_mm_mlp_adapter=False,
    vision_tower='openai/clip-vit-large-patch14-336',
    mm_vision_select_layer=-2,
    pretrain_mm_mlp_adapter=None,
    mm_projector_type='mlp2x_gelu',
    mm_use_im_start_end=False,
    mm_use_im_patch_token=False,
    mm_patch_merge_type='flat',
    mm_vision_select_feature='patch'
)

# Data Arguments
data_args = DataArguments(
    data_path='/home/devonperoutky/LLaVA/dataset/augmented/full_dataset.json',
    lazy_preprocess=True, # ?????
    is_multimodal=True,   # What does this do? False originally?????
    image_folder='/home/devonperoutky/LLaVA/dataset/',
    image_aspect_ratio='pad'
)


# Training Arguments
training_args = TrainingArguments(
    group_by_modality_length=True,
    bf16=True,
    output_dir="./checkpoints/llava-v1.5-7b-run-1",
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=1,
    save_strategy="steps",
    save_steps=50000,
    save_total_limit=1,
    learning_rate=4e-5,
    weight_decay=0,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    logging_steps=1,
    tf32=True,
    model_max_length=2048,
    gradient_checkpointing=True,
    dataloader_num_workers=4,
    report_to="wandb",

    # Lora Args
    lora_enable=True,
    mm_projector_lr=2e-05,
    lora_r = 128,
    lora_bias = "none",
    lora_alpha = 256,
    lora_dropout = .05
)

In [None]:
model = LlavaLlamaForCausalLM.from_pretrained(
    model_name,
    cache_dir=training_args.cache_dir,
    attn_implementation=attn_implementation,
    torch_dtype=torch_dtype
)
model.to(torch.bfloat16)

You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Loading checkpoint shards:  50%|███████████████████████████████████████████▌                                           | 1/2 [00:29<00:29, 29.12s/it]

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=training_args.cache_dir,
    model_max_length=training_args.model_max_length,
    padding_side="right",
    use_fast=False,
)

In [None]:
# Enable gradient checkpointing
if hasattr(model, "enable_input_require_grads"):
    model.enable_input_require_grads()

In [None]:
# Enabling Lora 
lora_config = LoraConfig(
    r=training_args.lora_r,
    lora_alpha=training_args.lora_alpha,
    target_modules=find_all_linear_names(model),
    lora_dropout=training_args.lora_dropout,
    bias=training_args.lora_bias,
    task_type="CASUAL_LM"
)

model = get_peft_model(model, lora_config)

In [None]:
conversation_lib.default_conversation = conversation_lib.conv_templates["v1"]

In [None]:
# What does this do? Initialize Vision Tower aka CLIP?
model.get_model().initialize_vision_modules(
    model_args=model_args,
    fsdp=training_args.fsdp
)
vision_tower = model.get_vision_tower()
vision_tower.to(dtype=torch_dtype, device=training_args.device)

# Set Image Processor and multimodal
data_args.image_processor = vision_tower.image_processor
data_args.is_multimodal = True

# Prepare model to dataset's preprocessing configuration
model.config.image_aspect_ratio = data_args.image_aspect_ratio
model.config.tokenizer_padding_side = tokenizer.padding_side
model.config.tokenizer_model_max_length = tokenizer.model_max_length

model.config.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter # False

# We are NOT tuning
if model_args.tune_mm_mlp_adapter:
    print("Tuning the mm_mlp adapter")
    model.requires_grad_(False)
    for p in model.get_model().mm_projector.parameters():
        p.requires_grad = True
else:
    print("We are not tuning the mm_mlp adapter")

model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
if training_args.freeze_mm_mlp_adapter:
    print("We are freezing the mm_mlp adapter")
    for p in model.get_model().mm_projector.parameters():
        p.requires_grad = False
else:
    print("We are not freezing the mm_mlp adapter")        

if training_args.bits in [4, 8]:
    print("LITTLE BITS")
    model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)

model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_projector_lr = training_args.mm_projector_lr
training_args.use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)

In [None]:
print(data_args)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)

In [None]:
data_module.keys()

In [None]:
trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)

In [None]:
trainer.__dict__

In [None]:
trainer.train()

In [None]:
!nvidia-smi

In [40]:
def print_module_parameters(module, parent_name=''):
    """
    Recursively print a module's parameters and their 'requires_grad' status.

    Args:
    - module (nn.Module): The current module to print parameters from.
    - parent_name (str): Name of the parent module, used for hierarchical naming.
    """
    # If the module has children, we're going to call this function on them
    if list(module.children()):
        # Module name (type) for more readable output
        module_name = module.__class__.__name__
        current_name = f"{parent_name}.{module_name}" if parent_name else module_name

        if isinstance(module, ModuleList):
            layers = [(name, child) for name, child in module.named_children()]
            print(f"{len(layers)} {current_name} layers(s)")
            print(f"{'-'*15} Start of Layer Structure {'-'*15}")
            child_name, child = layers[0]
            print_module_parameters(child, "")
            print(f"{'-'*30} End of Layer Structure {'-'*15}")

            # for name, child in module.named_children():
            #     print(name)
            #     print(child)
            #     # Construct child's full name by appending its name to the parent's name
            #     child_name = f"{current_name}.{name}" if name else current_name
            #     print_module_parameters(child, child_name)
        else:
            for name, child in module.named_children():
                # Construct child's full name by appending its name to the parent's name
                child_name = f"{current_name}.{name}" if name else current_name
                print_module_parameters(child, child_name)
    else:
        # Leaf module, print its parameters
        for name, param in module.named_parameters(recurse=False):
            full_name = f"{parent_name}.{name}" if parent_name else name
            print(f"{full_name}: {'Frozen' if not param.requires_grad else 'Trainable'}")


def output_layer(name, module, depth=0):
    # print(name)
    print("=" * 30)
    prefix = " " * depth
    print(f"{prefix} {name}")

    # Print out information about current layer

    # For each module, output_layer
    named_modules = [(n, m) for n, m in module.named_modules()]
    if not named_modules or len(named_modules) < 1:
        return

    # print(named_modules[0])
    # print(named_modules[1])
    first = True
    for n, m in named_modules:
        if first:
            first = False
            continue
        
        print("-"*30)
        print(f"{n} is of type {m.__class__.__name__}")
        print(m.children())
        # output_layer(n, m, depth+1)

        # print("<---->")
        # for n1, m1 in m.named_modules():
        #     print(n1)

In [41]:
print_module_parameters(model.get_model(), "LlavaLlamaModel")
# output_layer("LlavaLlamaModel", model.get_model(), 1)

LlavaLlamaModel.LlavaLlamaModel.embed_tokens.weight: Trainable
32 LlavaLlamaModel.LlavaLlamaModel.layers.ModuleList layers(s)
--------------- Start of Layer Structure ---------------
LlamaDecoderLayer.self_attn.LlamaFlashAttention2.q_proj.weight: Trainable
LlamaDecoderLayer.self_attn.LlamaFlashAttention2.k_proj.weight: Trainable
LlamaDecoderLayer.self_attn.LlamaFlashAttention2.v_proj.weight: Trainable
LlamaDecoderLayer.self_attn.LlamaFlashAttention2.o_proj.weight: Trainable
LlamaDecoderLayer.mlp.LlamaMLP.gate_proj.weight: Trainable
LlamaDecoderLayer.mlp.LlamaMLP.up_proj.weight: Trainable
LlamaDecoderLayer.mlp.LlamaMLP.down_proj.weight: Trainable
LlamaDecoderLayer.input_layernorm.weight: Trainable
LlamaDecoderLayer.post_attention_layernorm.weight: Trainable
------------------------------ End of Layer Structure ---------------
LlavaLlamaModel.LlavaLlamaModel.norm.weight: Trainable
LlavaLlamaModel.LlavaLlamaModel.mm_projector.Sequential.0.weight: Trainable
LlavaLlamaModel.LlavaLlamaModel.

In [None]:
for name, module in model.get_model().named_modules():
    if not name or name == '' or name == 'layers':
        continue

    print("="*30)

    print(module.__class__.__name__)
    print(f"{module.__class__.__name__} aka {name}")

    for n, m in module.named_modules():
        print(f"{m.__class__.__name__} aka {n}")
        print(f"{len([mod for mod in m.modules()]) -1} sub modules")
        print(f"{len([mod for mod in m.parameters()])} parameters")
        # print(m)


    # print(module)
    # print(f"{len([m for m in module.modules()]) -1} sub modules")
    # print([(p.shape, p.requires_grad) for p in module.parameters()])
    
    # for p in module.parameters():
    #     print(p.shape)

print(model.get_model().named_modules())

In [None]:
# model.get_model().mm_projector.parameters()
model.get_model().layers.parameters()
model.get_model().layers.parameters()

In [None]:
for n, m in model.get_model().named_modules():
    if not list(m.children()):
        pass
    else:
        print(n)
        print(type(m))

In [None]:
from transformers import LlamaForCausalLM, LlamaTokenizer

vicuna_tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
vicuna = LlamaForCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5")

In [None]:
model.embed_tokens()

In [None]:
vicuna

In [None]:
from transformers import CLIPConfig, CLIPModel

# CLIP
configuration = CLIPConfig()
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14-336")

In [None]:
model