Skip to content

Commit

Permalink
Merge pull request #1 from gpucce/beamsearch
Browse files Browse the repository at this point in the history
Simpler beam search for now
  • Loading branch information
Soonhwan-Kwon committed Jan 14, 2023
2 parents f640696 + dc8c128 commit d0469ee
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 140 deletions.
167 changes: 47 additions & 120 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@
MultimodalTransformer,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
from .generation_utils import top_a, top_k, top_p, prepare_inputs_for_generation, torch_int_div
from transformers import BeamSearchScorer, LogitsProcessorList, HammingDiversityLogitsProcessor, \
MinLengthLogitsProcessor,StoppingCriteriaList
from .generation_utils import top_a, top_k, top_p, prepare_inputs_for_generation
from transformers import BeamSearchScorer, LogitsProcessorList, MinLengthLogitsProcessor, StoppingCriteriaList

from .generation_utils import validate_stopping_criteria
import os
import gc
import warnings
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, Dict

@dataclass
Expand Down Expand Up @@ -130,15 +127,15 @@ def encode_image(self, images, normalize=True, return_tokens=False):
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
return (image_latent, tokens_embs) if return_tokens else image_latent

def encode_text(self, text, normalize=True, return_tokens=False):
text = text[:, :-1] # make space for CLS token
def encode_text(self, text, normalize=True, return_tokens=False, add_cls=True):
text = text[:, :-1] if add_cls else text # make space for CLS token
text_latent, token_emb = self.text.encoder(text, output_tokens=True)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return (text_latent, token_emb) if return_tokens else text_latent

def forward(self, image, text, output_dict=False, image_latent=None, image_embs=None):
def forward(self, image, text, output_dict=False, image_latent=None, image_embs=None, add_cls=True):

text_latent, token_embs = self.encode_text(text, return_tokens=True)
text_latent, token_embs = self.encode_text(text, return_tokens=True, add_cls=add_cls)
if image_latent is None or image_embs is None:
image_latent, image_embs = self.encode_image(image, return_tokens=True)

Expand Down Expand Up @@ -223,7 +220,7 @@ def _update_model_kwargs_for_generation(self,
model_kwargs["past"] = outputs["past_key_values"]
elif "mems" in outputs:
model_kwargs["past"] = outputs.memes
elif "past_buckets_states" in outpus:
elif "past_buckets_states" in outputs:
model_kwargs["past"] = outputs.past_buckets_states
else:
model_kwargs["past"] = None
Expand All @@ -242,32 +239,29 @@ def _update_model_kwargs_for_generation(self,
)
return model_kwargs

def generate_beamseach(
self,
image_inputs,
max_length=None,
pad_token_id=0,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
output_scores=True,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=True,
synced_gpus=False,
**kwargs,
def generate_beamsearch(
self,
image_inputs,
max_length=None,
pad_token_id=0,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
synced_gpus=False,
**kwargs,
):
device = image_inputs.device
image_inputs = image_inputs.repeat(num_beams, 1, 1, 1)
image_latent, image_embs = self.encode_image(image_inputs, return_token=True)
batch_size = image_inputs.shape[0]
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
image_latent, image_embs = self.encode_image(image_inputs, return_tokens=True)

input_ids = torch.ones((num_beams, 1), device=device, dtype=torch.long)
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
input_ids = input_ids * sot_token_id
beam_scorer = BeamSearchScorer(
batch_size=1,
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=device,
Expand All @@ -284,55 +278,24 @@ def generate_beamseach(
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
validate_stopping_criteria(stopping_criteria, max_length)

# TODO: where it gets config
pad_token_id = pad_token_id if pad_token_id is not None else self.text_cfg.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.text_cfg.pad_token_id

# in HF reads output_scores from config when it is None
output_scores = output_scores if output_scores is not None else False
# in HF reads output_attention from config when it is None
output_attention = output_attentions if output_attentions is not None else False
# in HF reads return_dict_in_generate from config when it is None
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else False
)
pad_token_id = pad_token_id if pad_token_id is not None else self.pad_id
eos_token_id = eos_token_id if eos_token_id is not None else self.pad_id

batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
batch_beam_size, cur_len = input_ids.shape

if return_dict_in_generate and output_scores:
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
else:
beam_indices = None
beam_indices = None

if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)

# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
# if return_dict_in_generate and model.config.is_encoder_decoder:
# encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
# encoder_hidden_states = (
# model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
# )

beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime.
Expand All @@ -344,7 +307,7 @@ def generate_beamseach(
while True:
if synced_gpus:
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
torch.distributed.all_reduce(this_peer_finished_flag, op=torch.distributed.ReduceOp.SUM)
if this_peer_finished_flag.item() == 0.0:
break

Expand All @@ -356,23 +319,20 @@ def generate_beamseach(

# do one decoder step on all beams of all sentences in batch
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
# logits = self(image, x, attn_mask=self.attn_mask[:x_seq_len, :x_seq_len])[2][:, -1]

outputs = self(model_inputs['images'],
model_inputs['text'],
attn_mask[:cur_len, :cur_len],
image_latent=image_latent,
image_embs=image_embs,
output_dict=True)
outputs["past_key_values"] = None
outputs = self(
model_inputs['images'],
model_inputs['text'],
image_latent=image_latent,
image_embs=image_embs,
output_dict=True,
add_cls=False
)

if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue

if output_scores:
processed_score = torch.zeros_like(outputs['logits'][:, -1, :])

for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
Expand All @@ -389,32 +349,22 @@ def generate_beamseach(

# select outputs of beams of currentg group only
next_token_logits = outputs['logits'][batch_group_indices, -1, :]

# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
# next_token_logits = model.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
)# (batch_size * group_size, vocab_size)
vocab_size = next_token_scores.shape[-1]
vocab_size = next_token_logits.shape[-1]

next_token_scores_processed = logits_processor(
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)

if output_scores:
processed_score[batch_group_indices] = next_token_scores_processed


# reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)

next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
)

next_indices = torch_int_div(next_tokens, vocab_size)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size

# stateless
Expand All @@ -432,48 +382,25 @@ def generate_beamseach(
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]

if return_dict_in_generate and output_scores:
beam_indices[beam_group_idx] = tuple(
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
)
input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]

# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size)
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
)

if return_dict_in_generate:
if output_scores:
scores += (processed_score,)
# TODO : deal with when output_attentions in next time
# if output_attentions:
# decoder_attentions += (
# (outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions, )
# )
# if model.config.is_encoder_decoder:
# cross_attention += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if model.config.is_encoder_decoder
else (outputs.hidden_states,)
)

input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs
)
# TODO: support it in next step
# if model_kwargs["past"] is not None:
# model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices)

# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if beam_scorer.is_done or stopping_criteria(input_ids, None):
if not synced_gpus:
break
else:
Expand All @@ -489,4 +416,4 @@ def generate_beamseach(
max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
)
return sequence_outputs['sequences'][0]
return sequence_outputs['sequences']
21 changes: 1 addition & 20 deletions src/open_clip/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,12 @@ def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_leng
for stopping_criterium in stopping_criteria:
if isinstance(stopping_criterium, MaxLengthCriteria):
found = True
if stopping_criterium.max_length != max_length:
warnings.warn(
"You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning
)
if not found:
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))

def torch_int_div(tensor1, tensor2):
"""
A function that performs integer division across different versions of PyTorch.
"""
if is_torch_less_than_1_8:
return tensor1 // tensor2
else:
return torch.div(tensor1, tensor2, rounding_mode="floor")

def exists(val):
return val is not None

# nucleus

def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
Expand All @@ -41,25 +26,21 @@ def top_p(logits, thres = 0.9):
sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# topk

def top_k(logits, thres = 0.9):
k = ceil((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs

# top_a

def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02):
probs = F.softmax(logits, dim=-1)
limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
logits[probs < limit] = float('-inf')
logits[probs >= limit] = 1
return logits


# prep for adding past_key_values
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
Expand Down

0 comments on commit d0469ee

Please sign in to comment.