Skip to content

Commit

Permalink
Add CERBERUS model code
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshitomo-matsubara committed Apr 12, 2023
1 parent e53d016 commit 9a63688
Showing 1 changed file with 321 additions and 0 deletions.
321 changes: 321 additions & 0 deletions cerberus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
import os
from pathlib import Path

import torch
from torch import nn
from transformers.activations import get_activation
from transformers import AutoModelForSequenceClassification


class CerberusSharedEncoder(nn.Module):
def __init__(self, layer, config):
super().__init__()
self.layer = layer
self.config = config

def forward(self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False
):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if getattr(self.config, 'gradient_checkpointing', False):

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0]

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)

# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return hidden_states, all_hidden_states, all_attentions


class CerberusHead(nn.Module):
def __init__(self, config, start_idx, layer, classifier):
super().__init__()
self.config = config
self.start_idx = start_idx
self.layer = layer
self.classifier = classifier

def forward(self, hidden_states, all_hidden_states, all_attentions,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False):
first = True
for i, layer_module in enumerate(self.layer):
i += self.start_idx
if first:
first = False
elif output_hidden_states and not first:
all_hidden_states = all_hidden_states + (hidden_states,)

if getattr(self.config, 'gradient_checkpointing', False):

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0]

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)

# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

outputs = (hidden_states,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)

if self.classifier is None:
return outputs

sequence_output = outputs[0]
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[1:] # add hidden states and attention if they are here
return outputs


class CerberusClassifier(nn.Module):
def __init__(self, in_features, out_features, dropout_prob):
super().__init__()
self.dense = nn.Linear(in_features, in_features)
self.dropout = nn.Dropout(dropout_prob)
self.out_proj = nn.Linear(in_features, out_features)

def forward(self, x, **kwargs):
# x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = get_activation('gelu')(x) # although BERT uses tanh here, it seems Electra authors used gelu here
x = self.dropout(x)
x = self.out_proj(x)
return x


class CerberusModel(nn.Module):
'''
head_configs:
- model:
class!: TandaElectra
pretrained!: ${student_model_path}/best
base_model: electra
classifier: classifier
- model:
class!: TandaElectra
pretrained!: ${student_model_path}/best
base_model: electra
classifier: classifier
- model:
class!: TandaElectra
pretrained!: ${student_model_path}/best
base_model: electra
classifier: classifier
num_shared_blocks: 11
requires_head_classifier: True
uses_softmax_avg: True
final_classifier_config:
'''
def __init__(self, head_configs=None, num_shared_blocks=11, requires_head_classifier=True,
final_classifier_config=None, start_ckpt_file_path=None, save_ckpt_file_name='cerberus_model.pt',
freezes_all_except_classifier=False, uses_softmax_avg=True):
super().__init__()
if head_configs is None:
head_configs = [
{'model': {'pretrained_model_name_or_path': 'google/electra-base-discriminator'},
'base_model': 'electra', 'classifier': 'classifier'},
{'model': {'pretrained_model_name_or_path': 'google/electra-base-discriminator'},
'base_model': 'electra', 'classifier': 'classifier'},
{'model': {'pretrained_model_name_or_path': 'google/electra-base-discriminator'},
'base_model': 'electra', 'classifier': 'classifier'}
]

self.num_shared_blocks = num_shared_blocks
self.requires_head_classifier = requires_head_classifier
models = [AutoModelForSequenceClassification.from_pretrained(**head_config['model'])
for head_config in head_configs]
core_config = head_configs[0]
core_base_model = getattr(models[0], core_config['base_model'])
self.core_base_model = core_base_model
self.embeddings = core_base_model.embeddings
self.embeddings_project =\
core_base_model.embeddings_project if hasattr(core_base_model, 'embeddings_project') else None
self.cerberus_shared_encoder = \
CerberusSharedEncoder(core_base_model.encoder.layer[:num_shared_blocks], core_config)
module_list = list()
for model, head_config in zip(models, head_configs):
base_model_config = head_config['base_model']
base_model = getattr(model, base_model_config)
classifier = getattr(model, head_config['classifier']) if requires_head_classifier else None
cerberus_head = CerberusHead(base_model_config, num_shared_blocks,
base_model.encoder.layer[num_shared_blocks:], classifier)
module_list.append(cerberus_head)

self.cerberus_heads = nn.ModuleList(module_list)
self.final_classifier = \
CerberusClassifier(**final_classifier_config) if final_classifier_config is not None else None
self.ckpt_file_name = save_ckpt_file_name
if start_ckpt_file_path is not None and os.path.isfile(start_ckpt_file_path):
self.load_state_dict(torch.load(start_ckpt_file_path, map_location='cpu'), strict=False)
print('Loaded model weights at {}'.format(start_ckpt_file_path))
elif start_ckpt_file_path is not None:
print('start ckpt is given `{}`, but not found'.format(start_ckpt_file_path))

self.freezes_all_except_classifier = freezes_all_except_classifier
self.uses_softmax_avg = uses_softmax_avg

def sub_forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds,
labels, output_attentions, output_hidden_states):

output_attentions = \
output_attentions if output_attentions is not None else self.core_base_model.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.core_base_model.config.output_hidden_states
)

if input_ids is not None and inputs_embeds is not None:
raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError('You have to specify either input_ids or inputs_embeds')

device = input_ids.device if input_ids is not None else inputs_embeds.device

if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

extended_attention_mask = self.core_base_model.get_extended_attention_mask(attention_mask, input_shape, device)
head_mask = self.core_base_model.get_head_mask(head_mask, self.core_base_model.config.num_hidden_layers)

hidden_states = self.embeddings(
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)

if hasattr(self, 'embeddings_project') and self.embeddings_project is not None:
hidden_states = self.embeddings_project(hidden_states)

hidden_states, all_hidden_states, all_attentions = self.cerberus_shared_encoder(
hidden_states=hidden_states,
attention_mask=extended_attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states
)
outputs_list = list()
for cerberus_head in self.cerberus_heads:
outputs = cerberus_head(hidden_states, all_hidden_states, all_attentions,
attention_mask=extended_attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states)
outputs_list.append(outputs[0])
return outputs_list

def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None):

if self.freezes_all_except_classifier:
with torch.no_grad():
outputs_list = self.sub_forward(input_ids, attention_mask, token_type_ids, position_ids, head_mask,
inputs_embeds, labels, output_attentions, output_hidden_states)
else:
outputs_list = self.sub_forward(input_ids, attention_mask, token_type_ids, position_ids, head_mask,
inputs_embeds, labels, output_attentions, output_hidden_states)

if self.final_classifier is None:
if self.requires_head_classifier:
return sum([torch.softmax(outputs, dim=1) for outputs in outputs_list]) / len(outputs_list) \
if self.uses_softmax_avg else sum(outputs_list) / len(outputs_list)
return outputs_list

for i in range(len(outputs_list)):
outputs_list[i] = outputs_list[i][:, 0, :]
logits = self.final_classifier(torch.cat(outputs_list, dim=1))
return logits

def save_pretrained(self, ckpt_dir_path):
if os.path.isfile(ckpt_dir_path):
print('Provided path ({}) should be a directory, not a file'.format(ckpt_dir_path))
return

Path(ckpt_dir_path).mkdir(parents=True, exist_ok=True)
# Only save the model itself if we are using distributed training
model_to_save = self.module if hasattr(self, 'module') else self

ckpt_file_path = os.path.join(ckpt_dir_path, self.ckpt_file_name)
torch.save(model_to_save.state_dict(), ckpt_file_path)
print('Model weights saved in {}'.format(ckpt_file_path))

0 comments on commit 9a63688

Please sign in to comment.