-
Notifications
You must be signed in to change notification settings - Fork 2
/
modeling_git_opt.py
534 lines (448 loc) · 21.8 KB
/
modeling_git_opt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
"""PyTorch GIT OPT model."""
import copy
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import (CLIPVisionConfig, CLIPVisionModel, OPTConfig,
OPTForCausalLM, OPTModel)
from transformers.modeling_outputs import (BaseModelOutputWithPast,
BaseModelOutputWithPooling,
CausalLMOutputWithPast)
from transformers.models.git.modeling_git import GitProjection
from transformers.models.opt.modeling_opt import OPTLearnedPositionalEmbedding
class GitOPTConfig(OPTConfig):
model_type = "git_opt"
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
self.vision_config = CLIPVisionConfig()
self.num_image_with_embedding = None
def set_vision_configs(
self,
num_image_with_embedding: Union[int, None] = None,
vision_model_name: Union[str, None] = None,
):
self.num_image_with_embedding = num_image_with_embedding
self.vision_model_name = vision_model_name
self.vision_config = CLIPVisionConfig.from_pretrained(vision_model_name)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class GitOPTModel(OPTModel):
config_class = GitOPTConfig
def __init__(self, config: OPTConfig):
super(GitOPTModel, self).__init__(config)
# Git modules
self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
self.visual_projection = GitProjection(config)
if config.num_image_with_embedding is not None:
self.img_temporal_embedding = nn.ParameterList(
nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
for _ in range(config.num_image_with_embedding)
)
self.image_patch_tokens = int(
(config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1
)
if config.num_image_with_embedding is not None:
self.image_patch_tokens *= config.num_image_with_embedding
self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.decoder.embed_tokens
def set_input_embeddings(self, value):
self.decoder.embed_tokens = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def _generate_future_mask(
self, size: int, dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
# Default mask is for forward direction. Flip for backward direction.
mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
mask = mask.masked_fill(mask == 1, float("-inf"))
return mask
def create_attention_mask(
self,
tgt,
memory,
tgt_mask,
past_key_values_length,
memory_key_padding_mask=None,
):
num_tgt = tgt.shape[1]
num_memory = memory.shape[1]
device = tgt.device
dtype = tgt.dtype
top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
top_right = torch.full(
(num_memory, num_tgt + past_key_values_length),
float("-inf"),
device=tgt.device,
dtype=dtype,
)
bottom_left = torch.zeros(
(num_tgt, num_memory),
dtype=dtype,
device=tgt_mask.device,
)
if past_key_values_length > 0:
tgt_mask = torch.zeros(
(tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
dtype=dtype,
device=tgt_mask.device,
)
left = torch.cat((top_left, bottom_left), dim=0)
right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
full_attention_mask = torch.cat((left, right), dim=1)[None, :]
if memory_key_padding_mask is None:
memory_key_padding_mask = torch.full(
(memory.shape[0], memory.shape[1]), fill_value=False, device=device
)
# if it is False, it means valid. That is, it is not a padding
if memory_key_padding_mask.dtype != torch.bool:
raise ValueError("Memory key padding mask must be a boolean tensor.")
zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
zero_negative_infinity[memory_key_padding_mask] = float("-inf")
full_attention_mask = full_attention_mask.expand(
(
memory_key_padding_mask.shape[0],
num_memory + num_tgt,
num_memory + past_key_values_length + num_tgt,
)
)
full_attention_mask = full_attention_mask.clone()
origin_left = full_attention_mask[:, :, :num_memory]
update = zero_negative_infinity[:, None, :]
full_attention_mask[:, :, :num_memory] = origin_left + update
# add axis for multi-head
full_attention_mask = full_attention_mask[:, None, :, :]
return full_attention_mask
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Returns:"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_hidden_states
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
# past_key_values_length
past_key_values_length = (
past_key_values[0][0].shape[2] if past_key_values is not None else 0
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
projected_visual_features = None
if pixel_values is not None and past_key_values is None:
if pixel_values.ndim == 4:
# here we assume pixel_values is of shape (batch_size, num_channels, height, width)
visual_features = self.image_encoder(pixel_values).last_hidden_state
elif pixel_values.ndim == 5:
# here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
visual_features = []
for frame_idx in range(pixel_values.shape[1]):
visual_features_frame = self.image_encoder(
pixel_values[:, frame_idx, :, :]
).last_hidden_state
visual_features_frame += self.img_temporal_embedding[frame_idx]
visual_features.append(visual_features_frame)
# finally, concatenate all features along sequence dimension
visual_features = torch.cat(visual_features, dim=1)
else:
raise ValueError("pixel_values must be of rank 4 or 5")
projected_visual_features = self.visual_projection(visual_features)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L634-L658
inputs_embeds = self.decoder.embed_tokens(input_ids)
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
if past_key_values is not None:
mask_seq_length = mask_seq_length - self.image_patch_tokens
past_key_values_length = past_key_values_length - self.image_patch_tokens
# embed positions
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
if self.decoder.project_in is not None:
inputs_embeds = self.decoder.project_in(inputs_embeds)
embedding_output = inputs_embeds + pos_embeds
if projected_visual_features is None:
projected_visual_features = torch.zeros(
(embedding_output.shape[0], 0, embedding_output.shape[2]),
dtype=embedding_output.dtype,
device=embedding_output.device,
)
# Repeat visual features to match embedding batch size.
projected_visual_features = projected_visual_features.repeat(
embedding_output.size(0) // projected_visual_features.size(0), 1, 1
)
# concatenate patch token and text token embeddings
hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
# By default, an additive causal mask is created
# for masking the future (one direction).
tgt_mask = self._generate_future_mask(
seq_length, embedding_output.dtype, embedding_output.device
)
# for full sequence (w/ image patch tokens)
if past_key_values is not None:
past_key_values_length = past_key_values_length + self.image_patch_tokens
# Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
combined_attention_mask = self.create_attention_mask(
tgt=embedding_output,
memory=projected_visual_features,
tgt_mask=tgt_mask,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# if the user provides an attention mask, we add it to the default one
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]
).to(embedding_output.device)
if past_key_values_length > 0:
expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
else:
combined_attention_mask[
:, :, -input_shape[1] :, -input_shape[1] :
] += expanded_attn_mask
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.decoder.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
# dropout_probability = random.uniform(0, 1)
# if self.training and (dropout_probability < self.layerdrop):
# continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.decoder.gradient_checkpointing and self.decoder.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
combined_attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.decoder.final_layer_norm is not None:
hidden_states = self.decoder.final_layer_norm(hidden_states)
if self.decoder.project_out is not None:
hidden_states = self.decoder.project_out(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class GitOPTForCausalLM(OPTForCausalLM):
config_class = GitOPTConfig
def __init__(
self,
config,
):
super(GitOPTForCausalLM, self).__init__(config)
self.model = GitOPTModel(config)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.lm_head(sequence_output)
loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
num_image_tokens = self.model.image_patch_tokens
shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
use_cache=None,
**kwargs,
):
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
input_shape = input_ids.shape
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": kwargs.get("pixel_values", None),
"past_key_values": past_key_values,
"use_cache": use_cache,
}
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),
)
return reordered_past