<a href="https://colab.research.google.com/github/BeArnab96/beingarnab.github.io/blob/gh-pages/Block_Recurrent_TimesNet_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install einops

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [2]:
import numpy as np
import pickle
import pandas as pd
import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, random
import matplotlib.pyplot as plt
from scipy.ndimage import median_filter
from einops import rearrange, repeat, einsum, reduce
from einops.layers.torch import Rearrange
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader

from glob import glob
import natsort
import random
import re

Conv Blocks

In [3]:
class Inception_Block_V1(nn.Module):
    def __init__(self, in_channels, out_channels, final_shape=None, num_kernels=6, kernel_size=7, patch_size=4,
                 init_weight=True):
        super(Inception_Block_V1, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_kernels = num_kernels
        kernels = []
        for i in range(self.num_kernels):
            kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=int(2 * i + 1), padding=int(i)))
        self.kernels = nn.ModuleList(kernels)
        if init_weight:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        res_list = []
        for i in range(self.num_kernels):
            res_list.append(self.kernels[i](x))
        res = torch.stack(res_list, dim=-1).mean(-1)
        return res

class Inception_Block_V2(nn.Module):
    def __init__(self, in_channels, out_channels, final_shape=None, num_kernels=6, kernel_size=7, patch_size=4,
                 init_weight=True):
        super(Inception_Block_V2, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_kernels = num_kernels
        kernels = []
        for i in range(self.num_kernels // 2):
            kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[1, 2 * i + 3], padding=[0, i + 1]))
            kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[2 * i + 3, 1], padding=[i + 1, 0]))
        kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
        self.kernels = nn.ModuleList(kernels)
        if init_weight:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        res_list = []
        for i in range(self.num_kernels + 1):
            res_list.append(self.kernels[i](x))
        res = torch.stack(res_list, dim=-1).mean(-1)
        return res

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return x + self.fn(x)

class ConvNeXT_Block(nn.Module):

    def __init__(self, in_channels, out_channels, final_shape=None, num_kernels=6, kernel_size=7, patch_size=4,
                 init_weight=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.patch_size = patch_size
        self.final_shape = (10, 10) if final_shape is None else final_shape

        self.patchifying_conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)

        self.depth_conv = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, groups=out_channels,
                                    padding='same')
        self.layernorm = nn.LayerNorm(out_channels)

        self.mixer_bottleneck = nn.Sequential(nn.Conv2d(out_channels, int(out_channels * 4), kernel_size=1),
                                              nn.GELU(),
                                              nn.Conv2d(int(4 * out_channels), out_channels, kernel_size=1))

        self.pool = nn.AdaptiveAvgPool2d(self.final_shape)

    def forward(self, x):
        patches = self.patchifying_conv(x)
        depth_conv_emb = self.depth_conv(patches).permute(0, 2, 3, 1)
        residual_emb = self.layernorm(depth_conv_emb).permute(0, 3, 1, 2) + patches

        final_op = self.mixer_bottleneck(residual_emb)
        final_op = self.pool(final_op)

        return final_op

class ResNeXT_Block(nn.Module):

    def __init__(self, in_channels, out_channels, final_shape=None, num_kernels=6, kernel_size=3, patch_size=4,
                 init_weight=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.bnorm1 = nn.BatchNorm2d(out_channels)

        self.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, padding='same'),
                                   nn.GELU(),
                                   nn.BatchNorm2d(out_channels))
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, groups=out_channels, padding='same'),
            nn.GELU(),
            nn.BatchNorm2d(out_channels))
        self.conv3 = nn.Sequential(nn.Conv2d(out_channels, in_channels, kernel_size=1, padding='same'),
                                   nn.GELU(),
                                   nn.BatchNorm2d(in_channels))
        # self.pool = nn.AdaptiveAvgPool2d(self.final_shape)

    def forward(self, x):
        return x + self.conv3(self.conv2(self.conv1(x)))

class ConvMix_Block(nn.Module):

    def __init__(self, in_channels, out_channels, final_shape=None, num_kernels=None, kernel_size=9, patch_size=7,
                 init_weight=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.patch_size = patch_size

        self.final_shape = (10, 10) if final_shape is None else final_shape

        self.patchifying_conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)
        self.act = nn.GELU()
        self.bnorm1 = nn.BatchNorm2d(out_channels)

        self.spatial_mixer = nn.Sequential(Residual(nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, groups=out_channels, padding='same'),
            nn.GELU(),
            nn.BatchNorm2d(out_channels))))

        self.channel_mixer = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size=1),
                                           nn.GELU(),
                                           nn.BatchNorm2d(out_channels))

        self.pool = nn.AdaptiveAvgPool2d(self.final_shape)

    def forward(self, x):
        patches = self.bnorm1(self.act(self.patchifying_conv(x)))
        spatial_mix = self.spatial_mixer(patches)
        channel_mix = self.channel_mixer(spatial_mix)
        final_op = self.pool(channel_mix)

        return final_op

class ConvNeXT_multiscale_shared(nn.Module):

    def __init__(self, in_channels, out_channels, final_shape=None, num_kernels=6, kernel_size=None,
                 patch_size=[2, 4, 6], init_weight=True):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_kernels = num_kernels
        self.patch_size = patch_size
        self.final_shape = (10, 10) if final_shape is None else final_shape

        patch_list = []
        for i in range(len(patch_size)):
            patch_list.append(
                nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=patch_size[i], stride=patch_size[i]),
                              # Patchify
                              nn.BatchNorm2d(out_channels),
                              nn.GELU()))

        kernel_list = []
        for j in range(num_kernels // 2):
            kernel_list.append(nn.Sequential(
                nn.Conv2d(out_channels, out_channels, kernel_size=[1, int(2 * j + 3)], groups=out_channels,
                          padding=[0, j + 1]),  # Depthwise convolution
                nn.Conv2d(out_channels, out_channels, kernel_size=[int(2 * j + 3), 1], groups=out_channels,
                          padding=[j + 1, 0]),  # Inception Style
                nn.BatchNorm2d(out_channels),
                nn.Conv2d(out_channels, 4 * out_channels, kernel_size=1),  # Inverse Bottleneck
                nn.GELU(),
                nn.Conv2d(4 * out_channels, out_channels, kernel_size=1)))

        self.patch_module_list = nn.ModuleList(patch_list)
        self.kernel_module_list = nn.ModuleList(kernel_list)

        # self.kernel_module_list = nn.ModuleList(kernel_list)

        self.pool = nn.AdaptiveAvgPool2d(self.final_shape)

        # self.dim_proj = nn.Conv2d(out_channels*)

        if init_weight:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):

        B, C, H, W = x.shape

        res_list = []
        for i in range(len(self.patch_size)):

            patches = self.patch_module_list[i](x)

            patch_res_list = []
            for j in range(self.num_kernels // 2):
                patch_res_list.append(self.kernel_module_list[j](patches))

            patch_res = torch.stack(patch_res_list, dim=-1).mean(-1) + patches
            patch_res = self.pool(patch_res)
            res_list.append(patch_res)

        final_res = torch.stack(res_list, dim=-1).mean(-1)
        return final_res

class ConvNeXT_multiscale_independent(nn.Module):

    def __init__(self, in_channels, out_channels, final_shape=None, num_kernels=6, kernel_size=None,
                 patch_size=[2, 4, 6], init_weight=True):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_kernels = num_kernels
        self.patch_size = patch_size
        self.final_shape = (10, 10) if final_shape is None else final_shape

        kernel_list = []
        for i in range(len(patch_size)):
            patch_list = [
                nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=patch_size[i], stride=patch_size[i]),
                              # Patchify
                              nn.BatchNorm2d(out_channels),
                              nn.GELU())]
            for j in range(num_kernels // 2):
                patch_list.append(nn.Sequential(
                    nn.Conv2d(out_channels, out_channels, kernel_size=[1, int(2 * j + 3)], groups=out_channels,
                              padding=[0, j + 1]),  # Depthwise convolution
                    nn.Conv2d(out_channels, out_channels, kernel_size=[int(2 * j + 3), 1], groups=out_channels,
                              padding=[j + 1, 0]),  # Inception Style
                    nn.BatchNorm2d(out_channels),
                    nn.Conv2d(out_channels, 4 * out_channels, kernel_size=1),  # Inverse Bottleneck
                    nn.GELU(),
                    nn.Conv2d(4 * out_channels, out_channels, kernel_size=1)))

            patch_module_list = nn.ModuleList(patch_list)

            kernel_list.append(patch_module_list)
        self.kernel_module_list = nn.ModuleList(kernel_list)

        self.pool = nn.AdaptiveAvgPool2d(self.final_shape)

        # self.dim_proj = nn.Conv2d(out_channels*)

        if init_weight:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):

        B, C, H, W = x.shape

        res_list = []

        for i in range(len(self.patch_size)):

            patches = self.kernel_module_list[i][0](x)
            patch_res_list = []

            for j in range(self.num_kernels // 2):
                patch_res_list.append(self.kernel_module_list[i][j + 1](patches))

            patch_res = torch.stack(patch_res_list, dim=-1).mean(-1) + patches
            patch_res = self.pool(patch_res)
            res_list.append(patch_res)

        final_res = torch.stack(res_list, dim=-1).mean(-1)
        return final_res

class SwinTransformer_Block(nn.Module):

    def __init__(self, in_channels, out_channels, final_shape=None, num_kernels=6, kernel_size=7, patch_size=7,
                 shift_size=None, dropout=0.1):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.patch_size = patch_size
        self.shift_size = (patch_size // 2, patch_size // 2) if shift_size is None else shift_size
        self.attn_dropout = dropout
        self.proj_drop = dropout
        self.projection = nn.Linear(out_channels, out_channels)

        self.softmax = nn.Softmax(dim=-1)
        self.qkv_projection = nn.Linear(out_channels, int(out_channels * 3))

        self.norm1 = nn.LayerNorm()
        self.norm2 = nn.LayerNorm()

        pass

Timesnet feature extractor

In [6]:
@torch.no_grad()
class RoPE(nn.Module):

    def __init__(self, dmodel):
        super().__init__()
        self.dmodel = dmodel
        dims = torch.arange(self.dmodel // 2)
        self.theta = torch.zeros((self.dmodel,))

        self.theta[0::2] = 10000 ** (2 * dims / dmodel)
        self.theta[1::2] = 10000 ** (2 * dims / dmodel)

    def forward(self, x, timestamp):
        B, T, D = x.shape
        _, _ = timestamp.shape

        xcos = x
        xsin = torch.zeros_like(x)
        xsin[:, :, 0::2] = -x[:, :, 1::2]
        xsin[:, :, 1::2] = x[:, :, 0::2]

        trig_args = torch.matmul(timestamp.reshape(B, T, 1).float(), self.theta.reshape(1, 1, -1).float())
        sin_ = torch.sin(trig_args)
        cos_ = torch.cos(trig_args)

        rpos_emb = x * sin_ + x * cos_

        return rpos_emb


class TaskEmbedding(nn.Module):

  def __init__(self, din, dmodel):

    super().__init__()
    self.dmodel = dmodel
    self.din = din
    self.embedding = nn.Linear(din,dmodel)

  def forward(self, task_feats):

    return self.embedding(task_feats)



In [8]:
def FFT_for_Period(x, k=2):
    # [B, T, C]
    T = x.shape[1]
    xf = torch.fft.rfft(x, dim=1)
    # find period by amplitudes
    frequency_list = abs(xf).mean(0).mean(-1)
    frequency_list[0] = 0
    _, top_list = torch.topk(frequency_list, len(frequency_list))
    top_list = top_list.detach().cpu().numpy()

    actual_list = []
    period_list = []
    count = 0

    while len(actual_list) < k and count < len(top_list):
        if T // top_list[count] not in period_list:
            period_list.append(T // top_list[count])
            actual_list.append(top_list[count])
        else:
            pass
        count += 1

    actual_list = np.asarray(actual_list)

    period = x.shape[1] // actual_list
    return period, abs(xf).mean(-1)[:, actual_list]


class TimesBlock(nn.Module):
    def __init__(self, tblock_configs):
        super(TimesBlock, self).__init__()

        self.seq_len = tblock_configs['cycle_len']
        self.k = tblock_configs['top_k']
        self.conv_type = tblock_configs['conv_type']
        self.dmodel = tblock_configs['dmodel']

        if self.conv_type != 'ConvNeXT_MS':
          in_c, out_c, fin_shape, num_k, k_size, p_size, init_w =  tblock_configs['conv_configs'][self.conv_type]['configs'][0]
          in_c1, out_c1, fin_shape1, num_k1, k_size1, p_size1, init_w1 = tblock_configs['conv_configs'][self.conv_type]['configs'][1]
        else:
          in_c, out_c, fin_shape, num_k, k_size, p_size, init_w = tblock_configs['conv_configs'][self.conv_type]['configs'][0]

        if self.conv_type == 'InceptionV1':
            self.conv_net = nn.Sequential(Inception_Block_V1(in_c, out_c, fin_shape, num_k, k_size, p_size, init_w),
                                          nn.GELU(),
                                          Inception_Block_V1(in_c1, out_c1, fin_shape1, num_k1, k_size1, p_size1,
                                                             init_w1))

        elif self.conv_type == 'InceptionV2':
            self.conv_net = nn.Sequential(Inception_Block_V2(in_c, out_c, fin_shape, num_k, k_size, p_size, init_w),
                                          nn.GELU(),
                                          Inception_Block_V2(in_c1, out_c1, fin_shape1, num_k1, k_size1, p_size1,
                                                             init_w1))

        elif self.conv_type == 'ConvNeXT':
            self.conv_net = nn.Sequential(ConvNeXT_Block(in_c, out_c, fin_shape, num_k, k_size, p_size, init_w),
                                          nn.GELU(),
                                          ConvNeXT_Block(in_c1, out_c1, fin_shape1, num_k1, k_size1, p_size1, init_w1))

        elif self.conv_type == 'ResNeXT':
            self.conv_net = nn.Sequential(ResNeXT_Block(in_c, out_c, fin_shape, num_k, k_size, p_size, init_w),
                                          nn.GELU(),
                                          ResNeXT_Block(in_c1, out_c1, fin_shape1, num_k1, k_size1, p_size1, init_w1))

        elif self.conv_type == 'ConvMix':
            self.conv_net = nn.Sequential(ConvMix_Block(in_c, out_c, fin_shape, num_k, k_size, p_size, init_w),
                                          nn.GELU(),
                                          ConvMix_Block(in_c1, out_c1, fin_shape1, num_k1, k_size1, p_size1, init_w1))

        elif self.conv_type == 'ConvNeXT_MS':
            self.conv_net = ConvNeXT_multiscale_shared(in_c, out_c, fin_shape, num_k, k_size, p_size, init_w)

        if self.conv_type in ['ConvNeXT', 'ConvNeXT_MS', 'ConvMix']:
            h, w = tblock_configs['final_shape']
            self.time_projection = nn.Linear(int(h * w), self.seq_len)

    def forward(self, x):

        B, T, D = x.size()
        assert (T == self.seq_len), 'Time dimensions do not match'

        period_list, period_weight = FFT_for_Period(x, self.k)

        res = []
        for i in range(self.k):
            period = period_list[i]
            # padding
            if self.seq_len % period != 0:
                length = ((self.seq_len // period) + 1) * period
                padding = torch.zeros([x.shape[0], (length - self.seq_len), x.shape[2]]).to(x.device)
                out = torch.cat([x, padding], dim=1)
            else:
                length = self.seq_len
                out = x
            # reshape
            out = out.reshape(B, length // period, period, D).permute(0, 3, 1, 2).contiguous()

            # 2D conv: from 1d Variation to 2d Variation
            out = self.conv_net(out)
            # reshape back
            out = out.permute(0, 2, 3, 1).reshape(B, -1, D)

            if self.conv_type in ['ConvNeXT', 'ConvNeXT_MS', 'ConvMix']:
                out = self.time_projection(out.permute(0,2,1)).permute(0,2,1)

            res.append(out[:, :self.seq_len, :])

        res = torch.stack(res, dim=-1)
        # adaptive aggregation
        period_weight = F.softmax(period_weight, dim=1)
        period_weight = period_weight.unsqueeze(1).unsqueeze(1).repeat(1, T, D, 1)
        res = torch.sum(res * period_weight, -1)
        # residual connection
        res = res + x
        return res


class TimesNetModel(nn.Module):

    def __init__(self, configs):
        super().__init__()

        self.seq_len = configs['cycle_len']
        self.dmodel = configs['dmodel']
        self.obs_din = configs['obs_din']
        self.dec_din = configs['dec_din']
        self.task_dim = configs['task_dim']

        self.obs_model = nn.ModuleList([TimesBlock(configs)
                                        for _ in range(configs['e_layers'])])

        self.dec_model = nn.ModuleList([TimesBlock(configs)
                                        for _ in range(configs['e_layers'])])

        self.obs_val_embedding = nn.Linear(self.obs_din, self.dmodel)
        self.rel_obs_emb = RoPE(self.dmodel)  # nn.Linear(self.seq_len, self.dmodel)

        self.dec_val_embedding = nn.Linear(self.dec_din, self.dmodel)
        self.rel_dec_emb = RoPE(self.dmodel)  # nn.Linear(self.seq_len, self.dmodel)

        self.glob_fec_emb = nn.Embedding(10000, self.dmodel)

        self.dropout = nn.Dropout(configs['emb_dropout'])

        self.layer = configs['e_layers']

        self.layer_norm_obs = nn.LayerNorm(configs['dmodel'])
        self.layer_norm_dec = nn.LayerNorm(configs['dmodel'])

        self.task_embedding = TaskEmbedding(self.task_dim, self.dmodel)

    def forward(self, obs_vars, dec_vars, tsk_vars=None):
        # dec_vars.shape = (B, T, 4)
        # obs_vars.shape = (B, T, 6)

        B, T, _ = obs_vars.shape

        if tsk_vars is None:
          tsk_emb = torch.zeros((B,1,self.dmodel))
        else:
          tsk_emb = self.task_embedding(tsk_vars)

        obs_td = obs_vars[:, :, -2]
        dec_td = dec_vars[:, :, -2]

        obs_fec = obs_vars[:, 0, -1].int()
        dec_fec = dec_vars[:, 0, -1].int()

        obs_vars = self.obs_val_embedding(obs_vars[:, :, :-2])
        dec_vars = self.dec_val_embedding(dec_vars[:, :, :-2])

        obs_time_cum = torch.cumsum(obs_td, dim=-1)  # cumulative sum of time difference
        dec_time_cum = torch.cumsum(dec_td, dim=-1)  # cumulative sum of time difference

        # obs_td_mat = obs_td[:,None,:] - obs_td[:,:,None]  # obs_td_mat.shape = (B, T, T)
        # dec_td_mat = dec_td[:,None,:] - dec_td[:,:,None]  # dec_td_mat.shape = (B, T, T)
        # Series Stationarization
        obs_means = obs_vars.mean(1, keepdim=True).detach()
        obs_enc = obs_vars - obs_means
        obs_stdev = torch.sqrt(torch.var(obs_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
        obs_enc /= obs_stdev

        dec_means = dec_vars.mean(1, keepdim=True).detach()
        dec_enc = dec_vars - dec_means
        dec_stdev = torch.sqrt(torch.var(dec_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
        dec_enc /= dec_stdev

        # Value_embedding + global_pos_embedding
        obs_emb = self.dropout(self.rel_obs_emb(obs_enc ,obs_time_cum) +  self.glob_fec_emb(obs_fec).reshape(B, -1, self.dmodel)) + tsk_emb.unsqueeze(1) # * self.rel_obs_emb(obs_td_mat)
        dec_emb = self.dropout(self.rel_dec_emb(dec_enc ,dec_time_cum) + self.glob_fec_emb(dec_fec).reshape(B, -1, self.dmodel)) + tsk_emb.unsqueeze(1) # * self.rel_dec_emb(dec_td_mat)

        # Apply TimesBlocks
        for i in range(self.layer):
            obs_emb = self.layer_norm_obs(self.obs_model[i](obs_emb))
            dec_emb = self.layer_norm_dec(self.dec_model[i](dec_emb))

        # De-stationarization

        obs_out = obs_emb * (obs_stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
        obs_out = obs_out + (obs_means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))

        dec_out = dec_emb * (dec_stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
        dec_out = dec_out + (dec_means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))

        return obs_out, obs_td, dec_out, dec_td

Block Recurrent Transformer

In [9]:
class AttentionLayer(nn.Module):

    def __init__(self, attention, num_heads, dmodel):
        super().__init__()

        self.att_ = attention
        self.num_heads = num_heads
        self.dmodel = dmodel


        self.num_heads = num_heads
        self.datt = int(dmodel * num_heads)

        self.q_emb = nn.Linear(self.dmodel, self.datt)
        self.v_emb = nn.Linear(self.dmodel, self.datt)
        self.k_emb = nn.Linear(self.dmodel, self.datt)
        self.out = nn.Linear(self.datt, self.dmodel)

    def forward(self, q, k, v, attn_mask=None, tau=None, delta=None):
        B, L, _ = q.shape
        k_shape = k.shape

        if len(k_shape) == 3:
            _, S, _ = k.shape
            key = self.k_emb(k).reshape(B, S, self.num_heads, self.dmodel)
            value = self.v_emb(v).reshape(B, S, self.num_heads, self.dmodel)
        elif len(k_shape) == 4:
            _, S, T, _ = k.shape
            key = self.k_emb(k).reshape(B, S, T, self.num_heads, self.dmodel).permute(0,1,3,2,4)
            value = self.v_emb(v).reshape(B, S, T, self.num_heads, self.dmodel).permute(0,1,3,2,4)
        query = self.q_emb(q).reshape(B, L, self.num_heads, self.dmodel)


        V, attn = self.att_(query,
                            key,
                            value,
                            attn_mask)

        V = V.reshape(B, L, self.datt)
        output = self.out(V)

        return output, attn

class FullAttention(nn.Module):

    def __init__(self, trainable_scale=True, scale=None, output_attention=False, mask_flag=False, dropout=0.1):

        super().__init__()

        self.scale_train = trainable_scale
        self.output_attention = output_attention
        self.mask_flag = mask_flag
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
        if trainable_scale:
            if scale is None:
                self.scale = nn.Parameter(torch.log2(torch.tensor(100**2-100).float()))
            else:
                self.scale = nn.Parameter(torch.log2(torch.tensor(scale**2-scale).float()))
        else:
            self.scale = None

    def forward(self, query, key, value, mask=None):

        B, L, H, D = query.shape
        v_shape = value.shape

        if self.scale is None:
            scale = 1 / math.sqrt(D)
        else:
            scale = self.scale
            query = F.normalize(query, dim = 2)
            key = F.normalize(key, dim = 2)

        if len(v_shape) == 4:
            _, S, _, _ = value.shape
            scores = einsum(query, key,'b t1 h d, b t2 h d -> b h t1 t2')
        elif len(v_shape) == 5:
            _, S, _, T, _ = value.shape
            scores = einsum(query, key, 'b l1 h d, b l2 h t d -> b h l1 l2 t').mean(-1)

        if self.mask_flag:
            scores = torch.tril(scores.reshape(-1,L,S)).reshape(B, H, L , S)


            #scores = rearrange(scores, 'b h t1 t2 -> (b h) t1 t2')
            #scores = torch.tril(scores)
            #scores = rearrange(scores, '(b h) t1 t2 -> b h t1 t2')

        scores = scores * scale
        A = self.dropout(self.softmax(scores))
        if len(v_shape) == 4:
            V = einsum( A, value, 'b h t1 t2, b t2 h d -> b t1 h d')
        elif len(v_shape) == 5:
            V = einsum( A, value, 'b h t1 t2, b t2 h t d -> b t1 h d t').mean(-1)


        if self.output_attention:
            return (V.contiguous(), A)
        else:
            return (V.contiguous(), None)

In [10]:
class Block_Recurrent_Cell(nn.Module):

    def __init__(self, attention_details, din, dout, gate_config='LSTM', gate_type='dual', vert_activation='leaky-relu',
                 hor_activation='relu', d_ff=None, dropout=0.1):
        super().__init__()

        if din != dout:
            self.transform_inp = nn.Linear(din, dout)

        self.din = din
        self.dmodel = dout

        self.d_ff = int(4 * self.dmodel) if d_ff is None else d_ff

        self.gate_config = gate_config
        self.gate_type = gate_type

        attention_list = attention_details['layers']
        self.num_heads = attention_details['num_heads']

        self.qs_hor = nn.Linear(dout, int(self.num_heads * self.dmodel))
        self.qs_vert = nn.Linear(din, int(self.num_heads * self.dmodel))
        self.qe_hor = nn.Linear(dout, int(self.num_heads * self.dmodel))
        self.qe_vert = nn.Linear(din, int(self.num_heads * self.dmodel))

        self.ks_emb = nn.Linear(dout, int(self.num_heads * self.dmodel))
        self.ke_emb = nn.Linear(din, int(self.num_heads * self.dmodel))
        self.vs_emb = nn.Linear(dout, int(self.num_heads * self.dmodel))
        self.ve_emb = nn.Linear(din, int(self.num_heads * self.dmodel))

        self.vert_self_att = attention_list[0]
        self.vert_cross_att = attention_list[1]
        self.hor_self_att = attention_list[2]
        self.hor_cross_att = attention_list[3]

        if vert_activation == 'relu':
            self.vert_act = nn.ReLU()
        elif vert_activation == 'leaky-relu':
            self.vert_act = nn.LeakyReLU()
        elif vert_activation == 'gelu':
            self.vert_act = nn.GELU()

        self.vert_projection = nn.Linear(int(2 * self.dmodel), self.dmodel)
        self.vert_MLP = nn.Sequential(nn.Linear(self.dmodel, self.d_ff),
                                      self.vert_act,
                                      nn.Dropout(dropout),
                                      nn.Linear(self.d_ff, self.dmodel))

        if hor_activation == 'relu':
            self.hor_act = nn.ReLU()
        elif hor_activation == 'leaky-relu':
            self.hor_act = nn.LeakyReLU()
        elif hor_activation == 'gelu':
            self.hor_act = nn.GELU()

        if gate_config == 'LSTM':

            if gate_type == 'dual':

                self.inp_gate1 = nn.Linear(self.dmodel, self.dmodel)
                self.forget_gate1 = nn.Linear(self.dmodel, self.dmodel)
                self.z_emb1 = nn.Linear(self.dmodel, self.dmodel)
                torch.nn.init.normal_(self.inp_gate1.weight, mean=0.0, std=0.1)
                torch.nn.init.normal_(self.inp_gate1.bias, mean=0.0, std=0.1)
                torch.nn.init.normal_(self.forget_gate1.weight, mean=0.0, std=0.1)
                torch.nn.init.normal_(self.forget_gate1.bias, mean=0.0, std=0.1)

                self.inp_gate2 = nn.Linear(self.dmodel, self.dmodel)
                self.forget_gate2 = nn.Linear(self.dmodel, self.dmodel)
                self.z_emb2 = nn.Linear(self.dmodel, self.dmodel)
                torch.nn.init.normal_(self.inp_gate2.weight, mean=0.0, std=0.1)
                torch.nn.init.normal_(self.inp_gate2.bias, mean=0.0, std=0.1)
                torch.nn.init.normal_(self.forget_gate2.weight, mean=0.0, std=0.1)
                torch.nn.init.normal_(self.forget_gate2.bias, mean=0.0, std=0.1)

            elif gate_type == 'single' or gate_type == 'skip':

                self.inp_gate1 = nn.Linear(self.dmodel, self.dmodel)
                self.forget_gate1 = nn.Linear(self.dmodel, self.dmodel)
                self.z_emb1 = nn.Linear(self.dmodel, self.dmodel)
                torch.nn.init.normal_(self.inp_gate1.weight, mean=0.0, std=0.1)
                torch.nn.init.normal_(self.inp_gate1.bias, mean=0.0, std=0.1)
                torch.nn.init.normal_(self.forget_gate1.weight, mean=0.0, std=0.1)
                torch.nn.init.normal_(self.forget_gate1.bias, mean=0.0, std=0.1)

        elif gate_config == 'Fixed':

            if gate_type == 'single' or gate_type == 'skip':

                self.z_emb1 = nn.Linear(self.dmodel, self.dmodel)
                self.bg1 = nn.Parameter(torch.randn(1, 1, self.dmodel))

            elif gate_type == 'dual':

                self.z_emb1 = nn.Linear(self.dmodel, self.dmodel)
                self.bg1 = nn.Parameter(torch.randn(1, 1, self.dmodel))

                self.z_emb2 = nn.Linear(self.dmodel, self.dmodel)
                self.bg2 = nn.Parameter(torch.randn(1, 1, self.dmodel))

        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

        if gate_type == 'dual':

            self.hor_projection = nn.Linear(int(2 * self.dmodel), self.dmodel)
            self.hor_MLP = nn.Sequential(nn.Linear(self.dmodel, self.d_ff),
                                         self.hor_act,
                                         nn.Dropout(dropout),
                                         nn.Linear(self.d_ff, self.dmodel))
            self.hor_norm1 = nn.LayerNorm(self.dmodel)
            self.hor_norm2 = nn.LayerNorm(self.dmodel)

        elif gate_type == 'single':
            self.hor_MLP = nn.Sequential(nn.Linear(int(2*self.dmodel), self.d_ff),
                                         self.hor_act,
                                         nn.Dropout(dropout),
                                         nn.Linear(self.d_ff, self.dmodel))
            self.hor_norm1 = nn.LayerNorm(self.dmodel)

        elif gate_type == 'skip':
            self.hor_projection = nn.Linear(int(2 * self.dmodel), self.dmodel)
            self.hor_norm1 = nn.LayerNorm(self.dmodel)

        self.dropout = nn.Dropout(dropout)

        self.vert_norm1 = nn.LayerNorm(self.dmodel)
        self.vert_norm2 = nn.LayerNorm(self.dmodel)

        self.vert_satt_MLP = nn.Linear(int(self.num_heads*self.dmodel), self.dmodel)
        self.vert_catt_MLP = nn.Linear(int(self.num_heads*self.dmodel), self.dmodel)
        self.hor_satt_MLP = nn.Linear(int(self.num_heads*self.dmodel), self.dmodel)
        self.hor_catt_MLP = nn.Linear(int(self.num_heads*self.dmodel), self.dmodel)

    def forward(self, input_state, recurrent_state):

        B, T, Din = input_state.shape

        _, _, Dout = recurrent_state.shape

        if Din != Dout:
            residual = self.transform_inp(input_state)
        else:
            residual = input_state

        # Vertical Direction

        vert_self, vert_sa = self.vert_self_att(self.qe_vert(input_state).reshape(B, T, self.num_heads, self.dmodel),
                                                self.ke_emb(input_state).reshape(B, T, self.num_heads, self.dmodel),
                                                self.ve_emb(input_state).reshape(B, T, self.num_heads, self.dmodel))
        vert_self = vert_self.reshape(B,T,-1)
        vert_self = self.vert_satt_MLP(vert_self)

        vert_cross, vert_ca = self.vert_cross_att(self.qs_vert(input_state).reshape(B, T, self.num_heads, self.dmodel),
                                                  self.ks_emb(recurrent_state).reshape(B, T, self.num_heads,
                                                                                       self.dmodel),
                                                  self.vs_emb(recurrent_state).reshape(B, T, self.num_heads,
                                                                                       self.dmodel))
        vert_cross = vert_cross.reshape(B,T,-1)
        vert_cross = self.vert_catt_MLP(vert_cross)

        vert_embed = torch.cat((vert_self, vert_cross), dim=-1)
        vert_embed = self.vert_norm1(self.dropout(self.vert_projection(vert_embed)) + residual)
        vert_op = self.vert_norm2(self.vert_MLP(vert_embed) + vert_embed)

        # Horizontal Direction

        hor_self, hor_sa = self.hor_self_att(self.qs_hor(recurrent_state).reshape(B, T, self.num_heads, self.dmodel),
                                             self.ks_emb(recurrent_state).reshape(B, T, self.num_heads, self.dmodel),
                                             self.vs_emb(recurrent_state).reshape(B, T, self.num_heads, self.dmodel))
        hor_self = hor_self.reshape(B,T,-1)
        hor_self = self.hor_satt_MLP(hor_self)

        hor_cross, hor_ca = self.hor_cross_att(self.qe_hor(recurrent_state).reshape(B, T, self.num_heads, self.dmodel),
                                               self.ke_emb(input_state).reshape(B, T, self.num_heads, self.dmodel),
                                               self.ve_emb(input_state).reshape(B, T, self.num_heads, self.dmodel))
        hor_cross = hor_cross.reshape(B,T,-1)
        hor_cross = self.vert_catt_MLP(hor_cross)

        hor_embed = torch.cat((hor_self, hor_cross), dim=-1)

        if self.gate_type == 'dual':

            hor_embed = self.hor_norm1(self.hor_projection(hor_embed))

            # gate 1

            z_gate1 = self.tanh(self.z_emb1(hor_embed))
            inp_gate_emb1 = self.sigmoid(self.inp_gate1(hor_embed) - 1)
            for_gate_emb1 = self.sigmoid(self.forget_gate1(hor_embed) + 1)

            next_state1 = recurrent_state * for_gate_emb1 + inp_gate_emb1 * z_gate1

            # gate 2

            next_state_emb = self.hor_norm2(self.hor_MLP(next_state1))
            z_gate2 = self.tanh(self.z_emb2(next_state_emb))

            inp_gate_emb2 = self.sigmoid(self.inp_gate2(next_state_emb) - 1)
            for_gate_emb2 = self.sigmoid(self.forget_gate2(next_state_emb) + 1)

            final_next_state = next_state1 * for_gate_emb2 + inp_gate_emb2 * z_gate2

        elif self.gate_type == 'single':

            hor_embed = self.hor_norm1(self.hor_MLP(hor_embed))

            # gate 1

            z_gate1 = self.tanh(self.z_emb1(hor_embed))
            inp_gate_emb1 = self.sigmoid(self.inp_gate1(hor_embed) - 1)
            for_gate_emb1 = self.sigmoid(self.forget_gate1(hor_embed) + 1)

            final_next_state = recurrent_state * for_gate_emb1 + inp_gate_emb1 * z_gate1

        elif self.gate_type == 'skip':

            hor_embed = self.hor_norm1(self.hor_projection(hor_embed))

            # gate 1

            z_gate1 = self.tanh(self.z_emb1(hor_embed))
            inp_gate_emb1 = self.sigmoid(self.inp_gate1(hor_embed) - 1)
            for_gate_emb1 = self.sigmoid(self.forget_gate1(hor_embed) + 1)

            final_next_state = recurrent_state * for_gate_emb1 + inp_gate_emb1 * z_gate1

        vert_outputs = (vert_op, vert_sa, vert_ca)
        hor_outputs = (final_next_state, hor_sa, hor_ca)

        return vert_outputs, hor_outputs

In [11]:
class Block_Recurrent_Transformer_Layer(nn.Module):

  def __init__(self, cell_config, max_state_len = 10000, dropout=0.1, return_sequence = True):
    super().__init__()

    self.din = cell_config['brt_din']
    self.dout = cell_config['brt_dout']
    self.dmodel = self.dout

    #self.num_cells = cell_config['num_cells']
    self.return_sequence = return_sequence

    self.cell_gate_config = cell_config['gate_config']
    self.cell_gate_type = cell_config['gate_type']
    self.cell_vert_act = cell_config['vert_activation']
    self.cell_hor_act = cell_config['hor_activation']
    self.cell_dff = int(4*self.dmodel) if cell_config['d_ff'] is None else cell_config['d_ff']
    self.attention_details = cell_config['attention_details']


    self.state_rope = RoPE(self.dmodel)

    self.BRCell = Block_Recurrent_Cell(self.attention_details,
                                       self.din,
                                       self.dmodel,
                                       self.cell_gate_config,
                                       self.cell_gate_type,
                                       self.cell_vert_act,
                                       self.cell_hor_act,
                                       self.cell_dff, dropout)

  def forward(self, x_seq, initial_state = None):

    # x_seq.shape = B x num_cells x T x D
    # initial_state.shape = B x T_state x D
    #B, _, T, _ = obs_seq
    #_, _, _, _ = dec_seq

    B, num_cells, T, Din = x_seq.shape

    if initial_state is not None:
      _, T_state, _ = initial_state.shape
      assert T_state==T, 'state and input time dimensions do not match'
    else:
      initial_state = torch.randn((B,T,self.dmodel), requires_grad=False)

    initial_state = self.state_rope(initial_state, torch.arange(T).reshape(1,-1).repeat(B, 1))

    output_list = torch.empty((B,num_cells,T,self.dmodel))
    state_list = torch.empty((B,num_cells,T,self.dmodel))

    vert_att_list = []
    hor_att_list = []

    for cell in range(num_cells):

      next_output, next_state = self.BRCell(x_seq[:,cell,:,:], initial_state)

      output_list[:,cell] = next_output[0]
      state_list[:,cell] = next_state[0]

      vert_att_list.append((next_output[1], next_output[2]))
      hor_att_list.append((next_state[1], next_state[2]))

      initial_state = self.state_rope(next_state[0], torch.arange(T).reshape(1,-1).repeat(B, 1)) #+ self.state_embed(torch.arange(T)).reshape(1,T,self.dmodel)
      #initial_state = initial_state + self.state_embed(torch.arange(T)).reshape(1,T,self.dmodel)

    if self.return_sequence:
      return output_list, state_list, vert_att_list, hor_att_list
    else:
      return output_list[:,-1,:,:], state_list[:,-1,:,:], vert_att_list[-1], hor_att_list[-1]

Block Causal Decoder

In [12]:
class DecoderLayer(nn.Module):

    def __init__(self, configs):

        super().__init__()

        self.att_heads = configs['dec_att_heads']
        self.dmodel = configs['dmodel']
        d_ff = configs['dec_d_ff'] or 4 * self.dmodel

        self.self_attn = AttentionLayer(FullAttention(trainable_scale=True,
                                                      scale = configs['num_cells'],
                                                      mask_flag = True),
                                        num_heads=self.att_heads,
                                        dmodel=self.dmodel)

        self.cross_attn = AttentionLayer(FullAttention(trainable_scale=True,
                                                       mask_flag=True,
                                                       scale = configs['num_cells']),
                                         num_heads=self.att_heads,
                                         dmodel=self.dmodel)

        self.conv1 = nn.Conv1d(in_channels=int(2*self.dmodel), out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=self.dmodel, kernel_size=1)
        self.norm1 = nn.LayerNorm(self.dmodel)
        self.norm2 = nn.LayerNorm(self.dmodel)
        self.norm3 = nn.LayerNorm(self.dmodel)
        self.dropout = nn.Dropout(configs['dec_dropout'])
        self.activation = nn.ReLU() if configs['dec_act'] == 'relu' else nn.GELU()

    def forward(self, Q, state):

        qself, Aself = self.self_attn(Q, Q, Q)
        qself = self.dropout(qself)
        qcross, Across = self.cross_attn(Q, state, state)
        qcross = self.dropout(qcross)
        qembed = torch.concat((self.norm1(qself + Q), self.norm2(qcross + Q)), dim=-1)

        qout = self.conv2(self.dropout(self.activation(self.conv1(qembed.transpose(-1, 1))))).transpose(-1, 1)

        return self.norm3(qout + Q)


class Decoder(nn.Module):

    def __init__(self, configs):
        super(Decoder, self).__init__()

        self.dmodel = configs['dmodel']
        self.num_layers = configs['num_decoder_layers']
        self.dout = configs['dec_dout']

        self.layers = nn.ModuleList()
        for l in range(self.num_layers):
            self.layers.append(DecoderLayer(configs))

        self.norm = nn.LayerNorm(self.dmodel)
        self.projection = nn.Linear(self.dmodel, self.dout)

        #self.input_embedding = RoPE(self.dmodel)
        #self.cross_embedding = RoPE(int(self.state_seq_len*self.dmodel))

    def forward(self, x, cross):

        B, L, D = x.shape
        _, _, T, _ = cross.shape

        for layer in self.layers:
            x = layer(x, cross)

        if self.norm is not None:
            x = self.norm(x)

        if self.projection is not None:
            x = self.projection(x)

        return x

In [20]:
class Encoder_SSL(nn.Module):

    def __init__(self, model_config):
        super().__init__()

        self.num_cells = model_config['num_cells']
        self.dmodel = model_config['dmodel']
        #self.din = model_config['din']
        self.dout = model_config['dout']
        self.encoder_layers = model_config['recurrent_encoder_layers']

        self.feat_ex = TimesNetModel(model_config)

        self.encoder_layer = Block_Recurrent_Transformer_Layer(model_config)

        self.obs_rope = RoPE(self.dmodel)
        self.dec_rope = RoPE(self.dmodel)

        self.out_MLP = nn.Sequential(nn.Linear(self.dmodel, int(4*self.dout)),
                                     nn.LeakyReLU(0.05),
                                     nn.Linear(int(4*self.dout), self.dout))

        #self.glob_pos_obs_emb = nn.Embedding(self.num_cells, self.dmodel)

    def forward(self, obs, dec, tsk = None, init_states=None):

        # init_state.shape = B, T, dmodel

        # Raw observation and control data. B:batch_size , L: number of cells, T: cycle_length
        B, L, T, Dobs = obs.shape
        _, _, _, Ddec = dec.shape

        if tsk is not None:
          _, Dtsk = tsk.shape

        # Define initial states for each encoder layer - init_states needs to be a list of init_states for each recurrent encoder layer

        #if init_states == None:
        #    init_states = [None]*self.encoder_layers
        #elif len(init_states) != self.encoder_layers:
        #    init_states = [init_states]*self.encoder_layers

        # Feature extraction using TimesNet
        # We have two TimesNet networks - one for obs and one for dec
        # The same two models are shared for all the cells.
        # (B, L, T, Dobs), (B, L, T, Ddec) -> (B, L, T, Dmodel), (B, L, T, Dmodel)

        obs_ = obs.reshape(int(B*L), T, -1)
        dec_ = dec.reshape(int(B*L), T, -1)

        if tsk is not None:
          tsk_ = tsk.unsqueeze(1).repeat(1,L,1).reshape(int(B*L),-1)
        else:
          tsk_ = None


        obs_out, obs_td, dec_out, dec_td = self.feat_ex(obs_,dec_, tsk_)


        # The understated lines represent the incorporation of RoPE in obs and dec states following timesnet processing.
        # This goes into the recurrent encoder
        obs_out = self.obs_rope(obs_out, torch.cumsum(obs_td, dim = -1)).reshape(B,L,T,-1)
        dec_out = self.dec_rope(dec_out, torch.cumsum(dec_td, dim = -1)).reshape(B,L,T,-1)

        xout = torch.concat((obs_out, dec_out), dim = -1) # (B, L, T, D), (B, L, T, D) -> (B, L T, 2*D)

        ops, states, _, _ = self.encoder_layer(xout, init_states) # except first cycle, all other cycles will have a non random init_state cached from the previous cycle
        output = self.out_MLP(ops) # (B,L,T,D) - > (B,L,T, Dout)

        return output, states

In [25]:
class Model(nn.Module):

    def __init__(self, model_config):
        super().__init__()

        self.num_cells = model_config['num_cells']
        self.dmodel = model_config['dmodel']
        #self.max_cycles = model_config['max_cycles']
        self.d_cap = model_config['cap_dim']
        self.cycle_len = model_config['cycle_len']


        self.SSL_encoder = Encoder_SSL(model_config)
        self.decoder = Decoder(model_config)

        self.cap_rope_emb = RoPE(self.dmodel)
        self.state_rope_emb = RoPE(int(self.dmodel*self.cycle_len))

        #self.cap_global_pos_emb = nn.Embedding(self.max_cycles,self.dmodel)
        #self.state_global_pos_emb = nn.Embedding(self.max_cycles, int(self.dmodel*self.cycle_len))

        self.cap_val_emb = nn.Linear(self.d_cap, self.dmodel)

    def forward(self, obs, dec, cap, tsk = None, initial_state=None):

        B, Nc, L, Do = obs.shape
        _, _, _, Du = dec.shape
        _, _, Dc = cap.shape

        ssl_output, ssl_states = self.SSL_encoder(obs, dec, tsk, initial_state)

        cap_cycle_index = cap[:,:,-1].int()
        cap_emb = self.cap_rope_emb(self.cap_val_emb(cap[:,:,0:-1]), cap_cycle_index) #+ self.cap_global_pos_emb(cap_cycle_index)

        states_ = ssl_states.reshape(B, Nc, -1)
        dec_state_emb = self.state_rope_emb(states_, cap_cycle_index-1) #+ self.state_global_pos_emb(cap_cycle_index-1)

        cap_next = self.decoder(cap_emb, dec_state_emb.reshape(B, Nc, L, -1))

        return cap_next, ssl_output, ssl_states

In [26]:
h = 10
w = 10

model_config = {
                # Model:
                  'num_cells':50,'dmodel':32,'cap_dim':1,'cycle_len':100,

                  # Encoder_SSL:
                    'dout':3,'recurrent_encoder_layers':1,

                    # TimesNetModel:
                      'obs_din':4, 'dec_din':2, 'emb_dropout':0.1, 'e_layers':2,'task_dim':11,

                      # TimeBlock:
                        'top_k':5, 'conv_type':'InceptionV2', 'final_shape':(10,10),
                        'conv_configs':{'InceptionV1':{ 'num_conv_blocks':2,
                                                      'conv_layer_activation':'gelu',
                                                      'configs': [(32, 128, None, 6, None, None, True),
                                                                  (128, 32, None, 6, None, None, True)]},
                                      'InceptionV2':{ 'num_conv_blocks':2,
                                                      'conv_layer_activation':'gelu',
                                                      'configs': [(32, 128, None, 6, None, None, True),
                                                                  (128, 32, None, 6, None, None, True)]},
                                      'ConvNeXT':{ 'num_conv_blocks':2,
                                                    'conv_layer_activation':'gelu',
                                                    'configs': [(32, 128, (h,w), None, 7, 4, True),
                                                                (128, 32, (h,w), None, 7, 4, True)]},
                                      'ResNeXT':{'num_conv_blocks':2,
                                                    'conv_layer_activation':'gelu',
                                                    'configs': [(32, 128, None, None, 3, None, True),
                                                                (128, 32, None, None, 3, None, True)]},
                                      'ConvMix':{'num_conv_blocks':2,
                                                    'conv_layer_activation':'gelu',
                                                    'configs': [(32, 128, (h,w), None, 7, 8, True),
                                                                (128, 32, (h,w), None, 7, 2, True)]},
                                      'ConvNeXT_MS':{'num_conv_blocks':1,
                                                    'conv_layer_activation':'gelu',
                                                    'configs': [(32, 32, (h,w), 6, 8, [2,5], True)]  }
                                      },

                    # Block_Recurrent_Transformer_Layer:
                        'brt_din':64, 'brt_dout':32, 'gate_config':'LSTM', 'gate_type':'single','vert_activation':'gelu','hor_activation':'gelu','d_ff':None,
                        'attention_details': {'num_heads':8,
                                              'layers':[FullAttention(scale = None,
                                                                      output_attention = False,
                                                                      mask_flag = False,
                                                                      dropout = 0.1),
                                                        FullAttention(scale = None,
                                                                      output_attention = False,
                                                                      mask_flag = False,
                                                                      dropout = 0.1),
                                                        FullAttention(scale = None,
                                                                      output_attention = False,
                                                                      mask_flag = False,
                                                                      dropout = 0.1),
                                                        FullAttention(scale = None,
                                                                      output_attention = False,
                                                                      mask_flag = False,
                                                                      dropout = 0.1)],
                                              },

                  # Decoder:
                    'num_decoder_layers':2, 'dec_dout':1,

                    # Decoder_layer:
                      'dec_att_heads':8, 'dec_d_ff':None, 'dec_dropout': 0.2, 'dec_act':'gelu'
                }

In [27]:
battery_model = Model(model_config)

In [29]:
obs = torch.randn(1, 50, 100, 6)
obs[:,:,:,-1] = torch.arange(50).unsqueeze(0).unsqueeze(2).repeat(1,1,100).int()
dec = torch.randn(1, 50, 100, 4)
dec[:,:,:,-1] = torch.arange(50).unsqueeze(0).unsqueeze(2).repeat(1,1,100).int()
tsk = torch.randn(1, 11)
cap = torch.randn(1, 50, 2)
cap[:,:,-1] = torch.arange(50).unsqueeze(0).repeat(1,1).int()
init_state = torch.randn(1, 100, 32)

cnext, ssl_op, ssl_st = battery_model(obs, dec, cap, tsk, init_state)

In [31]:
num_parameters = sum(p.numel() for p in battery_model.parameters())

print("The number of trainable parameters is", num_parameters)

The number of trainable parameters is 1662780
