In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('/u/shuhan/projects/vla')

In [2]:
from src.models.vlas.cont_obs_token_action_cot_unified_token_collision import ContObsTokenActionCOTVLAUnifiedTokenCollision
from src.auto_labeling.highway_env.lane_change import LaneChangeTaskSpecCollision
from transformers import AutoModelForCausalLM, AutoTokenizer

llm_model = 'HuggingFaceTB/SmolLM2-135M-Instruct'
# llm_model = "HuggingFaceTB/SmolLM2-360M-Instruct"

llm_backbone = AutoModelForCausalLM.from_pretrained(llm_model)
tokenizer = AutoTokenizer.from_pretrained(llm_model)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import torch
from torch import nn
from torch.nn.functional import scaled_dot_product_attention
from transformers import AutoModel, AutoConfig

class ScaledDotProductAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        # Projection layers
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, hidden_states, attention_mask=None):
        batch_size, seq_length, embed_dim = hidden_states.size()

        # Linear projections for query, key, and value
        query = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        value = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute attention using scaled_dot_product_attention
        attention_output = scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)

        # Reshape and apply output projection
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_dim)
        return self.out_proj(attention_output)

# Load a HuggingFace model and replace its attention layers
class FlashAttentionTransformer(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.base_model = AutoModel.from_pretrained(model_name)
        self.replace_attention_layers()

    def replace_attention_layers(self):
        for name, module in self.base_model.named_modules():
            if isinstance(module, nn.MultiheadAttention):
                # Replace with custom ScaledDotProductAttention
                embed_dim = module.embed_dim
                num_heads = module.num_heads
                replacement_layer = ScaledDotProductAttention(embed_dim, num_heads)
                parent_module = self.get_parent_module(name)
                setattr(parent_module, name.split('.')[-1], replacement_layer)
                print(f'Replaced {name} with FlashAttention')

    def get_parent_module(self, name):
        parent = self.base_model
        for part in name.split('.')[:-1]:
            parent = getattr(parent, part)
        return parent

    def forward(self, *args, **kwargs):
        return self.base_model(*args, **kwargs)


In [5]:
flash_model = FlashAttentionTransformer(llm_model)

In [14]:
flash_model.base_model.layers[3]

LlamaDecoderLayer(
  (self_attn): LlamaSdpaAttention(
    (q_proj): Linear(in_features=576, out_features=576, bias=False)
    (k_proj): Linear(in_features=576, out_features=192, bias=False)
    (v_proj): Linear(in_features=576, out_features=192, bias=False)
    (o_proj): Linear(in_features=576, out_features=576, bias=False)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (mlp): LlamaMLP(
    (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
    (up_proj): Linear(in_features=576, out_features=1536, bias=False)
    (down_proj): Linear(in_features=1536, out_features=576, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
  (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
)