In [2]:
# Ajouter src au path pour importer les modules
import sys
sys.path.append('src')

In [3]:
from siglip.model import SiglipVisionModel
from siglip.config import SiglipVisionConfig

from paligemma.config import PaliGemmaConfig
from paligemma.model import PaliGemmaForConditionalGeneration
import torch
import einops
from torchinfo import summary

In [6]:
config_siglip = SiglipVisionConfig(
	model_name="siglip-vision",
	hidden_size=512,
	intermediate_size=2048,
	num_hidden_layers=12,
	num_attention_heads=8,
)

model = SiglipVisionModel(config_siglip)
# Print the model architecture
summary(model, depth=5)

Layer (type:depth-idx)                             Param #
SiglipVisionModel                                  --
├─SiglipVisionTransformer: 1-1                     --
│    └─SiglipVisionEmbeddings: 2-1                 --
│    │    └─Conv2d: 3-1                            393,728
│    │    └─Embedding: 3-2                         100,352
│    └─SiglipEncoder: 2-2                          --
│    │    └─ModuleList: 3-3                        --
│    │    │    └─SiglipEncoderLayer: 4-1           --
│    │    │    │    └─SiglipAttention: 5-1         1,050,624
│    │    │    │    └─SiglipMLP: 5-2               2,099,712
│    │    │    │    └─LayerNorm: 5-3               1,024
│    │    │    │    └─LayerNorm: 5-4               1,024
│    │    │    └─SiglipEncoderLayer: 4-2           --
│    │    │    │    └─SiglipAttention: 5-5         1,050,624
│    │    │    │    └─SiglipMLP: 5-6               2,099,712
│    │    │    │    └─LayerNorm: 5-7               1,024
│    │    │    │    └─LayerNor

In [7]:
# Dict
vision_config = {
	"hidden_size": 512,
	"intermediate_size": 2048,
	"num_hidden_layers": 4,
	"num_attention_heads": 8,
	"layer_norm_eps": 1e-12,
	"max_position_embeddings": 512,
	"vocab_size": 30522,
	"image_size": 224,
	"patch_size": 16,
}
text_config = {
	"hidden_size": 512,
	"intermediate_size": 2048,
	"num_hidden_layers": 4,
	"num_attention_heads": 8,
	"layer_norm_eps": 1e-12,
	"max_position_embeddings": 512,
	"vocab_size": 30522,
	"num_key_value_heads": 8,
}
config_pali = PaliGemmaConfig(
	model_name="paligemma",
	hidden_size=512,
	intermediate_size=2048,
	num_hidden_layers=4,
	num_attention_heads=8,
	vision_config=vision_config,
	text_config=text_config,
)
model = PaliGemmaForConditionalGeneration(config_pali)
# Print the model architecture
summary(model, depth=5)

Layer (type:depth-idx)                                       Param #
PaliGemmaForConditionalGeneration                            --
├─SiglipVisionModel: 1-1                                     --
│    └─SiglipVisionTransformer: 2-1                          --
│    │    └─SiglipVisionEmbeddings: 3-1                      --
│    │    │    └─Conv2d: 4-1                                 393,728
│    │    │    └─Embedding: 4-2                              100,352
│    │    └─SiglipEncoder: 3-2                               --
│    │    │    └─ModuleList: 4-3                             --
│    │    │    │    └─SiglipEncoderLayer: 5-1                3,152,384
│    │    │    │    └─SiglipEncoderLayer: 5-2                3,152,384
│    │    │    │    └─SiglipEncoderLayer: 5-3                3,152,384
│    │    │    │    └─SiglipEncoderLayer: 5-4                3,152,384
│    │    └─LayerNorm: 3-3                                   1,024
├─PaliGemmaMultiModalProjector: 1-2                       

In [None]:
# Test
B, E, H, W = 2, 512, 7, 7
patch_embeds = torch.randn(B, E, H, W)
# [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W] -> [Batch_Size, Embed_Dim, Num_Patches]
# where Num_Patches = Num_Patches_H * Num_Patches_W
embeddings = patch_embeds.flatten(2)
# [Batch_Size, Embed_Dim, Num_Patches] -> [Batch_Size, Num_Patches, Embed_Dim]
embeddings = embeddings.transpose(1, 2)

embeddings.shape 

torch.Size([2, 512, 49])


torch.Size([2, 49, 512])

In [20]:
embeddings_2 = einops.rearrange(patch_embeds, "b e h w -> b (h w) e")

In [21]:
(embeddings == embeddings_2).all() # True

tensor(True)

In [24]:
from einops import repeat
from time import time
# Comparaison ds 2 fonctions : 
def __repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
	batch_size, num_key_value_heads, seq_len, head_dim = hidden_states.size()
	# On répète les clés et valeurs pour chaque tête de query
	if n_rep == 1:
		return hidden_states
	
	hidden_states = hidden_states[:, :, None, :, :].expand(batch_size, num_key_value_heads, n_rep, seq_len, head_dim)
	return hidden_states.reshape(batch_size, num_key_value_heads * n_rep, seq_len, head_dim)

def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
	"""
	Répète les clés et valeurs pour chaque tête de query.
	Args:
		hidden_states (torch.Tensor): Les clés ou valeurs. [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
		n_rep (int): Le nombre de répétitions.
	Returns:
		torch.Tensor: Les clés ou valeurs répétées. [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
	"""
	if n_rep == 1:
		return hidden_states
	
	return repeat(hidden_states, 'b h s d -> b (h r) s d', r=n_rep)


# Test
B, H, S, D = 2, 8, 7, 64
hidden_states = torch.randn(B, H, S, D)
n_rep = 4
# Test de la fonction __repeat_kv
repeated_states = __repeat_kv(hidden_states, n_rep)
# Test de la fonction _repeat_kv
repeated_states_2 = _repeat_kv(hidden_states, n_rep)

# Vérification de l'égalité
print((repeated_states == repeated_states_2).all())  # True



tensor(True)


In [15]:
# Test de la vitesse
start = time()
n_time = 100_000
for i in range(n_time):
	__repeat_kv(hidden_states, n_rep)
end = time()
print(f"__repeat_kv : {end - start} sec")
start = time()
for i in range(n_time):
	_repeat_kv(hidden_states, n_rep)
end = time()
print(f"repeat_kv avec einops : {end - start} sec")


__repeat_kv : 6.073018312454224 sec
repeat_kv avec einops : 4.460999011993408 sec


In [45]:
# test de rotate
def rotate_half(x: torch.Tensor) -> torch.Tensor:
	"""
	Effectue une rotation de moitié sur les clés et valeurs.
	Permet de construire le vecteur [-x2, x1, -x4, x3, ...]
	Args:
		x (torch.Tensor): Les clés ou valeurs. [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
	Returns:
		torch.Tensor: Les clés ou valeurs après rotation. [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
	"""
	x1 = x[..., :x.size(-1) // 2] # Récupère la première moitié
	print("x1, ", x1)
	x2 = x[..., x.size(-1) // 2:] # Récupère la deuxième moitié
	print("x2, ", x2)
	return torch.cat((-x2, x1), dim=-1)

# Test
hidden_states = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]).reshape(1, 2, 2, 4) # [Batch_Size, Num_Heads_KV, Head_Dim]
print(hidden_states)
# Test de la fonction rotate_half
rotated_states = rotate_half(hidden_states)
rotated_states


tensor([[[[ 1,  2,  3,  4],
          [ 5,  6,  7,  8]],

         [[ 9, 10, 11, 12],
          [13, 14, 15, 16]]]])
x1,  tensor([[[[ 1,  2],
          [ 5,  6]],

         [[ 9, 10],
          [13, 14]]]])
x2,  tensor([[[[ 3,  4],
          [ 7,  8]],

         [[11, 12],
          [15, 16]]]])


tensor([[[[ -3,  -4,   1,   2],
          [ -7,  -8,   5,   6]],

         [[-11, -12,   9,  10],
          [-15, -16,  13,  14]]]])

In [52]:
# test :
# freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
# Autre
# freqs = torch.einsum('bdr,brs->bsd', freqs, position_ids_expanded)

# Test
# [Batch_Size, Head_Dim // 2, 1]
inv_freq_expanded = torch.randn(3, 4, 1)

# [Batch_Size, 1, Seq_Len]
position_ids_expanded = torch.randn(3, 1, 5)

# [Batch_Size, Head_Dim // 2, Seq_Len]
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)


freqs_2 = torch.einsum('bdr,brs->bsd', inv_freq_expanded, position_ids_expanded)
freqs_3 = torch.einsum('bdr,brs->bds', inv_freq_expanded, position_ids_expanded).transpose(1, 2)

print((freqs == freqs_2).all())  # True
print((freqs == freqs_3).all())  # True
print((freqs_2 == freqs_3).all())  # True

tensor(True)
tensor(True)
tensor(True)
