# Speech Emotion Recognition with `CNN + Transformer`

- The major architecture of this model is motivated by this paper [Self-attention for Speech Emotion Recognition](https://publications.idiap.ch/attachments/papers/2019/Tarantino_INTERSPEECH_2019.pdf)

- This is my individual notebook to understand transformer and speech emotion recognition task.

In [1]:
from google.colab import drive
drive.mount('/gdrive')
from pathlib import Path
drive_path = Path('/gdrive/Shareddrives/Dion-Account/2122WS/4-dl4slp/coding-project/ser/')

Mounted at /gdrive


In [2]:
from exp.nb_08 import *

In [3]:
drive_data_path = drive_path/'data/v1'
(drive_data_path).ls()

[PosixPath('/gdrive/Shareddrives/Dion-Account/2122WS/4-dl4slp/coding-project/ser/data/v1/train'),
 PosixPath('/gdrive/Shareddrives/Dion-Account/2122WS/4-dl4slp/coding-project/ser/data/v1/dev'),
 PosixPath('/gdrive/Shareddrives/Dion-Account/2122WS/4-dl4slp/coding-project/ser/data/v1/ser.tar-v1.gz')]

In [4]:
class ItemList(ListContainer):
    def __init__(self, items, path='.', tfms=None):
        super().__init__(items)
        self.path,self.tfms = Path(path),tfms

    def __repr__(self): return f'{super().__repr__()}\nPath: {self.path}'

    def new(self, items, cls=None):
        if cls is None: cls=self.__class__
        return cls(items, self.path, tfms=self.tfms)

    def  get(self, i): return i
    def _get(self, i): return compose(self.get(i), self.tfms)

    def __getitem__(self, idx):
        res = super().__getitem__(idx)
        if isinstance(res,list): return [self._get(o) for o in res]
        return self._get(res)


In [5]:
class AudioList(ItemList):
    @classmethod
    def from_files(cls, path, extensions = None, recurse=True, include=None, **kwargs):
        return cls(get_files(path, extensions, recurse=recurse, include=include), path, **kwargs)
    
    def get(self, fn):
        return torch.load(fn)

class Reshape():
    "transpose to [n_features, n_frames]"
    _order=12
    def __call__(self, item):
        w, h = item.shape
        return item.view(h, w)

class DummyChannel():
    "insert pseudo axis in height [n_features, 1, n_frames]"
    _order = 30
    def __call__(self, item):
        return item.unsqueeze(1)

def re_labeler(fn, pat, subcl='act'):
    assert subcl in ['act', 'val', 'all']
    if subcl=='all': return tuple(int(i) for i in re.findall(pat, str(fn)))
    else:
        return re.findall(pat, str(fn))[0] if pat == 'act' else re.findall(pat, str(fn))[1]


In [6]:
import random
def random_splitter(fn, p_valid): return random.random() < p_valid

In [7]:
class CategoryProcessor(Processor):
    "convert string to float, which was retrieved from the file name"
    def __init__(self): self.vocab=None

    def __call__(self, items):
        #The vocab is defined on the first use.
        if self.vocab is None:
            # set_trace()
            self.vocab = uniqueify(items)
            # self.otoi  = {v:k for k,v in enumerate(self.vocab)}
        return [torch.tensor(o).float() for o in items]
    def proc1(self, item):  return self.otoi[item]

    def deprocess(self, idxs):
        assert self.vocab is not None
        return [self.deproc1(idx) for idx in idxs]
    def deproc1(self, idx): return self.vocab[idx]


In [8]:
def _get_files(p, fs, extensions=None):
    p = Path(p)
    res = [p/f for f in fs if not f.startswith('.')
        #    and '_0_0' in f
           and ((not extensions) or f'.{f.split(".")[-1].lower()}' in extensions)]
    return res

def get_files(path, extensions=None, recurse=False, include=None):
    path = Path(path)
    extensions = setify(extensions)
    extensions = {e.lower() for e in extensions}
    if recurse:
        res = []
        for i,(p,d,f) in enumerate(os.walk(path)): # returns (dirpath, dirnames, filenames)
            if include is not None and i==0: d[:] = [o for o in d if o in include]
            else:                            d[:] = [o for o in d if not o.startswith('.')]
            res += _get_files(p, f, extensions)
        return res
    else:
        f = [o.name for o in os.scandir(path) if o.is_file()]
        return _get_files(path, f, extensions)

In [9]:
def get_dls(train_ds, valid_ds, bs, **kwargs):
    return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, batch_size=bs, **kwargs))

In [10]:
train_path = drive_data_path/'train'
tfms = [Reshape(), DummyChannel()]
al=AudioList.from_files(train_path, tfms=tfms)

In [11]:
al[0].shape, al.items.__len__()

(torch.Size([26, 1, 614]), 7800)

In [12]:
def databunchify(sd, bs, c_in=None, c_out=None, **kwargs):
    dls = get_dls(sd.train, sd.valid, bs, **kwargs)
    return DataBunch(*dls, c_in=c_in, c_out=c_out)

SplitData.to_databunch = databunchify

In [13]:
label_pat = r'_(\d+)'
emotion_labeler = partial(re_labeler, pat=label_pat, subcl='all')
sd = SplitData.split_by_func(al, partial(random_splitter, p_valid=0.00))
ll = label_by_func(sd, emotion_labeler, proc_y=CategoryProcessor())

In [14]:
ll.train.y

ItemList (7800 items)
[tensor([1., 1.]), tensor([0., 0.]), tensor([0., 0.]), tensor([1., 0.]), tensor([1., 1.]), tensor([1., 0.]), tensor([1., 1.]), tensor([1., 0.]), tensor([1., 0.]), tensor([1., 1.])...]
Path: /gdrive/Shareddrives/Dion-Account/2122WS/4-dl4slp/coding-project/ser/data/v1/train

In [15]:
bs=1

In [16]:
c_in = ll.train[0][0].shape[0]
c_out = 2
data = ll.to_databunch(bs,c_in=c_in,c_out=c_out)

In [17]:
data.train_dl.batch_size, data.valid_dl.batch_size

(1, 1)

In [18]:
data.c_in, data.c_out

(26, 2)

In [22]:
xb, yb = next(iter(data.train_dl))

In [23]:
xb.shape

torch.Size([1, 26, 1, 153])

In [24]:
yb

tensor([[0., 1.]])

## Investigate label distribution

![](https://www.researchgate.net/profile/Lung-Hao-Lee-2/publication/304124018/figure/fig1/AS:374864755085312@1466386130906/Two-dimensional-valence-arousal-space.png)

In [40]:
from collections import Counter
Counter(str(i.tolist()) for i in data.train_ds.y)

Counter({'[0.0, 0.0]': 1240,
         '[0.0, 1.0]': 1023,
         '[1.0, 0.0]': 3194,
         '[1.0, 1.0]': 2343})

## Model

In [None]:
#@title
class RunningBatchNorm(nn.Module):
    def __init__(self, nf, mom=0.1, eps=1e-5):
        super().__init__()
        self.mom, self.eps = mom, eps
        self.mults = nn.Parameter(torch.ones (nf,1,1))
        self.adds  = nn.Parameter(torch.zeros(nf,1,1))
        self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
        self.register_buffer('count', tensor(0.))
        self.register_buffer('factor', tensor(0.))
        self.register_buffer('offset', tensor(0.))
        self.batch = 0

    def update_stats(self, x):
        bs,nc,*_ = x.shape
        self.sums.detach_()
        self.sqrs.detach_()
        dims = (0,2,3)
        s    = x    .sum(dims, keepdim=True)
        ss   = (x*x).sum(dims, keepdim=True)
        c    = s.new_tensor(x.numel()/nc)
        mom1 = s.new_tensor(1 - (1-self.mom)/math.sqrt(bs-1))
        self.sums .lerp_(s , mom1)
        self.sqrs .lerp_(ss, mom1)
        self.count.lerp_(c , mom1)
        self.batch += bs
        means = self.sums/self.count
        varns = (self.sqrs/self.count).sub_(means*means)
        if bool(self.batch < 20): varns.clamp_min_(0.01)
        self.factor = self.mults / (varns+self.eps).sqrt()
        self.offset = self.adds - means*self.factor

    def forward(self, x):
        if self.training: self.update_stats(x)
        return x*self.factor + self.offset

In [41]:
class Optimizer():
    def __init__(self, params, steppers, **defaults):
        self.steppers = listify(steppers)
        maybe_update(self.steppers, defaults, get_defaults)
        # might be a generator
        self.param_groups = list(params)
        # ensure params is a list of lists
        if not isinstance(self.param_groups[0], list): self.param_groups = [self.param_groups]
        self.hypers = [{**defaults} for p in self.param_groups]

    def grad_params(self):
        return [(p,hyper) for pg,hyper in zip(self.param_groups,self.hypers)
            for p in pg if p.grad is not None]

    def zero_grad(self):
        for p,hyper in self.grad_params():
            p.grad.detach_()
            p.grad.zero_()

    def step(self):
        for p,hyper in self.grad_params(): compose(p, self.steppers, **hyper)
def maybe_update(os, dest, f):
    for o in os:
        for k,v in f(o).items():
            if k not in dest: dest[k] = v

class Stat():
    _defaults = {}
    def init_state(self, p): raise NotImplementedError
    def update(self, p, state, **kwargs): raise NotImplementedError

class AverageGrad(Stat):
    _defaults = dict(mom=0.9)

    def __init__(self, dampening:bool=False): self.dampening=dampening
    def init_state(self, p): return {'grad_avg': torch.zeros_like(p.grad.data)}
    def update(self, p, state, mom, **kwargs):
        state['mom_damp'] = 1-mom if self.dampening else 1.
        state['grad_avg'].mul_(mom).add_(state['mom_damp'], p.grad.data)
        return state

class AverageSqrGrad(Stat):
    _defaults = dict(sqr_mom=0.99)

    def __init__(self, dampening:bool=True): self.dampening=dampening
    def init_state(self, p): return {'sqr_avg': torch.zeros_like(p.grad.data)}
    def update(self, p, state, sqr_mom, **kwargs):
        state['sqr_damp'] = 1-sqr_mom if self.dampening else 1.
        state['sqr_avg'].mul_(sqr_mom).addcmul_(state['sqr_damp'], p.grad.data, p.grad.data)
        return state

class StepCount(Stat):
    def init_state(self, p): return {'step': 0}
    def update(self, p, state, **kwargs):
        state['step'] += 1
        return state

def debias(mom, damp, step): return damp * (1 - mom**step) / (1-mom)

class StatefulOptimizer(Optimizer):
    def __init__(self, params, steppers, stats=None, **defaults):
        self.stats = listify(stats)
        maybe_update(self.stats, defaults, get_defaults)
        super().__init__(params, steppers, **defaults)
        self.state = {}

    def step(self):
        for p,hyper in self.grad_params():
            if p not in self.state:
                #Create a state for p and call all the statistics to initialize it.
                self.state[p] = {}
                maybe_update(self.stats, self.state[p], lambda o: o.init_state(p))
            state = self.state[p]
            for stat in self.stats: state = stat.update(p, state, **hyper)
            compose(p, self.steppers, **state, **hyper)
            self.state[p] = state

def adam_step(p, lr, mom, mom_damp, step, sqr_mom, sqr_damp, grad_avg, sqr_avg, eps, **kwargs):
    debias1 = debias(mom,     mom_damp, step)
    debias2 = debias(sqr_mom, sqr_damp, step)
    p.data.addcdiv_(-lr / debias1, grad_avg, (sqr_avg/debias2).sqrt() + eps)
    return p
adam_step._defaults = dict(eps=1e-5)

def weight_decay(p, lr, wd, **kwargs):
    p.data.mul_(1 - lr*wd)
    return p
weight_decay._defaults = dict(wd=0.)

def adam_opt(xtra_step=None, **kwargs):
    return partial(StatefulOptimizer, steppers=[adam_step,weight_decay]+listify(xtra_step),
                   stats=[AverageGrad(dampening=True), AverageSqrGrad(), StepCount()], **kwargs)
    
opt_func = adam_opt(mom=0.9, mom_sqr=0.99, eps=1e-6, wd=1e-1, )    

In [42]:
def get_defaults(d): return getattr(d,'_defaults',{})

In [43]:
class Callback():
    _order=0
    def set_runner(self, run): self.run=run
    def __getattr__(self, k): return getattr(self.run, k)

    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')

    def __call__(self, cb_name):
        f = getattr(self, cb_name, None)
        if f and f(): return True
        return False

class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs=0.
        self.run.n_iter=0

    def after_batch(self):
        if not self.in_train: return
        self.run.n_epochs += 1./self.iters
        self.run.n_iter   += 1

    def begin_epoch(self):
        self.run.n_epochs=self.epoch
        self.model.train()
        self.run.in_train=True

    def begin_validate(self):
        self.model.eval()
        self.run.in_train=False

class CancelTrainException(Exception): pass
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass

class Runner():
    def __init__(self, cbs=None, cb_funcs=None):
        self.in_train = False
        cbs = listify(cbs)
        for cbf in listify(cb_funcs):
            cb = cbf()
            setattr(self, cb.name, cb)
            cbs.append(cb)
        self.stop,self.cbs = False,[TrainEvalCallback()]+cbs

    @property
    def opt(self):       return self.learn.opt
    @property
    def model(self):     return self.learn.model
    @property
    def loss_func(self): return self.learn.loss_func
    @property
    def data(self):      return self.learn.data

    def one_batch(self, xb, yb):
        try:
            self.xb,self.yb = xb,yb
            self('begin_batch')
            # print(self.xb.shape)
            self.pred = self.model(self.xb)
            self('after_pred')
            # print(self.pred.shape, self.yb.shape)
            self.loss = self.loss_func(self.pred, self.yb)
            self('after_loss')
            if not self.in_train: return
            self.loss.backward()
            self('after_backward')
            self.opt.step()
            self('after_step')
            self.opt.zero_grad()
        except CancelBatchException: self('after_cancel_batch')
        finally: self('after_batch')

    def all_batches(self, dl):
        self.iters = len(dl)
        try:
            for xb,yb in dl: self.one_batch(xb, yb)
        except CancelEpochException: self('after_cancel_epoch')

    def fit(self, epochs, learn):
        self.epochs,self.learn,self.loss = epochs,learn,tensor(0.)

        try:
            for cb in self.cbs: cb.set_runner(self)
            self('begin_fit')
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'): self.all_batches(self.data.train_dl)

                with torch.no_grad():
                    if not self('begin_validate'): self.all_batches(self.data.valid_dl)
                self('after_epoch')

        except CancelTrainException: self('after_cancel_train')
        finally:
            self('after_fit')
            self.learn = None

    def __call__(self, cb_name):
        res = False
        for cb in sorted(self.cbs, key=lambda x: x._order):
            # print(cb_name, cb)
            res = cb(cb_name) and res
        return res

class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats,self.valid_stats = AvgStats(metrics,True),AvgStats(metrics,False)

    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()

    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): stats.accumulate(self.run)
        # print stats based on thousand iteration
        if (self.run.n_iter % 1000) ==0:
            print(f"iteration: {self.run.n_iter}, accuracy:  {self.train_stats}")
    def after_epoch(self):
        print(f"epoch {self.run.n_epoch} done!")

class Recorder(Callback):
    def begin_fit(self): self.lrs,self.losses = [],[]
    def after_batch(self):
        if not self.in_train: return
        self.lrs.append(self.opt.hypers[-1]['lr'])
        self.losses.append(self.loss.detach().cpu())

    def plot_lr  (self): plt.plot(self.lrs)
    def plot_loss(self): plt.plot(self.losses)

    def plot(self, skip_last=0):
        losses = [o.item() for o in self.losses]
        n = len(losses)-skip_last
        plt.xscale('log')
        plt.plot(self.lrs[:n], losses[:n])

class ParamScheduler(Callback):
    _order=1
    def __init__(self, pname, sched_funcs):
        self.pname,self.sched_funcs = pname,listify(sched_funcs)

    def begin_batch(self):
        if not self.in_train: return
        fs = self.sched_funcs
        if len(fs)==1: fs = fs*len(self.opt.param_groups)
        pos = self.n_epochs/self.epochs
        for f,h in zip(fs,self.opt.hypers): h[self.pname] = f(pos)

class LR_Find(Callback):
    _order=1
    def __init__(self, max_iter=100, min_lr=1e-6, max_lr=10):
        self.max_iter,self.min_lr,self.max_lr = max_iter,min_lr,max_lr
        self.best_loss = 1e9

    def begin_batch(self):
        if not self.in_train: return
        pos = self.n_iter/self.max_iter
        lr = self.min_lr * (self.max_lr/self.min_lr) ** pos
        for pg in self.opt.hypers: pg['lr'] = lr

    def after_step(self):
        if self.n_iter>=self.max_iter or self.loss>self.best_loss*10:
            raise CancelTrainException()
        if self.loss < self.best_loss: self.best_loss = self.loss

def get_runner(model, data, lr=0.6, cbs=None, opt_func=None, loss_func = F.cross_entropy):
    if opt_func is None: opt_func = optim.SGD
    opt = opt_func(model.parameters(), lr=lr)
    learn = Learner(model, opt, loss_func, data)
    return learn, Runner(cb_funcs=listify(cbs))

In [44]:
class Flatten(nn.Module):
    "remove last (pooled) dimension and reshape tensor tp (seq_len x d_model)"
    def __init__(self): super().__init__()
    def forward(self, x):
        return x.squeeze(-1).permute(1,0)

- Here, the Author mentioned that they use '6-block' for raw data. but in our case, as it is log mel, it's not 100% raw data.

# TO Fix the length, cnn 1d

In [45]:
class CNN1d(nn.Module):
    def __init__(self):
        super().__init__()
        self.module = nn.Sequential(
            nn.Conv1d(1,8, kernel_size=10, stride=1), GeneralRelu(),
            nn.Conv1d(8,16, kernel_size=10), GeneralRelu(),
            nn.Conv1d(16,32, kernel_size=5), GeneralRelu(),
            nn.Conv1d(32,64, kernel_size=5), GeneralRelu(),
            nn.Conv1d(64,128, kernel_size=3), GeneralRelu(),
            nn.AdaptiveMaxPool1d(1), #26, 128, 1
            Flatten()) #128, 26
    def forward(self, x):
        return self.module(x)

In [55]:
# dummy_filter = CNN1d()
# n_frames = [xb.shape[-1] for xb, yb in iter(data.train_dl)]
# len_dist = [sd.train[idx].shape[0] for idx, item in enumerate(sd.train)]

In [None]:
Counter(n_frames)

In [56]:
xb.shape, CNN1d()(xb.squeeze(0)).shape

(torch.Size([1, 26, 1, 67]), torch.Size([128, 26]))

- 128 : Time frame, which corresponds to `seq_len`
- 26 : d_model, which corresponds to `d_model` 

# Let us introduce transformer
- embedding is composed of two parts
    - positional encoding - equal to original one
    - embedding - learn from CNN 1d model

In [57]:
from torch import Tensor

In [None]:
!pip install -qq ipdb
from ipdb import set_trace

In [58]:
class PositionalEncoding(nn.Module):
    "Encode the position with a sinusoid."
    def __init__(self, d:int):
        super().__init__()
        self.register_buffer('freq', 1 / (10000 ** (torch.arange(0., d, 2.)/d)))

    def forward(self, pos:Tensor):
        inp = torch.ger(pos, self.freq)
        enc = torch.cat([inp.sin(), inp.cos()], dim=-1)
        return enc

class TransformerEmbedding(nn.Module):
    "Embedding from CNN + positional encoding + dropout"
    def __init__(self, emb_sz:int, inp_p:float=0.):
        super().__init__()
        self.emb_sz = emb_sz
        # (seq_len x d_model)
        self.embed = CNN1d()
        self.pos_enc = PositionalEncoding(emb_sz)
        self.drop = nn.Dropout(inp_p)

    def forward(self, inp):
        # Need to insert the batch dimension as you removed it when it is used for conv1d
        # 1,       26,       1,       75 -> 128    , 26
        # bs, d_model, seq_len, n_frames -> seq_len, d_model
        inp = self.embed(inp)
        pos = torch.arange(0, inp.size(0), device=inp.device).float()     
        # reconstruct pseudo batch dimension (1 x 128 x 26)
        return self.drop(inp * math.sqrt(self.emb_sz) + self.pos_enc(pos)).unsqueeze(0)


In [59]:
dummy_inp = torch.randn(128, 26)
dummy_pos_emb = PositionalEncoding(26)
dummy_pos = torch.arange(0, dummy_inp.size(0)).float()
dummy_pos_emb(dummy_pos)

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  1.0000,  1.0000,  1.0000],
        [ 0.8415,  0.4727,  0.2401,  ...,  1.0000,  1.0000,  1.0000],
        [ 0.9093,  0.8331,  0.4661,  ...,  1.0000,  1.0000,  1.0000],
        ...,
        [-0.6160, -0.9590, -0.8958,  ...,  0.9945,  0.9987,  0.9997],
        [ 0.3300, -0.7110, -0.7628,  ...,  0.9944,  0.9986,  0.9997],
        [ 0.9726, -0.2941, -0.5853,  ...,  0.9943,  0.9986,  0.9997]])

- `emb_sz`: 26
- Transformer-embedding: 1 x 128 x 26 (bs, seq_len, d_model)



In [60]:
def feed_forward(d_model:int, d_ff:int, ff_p:float=0., double_drop:bool=True):
    layers = [nn.Linear(d_model, d_ff), nn.ReLU()]
    if double_drop: layers.append(nn.Dropout(ff_p))
    return SequentialEx(*layers, nn.Linear(d_ff, d_model), nn.Dropout(ff_p), MergeLayer(), nn.LayerNorm(d_model))

class MergeLayer(nn.Module):
    "Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`."
    def __init__(self, dense:bool=False):
        super().__init__()
        self.dense=dense
    def forward(self, x):
        return torch.cat([x,x.orig], dim=1) if self.dense else (x+x.orig)

class SequentialEx(nn.Module):
    "Like `nn.Sequential`, but with ModuleList semantics, and can access module input"
    def __init__(self, *layers):
        super().__init__()
        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        res = x
        for l in self.layers:
            res.orig = x
            nres = l(res)
            # We have to remove res.orig to avoid hanging refs and therefore memory leaks
            res.orig, nres.orig = None, None
            res = nres
        return res

    def __getitem__(self,i): return self.layers[i]
    def append(self,l):      return self.layers.append(l)
    def extend(self,l):      return self.layers.extend(l)
    def insert(self,i,l):    return self.layers.insert(i,l)

class MultiHeadAttention(nn.Module):
    "MutiHeadAttention."

    def __init__(self, n_heads:int, d_model:int, d_head:int=None, resid_p:float=0., attn_p:float=0., bias:bool=True,
                 scale:bool=True):
        super().__init__()
        # d_head = ifnone(d_head, d_model//n_heads)
        self.n_heads,self.d_head,self.scale = n_heads,d_head,scale
        self.q_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias)
        self.k_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias)
        self.v_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias)
        self.out = nn.Linear(n_heads * d_head, d_model, bias=bias)
        self.drop_att,self.drop_res = nn.Dropout(attn_p),nn.Dropout(resid_p)
        self.ln = nn.LayerNorm(d_model)

    def forward(self, q:Tensor, k:Tensor, v:Tensor, mask:Tensor=None):
        return self.ln(q + self.drop_res(self.out(self._apply_attention(q, k, v, mask=mask))))

    def _apply_attention(self, q:Tensor, k:Tensor, v:Tensor, mask:Tensor=None):
        bs,seq_len = q.size(0),q.size(1)
        wq,wk,wv = self.q_wgt(q),self.k_wgt(k),self.v_wgt(v)
        wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
        wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)
        attn_score = torch.matmul(wq, wk)
        if self.scale: attn_score = attn_score.div_(self.d_head ** 0.5)
        if mask is not None:
            attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
        attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))
        attn_vec = torch.matmul(attn_prob, wv)
        return attn_vec.permute(0, 2, 1, 3).contiguous().contiguous().view(bs, seq_len, -1)

class EncoderBlock(nn.Module):
    "Encoder block of a Transformer model."
    #Can't use Sequential directly cause more than one input...
    def __init__(self, n_heads:int, d_model:int, d_head:int, d_inner:int, resid_p:float=0., attn_p:float=0., ff_p:float=0.,
                 bias:bool=True, scale:bool=True, double_drop:bool=True):
        super().__init__()
        self.mha = MultiHeadAttention(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale)
        self.ff  = feed_forward(d_model, d_inner, ff_p=ff_p, double_drop=double_drop)

    def forward(self, x:Tensor, mask:Tensor=None): return self.ff(self.mha(x, x, x, mask=mask))

In [61]:
from torch import nn

In [62]:
class OutLayer(nn.Module):
    def __init__(self):
        super().__init__()
        # Tarantino (2019) applied 2d-convolution here.
        self.conv = nn.Conv2d(1, 5, 20)
        self.pool = nn.AdaptiveMaxPool1d(1) # (bs, 5, 109, 1)
        self.lin_out = nn.Linear(5*109, 2) # I might need to make this two.         
    
    def forward(self, inp):
        conv_out = self.conv(inp.unsqueeze(1)) # (bs, 128, 26) -> (bs, 1, 128, 26) -> (bs, 5, 109, 7), insert dummy channel, channel 1 to 5, kerner size = 20
        pool_out = self.pool(conv_out.squeeze(0)) # (bs, 5, 109, 7) -> (5, 109, 7) -> (5, 109, 1)
        out = pool_out.squeeze(-1).unsqueeze(0) # (5, 109, 1) -> (5, 109) -> (1, 5, 109)
        return self.lin_out(out.view(1, -1)) # (1, 5*109) -> (1, 2)    

In [63]:
class CNNTransformer(nn.Module):
    "CNN Transformer model"
    
    def __init__(self, out_vsz:int=4, n_layers:int=1, n_heads:int=3, d_model:int=26, d_head:int=9, 
                 d_inner:int=3 * 9,
                 inp_p:float=0.1, resid_p:float=0.1, attn_p:float=0.1,
                 ff_p:float=0.1,
                 bias:bool=True, 
                 scale:bool=True, double_drop:bool=True):
        super().__init__()
        self.enc_emb = TransformerEmbedding(d_model, inp_p)
        self.encoder = nn.ModuleList([EncoderBlock(n_heads, d_model, d_head, d_inner, resid_p, attn_p, ff_p, bias, scale, double_drop) for _ in range(0, n_layers)])
        self.out_layer = OutLayer()
        
    def forward(self, inp):
        # torch.Size([1, 26, 1, 75]) => torch.Size([1, 128, 26])
        enc = self.enc_emb(inp.squeeze(0))
        # torch.Size([1, 128, 26])
        for enc_block in self.encoder: enc = enc_block(enc)
        
        return self.out_layer(enc)

In [64]:
m1 = CNNTransformer()

In [73]:
z1 = xb.squeeze(0); z1.shape

torch.Size([26, 1, 67])

In [74]:
z2 = m1.enc_emb(z1); z2. # embedding. i.e., cnn feature extraction -> add with PE

torch.Size([1, 128, 26])

In [75]:
# encoder block
for enc_block in m1.encoder:
    print("n_block")
    z3 = enc_block(z2)

n_block


In [76]:
z3.shape

torch.Size([1, 128, 26])

In [77]:
# out layer. projection (dim * n_frames -> n_out)
m1.out_layer(z3).shape

torch.Size([1, 2])

In [78]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

class FocalLoss(nn.modules.loss._WeightedLoss):
    def __init__(self, weight=None, gamma=2,reduction='mean'):
        super(FocalLoss, self).__init__(weight,reduction=reduction)
        self.gamma = gamma
        self.weight = weight #weight parameter will act as the alpha parameter to balance class weights

    def forward(self, input, target):

        ce_loss = F.cross_entropy(input, target,reduction=self.reduction,weight=self.weight) 
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
        return focal_loss

In [79]:
cbfs = [partial(AvgStatsCallback,accuracy),
        CudaCallback]

In [80]:
def normalize_chan(x, mean, std):
    return (x-mean[...,None,None]) / std[...,None,None]

_m = tensor([4.6458, 5.4130, 6.5964, 6.8884, 6.8241, 7.1415, 7.2152, 7.1492, 6.6828,
        6.7434, 6.8610, 6.9456, 7.2149, 7.3144, 7.3993, 7.1714, 7.3913, 7.5860,
        7.3430, 7.3854, 7.4977, 7.4650, 7.3808, 7.0497, 6.7768, 6.4319])
_s = tensor([1.6839, 2.3302, 2.6372, 2.6552, 2.7861, 2.7495, 2.7446, 2.6085, 2.5018,
        2.3586, 2.3367, 2.3888, 2.4452, 2.4944, 2.4172, 2.4300, 2.3737, 2.4037,
        2.4891, 2.4774, 2.4399, 2.3689, 2.2110, 2.2310, 2.2802, 2.2686])
norm_ser = partial(normalize_chan, mean=_m.cuda(), std=_s.cuda())

class BatchTransformXCallback(Callback):
    _order=2
    def __init__(self, tfm): self.tfm = tfm
    def begin_batch(self):
        self.run.xb = self.tfm(self.xb).squeeze(0)
        # self.run.yb = self.yb.squeeze(0)

In [81]:
cbfs.append(partial(BatchTransformXCallback, norm_ser))

In [None]:
print(inspect.getsource(cos_1cycle_anneal))

def cos_1cycle_anneal(start, high, end):
    return [sched_cos(start, high), sched_cos(high, end)]



In [82]:
sched = combine_scheds([0.3,0.7], cos_1cycle_anneal(0.1,0.3,0.05))

In [None]:
import inspect; inspect.getsource(init_cnn)

'def init_cnn(m, uniform=False):\n    f = init.kaiming_uniform_ if uniform else init.kaiming_normal_\n    init_cnn_(m, f)\n'

In [83]:
def get_learn_run(model, data, lr, cbs=None, opt_func=None, **kwargs):
    init_cnn(model)
    return get_runner(model, data, lr=lr, cbs=cbs, opt_func=opt_func)

In [84]:
model = CNNTransformer()

In [85]:
learn,run = get_learn_run(model, data, 0.2, cbs=cbfs+[
    partial(ParamScheduler, 'lr', sched)], opt_func=opt_func, loss_func = FocalLoss()
)

In [86]:
run.fit(10, learn)

  self.tot_mets[i] += torch.tensor(m(run.pred, run.yb)) * bn
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1050.)


iteration: 0, accuracy:  train: [1.3878704951359675, tensor(0., device='cuda:0')]
iteration: 1000, accuracy:  train: [108.02894221163452, tensor(0.4680, device='cuda:0')]
iteration: 2000, accuracy:  train: [95.20179333410218, tensor(0.4570, device='cuda:0')]
iteration: 3000, accuracy:  train: [93.45566862327942, tensor(0.4503, device='cuda:0')]
iteration: 4000, accuracy:  train: [98.27458519985389, tensor(0.4538, device='cuda:0')]
iteration: 5000, accuracy:  train: [85.58152215710705, tensor(0.4543, device='cuda:0')]
iteration: 6000, accuracy:  train: [154.42100675528437, tensor(0.4563, device='cuda:0')]
iteration: 7000, accuracy:  train: [181.03646731785568, tensor(0.4582, device='cuda:0')]


AttributeError: ignored

In [None]:
def model_summary(run, learn, data, find_all=False):
    xb,yb = get_batch(data.valid_dl, run)
    device = next(learn.model.parameters()).device#Model may not be on the GPU yet
    xb,yb = xb.to(device),yb.to(device)
    mods = learn.model.children()
    f = lambda hook,mod,inp,out: print(f"{mod}\n{out.shape}\n")
    with Hooks(mods, f) as hooks: learn.model(xb)

In [None]:
run.cbs

[<__main__.TrainEvalCallback at 0x7f4cc4c13f90>,
 <__main__.AvgStatsCallback at 0x7f4cc4c13450>,
 <exp.nb_06.CudaCallback at 0x7f4cc7ae7a90>,
 <__main__.BatchTransformXCallback at 0x7f4cc7ae7210>,
 <__main__.ParamScheduler at 0x7f4cc4c15c90>]