set cuda id

In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=3

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=3


In [2]:
import jax
import jax.numpy as jnp
import flax
from entmax_jax import sparsemax
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
from IPython.display import display, HTML
from functools import partial
import json
from entmax_jax.activations import sparsemax, entmax15
from sklearn.metrics import roc_auc_score, average_precision_score

from meta_expl.explainers import load_explainer
from meta_expl.models import load_model
from meta_expl.data.mlqe import dataloader

from evaluate_explanations import evaluate_word_level, evaluate_sentence_level, aggregate_pieces

In [3]:
from argparse import Namespace
args = Namespace(
    seed=0,
    setup='static_teacher',  # "no_teacher", "static_teacher", "learnable_teacher"
    task='mlqe',
    task_params={
        'eval_lp': 'ro-en',
    },
    arch='xlm-r',
    tokenizer='xlm-r',
    
    explainer='attention_explainer',
    explainer_params={
        'normalize_head_coeffs': 'sparsemax',
        'normalizer_fn': 'softmax',
        'aggregator_idx': 'mean',
        'aggregator_dim': 'row',
        'layer_idx': -1,
        'head_idx': None
    },
    teacher_explainer='attention_explainer',
    teacher_explainer_params={
        'normalize_head_coeffs': 'sparsemax',
        'normalizer_fn': 'softmax',
        'aggregator_idx': 'mean',
        'aggregator_dim': 'row',
        'init_fn': 'uniform',
        'layer_idx': 9,
        'head_idx': 5,
    },
    initialize_embeddings=True,
    num_examples=4100,
    
    max_len=256,
    num_epochs=None,
    kld_coeff=5,
    patience=5,
    optimizer='sgd',
    learning_rate=1e-5,
    warmup_steps=4000,
    batch_size=16,
    clip_grads=1.0,
    weight_decay=0.01,
    
    # meta-learning
    meta_interval=1,
    meta_warmup=0,
    metaoptimizer='adamw',
    meta_lr=1e-3,
    meta_explicit=False,
    num_resets=0,
    
    teacher_dir='data/mlqe-xlmr-explainer/teacher_dir',
    model_dir=None,
    explainer_dir=None,
    teacher_explainer_dir=None,
    
    wandb=None,
    log_teacher_params=None,
    save_test_outputs=None,
)

In [4]:
# create dummy inputs for model instantiation
input_ids = jnp.ones((args.batch_size, args.max_len), jnp.int32)
dummy_inputs = {
    "input_ids": input_ids,
    "attention_mask": jnp.ones_like(input_ids),
    "token_type_ids": jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
    "position_ids": jnp.ones_like(input_ids),
}
dummy_inputs['input_ids'].shape

INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.


(16, 256)

In [5]:
from meta_expl.train import main
main(args)

You are using a model of type xlm-roberta to instantiate a model of type roberta. This is not supported for all configurations of models and can yield errors.
Some weights of the model checkpoint at xlm-roberta-base were not used when initializing FlaxRobertaForSequenceClassification: {('roberta', 'pooler', 'dense', 'kernel'), ('lm_head', 'dense', 'bias'), ('lm_head', 'layer_norm', 'kernel'), ('lm_head', 'bias'), ('lm_head', 'dense', 'kernel'), ('lm_head', 'layer_norm', 'bias'), ('roberta', 'pooler', 'dense', 'bias'), ('lm_head', 'decoder', 'kernel')}
- This IS expected if you are initializing FlaxRobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxRobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceC

Epoch: 1/-1


2022-02-21 13:26:37.598818: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 3.50GiB (rounded to 3755229440)requested by op 
2022-02-21 13:26:37.601067: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] *__******************_***********************_________________*****_________________________________
2022-02-21 13:26:37.608288: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2124] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 3755229344 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    2.07GiB
              constant allocation:       112B
        maybe_live_out allocation:    1.04GiB
     preallocated temp allocation:    3.50GiB
  preallocated temp fragmentation:   35.92MiB (1.00%)
                 total allocation:    6.60GiB
          

RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 3755229344 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    2.07GiB
              constant allocation:       112B
        maybe_live_out allocation:    1.04GiB
     preallocated temp allocation:    3.50GiB
  preallocated temp fragmentation:   35.92MiB (1.00%)
                 total allocation:    6.60GiB
              total fragmentation:  969.68MiB (14.34%)
Peak buffers:
	Buffer 1:
		Size: 732.43MiB
		Entry Parameter Subshape: f32[250002,768]
		==========================

	Buffer 2:
		Size: 732.43MiB
		Entry Parameter Subshape: f32[250002,768]
		==========================

	Buffer 3:
		Size: 732.43MiB
		Operator: op_name="jit(train_step_with_teacher)/jit(main)/add" source_file="/home/mtreviso/meta-expl/env/lib/python3.8/site-packages/optax/_src/update.py" source_line=43
		XLA Label: fusion
		Shape: f32[250002,768]
		==========================

	Buffer 4:
		Size: 48.00MiB
		XLA Label: custom-call
		Shape: f32[4096,3072]
		==========================

	Buffer 5:
		Size: 48.00MiB
		XLA Label: custom-call
		Shape: f32[16,12,256,256]
		==========================

	Buffer 6:
		Size: 48.00MiB
		XLA Label: custom-call
		Shape: f32[4096,3072]
		==========================

	Buffer 7:
		Size: 48.00MiB
		XLA Label: custom-call
		Shape: f32[16,12,256,256]
		==========================

	Buffer 8:
		Size: 48.00MiB
		XLA Label: custom-call
		Shape: f32[4096,3072]
		==========================

	Buffer 9:
		Size: 48.00MiB
		XLA Label: custom-call
		Shape: f32[16,12,256,256]
		==========================

	Buffer 10:
		Size: 48.00MiB
		XLA Label: custom-call
		Shape: f32[4096,3072]
		==========================

	Buffer 11:
		Size: 48.00MiB
		XLA Label: custom-call
		Shape: f32[16,12,256,256]
		==========================

	Buffer 12:
		Size: 48.00MiB
		XLA Label: custom-call
		Shape: f32[4096,3072]
		==========================

	Buffer 13:
		Size: 48.00MiB
		XLA Label: custom-call
		Shape: f32[16,12,256,256]
		==========================

	Buffer 14:
		Size: 48.00MiB
		XLA Label: custom-call
		Shape: f32[4096,3072]
		==========================

	Buffer 15:
		Size: 48.00MiB
		XLA Label: custom-call
		Shape: f32[16,12,256,256]
		==========================

