### Importing Libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import random
from torch.utils.data import Dataset, DataLoader
from shared.component_logger import component_logger as logger

In [2]:
torch.manual_seed(3)
np.random.seed(3)
random.seed(3)

In [3]:
class PatchEmbedd(nn.Module):
    def __init__(self, in_dim, embed_dim):
        super().__init__()
        self.linear = nn.Linear(in_dim, embed_dim)

    def forward(self, x):
        x = self.linear(x)
        return x

In [4]:
class Attention(nn.Module):
    def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0., proj_p=0.):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim//n_heads
        self.scale = self.head_dim** -0.5

        self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        n_samples, n_tokens, dim = x.shape

        # Sanity check
        if dim != self.dim:
            raise ValueError
        
        #(n_samples, seq_len + 1, 3 * dim)
        qkv = self.qkv(x)  
        
        #(n_smaples, seq_len + 1, 3, n_heads, head_dim)
        qkv = qkv.reshape(n_samples, n_tokens, 3, self.n_heads, self.head_dim)
        
        #(3, n_samples, n_heads, seq_len + 1, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  

        q, k, v = qkv[0], qkv[1], qkv[2]
        
        #(n_samples, n_heads, head_dim, seq_len + 1)
        k_t = k.transpose(-2, -1)  
        
        # (n_samples, n_heads, seq_len + 1, seq_len + 1)
        dp = (q @ k_t)*self.scale 
        attn = dp.softmax(dim=-1)  # (n_samples, n_heads, seq_len + 1, seq_len + 1)
        attn = self.attn_drop(attn)
        
        # (n_samples, n_heads, seq_len +1, head_dim)
        weighted_avg = attn @ v  
        
        # (n_samples, seq_len + 1, n_heads, head_dim)
        weighted_avg = weighted_avg.transpose(1, 2)  
        
        # (n_samples, seq_len + 1, dim)
        weighted_avg = weighted_avg.flatten(2)  
        
        # (n_samples, seq_len + 1, dim)
        x = self.proj(weighted_avg)  
        
        # (n_samples, seq_len + 1, dim)
        x = self.proj_drop(x)  

        return x


In [5]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, p=0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        x = self.fc1(x) # (n_samples, seq_len + 1, hidden_features)
        x = self.act(x)  # (n_samples, seq_len + 1, hidden_features)
        x = self.drop(x)  # (n_samples, seq_len + 1, hidden_features)
        x = self.fc2(x)  # (n_samples, seq_len + 1, out_features)
        x = self.drop(x)  # (n_samples, seq_len + 1, out_features)

        return x

In [6]:
class Block(nn.Module):
    def __init__(self, dim, n_heads, mlp_ratio=4.0, qkv_bias=True, p=0., attn_p=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = Attention(dim, n_heads=n_heads, qkv_bias=qkv_bias, attn_p=attn_p, proj_p=p)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        hidden_features = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=hidden_features, out_features=dim)

    def forward(self, x):
        out_norm = self.norm1(x)
        out_attn = self.attn(out_norm)
        x = x + out_attn

        out_norm = self.norm2(x)
        out_mlp = self.mlp(out_norm)
        x = x + out_mlp

        return x

In [7]:
class TimeTransformer(nn.Module):
    def __init__(self, in_dim, seq_len, embed_dim, out_dim, depth=12, n_heads=12, mlp_ratio=4., qkv_bias=True, p=0., attn_p=0.,):
        super().__init__()
        self.patch_embed = PatchEmbedd(in_dim, embed_dim)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # Total number of tokens = 1 + seq_len
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + seq_len, embed_dim))
        self.pos_drop = nn.Dropout(p=p)

        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim,
                n_heads=n_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                p=p,
                attn_p=attn_p,
            )
            for _ in range(depth)
            ])

        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        self.head = nn.Linear(embed_dim, 28)


    def forward(self, x):
        n_samples = x.shape[0]
        x = self.patch_embed(x)

        cls_token = self.cls_token.expand(n_samples, -1, -1)  # (n_samples, 1, embed_dim)
        x = torch.cat((cls_token, x), dim=1)  # (n_samples, 1 + seq_len, embed_dim)

        # Added positional embedding of the cls token + all the patches to indicate the positions. 
        x = x + self.pos_embed  # (n_samples, 1 + seq_len, embed_dim)
        x = self.pos_drop(x) # (n_samples, 1 + seq_len, embed_dim) (probability of dropping)
        
        for block in self.blocks:
            x = block(x)
            
        x = self.norm(x)
        cls_token_final = x[:, 0]  # just the CLS token
        x = self.head(cls_token_final)

        return x, cls_token_final


### Main 

In [8]:
class TimeDataset(Dataset):
    def __init__(self, csv_file, seq_len):
        self.file = csv_file
        self.data = self.file.values
        ori_data = self.data[::-1]
        # Normalize the data
        ori_data, (minimum, maximum) = self.MinMaxScaler(ori_data)
        temp_data = [] 
        # Cut data by sequence length
        for i in range(seq_len + 1, len(ori_data)):
            x = ori_data[i-seq_len:i]
            y = ori_data[i].reshape(1, -1)
            temp_data.append((x, y))
        self.temp_array = temp_data
    
    def MinMaxScaler(self, data):
        minimum, maximum = np.min(data, 0), np.max(data, 0)
        numerator = data - minimum
        denominator = maximum - minimum
        norm_data = numerator / (denominator + 1e-7)
        return norm_data, (minimum, maximum)

    def __len__(self):
        return len(self.temp_array)

    def __getitem__(self, index):
        x = torch.Tensor(self.temp_array[index][0])
        y = torch.Tensor(self.temp_array[index][1])
        return (x, y)
    


In [9]:
data = "energy"
absolute_path = "data/" + data + "_data.csv" 
model_path = "saved_model/" + data + "_time_transformer.pt"
df = pd.read_csv(absolute_path)

In [10]:
checkpoint = 1
seq_len = 24
batch_size = 64
device = "cuda" if torch.cuda.is_available() else "cpu"
print_every = 100 # 2000 minibatches
epochs = 50
split = 0.7
train_df = df[:int(len(df)*(split))]
valid_df = df[int(len(df)*(split)):]

In [11]:
traindataset = TimeDataset(train_df, seq_len)
validdataset = TimeDataset(valid_df, seq_len)
trainloader = DataLoader(dataset=traindataset, batch_size=batch_size, shuffle=True)
validloader = DataLoader(dataset=validdataset, batch_size=batch_size, shuffle=True)

In [12]:
custom_config = {
        "in_dim": 28,
        "seq_len": seq_len,
        "embed_dim": 16,
        "depth": 24,
        "n_heads": 16,
        "qkv_bias": True,
        "mlp_ratio": 4,
        "out_dim": 28
}
net = TimeTransformer(**custom_config).to(device)

In [13]:
def get_parameters_count(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)


logger.log("Number of trainable parameters: {}".format(get_parameters_count(net)))

2022-07-22 20:52:26.669192: INFO: time: <cell line: 5>: Number of trainable parameters: 80108


### Define a Loss function and optimizer

Let’s use a MSE loss and Adam.

In [14]:
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)

### Train the network

This is when things start to get interesting. We simply have to loop over our data iterator, and feed the inputs to the network and optimize.

In [15]:
for epoch in range(epochs):
    train_loss = 0.0
    net.train()
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        labels = torch.squeeze(labels, dim = 1)
        optimizer.zero_grad()
        outputs, embeddings = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    valid_loss = 0.0
    net.eval()
    with torch.no_grad():
        for i, data in enumerate(validloader, 0):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            labels = torch.squeeze(labels, dim = 1)
            outputs, embeddings = net(inputs)
            loss = criterion(outputs, labels)
            valid_loss += loss.item()
    
    if epoch%checkpoint == 0:
        torch.save({'epoch': epoch, 'model_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(), 'train_loss': train_loss, 
                    'valid_loss': valid_loss}, model_path)
        logger.log("Epoch: {}; mse_train: {}; mse_valid: {}".format(epoch + 1, np.round(train_loss/len(trainloader), 5), np.round(valid_loss/len(validloader), 5)))

logger.log('Finished Training')

2022-07-22 20:52:45.070716: INFO: time: <cell line: 1>: Epoch: 1; mse_train: 0.19434; mse_valid: 0.08987
2022-07-22 20:53:02.101305: INFO: time: <cell line: 1>: Epoch: 2; mse_train: 0.05561; mse_valid: 0.04477
2022-07-22 20:53:19.142146: INFO: time: <cell line: 1>: Epoch: 3; mse_train: 0.02918; mse_valid: 0.03545
2022-07-22 20:53:36.277613: INFO: time: <cell line: 1>: Epoch: 4; mse_train: 0.02279; mse_valid: 0.03475
2022-07-22 20:53:53.284123: INFO: time: <cell line: 1>: Epoch: 5; mse_train: 0.02057; mse_valid: 0.03323
2022-07-22 20:54:10.435806: INFO: time: <cell line: 1>: Epoch: 6; mse_train: 0.01916; mse_valid: 0.03191
2022-07-22 20:54:27.381480: INFO: time: <cell line: 1>: Epoch: 7; mse_train: 0.01806; mse_valid: 0.03098
2022-07-22 20:54:46.001963: INFO: time: <cell line: 1>: Epoch: 8; mse_train: 0.01716; mse_valid: 0.03002
2022-07-22 20:55:04.127743: INFO: time: <cell line: 1>: Epoch: 9; mse_train: 0.01646; mse_valid: 0.02898
2022-07-22 20:55:22.284944: INFO: time: <cell line: 1>: