In [99]:
import torch
import math
from torch import nn
from torch.nn.modules import activation
import torch.nn.functional as F
import math
import TAR_transformer

In [100]:
class PositionalEncoding(nn.Module):

    def __init__(self, seq_len, d_model, dropout = 0.1):
        super(PositionalEncoding, self).__init__()
        max_len = max(5000, seq_len)
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        
        if d_model % 2 == 0:
            pe[:, 1::2] = torch.cos(position * div_term)
        else:
            pe[:, 1::2] = torch.cos(position * div_term)[: , 0 : -1]
        
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    # Input: seq_len x batch_size x dim, Output: seq_len, batch_size, dim
    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
    
    
    
class Permute(torch.nn.Module):
    def forward(self, x):
        return x.permute(1, 0)
    
    

class MultitaskTransformerModel(nn.Module):

    def __init__(self, task_type, device, nclasses, seq_len, batch, input_size, emb_size, nhead, nhid, nhid_tar, nhid_task, nlayers, dropout = 0.1):
        super(MultitaskTransformerModel, self).__init__()
        # from torch.nn import TransformerEncoder, TransformerEncoderLayer
        
        self.trunk_net = nn.Sequential(
            nn.Linear(input_size, emb_size),
#             nn.BatchNorm1d(batch),
            nn.LayerNorm(emb_size),
            PositionalEncoding(seq_len, emb_size, dropout),
            nn.LayerNorm(emb_size),
#             nn.BatchNorm1d(batch)
        )
        
        # encoder_layers = transformer_encoder_class.TransformerEncoderLayer(emb_size, nhead, nhid, out_channel, filter_height, filter_width, dropout)
        # encoder_layers = TransformerEncoderLayer(emb_size, nhead, nhid, dropout)
        # self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        
        encoder_layers = TAR_transformer.TransformerEncoderLayer(emb_size, nhead, nhid, dropout)
        self.transformer_encoder = TAR_transformer.TransformerEncoder(encoder_layers, nlayers, device)
        
        self.batch_norm = nn.BatchNorm1d(batch)
        
        self.layer_norm = nn.LayerNorm(emb_size)
        
        # Task-aware Reconstruction Layers
        self.tar_net = nn.Sequential(
            nn.Linear(emb_size, nhid_tar),
            nn.BatchNorm1d(batch),
            nn.Linear(nhid_tar, nhid_tar),
            nn.BatchNorm1d(batch),
            nn.Linear(nhid_tar, input_size),
        )

        if task_type == 'classification':
            # Classification Layers
            self.class_net = nn.Sequential(
                nn.Linear(emb_size, nhid_task),
                nn.ReLU(),
                Permute(),
#                 nn.BatchNorm1d(batch),
                Permute(),
                nn.Dropout(p = 0.3),
                nn.Linear(nhid_task, nhid_task),
                nn.ReLU(),
                Permute(),
#                 nn.BatchNorm1d(batch),
                Permute(),
                nn.Dropout(p = 0.3),
                nn.Linear(nhid_task, nclasses)
            )
        else:
            # Regression Layers
            self.reg_net = nn.Sequential(
                nn.Linear(emb_size, nhid_task),
                nn.ReLU(),
                Permute(),
                nn.BatchNorm1d(batch),
                Permute(),
                nn.Linear(nhid_task, nhid_task),
                nn.ReLU(),
                Permute(),
                nn.BatchNorm1d(batch),
                Permute(),
                nn.Linear(nhid_task, 1),
            )
            

        
    def forward(self, x):
        x = torch.tensor(x, dtype=torch.float32)
        x = self.trunk_net(x.permute(2, 0, 1))
        x, attn = self.transformer_encoder(x)
#         x = self.batch_norm(x)
        # x : seq_len x batch x emb_size
        output = self.class_net(x[-1])
        return output, attn


"\ndevice = 'cuda:2'    \nlr, dropout = 0.01, 0.01\nnclasses, seq_len, batch, input_size = 12, 5, 11, 10\nemb_size, nhid, nhead, nlayers = 32, 128, 2, 3\nnhid_tar, nhid_task = 128, 128\ntask_type = 'regression'\nmodel = MultitaskTransformerModel(task_type, device, nclasses, seq_len, batch, input_size, emb_size, nhead, nhid, nhid_tar, nhid_task, nlayers, dropout = 0.1).to(device)\nx = torch.randn(batch, seq_len, input_size) * 50\nx = torch.as_tensor(x).float()\nprint(x.shape)\n(out_tar, attn_tar), (out_task, attn_task) = model(torch.as_tensor(x, device = device), 'reconstruction'), model(torch.as_tensor(x, device = device), task_type)\nprint(out_tar.shape)\nprint(attn_tar.shape)\nprint(out_task.shape)\nprint(attn_task.shape)\n"

In [101]:
task_type = "classification"
device = "cuda"
nclasses = 3
seq_len = 1500
batch_size = 64
input_size = 1500
emb_size = 512
nhead = 8
nhid = 64
nhid_tar = 1024
nhid_task = 128
nlayers = 2


learning_rate = 0.001
num_epochs = 80
use_cuda = True
early_stop = 3
min_delta = -0.025
min_epochs = 30
num_workers = 8

model = MultitaskTransformerModel(task_type, device, nclasses, seq_len, batch, input_size, emb_size, nhead, nhid, nhid_tar, nhid_task, nlayers, dropout = 0.1)

In [102]:
x = torch.load("Rat2_DF_2.pt")
x = x.unsqueeze(0)
x = x.cuda()
model = model.cuda()

In [103]:
x.shape

torch.Size([1, 1500, 56])

In [104]:
model(x)

  x = torch.tensor(x, dtype=torch.float32)


(tensor([[-0.2644,  0.0107, -0.0934]], device='cuda:0',
        grad_fn=<AddmmBackward0>),
 tensor([[[0.0355, 0.0453, 0.0484,  ..., 0.0236, 0.0283, 0.0341],
          [0.0254, 0.0365, 0.0385,  ..., 0.0302, 0.0337, 0.0319],
          [0.0270, 0.0385, 0.0390,  ..., 0.0342, 0.0402, 0.0451],
          ...,
          [0.0314, 0.0388, 0.0406,  ..., 0.0420, 0.0411, 0.0381],
          [0.0358, 0.0336, 0.0420,  ..., 0.0361, 0.0312, 0.0341],
          [0.0306, 0.0337, 0.0389,  ..., 0.0391, 0.0324, 0.0357]]],
        device='cuda:0', grad_fn=<AddBackward0>))

In [55]:
x.permute(2, 0, 1).shape

torch.Size([56, 1, 1500])

In [37]:
torch.permute

<function torch._VariableFunctionsClass.permute>

In [86]:
model

MultitaskTransformerModel(
  (trunk_net): Sequential(
    (0): Linear(in_features=1500, out_features=128, bias=True)
    (1): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (3): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=64, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=

In [88]:
model = model.eval

In [89]:
model

<bound method Module.eval of MultitaskTransformerModel(
  (trunk_net): Sequential(
    (0): Linear(in_features=1500, out_features=128, bias=True)
    (1): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (3): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=64, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
 