In [1]:
import os
import sys
import re
import json
import logging
import pandas as pd
import numpy as np

In [2]:
import torch
from transformers import (
    HfArgumentParser,
    set_seed,
    EvalPrediction,
    BertConfig, 
    BertTokenizer
)

from src.model.ca_mtl import CaMtl, CaMtlArguments
from src.utils.misc import MultiTaskDataArguments, Split
from src.mtl_trainer import MultiTaskTrainer, MultiTaskTrainingArguments
from src.data.mtl_dataset import MultiTaskDataset
from src.data.task_dataset import TaskDataset

logger = logging.getLogger(__name__)

# Manual Implementation

In [3]:
import os
import sys
import re
import json
import logging

import torch
from transformers import (
    HfArgumentParser,
    set_seed,
    EvalPrediction,
    BertConfig, 
    BertTokenizer
)

from src.model.ca_mtl import CaMtl, CaMtlArguments
from src.utils.misc import MultiTaskDataArguments, Split
from src.mtl_trainer import MultiTaskTrainer, MultiTaskTrainingArguments
from src.data.mtl_dataset import MultiTaskDataset
from src.data.task_dataset import TaskDataset

```
python run.py \
--model_name_or_path CA-MTL-tiny \
--data_dir /hub/CA-MTL/data \
--output_dir /hub/CA-MTL/mock_models \
--tasks D0 D1 MANC LOC SIGNT \
--overwrite_cache \
--task_data_folders D0/2021_04_08 D1/2021_04_08 MANC/2021_04_08 LOC/2021_04_08 SIGNT/2021_04_08 \
--do_train \
--do_eval \
--do_predict \
--evaluate_during_training \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--learning_rate 5e-5 \
--adam_epsilon 1e-8 \
--num_train_epochs 7 \
--warmup_steps 0 \
--save_steps 1500 \
--save_total_limit 1 \
--seed 43
```

```
python run.py \
--model_name_or_path CA-MTL-base \
--data_dir /hub/CA-MTL/data \
--output_dir /hub/CA-MTL/mock_models \
--tasks D0 D1 MANC LOC SIGNT \
--overwrite_cache \
--task_data_folders D0/2021_04_08 D1/2021_04_08 MANC/2021_04_08 LOC/2021_04_08 SIGNT/2021_04_08 \
--do_train \
--do_eval \
--do_predict \
--evaluate_during_training \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 16 \
--learning_rate 5e-5 \
--adam_epsilon 1e-8 \
--num_train_epochs 7 \
--warmup_steps 0 \
--save_steps 1500 \
--save_total_limit 1 \
--seed 43
```

```
python run_inference.py \
--model_name_or_path /hub/CA-MTL/mock_models/vital-smoke-40-9000 \
--data_dir /hub/CA-MTL/data \
--output_dir /hub/CA-MTL/data/SCORED \
--overwrite_cache \
--task_data_folders TOSCORE \
--do_predict
```

```
run.py --model_name_or_path CA-MTL-tiny --data_dir /hub/CA-MTL/data --output_dir /hub/CA-MTL/mock_models --tasks D0 D1 MANC LOC SIGNT --overwrite_cache --task_data_folders D0/2021_04_08 D1/2021_04_08 MANC/2021_04_08 LOC/2021_04_08 SIGNT/2021_04_08 --do_train --do_eval --do_predict --evaluate_during_training --per_device_train_batch_size 32 --per_device_eval_batch_size 32 --learning_rate 5e-5 --adam_epsilon 1e-8 --num_train_epochs 7 --warmup_steps 0 --save_steps 10 --save_total_limit 1 --seed 43
```

Check that predictions do not change when changing the model code

In [4]:
model_dir = "/hub/CA-MTL/mock_models"
tasks = ['D0', 'D1', 'LOC', 'MANC', 'SIGNT']
orig_model_run = "leafy-sky-42"
new_model_run = "radiant-music-89"
orig = {}
new = {}
for task in tasks:
    orig[task] = pd.read_csv(f"{model_dir}/{task}_test_iter_{orig_model_run}.tsv", 
                             sep="\t")
    new[task] = pd.read_csv(f"{model_dir}/{task}_test_iter_{new_model_run}.tsv",
                            sep="\t")                        

In [5]:
check_task = "SIGNT"
print((orig[check_task]['probability'] == new[check_task]['probability']).mean())
display(orig[check_task].head(5))
display(new[check_task].head(5))

1.0


Unnamed: 0,index,scoring_model,prediction,probability,logits
0,0,leafy-sky-42,Street Name Sign,0.999564,"{""Advance Traffic Control Sign"": -2.1757247447..."
1,1,leafy-sky-42,Street Name Sign,0.999648,"{""Advance Traffic Control Sign"": -2.0285027027..."
2,2,leafy-sky-42,Street Name Sign,0.999622,"{""Advance Traffic Control Sign"": -1.9623690843..."
3,3,leafy-sky-42,Street Name Sign,0.999646,"{""Advance Traffic Control Sign"": -2.1853635311..."
4,4,leafy-sky-42,Street Name Sign,0.993865,"{""Advance Traffic Control Sign"": -1.5310235023..."


Unnamed: 0,index,scoring_model,prediction,probability,logits
0,0,radiant-music-89,Street Name Sign,0.999564,"{""Advance Traffic Control Sign"": -2.1757247447..."
1,1,radiant-music-89,Street Name Sign,0.999648,"{""Advance Traffic Control Sign"": -2.0285027027..."
2,2,radiant-music-89,Street Name Sign,0.999622,"{""Advance Traffic Control Sign"": -1.9623690843..."
3,3,radiant-music-89,Street Name Sign,0.999646,"{""Advance Traffic Control Sign"": -2.1853635311..."
4,4,radiant-music-89,Street Name Sign,0.993865,"{""Advance Traffic Control Sign"": -1.5310235023..."


Test model deployment

In [6]:
model_run = "zesty-glitter-82"
model_args = CaMtlArguments(
    model_name_or_path=f'/hub/CA-MTL/mock_models/{model_run}-9000', 
    encoder_type="CA-MTL-tiny")

data_args = MultiTaskDataArguments(
    data_dir='/hub/CA-MTL/data',
    overwrite_cache = True,
    task_data_folders=['TOSCORE'])

training_args = MultiTaskTrainingArguments(
    output_dir='/hub/CA-MTL/data/SCORED', 
    overwrite_output_dir=False,
    do_train=False, do_eval=False,
    do_predict=True, 
)

In [7]:
logger = logging.getLogger(__name__)
def setup_logging(training_args):
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
def create_eval_datasets(mode, data_args, tokenizer, model_metadata):
    eval_datasets = {}
    for task_id, task_name in enumerate(model_metadata['tasks']):
        eval_datasets[task_name] = TaskDataset(
            task_name, task_id, data_args, tokenizer, mode=mode, 
            label_list=model_metadata['label_set'][task_name]
        )
    return eval_datasets

In [8]:
setup_logging(training_args)

set_seed(training_args.seed)
logger.info(training_args)

model_metadata = json.load(open(f"{model_args.model_name_or_path}/metadata.json", 'r'))

# update arguments with condition at model training
data_args.max_seq_length = model_metadata['max_seq_length']
num_tasks = len(model_metadata['label_set'])
data_args.task_data_folders = data_args.task_data_folders*num_tasks
data_args.tasks = model_metadata['tasks']
model_args.encoder_type = model_metadata['model_name_or_path']

config = BertConfig.from_pretrained(model_args.model_name_or_path)
config.torchscript=True

model = CaMtl.from_pretrained(
    model_args.model_name_or_path,
    model_args,
    data_args,
    config=config)

logger.info(model)

# load the tokenizer that was used when the model was trained
tokenizer = BertTokenizer.from_pretrained(
    CaMtl.get_base_model(model_metadata['model_name_or_path']),
)

05/10/2021 06:13:20 - INFO - transformers.training_args -   PyTorch: setting up devices
05/10/2021 06:13:20 - INFO - __main__ -   MultiTaskTrainingArguments(output_dir='/hub/CA-MTL/data/SCORED', overwrite_output_dir=False, do_train=False, do_eval=False, do_predict=True, evaluate_during_training=False, per_device_train_batch_size=8, per_device_eval_batch_size=8, per_gpu_train_batch_size=None, per_gpu_eval_batch_size=None, gradient_accumulation_steps=1, learning_rate=5e-05, weight_decay=0.0, adam_epsilon=1e-08, max_grad_norm=1.0, num_train_epochs=3.0, max_steps=-1, warmup_steps=0, logging_dir=None, logging_first_step=False, logging_steps=500, save_steps=500, save_total_limit=None, no_cuda=False, seed=42, fp16=False, fp16_opt_level='O1', local_rank=-1, tpu_num_cores=None, tpu_metrics_debug=False, use_mt_uncertainty=False, uniform_mt_sampling=False, percent_of_max_data_size=1.0)
05/10/2021 06:13:20 - INFO - transformers.configuration_utils -   loading configuration file /hub/CA-MTL/mock_mo

05/10/2021 06:13:29 - INFO - transformers.tokenization_utils -   Model name 'huawei-noah/TinyBERT_General_6L_768D' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, TurkuNLP/bert-base-finnish-cased-v1, TurkuNLP/bert-base-finnish-uncased-v1, wietsedv/bert-base-dutch-cased). Assuming 'huawei-noah/TinyBERT_General_6L_768D' is a path, a model identifier, or url to a directory containing tokenizer files.
05/10/2021 06:13:31 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/huawei-noah/TinyBERT_

In [9]:
# class TSConfig():
#     def __init__(
#         self,
#         config_dict
#     ):
#         self.architectures = config_dict.pop("architectures", None)
#         self.attention_probs_dropout_prob = config_dict.pop("attention_probs_dropout_prob", None)
#         self.cell = config_dict.pop("cell", None)
#         self.hidden_act = config_dict.pop("hidden_act", None)
#         self.hidden_dropout_prob = config_dict.pop("hidden_dropout_prob", None)
#         self.hidden_size = config_dict.pop("hidden_size", None)
#         self.initializer_range = config_dict.pop("initializer_range", None)
#         self.intermediate_size = config_dict.pop("intermediate_size", None)
#         self.layer_norm_eps = config_dict.pop("layer_norm_eps", None)
#         self.max_position_embeddings = config_dict.pop("max_position_embeddings", None)
#         self.max_seq_length = config_dict.pop("max_seq_length", None)
#         self.model_type = config_dict.pop("model_type", None)
#         self.num_attention_heads = config_dict.pop("num_attention_heads", None)
#         self.num_hidden_layers = config_dict.pop("num_hidden_layers", None)
#         self.num_tasks = config_dict.pop("num_tasks", None)
#         self.pad_token_id = config_dict.pop("pad_token_id", None)
#         self.pre_trained = config_dict.pop("pre_trained", None)
#         self.structure = config_dict.pop("structure", None)
#         self.torchscript = config_dict.pop("torchscript", True)
#         self.type_vocab_size = config_dict.pop("type_vocab_size", None)
#         self.vocab_size = config_dict.pop("vocab_size", None)

In [10]:
# ts_config = TSConfig(config.to_dict())

In [11]:
model_args

CaMtlArguments(model_name_or_path='/hub/CA-MTL/mock_models/zesty-glitter-82-9000', encoder_type='CA-MTL-tiny')

In [12]:
data_args

MultiTaskDataArguments(data_dir='/hub/CA-MTL/data', tasks=['D0', 'D1', 'MANC', 'LOC', 'SIGNT'], task_data_folders=['TOSCORE', 'TOSCORE', 'TOSCORE', 'TOSCORE', 'TOSCORE'], overwrite_cache=True, max_seq_length=256)

In [13]:
# from transformers import BertForSequenceClassification
# bert_model = torch.jit.script(BertForSequenceClassification.from_pretrained(
#     "bert-base-uncased", torchscript=True))

In [14]:
s_model = torch.jit.script(CaMtl.from_pretrained(
    model_args.model_name_or_path,
    model_args=model_args,
    data_args=data_args,
    config=config))

05/10/2021 06:13:31 - INFO - transformers.modeling_utils -   loading weights file /hub/CA-MTL/mock_models/zesty-glitter-82-9000/pytorch_model.bin
  " but it is a non-constant {}. Consider removing it.".format(name, hint))
  " but it is a non-constant {}. Consider removing it.".format(name, hint))


RuntimeError: 
Module 'CaMtlBaseEncoder' has no attribute 'config' (This attribute exists on the Python module, but we failed to convert Python type: 'BertConfig' to a TorchScript type.):
  File "/home/datasci/CA-MTL/src/model/encoders/ca_mtl_base.py", line 676
                )
                head_mask = head_mask.expand(
                    self.config.num_hidden_layers, -1, -1, -1, -1
                    ~~~~~~~~~~~ <--- HERE
                )
            elif head_mask.dim() == 2:


In [None]:
this = torch.Tensor([1, 2, 3, 0, 3, 4, 1, 1, 0, 0, 2, 3, 4, 5])

In [None]:
for num in this:
    print(int(num))

^^^ I need to find a way to do this without using .numpy() or the numpy library... maybe I can iterate through the tensor in another way??

In [None]:
this = []
for i in range(10):
    this.append(i)
that = []
for i in range(10):
    that.append(i)

In [None]:
# this.append(that)

In [None]:
(this, that)

In [None]:
(this, )

In [None]:
model.parameters()

In [None]:
next(model.parameters()).dtype

In [None]:
model

In [None]:
s_model

In [None]:
# load original test data
# load original test data scores
# load deployment test data sample
# score test sample with model
# score test sample with serialized model
# compare all

In [None]:
#     def predict(
#         self,
#         eval_dataset: Optional[Dataset] = None,
#         prediction_loss_only: Optional[bool] = None,
#         scoring_model: Optional[str] = None
#     ):
#         logging.info("*** Test ***")
#         datasets = eval_dataset or self.test_datasets
#         for task_name, test_dataset in datasets.items():
#             logger.info(task_name)
            
#             test_dataloader = self.get_test_dataloader(test_dataset)
#             test_result = self._prediction_loop(
#                 test_dataloader, description="Prediction", task_name=task_name, 
#                 mode=test_dataset.mode)
            
#             self._log(test_result.metrics)
#             for key, value in test_result.metrics.items():
#                 logger.info("  %s = %s", key, value)
                
#             softmax = torch.nn.Softmax(dim=1)
#             probs = softmax(torch.Tensor(test_result.predictions)).numpy().astype('float64')
#             logits = test_result.predictions.astype('float64')
#             output_mode = task_output_modes[task_name] 
#             if output_mode == "classification":
#                 predictions = np.argmax(logits, axis=1)
            
#             self.run_name = wandb.run.name
#             output_test_file = os.path.join(
#                 self.args.output_dir,
#                 f"{task_name}_test_iter_{self.run_name}.tsv",
#             )
#             if scoring_model is None:
#                 scoring_model = self.run_name
#             if self.is_world_master():
#                 with open(output_test_file, "w") as writer:
#                     logger.info("***** Test results {} *****".format(task_name))
#                     logger.info("***** Writing as {} *****".format(self.run_name))
#                     if output_mode == "regression":
#                         writer.write("index\tprediction\n")
#                     else:
#                         writer.write("index\tscoring_model\tprediction\tprobability\tlogits\n")
#                     for index, item in enumerate(predictions):
#                         if output_mode == "regression":
#                             writer.write("%d\t%3.3f\n" % (index, item))
#                         else:
#                             i_probs = probs[index,:]
#                             i_logits = logits[index,:]
#                             i_logits = json.dumps(dict(zip(test_dataset.get_labels(), i_logits)))
#                             writer.write(
#                                 "%d\t%s\t%s\t%3.6f\t%s\n" % (
#                                     index, scoring_model, test_dataset.get_labels()[item], 
#                                     i_probs[item], i_logits)
#                             )
                            
#     def _prediction_loop(
#         self, dataloader: DataLoader, description: str, task_name: str, mode: str,
#         prediction_loss_only: Optional[bool] = None, 
#     ) -> PredictionOutput:
#         """
#         Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
#         Works both with or without labels.
#         """

#         prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only

#         model = self.model
#         # multi-gpu eval
#         if self.args.n_gpu > 1:
#             model = torch.nn.DataParallel(model)
#         else:
#             model = self.model
#         # Note: in torch.distributed mode, there's no point in wrapping the model
#         # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

#         batch_size = dataloader.batch_size
#         logger.info("***** Running %s *****", description)
#         logger.info("  Num examples = %d", self.num_examples(dataloader))
#         logger.info("  Batch size = %d", batch_size)
#         eval_losses: List[float] = []
#         preds: torch.Tensor = None
#         label_ids: torch.Tensor = None
#         model.eval()

#         if is_tpu_available():
#             dataloader = pl.ParallelLoader(dataloader,
#                                            [self.args.device]).per_device_loader(self.args.device)

#         for inputs in tqdm(dataloader, desc=description):
#             has_labels = any(
#                 inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])

#             for k, v in inputs.items():
#                 inputs[k] = v.to(self.args.device)

#             with torch.no_grad():
#                 outputs = model(**inputs)
#                 if has_labels:
#                     step_eval_loss, logits = outputs[:2]
#                     eval_losses += [step_eval_loss.mean().item()]
#                 else:
#                     logits = outputs[0]

#             if not prediction_loss_only:
#                 if preds is None:
#                     preds = logits.detach()
#                 else:
#                     preds = torch.cat((preds, logits.detach()), dim=0)
#                 if inputs.get("labels") is not None:
#                     if label_ids is None:
#                         label_ids = inputs["labels"].detach()
#                     else:
#                         label_ids = torch.cat((label_ids, inputs["labels"].detach()), dim=0)

#         if self.args.local_rank != -1:
#             # In distributed mode, concatenate all results from all nodes:
#             if preds is not None:
#                 preds = self.distributed_concat(preds,
#                                                 num_total_examples=self.num_examples(dataloader))
#             if label_ids is not None:
#                 label_ids = self.distributed_concat(label_ids,
#                                                     num_total_examples=self.num_examples(dataloader))
#         elif is_tpu_available():
#             # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
#             if preds is not None:
#                 preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
#             if label_ids is not None:
#                 label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)

#         # Finally, turn the aggregated tensors into numpy arrays.
#         if preds is not None:
#             preds = preds.cpu().numpy()
#         if label_ids is not None:
#             label_ids = label_ids.cpu().numpy()

#         if self.compute_metrics is not None and preds is not None and label_ids is not None:
#             metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
#         else:
#             metrics = {}
#         if len(eval_losses) > 0:
#             metrics[f"{task_name}_{mode}_loss"] = np.mean(eval_losses)

#         # Prefix all keys with {task_name}_{model}_
#         for key in list(metrics.keys()):
#             if not key.startswith(f"{task_name}_{mode}_"):
#                 metrics[f"{task_name}_{mode}_{key}"] = metrics.pop(key)

#         return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)


# CBDA

In [None]:
from src.model.encoders.conditional_modules import CBDA
from src.utils.misc import Split
import math
from torch import nn

In [None]:
hidden_size = 768
max_seq_length = 256
num_blocks = hidden_size//max_seq_length

In [None]:
model_args = CaMtlArguments(
    model_name_or_path="CA-MTL-tiny", 
    encoder_type="CA-MTL-tiny")

data_args = MultiTaskDataArguments(
    data_dir='/hub/CA-MTL/data',
    tasks=['D0', 'D1', 'LOC', 'MANC', 'SIGNT'],
    overwrite_cache = True,
    task_data_folders=[
        'D0/2021_04_08', 'D1/2021_04_08', 'MANC/2021_04_08', 
        'LOC/2021_04_08', 'SIGNT/2021_04_08'])

training_args = MultiTaskTrainingArguments(
    output_dir='/hub/CA-MTL/mock_models', 
    overwrite_output_dir=False, 
    do_train=True, 
    do_eval=True, 
    do_predict=True,
    evaluate_during_training=True,
    per_device_train_batch_size=32, 
    per_device_eval_batch_size=32, 
    per_gpu_train_batch_size=None, 
    per_gpu_eval_batch_size=None, 
    gradient_accumulation_steps=1, 
    learning_rate=5e-05, 
    weight_decay=0.0,
    adam_epsilon=1e-08,
    max_grad_norm=1.0,
    num_train_epochs=7.0, 
    max_steps=-1, 
    warmup_steps=0, 
    logging_dir=None, 
    logging_first_step=False,
    logging_steps=500,
    save_steps=1500,
    save_total_limit=1, 
    no_cuda=False, 
    seed=43, 
    fp16=False,
    fp16_opt_level='O1',
    local_rank=-1, 
    tpu_num_cores=None, 
    tpu_metrics_debug=False, 
    use_mt_uncertainty=False, 
    uniform_mt_sampling=False, 
    percent_of_max_data_size=1.0)

In [None]:
config = BertConfig.from_pretrained(CaMtl.get_base_model(model_args.model_name_or_path))

model = CaMtl.from_pretrained(
    CaMtl.get_base_model(model_args.model_name_or_path),
    model_args,
    data_args,
    config=config)

logger.info(model)

# load the tokenizer that was used when the model was trained
tokenizer = BertTokenizer.from_pretrained(
    CaMtl.get_base_model(model_args.model_name_or_path),
)

In [None]:
train_dataset=MultiTaskDataset(data_args, tokenizer, limit_length=50)

for i, batch in enumerate(train_dataset):
    if i == 200:
        print(batch)

In [None]:
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from transformers import (
    DefaultDataCollator,
)
def get_train_dataloader(train_dataset):
    sampler = RandomSampler(train_dataset)

    data_loader = DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=32,
        collate_fn=DefaultDataCollator().collate_batch,
    )

    return data_loader

train_dataloader = get_train_dataloader(train_dataset)
for i, batch in enumerate(train_dataloader):
    if i == 1:
        print(batch)

In [None]:
task_id_2_task_idx = {i: i for i, t in enumerate(data_args.tasks)}
def _create_task_type(task_id):
    task_type = task_id.clone()
    unique_task_ids = torch.unique(task_type)
    unique_task_ids_list = (
        unique_task_ids.cpu().numpy()
        if unique_task_ids.is_cuda
        else unique_task_ids.numpy()
    )
    for unique_task_id in unique_task_ids_list:
        task_type[task_type == unique_task_id] = task_id_2_task_idx[
            unique_task_id
        ]
    return task_type

In [None]:
batch = iter(train_dataloader).next()
batch

In [None]:
task_type = _create_task_type(batch['task_id'])

task_type_embeddings = nn.Embedding(len(data_args.tasks), hidden_size)
task_embedding = task_type_embeddings(task_type)

task_transformation = nn.Linear(config.hidden_size, config.hidden_size)
task_embedding = task_transformation(task_embedding)

In [None]:
task_embedding.shape

In [None]:
random_weight_matrix = nn.Parameter(
    torch.zeros(
        [max_seq_length, math.ceil(max_seq_length/num_blocks)]
    ),
    requires_grad=True,
)

In [None]:
random_weight_matrix.shape

In [None]:
cond_block_diag_attn = CBDA(
    hidden_size, math.ceil(max_seq_length/num_blocks), num_blocks
)  # d x L/N

In [None]:
attention_scores = cond_block_diag_attn(
    x_cond=task_embedding,
    x_to_film=random_weight_matrix,
)

In [None]:
attention_scores.shape

In [None]:
attention_scores

# END CBDA

In [None]:
# %%timeit -n 1 -r 10
# setup_logging(training_args)

# set_seed(training_args.seed)
# logger.info(training_args)

# model_metadata = json.load(open(f"{model_args.model_name_or_path}/metadata.json", 'r'))

# # update arguments with condition at model training
# data_args.max_seq_length = model_metadata['max_seq_length']
# num_tasks = len(model_metadata['label_set'])
# data_args.task_data_folders = data_args.task_data_folders*num_tasks
# data_args.tasks = model_metadata['tasks']
# model_args.encoder_type = model_metadata['model_name_or_path']

# config = BertConfig.from_pretrained(model_args.model_name_or_path)

# model = CaMtl.from_pretrained(
#     model_args.model_name_or_path,
#     model_args,
#     data_args,
#     config=config)

# logger.info(model)

# # load the tokenizer that was used when the model was trained
# tokenizer = BertTokenizer.from_pretrained(
#     CaMtl.get_base_model(model_metadata['model_name_or_path']),
# )

# logger.info("Training tasks: %s", ", ".join([t for t in data_args.tasks]))

# trainer = MultiTaskTrainer(
#     tokenizer,
#     data_args,
#     model=model,
#     args=training_args,
#     train_dataset=None,
#     eval_datasets=None,
#     test_datasets=create_eval_datasets(Split.test, data_args, tokenizer, model_metadata)
#     if training_args.do_predict
#     else None,
# )

# scoring_model = model_args.model_name_or_path.split("/")[-1]
# if training_args.do_predict:
#     trainer.predict(scoring_model = scoring_model)
# 1min 1s ± 256 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)