In [7]:
import os, sys

import numpy as np
import pandas as pd 
from datasets import load_dataset

import importlib
from tqdm import tqdm
from joblib import Parallel, delayed
from copy import copy

import transformers
from transformers import (
    AutoConfig,
    AutoTokenizer,
    FlaxAutoModel,
    FlaxT5Model,
    FlaxAutoModelForSequenceClassification,
    HfArgumentParser,
    PretrainedConfig,
    TrainingArguments,
    is_tensorboard_available,
)

from flax.training.common_utils import get_metrics, onehot, shard


data_root = "/kaggle/input/feedback-prize-effectiveness/"
train = pd.read_csv("/kaggle/input/feedback-prize-effectiveness/train.csv")


In [8]:
sys.path.append("../configs")
cfg = copy(importlib.import_module("elu_config").cfg)
cfg.model_name_or_path = "facebook/opt-125m"

In [9]:
import jax.numpy as jnp

sys.path.append("../models/")
# FeedbackJax = importlib.import_module("gpt2_model").FeedbackJax
# Net(cfg, config_path=None, pretrained=True)


# FeedbackJax.from_text_pretrained(
#     cfg.model_name_or_path,
#     seed=cfg.seed,
#     dtype=jnp.float32,
#     text_from_pt=False,
# )

## Data loading section

In [10]:
## data collator with dynamic padding
# def train_data_collator(rng:)
import jax
import datasets
from typing import Any, Callable, Dict, Optional, Tuple

rng = jax.random.PRNGKey(1)#cfg.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

Array = Any
Dataset = datasets.arrow_dataset.Dataset
PRNGKey = Any


def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
    """Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        batch = {k: np.array(v) for k, v in batch.items()}
        #batch = shard(batch)

        yield batch

train_dataset = load_dataset("json", data_files="/kaggle/working/folds/valid_0.jsonl", split="train")
train_loader = train_data_collator(rng, train_dataset, cfg.per_device_train_batch_size)



## Model tests

In [11]:
for batch in train_loader:
    print(batch['input_ids'].shape)    
    break

(16, 512)


In [12]:
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import flax
from typing import Optional, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
from configuration_hybrid_clip import FeedbackJaxConfig
from flax.core.frozen_dict import FrozenDict
from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
from transformers.modeling_flax_utils import FlaxPreTrainedModel
from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
from transformers.utils import logging


logger = logging.get_logger(__name__)


from transformers.utils import ModelOutput
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling

@flax.struct.dataclass
class FeedbackJaxOutput(ModelOutput):
    """
    Args:
        logits:(`jnp.ndarray` of shape `(will see)`):
    """

    logits: jnp.ndarray = None
    text_outputs: jnp.array = None
    
    def to_tuple(self) -> Tuple[Any]:
        return tuple(
            self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
            for k in self.keys()
        )



class FeedbackJaxModule(nn.Module):
    config: FeedbackJaxConfig
    dtype: jnp.dtype = jnp.float32
    freeze_backbones: bool = False

    def setup(self):
        text_config = self.config.text_config

        self.projection_dim = self.config.projection_dim
        self.text_embed_dim = text_config.hidden_size

        text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class

        self.text_model = text_module(text_config, dtype=self.dtype)

        self.text_projection = nn.Dense(
            self.projection_dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
            use_bias=False,
        )

        self.logits = nn.Dense(3) 

    def __call__(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        token_type_ids=None,
        deterministic: bool = True,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        text_embeds = text_outputs[1]
        text_embeds = self.text_projection(text_embeds)
        logits = self.logits(text_embeds)

        return FeedbackJaxOutput(
            logits=logits,
            text_outputs=text_outputs
        )


class FeedbackJax(FlaxPreTrainedModel):
    config_class = FeedbackJaxConfig
    module_class = FeedbackJaxModule

    def __init__(
        self,
        config: FeedbackJaxConfig,
        input_shape: Optional[Tuple] = None,
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        **kwargs
    ):
        if input_shape is None:
            input_shape = ((1, 1), (1, 224, 224, 3))
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
        # init input tensor
        input_ids = jnp.zeros(input_shape[0], dtype="i4")
        token_type_ids = jnp.ones_like(input_ids)
        attention_mask = jnp.ones_like(input_ids)

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        return self.module.init(rngs, input_ids, attention_mask, token_type_ids)["params"]

    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        token_type_ids=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        if position_ids is None:
            position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        if token_type_ids is None:
            token_type_ids = jnp.zeros_like(input_ids)

        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        return self.module.apply(
            {"params": params or self.params},
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            jnp.array(token_type_ids, dtype="i4"),
            not train,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
        )

    def get_text_features(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        token_type_ids=None,
        dropout_rng: jax.random.PRNGKey = None,
        train=False,
    ):
        if position_ids is None:
            position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        if token_type_ids is None:
            token_type_ids = jnp.zeros_like(input_ids)

        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        def _get_features(module, input_ids, attention_mask, position_ids, token_type_ids, deterministic):
            text_outputs = module.text_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                token_type_ids=token_type_ids,
                deterministic=deterministic,
            )

            pooled_output = text_outputs[1]
            text_features = module.text_projection(pooled_output)
            logits = module.logits(text_features)
            return logits, text_outputs

        return self.module.apply(
            {"params": self.params},
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            jnp.array(token_type_ids, dtype="i4"),
            not train,
            method=_get_features,
            rngs=rngs,
        )

    @classmethod
    def from_pretrained(
        cls,
        model_name_or_path: str = None,
        *model_args,
        **kwargs,
    ) -> FlaxPreTrainedModel:

        kwargs_text = {
            argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
        }

        # remove text, vision kwargs from kwargs
        for key in kwargs_text.keys():
            del kwargs["text_" + key]

        # Load and initialize the text and vision model
        text_model = kwargs_text.pop("model", None)
        if text_model is None:
            assert (
                model_name_or_path is not None
            ), "If `model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
            from transformers import FlaxAutoModel

            if "config" not in kwargs_text:
                from transformers import AutoConfig

                text_config = AutoConfig.from_pretrained(model_name_or_path)
                kwargs_text["config"] = text_config

            text_model = FlaxAutoModel.from_pretrained(model_name_or_path, *model_args, **kwargs_text)


        # instantiate config with corresponding kwargs
        dtype = kwargs.pop("dtype", jnp.float32)
        config = FeedbackJaxConfig.from_text_configs(text_model.config, **kwargs)

        # init model
        model = cls(config, *model_args, dtype=dtype, **kwargs)

        model.params["text_model"] = text_model.params

        return model

In [14]:
cfg.dtype = "int"
cfg.from_pt = False

cfg.freeze_backbones = False

model = FeedbackJax.from_pretrained(
    cfg.model_name_or_path,
    seed=cfg.seed,
    dtype=jnp.float32,
    text_from_pt=False,
    freeze_backbones=cfg.freeze_backbones
)
config = model.config

Some weights of the model checkpoint at facebook/opt-125m were not used when initializing FlaxOPTModel: {('decoder', 'final_layer_norm', 'scale'), ('decoder', 'final_layer_norm', 'bias')}
- This IS expected if you are initializing FlaxOPTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxOPTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some of the weights of FlaxOPTModel were initialized in float16 precision from the model checkpoint at facebook/opt-125m:
[('decoder', 'embed_positions', 'embedding'), ('decoder', 'embed_tokens', 'embedding'), ('decoder', 'layers', '0', 'fc1', 'bias'), ('decoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 

TypeError: __call__() got an unexpected keyword argument 'token_type_ids'

In [12]:
# labels = batch.pop("labels")
outs = model(**batch)
nn.softmax(outs.logits)

DeviceArray([[nan, nan, nan]], dtype=float32)

In [17]:
sys.path.append("../shampoo")

In [19]:
from distributed_shampoo import distributed_shampoo

TypeError: distributed_shampoo() missing 2 required positional arguments: 'learning_rate' and 'block_size'