In [1]:
from transformers import AutoTokenizer, FlaxGPTNeoForCausalLM

In [4]:
from transformers import FlaxAutoModelForCausalLM

In [5]:
FlaxAutoModelForCausalLM.from_config??

[0;31mSignature:[0m [0mFlaxAutoModelForCausalLM[0m[0;34m.[0m[0mfrom_config[0m[0;34m([0m[0mconfig[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Instantiates one of the model classes of the library (with a causal language modeling head) from a configuration.

Note:
    Loading a model from its configuration file does **not** load the model weights. It only affects the
    model's configuration. Use :meth:`~transformers.FlaxAutoModelForCausalLM.from_pretrained` to load the model
    weights.

Args:
    config (:class:`~transformers.PretrainedConfig`):
        The model class to instantiate is selected based on the configuration class:

        - :class:`~transformers.GPT2Config` configuration class: :class:`~transformers.FlaxGPT2LMHeadModel` (OpenAI GPT-2 model)
        - :class:`~transformers.GPTNeoConfig` configuration class: :class:`~transformers.FlaxGPTNeoForCausalLM` (GPT Neo model)

Examples::

    >>> from transfo

In [10]:
model_ckpt = 'EleutherAI/gpt-neo-125M'

In [8]:
from transformers import AutoConfig

In [11]:
AutoConfig.from_pretrained(model_ckpt)

GPTNeoConfig {
  "activation_function": "gelu_new",
  "architectures": [
    "GPTNeoForCausalLM"
  ],
  "attention_dropout": 0,
  "attention_layers": [
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local"
  ],
  "attention_types": [
    [
      [
        "global",
        "local"
      ],
      6
    ]
  ],
  "bos_token_id": 50256,
  "embed_dropout": 0,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": null,
  "layer_norm_epsilon": 1e-05,
  "max_position_embeddings": 2048,
  "model_type": "gpt_neo",
  "num_heads": 12,
  "num_layers": 12,
  "resid_dropout": 0,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "transformers_version": "4.9.0.dev0",
  "use_cache": true,
  "vocab_size": 50257,
  "

In [26]:
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = FlaxGPTNeoForCausalLM.from_pretrained(model_ckpt)

Downloading:   0%|          | 0.00/501M [00:00<?, ?B/s]

In [40]:
tokenizer.pad_token = tokenizer.eos_token

In [41]:
prompt = """
import torch
from torch import nn

class Model(nn.Module):
"""

In [42]:
inputs = tokenizer(prompt, return_tensors='jax')
input_ids = inputs.input_ids

In [43]:
outputs = model(**inputs)
outputs.logits.shape

(1, 19, 50257)

In [32]:
print(model.generate.__doc__)


        Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
        and, multinomial sampling.

        Apart from :obj:`input_ids`, all the arguments below will default to the value of the attribute of the same
        name inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the
        default values of those config.

        Most of these parameters are explained in more detail in `this blog post
        <https://huggingface.co/blog/how-to-generate>`__.

        Parameters:

            input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
                The sequence used as a prompt for the generation.
            max_length (:obj:`int`, `optional`, defaults to 20):
                The maximum length of the sequence to be generated.
            do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or no

In [37]:
input_ids

DeviceArray([[  198, 11748, 28034,   198,  6738, 28034,  1330,   299,
                 77,   198,   198,  4871,  9104,     7, 20471,    13,
              26796,  2599,   198]], dtype=int32)

In [44]:
out = model.generate(input_ids,
                     max_length=200, 
#                      num_beams=5,
                     pad_token_id = tokenizer.pad_token_id
                    )

In [50]:
out

FlaxGreedySearchOutput(sequences=DeviceArray([[  198, 11748, 28034,   198,  6738, 28034,  1330,   299,
                 77,   198,   198,  4871,  9104,     7, 20471,    13,
              26796,  2599,   198,   220,   220,   220,   825, 11593,
              15003,   834,     7,   944,    11,  1438,    11,  2746,
                 11, 12429, 46265, 22046,  2599,   198,   220,   220,
                220,   220,   220,   220,   220,  2208,     7, 17633,
                 11,  2116,   737,   834, 15003,   834,     7,  3672,
                 11,  2746,    11, 12429, 46265, 22046,     8,   198,
                220,   220,   220,   220,   220,   220,   220,  2116,
                 13,  3672,   796,  1438,   198,   220,   220,   220,
                220,   220,   220,   220,  2116,    13, 19849,   796,
               2746,   198,   220,   220,   220,   220,   220,   220,
                220,  2116,    13, 46265, 22046,   796,   479,    86,
              22046,   198,   220,   220,   220,   220,  

In [52]:
print(tokenizer.decode(out[0][0]))


import torch
from torch import nn

class Model(nn.Module):
    def __init__(self, name, model, **kwargs):
        super(Model, self).__init__(name, model, **kwargs)
        self.name = name
        self.model = model
        self.kwargs = kwargs
        self.name_prefix ='model'
        self.name_suffix ='model_name'
        self.name_prefix_prefix ='model_name_prefix'
        self.name_suffix_prefix ='model_name_suffix'
        self.
