In [1]:
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM, AutoConfig
import os
from rich import print
os.environ['CUDA_VISIBLE_DEVICES']='1,2,3' # better be set at the start of the program

model_path = "/models/glm-4v-9b"
hf_config = AutoConfig.from_pretrained(model_path,
                                        trust_remote_code=True)

with init_empty_weights(): # This will accelerate 10x faster than normally from_config
    model = AutoModelForCausalLM.from_config(hf_config,
                                         trust_remote_code=True)
print(model) # model architecture with empty weights

## Remove llm modules
del model.transformer.embedding
del model.transformer.rotary_pos_emb
del model.transformer.encoder
del model.transformer.output_layer
## Remove llm modules


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from accelerate import  infer_auto_device_map
import torch

# affected by CUDA_VISIBLE_DEVICES and only affected from first initialized
print("CUDA device_count:", torch.cuda.device_count())
max_memory = { i: torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count()) }
max_memory[0] = max_memory[0]//16 # decrease the first card usage
print(max_memory) # {0: 3354591232//16, 1: 22564503552, 2: 22564503552}

# vision transformer layers.
# If not specified, some modules in a layer may be splitted to different cards.
no_split_module_classes = ['TransformerLayer']
device_map = infer_auto_device_map(
    model,
    no_split_module_classes=no_split_module_classes,
    dtype=torch.float16,
    max_memory=max_memory
)
print(device_map) # OrderedDict([('transformer', 0)])

torch.cuda.empty_cache()
print("real CUDA memory_allocated: ",torch.cuda.memory_allocated()) # nvidia-smi lies

same_device_keys = [('transformer.vision.linear_proj',
                     'transformer.vision.boi',  
                     'transformer.vision.eoi')]
print(type(model.transformer.vision.boi)) # torch.nn.Parameter

for keys in same_device_keys:
    pass
