-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e53d016
commit 9a63688
Showing
1 changed file
with
321 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,321 @@ | ||
import os | ||
from pathlib import Path | ||
|
||
import torch | ||
from torch import nn | ||
from transformers.activations import get_activation | ||
from transformers import AutoModelForSequenceClassification | ||
|
||
|
||
class CerberusSharedEncoder(nn.Module): | ||
def __init__(self, layer, config): | ||
super().__init__() | ||
self.layer = layer | ||
self.config = config | ||
|
||
def forward(self, | ||
hidden_states, | ||
attention_mask=None, | ||
head_mask=None, | ||
encoder_hidden_states=None, | ||
encoder_attention_mask=None, | ||
output_attentions=False, | ||
output_hidden_states=False | ||
): | ||
all_hidden_states = () | ||
all_attentions = () | ||
for i, layer_module in enumerate(self.layer): | ||
if output_hidden_states: | ||
all_hidden_states = all_hidden_states + (hidden_states,) | ||
|
||
if getattr(self.config, 'gradient_checkpointing', False): | ||
|
||
def create_custom_forward(module): | ||
def custom_forward(*inputs): | ||
return module(*inputs, output_attentions) | ||
|
||
return custom_forward | ||
|
||
layer_outputs = torch.utils.checkpoint.checkpoint( | ||
create_custom_forward(layer_module), | ||
hidden_states, | ||
attention_mask, | ||
head_mask[i], | ||
encoder_hidden_states, | ||
encoder_attention_mask, | ||
) | ||
else: | ||
layer_outputs = layer_module( | ||
hidden_states, | ||
attention_mask, | ||
head_mask[i], | ||
encoder_hidden_states, | ||
encoder_attention_mask, | ||
output_attentions, | ||
) | ||
hidden_states = layer_outputs[0] | ||
|
||
if output_attentions: | ||
all_attentions = all_attentions + (layer_outputs[1],) | ||
|
||
# Add last layer | ||
if output_hidden_states: | ||
all_hidden_states = all_hidden_states + (hidden_states,) | ||
return hidden_states, all_hidden_states, all_attentions | ||
|
||
|
||
class CerberusHead(nn.Module): | ||
def __init__(self, config, start_idx, layer, classifier): | ||
super().__init__() | ||
self.config = config | ||
self.start_idx = start_idx | ||
self.layer = layer | ||
self.classifier = classifier | ||
|
||
def forward(self, hidden_states, all_hidden_states, all_attentions, | ||
attention_mask=None, | ||
head_mask=None, | ||
encoder_hidden_states=None, | ||
encoder_attention_mask=None, | ||
output_attentions=False, | ||
output_hidden_states=False): | ||
first = True | ||
for i, layer_module in enumerate(self.layer): | ||
i += self.start_idx | ||
if first: | ||
first = False | ||
elif output_hidden_states and not first: | ||
all_hidden_states = all_hidden_states + (hidden_states,) | ||
|
||
if getattr(self.config, 'gradient_checkpointing', False): | ||
|
||
def create_custom_forward(module): | ||
def custom_forward(*inputs): | ||
return module(*inputs, output_attentions) | ||
|
||
return custom_forward | ||
|
||
layer_outputs = torch.utils.checkpoint.checkpoint( | ||
create_custom_forward(layer_module), | ||
hidden_states, | ||
attention_mask, | ||
head_mask[i], | ||
encoder_hidden_states, | ||
encoder_attention_mask, | ||
) | ||
else: | ||
layer_outputs = layer_module( | ||
hidden_states, | ||
attention_mask, | ||
head_mask[i], | ||
encoder_hidden_states, | ||
encoder_attention_mask, | ||
output_attentions, | ||
) | ||
hidden_states = layer_outputs[0] | ||
|
||
if output_attentions: | ||
all_attentions = all_attentions + (layer_outputs[1],) | ||
|
||
# Add last layer | ||
if output_hidden_states: | ||
all_hidden_states = all_hidden_states + (hidden_states,) | ||
|
||
outputs = (hidden_states,) | ||
if output_hidden_states: | ||
outputs = outputs + (all_hidden_states,) | ||
if output_attentions: | ||
outputs = outputs + (all_attentions,) | ||
|
||
if self.classifier is None: | ||
return outputs | ||
|
||
sequence_output = outputs[0] | ||
logits = self.classifier(sequence_output) | ||
outputs = (logits,) + outputs[1:] # add hidden states and attention if they are here | ||
return outputs | ||
|
||
|
||
class CerberusClassifier(nn.Module): | ||
def __init__(self, in_features, out_features, dropout_prob): | ||
super().__init__() | ||
self.dense = nn.Linear(in_features, in_features) | ||
self.dropout = nn.Dropout(dropout_prob) | ||
self.out_proj = nn.Linear(in_features, out_features) | ||
|
||
def forward(self, x, **kwargs): | ||
# x = features[:, 0, :] # take <s> token (equiv. to [CLS]) | ||
x = self.dropout(x) | ||
x = self.dense(x) | ||
x = get_activation('gelu')(x) # although BERT uses tanh here, it seems Electra authors used gelu here | ||
x = self.dropout(x) | ||
x = self.out_proj(x) | ||
return x | ||
|
||
|
||
class CerberusModel(nn.Module): | ||
''' | ||
head_configs: | ||
- model: | ||
class!: TandaElectra | ||
pretrained!: ${student_model_path}/best | ||
base_model: electra | ||
classifier: classifier | ||
- model: | ||
class!: TandaElectra | ||
pretrained!: ${student_model_path}/best | ||
base_model: electra | ||
classifier: classifier | ||
- model: | ||
class!: TandaElectra | ||
pretrained!: ${student_model_path}/best | ||
base_model: electra | ||
classifier: classifier | ||
num_shared_blocks: 11 | ||
requires_head_classifier: True | ||
uses_softmax_avg: True | ||
final_classifier_config: | ||
''' | ||
def __init__(self, head_configs=None, num_shared_blocks=11, requires_head_classifier=True, | ||
final_classifier_config=None, start_ckpt_file_path=None, save_ckpt_file_name='cerberus_model.pt', | ||
freezes_all_except_classifier=False, uses_softmax_avg=True): | ||
super().__init__() | ||
if head_configs is None: | ||
head_configs = [ | ||
{'model': {'pretrained_model_name_or_path': 'google/electra-base-discriminator'}, | ||
'base_model': 'electra', 'classifier': 'classifier'}, | ||
{'model': {'pretrained_model_name_or_path': 'google/electra-base-discriminator'}, | ||
'base_model': 'electra', 'classifier': 'classifier'}, | ||
{'model': {'pretrained_model_name_or_path': 'google/electra-base-discriminator'}, | ||
'base_model': 'electra', 'classifier': 'classifier'} | ||
] | ||
|
||
self.num_shared_blocks = num_shared_blocks | ||
self.requires_head_classifier = requires_head_classifier | ||
models = [AutoModelForSequenceClassification.from_pretrained(**head_config['model']) | ||
for head_config in head_configs] | ||
core_config = head_configs[0] | ||
core_base_model = getattr(models[0], core_config['base_model']) | ||
self.core_base_model = core_base_model | ||
self.embeddings = core_base_model.embeddings | ||
self.embeddings_project =\ | ||
core_base_model.embeddings_project if hasattr(core_base_model, 'embeddings_project') else None | ||
self.cerberus_shared_encoder = \ | ||
CerberusSharedEncoder(core_base_model.encoder.layer[:num_shared_blocks], core_config) | ||
module_list = list() | ||
for model, head_config in zip(models, head_configs): | ||
base_model_config = head_config['base_model'] | ||
base_model = getattr(model, base_model_config) | ||
classifier = getattr(model, head_config['classifier']) if requires_head_classifier else None | ||
cerberus_head = CerberusHead(base_model_config, num_shared_blocks, | ||
base_model.encoder.layer[num_shared_blocks:], classifier) | ||
module_list.append(cerberus_head) | ||
|
||
self.cerberus_heads = nn.ModuleList(module_list) | ||
self.final_classifier = \ | ||
CerberusClassifier(**final_classifier_config) if final_classifier_config is not None else None | ||
self.ckpt_file_name = save_ckpt_file_name | ||
if start_ckpt_file_path is not None and os.path.isfile(start_ckpt_file_path): | ||
self.load_state_dict(torch.load(start_ckpt_file_path, map_location='cpu'), strict=False) | ||
print('Loaded model weights at {}'.format(start_ckpt_file_path)) | ||
elif start_ckpt_file_path is not None: | ||
print('start ckpt is given `{}`, but not found'.format(start_ckpt_file_path)) | ||
|
||
self.freezes_all_except_classifier = freezes_all_except_classifier | ||
self.uses_softmax_avg = uses_softmax_avg | ||
|
||
def sub_forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, | ||
labels, output_attentions, output_hidden_states): | ||
|
||
output_attentions = \ | ||
output_attentions if output_attentions is not None else self.core_base_model.config.output_attentions | ||
output_hidden_states = ( | ||
output_hidden_states if output_hidden_states is not None else self.core_base_model.config.output_hidden_states | ||
) | ||
|
||
if input_ids is not None and inputs_embeds is not None: | ||
raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time') | ||
elif input_ids is not None: | ||
input_shape = input_ids.size() | ||
elif inputs_embeds is not None: | ||
input_shape = inputs_embeds.size()[:-1] | ||
else: | ||
raise ValueError('You have to specify either input_ids or inputs_embeds') | ||
|
||
device = input_ids.device if input_ids is not None else inputs_embeds.device | ||
|
||
if attention_mask is None: | ||
attention_mask = torch.ones(input_shape, device=device) | ||
if token_type_ids is None: | ||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) | ||
|
||
extended_attention_mask = self.core_base_model.get_extended_attention_mask(attention_mask, input_shape, device) | ||
head_mask = self.core_base_model.get_head_mask(head_mask, self.core_base_model.config.num_hidden_layers) | ||
|
||
hidden_states = self.embeddings( | ||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds | ||
) | ||
|
||
if hasattr(self, 'embeddings_project') and self.embeddings_project is not None: | ||
hidden_states = self.embeddings_project(hidden_states) | ||
|
||
hidden_states, all_hidden_states, all_attentions = self.cerberus_shared_encoder( | ||
hidden_states=hidden_states, | ||
attention_mask=extended_attention_mask, | ||
head_mask=head_mask, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states | ||
) | ||
outputs_list = list() | ||
for cerberus_head in self.cerberus_heads: | ||
outputs = cerberus_head(hidden_states, all_hidden_states, all_attentions, | ||
attention_mask=extended_attention_mask, | ||
head_mask=head_mask, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states) | ||
outputs_list.append(outputs[0]) | ||
return outputs_list | ||
|
||
def forward(self, | ||
input_ids=None, | ||
attention_mask=None, | ||
token_type_ids=None, | ||
position_ids=None, | ||
head_mask=None, | ||
inputs_embeds=None, | ||
labels=None, | ||
output_attentions=None, | ||
output_hidden_states=None, | ||
return_dict=None): | ||
|
||
if self.freezes_all_except_classifier: | ||
with torch.no_grad(): | ||
outputs_list = self.sub_forward(input_ids, attention_mask, token_type_ids, position_ids, head_mask, | ||
inputs_embeds, labels, output_attentions, output_hidden_states) | ||
else: | ||
outputs_list = self.sub_forward(input_ids, attention_mask, token_type_ids, position_ids, head_mask, | ||
inputs_embeds, labels, output_attentions, output_hidden_states) | ||
|
||
if self.final_classifier is None: | ||
if self.requires_head_classifier: | ||
return sum([torch.softmax(outputs, dim=1) for outputs in outputs_list]) / len(outputs_list) \ | ||
if self.uses_softmax_avg else sum(outputs_list) / len(outputs_list) | ||
return outputs_list | ||
|
||
for i in range(len(outputs_list)): | ||
outputs_list[i] = outputs_list[i][:, 0, :] | ||
logits = self.final_classifier(torch.cat(outputs_list, dim=1)) | ||
return logits | ||
|
||
def save_pretrained(self, ckpt_dir_path): | ||
if os.path.isfile(ckpt_dir_path): | ||
print('Provided path ({}) should be a directory, not a file'.format(ckpt_dir_path)) | ||
return | ||
|
||
Path(ckpt_dir_path).mkdir(parents=True, exist_ok=True) | ||
# Only save the model itself if we are using distributed training | ||
model_to_save = self.module if hasattr(self, 'module') else self | ||
|
||
ckpt_file_path = os.path.join(ckpt_dir_path, self.ckpt_file_name) | ||
torch.save(model_to_save.state_dict(), ckpt_file_path) | ||
print('Model weights saved in {}'.format(ckpt_file_path)) |