In [None]:
from fastai.basics import *
from fastcore.all import *
from reformer_fastai.all import *
import time

In [None]:
class ConfigBase:
    pass

In [None]:
class SyntConfig:
    _d = {
        'vocab_sz':128,
        'd_model':256,
        'n_layers':1,
        'n_heads':4,
        'd_ff':256,
        'attn_dropout':0.0,
        'ff_dropout':0.0,
        'emb_dropout':0.0,
        'bucket_size':64,
        'max_seq_len':1024,
        'random_state':123,
        'use_lsh':True,
        'n_hashes':4
    }
    
    @delegates(LSHLM)
    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            if k in self._d: self._d[k]=v
            else: print(f'Parameter {key} is not accepted by LSHLM. Skipped')
    
    def __repr__(self):
        s = "LSHLM config \n" + '-'*20
        s += ''.join(f'\n{k:16}{v}' for k,v in self._d.items())
        return s
    
    def dict(self): return self._d
    
    def save(self, fn, add_tstmp=False):
        if add_tstmp:
            tstmp = time.strftime("_%d_%m_%Y_%H:%M", time.gmtime())
            fn += tstmp
        save_pickle(fn, self)
    
    @classmethod
    def from_file(cls, fn):
        return load_pickle(fn)

In [None]:
config = SyntConfig()
config

LSHLM config 
--------------------
vocab_sz        128
d_model         256
n_layers        1
n_heads         4
d_ff            256
attn_dropout    0.0
ff_dropout      0.0
emb_dropout     0.0
bucket_size     64
max_seq_len     1024
random_state    123
use_lsh         True
n_hashes        4

In [None]:
config.dict()

{'vocab_sz': 128,
 'd_model': 256,
 'n_layers': 1,
 'n_heads': 4,
 'd_ff': 256,
 'attn_dropout': 0.0,
 'ff_dropout': 0.0,
 'emb_dropout': 0.0,
 'bucket_size': 64,
 'max_seq_len': 1024,
 'random_state': 123,
 'use_lsh': True,
 'n_hashes': 4}

In [None]:
config.save('test')

In [None]:
config2 = SyntConfig.from_file('test')
config2

LSHLM config 
--------------------
vocab_sz        128
d_model         256
n_layers        1
n_heads         4
d_ff            256
attn_dropout    0.0
ff_dropout      0.0
emb_dropout     0.0
bucket_size     64
max_seq_len     1024
random_state    123
use_lsh         True
n_hashes        4

In [None]:
SyntheticConfig = {
    'vocab_sz':128,
    'd_model':256,
    'n_layers':1,
    'n_heads':4,
    'd_ff':256,
    'attn_dropout':0.0,
    'ff_dropout':0.0,
    'emb_dropout':0.0,
    'bucket_size':64,
    'max_seq_len':1024,
    'random_state':123,
    'use_lsh':True,
    'n_hashes':4
}

In [None]:
@patch(cls_method=True)
def from_config(cls:LSHLM, config):
    return cls(**config.dict())

In [None]:
LSHLM.from_config(config)

LSHLM(
  (emb): TransformerEmbedding(
    (emb): Embedding(128, 256)
    (dropout): Dropout(p=0.0, inplace=False)
    (pos_enc): AbsolutePositionalEmbedding(
      (emb): Embedding(1024, 256)
    )
  )
  (encoder): LSHEncoder(
    (layers): ModuleList(
      (0): LSHEncoderBlock(
        (attn): PostNorm(
          (sublayer): Residual(
            (sublayer): ReformerAttentionV2(
              (in_proj): SharedQKAttnInProj(
                (to_qk): Linear(in_features=256, out_features=256, bias=False)
                (to_v): Linear(in_features=256, out_features=256, bias=False)
              )
              (lsh_attn): LSHAttention(
                (dropout): Dropout(p=0.0, inplace=False)
                (dropout_for_hash): Dropout(p=0.0, inplace=False)
              )
              (full_attn): ScaledDotProdAttention(
                (dropout): Dropout(p=0.0, inplace=False)
              )
              (out_proj): Linear(in_features=256, out_features=256, bias=False)
              (