Skip to content

Commit

Permalink
updat
Browse files Browse the repository at this point in the history
  • Loading branch information
jpthu17 committed Apr 12, 2024
1 parent eef46c8 commit f63c247
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 2 deletions.
5 changes: 3 additions & 2 deletions ChatUniVi/model/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ChatUniVi.constants import *
from .cluster import CTM, TCBlock
from collections import OrderedDict
from .multimodal_projector.builder import build_vision_projector


class MetaModel:
Expand Down Expand Up @@ -58,8 +59,8 @@ def initialize_vision_modules(self, model_args, fsdp=None):
else:
self.vision_tower = vision_tower

if not hasattr(self, 'mm_projector') or not self.mm_projector.weight.size(0):
self.mm_projector = nn.Linear(self.config.mm_hidden_size, self.config.hidden_size)
if not hasattr(self, 'mm_projector'):
self.mm_projector = build_vision_projector(self.config)

if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
Expand Down
52 changes: 52 additions & 0 deletions ChatUniVi/model/multimodal_projector/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
import torch.nn as nn
import re


class IdentityMap(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, *args, **kwargs):
return x

@property
def config(self):
return {"mm_projector_type": 'identity'}


class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)

self.proj = nn.Sequential(
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)


def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')

if projector_type == 'linear':
return nn.Linear(config.mm_hidden_size, config.hidden_size)

print("projector_type:", projector_type)
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)

if projector_type == 'identity':
return IdentityMap()

raise ValueError(f'Unknown projector type: {projector_type}')
1 change: 1 addition & 0 deletions ChatUniVi/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class ModelArguments:
mm_use_im_patch_token: bool = field(default=True)
mm_vision_select_feature: Optional[str] = field(default="patch")

mm_projector_type: Optional[str] = field(default='linear')
model_use: str = field(default="BASE")
mm_use_box_start_end: bool = field(default=False)

Expand Down

0 comments on commit f63c247

Please sign in to comment.