In [14]:
from transformers import GPT2Config, GPT2Model
from huggingface_hub import PyTorchModelHubMixin

class GPT2WithCustomHead(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config_or_model_name, num_labels=None):
        super().__init__()
        if isinstance(config_or_model_name, str):
            self.base_model = GPT2Model.from_pretrained(config_or_model_name)
            self.config = self.base_model.config
        else:
            self.config = config_or_model_name
            self.base_model = GPT2Model(self.config)
        self.score = nn.Linear(self.config.n_embd, num_labels or self.config.num_labels)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        logits = self.score(outputs.last_hidden_state[:, -1, :])
        return logits

    @classmethod
    def from_pretrained(cls, model_name, *args, **kwargs):
        config = GPT2Config.from_pretrained(model_name)
        return super().from_pretrained(model_name, config, *args, **kwargs)

# Initialize the base model
config = GPT2Config.from_pretrained('openai-community/gpt2')
config.num_labels = 2
custom_model = GPT2WithCustomHead(config)

# Save locally
custom_model.save_pretrained("my-awesome-model", config=config.to_dict())

# Push to the hub
custom_model.push_to_hub("my-awesome-model", config=config.to_dict())

# Reload
custom_model = GPT2WithCustomHead.from_pretrained("my-awesome-model")

TypeError: Object of type GPT2Config is not JSON serializable