<center><img src="./img/img_0.PNG"  width="1000" height="240"/></center>


This notebook demonstrate an apllication of SimMTM, a simple self-supervised learning framework for time series modeling. Self-supervised learning is a learning paradigm that allows model to learn a good representation from the input data itself. The learned representation will be beneficial to some downstream tasks such as forecasting, classification and outlier detection. 

Self-supervised learning has a lof of success and achieves state-of-the-art performance in some domains, especially in the image domain. In this demo, we will show a self-supervised learning method, SimMTM, in the time-series domain. SimMTM adopts both masked modeling and contrastive modeling to learn a good representation of the input data. By using the learned representation and finetuning it, we achieve a significant improvement compared to the model without self-supervised learning. 

In [1]:
import os


print(os.getcwd())

/home/shamvinc/ssl_time_series/SSL-Bootcamp/masked_modelling


In [2]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload


In [3]:
import copy
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch import Tensor, nn
from torch.nn.modules import BatchNorm1d, Dropout, Linear, MultiheadAttention
from tqdm import tqdm

# Masked Modeling

Self-supervision via a ‘pretext task’ on input data combined with finetuning on labeled data is widely used for improving model performance in language and computer
vision. One of the popular self-supervision tasks on language data is masked modeling. Masking modeling is to mask some of the input entries randomly and predict those masked entries by using unmasked entries. By masked modeling, the model can learn the relationship through different features and different timesteps. 

<img src="./img/img_1.PNG"  width="900" height="240"/>

<img src="./img/img_2.PNG"  width="900" height="240"/>

# Masking Choice
### Random Masking

Random Masking is not a good choice to learn a good representation because the model can simply learn to take the average from the neighbour values. 



<img src="./img/img_3.PNG" width="600"/>

### Geometric Masking

Instead, we choose to use the geometric masking method, which is to mask a sequence of the input data randomly. The length of the sequence is followed by a geometric distribution. In this case, the model requires to recover a masked sequence from other unmasked input data. We suggest the expected length of a masked sequence is a half of the whole time series sequence.

In [4]:
def geom_noise_mask_single(L: int, lm: int, masking_ratio: float) -> Tensor:
    """
    Randomly create a boolean mask of length `L`, consisting of subsequences of average length lm, masking with 0s a `masking_ratio`
    proportion of the sequence L. The length of masking subsequences and intervals follow a geometric distribution.

    Args:
    ----
        L: length of mask and sequence to be masked
        lm: average length of masking subsequences (streaks of 0s)
        masking_ratio: proportion of L to be masked.

    Returns:
    -------
        (L,) boolean numpy array intended to mask ('drop') with 0s a sequence of length L
    """
    keep_mask = np.ones(L, dtype=bool)
    p_m = 1 / lm  # probability of each masking sequence stopping. parameter of geometric distribution.
    p_u = (
        p_m * masking_ratio / (1 - masking_ratio)
    )  # probability of each unmasked sequence stopping. parameter of geometric distribution.
    p = [p_m, p_u]

    # Start in state 0 with masking_ratio probability
    state = int(np.random.rand() > masking_ratio)  # state 0 means masking, 1 means not masking
    for i in range(L):
        keep_mask[i] = state  # here it happens that state and masking value corresponding to state are identical
        if np.random.rand() < p[state]:
            state = 1 - state

    return keep_mask

# SimMTM ultilizes both contrastive learning and mask modeling to learn the data representation.
## 1 - Contrastive Learning

when we mask the input time series data, we create many masked views of the input data. We expect that the distance between two views of the same time series sequence is minimized while maximizing the distance between two different sequences.

<img src="./img/img_5.png"/>

## The contrastive loss is the following: (Eq. 8 in the paper)

<center><img src="./img/img_6.PNG"/><center/>

In [5]:
def demo_contrastive_loss(s: Tensor, batch_size: int, tau: float = 0.05) -> Tensor:
    s = s.squeeze(-1)

    B = s.shape[0]
    v = s.reshape(B, -1)

    norm_v = torch.norm(v, p=2, dim=-1).unsqueeze(-1)
    v = v / norm_v
    u = torch.transpose(v, 0, 1)

    R = torch.matmul(v, u)

    R = torch.exp(R / tau)  # (batch + mask size) x (batch + mask size)

    # number of masks
    M = B // batch_size
    mask = torch.eye(batch_size, device=R.device).repeat_interleave(M, dim=0).repeat_interleave(M, dim=1)

    denom = R * (torch.ones_like(R) - torch.eye(R.shape[0], device=R.device))

    denom = R.sum(-1).unsqueeze(-1)

    loss = torch.log(R / denom)

    loss = (loss * (mask - torch.eye(R.shape[0], device=R.device))).sum(1) / (M - 1)  # except no masked unit
    loss = loss.mean(0)

    return -loss

## 2 - Masked Modeling

SimMTM proposes to recover a time serie by the weighted sum of multiple masked points, which eases the reconstruction task by assembling ruined but complementary temporal variations.

<img src="./img/img_4.png"/>

## Model Components

In [6]:
class LearnablePositionalEncoding(nn.Module):
    """Learnable Positional Encoding.

    Args:
    ----
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=1024).

    """

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1024) -> None:
        """Init of LearnablePositionalEncoding."""
        super(LearnablePositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        # Each position gets its own embedding
        # Since indices are always 0 ... max_len, we don't have to do a look-up
        self.pe = nn.Parameter(torch.empty(max_len, 1, d_model))  # requires_grad automatically set to True
        nn.init.uniform_(self.pe, -0.02, 0.02)

    def forward(self, x: Tensor) -> Tensor:
        """Forward function of LearnablePositionalEncoding.

        Args:
        ----
        x: The sequence fed to the positional encoder model.

        """
        x = x + self.pe[: x.size(0), :]
        return self.dropout(x)


class TransformerBatchNormEncoderLayer(nn.modules.Module):
    """Transformer encoder layer block.

    Args:
    ----
        d_model: the number of expected features in the input.
        nhead: the number of heads in the multiheadattention models.
        dim_feedforward: the dimension of the feedforward network model.
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer.
    """

    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: str = "gelu",
    ) -> None:
        """Init of TransformerBatchNormEncoderLayer."""
        super(TransformerBatchNormEncoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm1 = BatchNorm1d(d_model, eps=1e-5)  # normalizes each feature across batch samples and time steps
        self.norm2 = BatchNorm1d(d_model, eps=1e-5)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        self.activation = F.gelu

    def __setstate__(self, state: dict) -> None:
        """Set state for batch statistics."""
        if "activation" not in state:
            state["activation"] = F.relu
        super(TransformerBatchNormEncoderLayer, self).__setstate__(state)

    def forward(
        self,
        src: Tensor,
        src_mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
    ) -> Tensor:
        """Pass the input through the encoder layer.

        Args:
        ----
        src:
            The sequence to the encoder layer (required).
        src_mask:
            The mask for the src sequence (optional).
        src_key_padding_mask:
            The mask for the src keys per batch (optional).

        """
        src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)  # (seq_len, batch_size, d_model)
        src = src.permute(1, 2, 0)  # (batch_size, d_model, seq_len)
        # src = src.reshape([src.shape[0], -1])  # (batch_size, seq_length * d_model)
        src = self.norm1(src)
        src = src.permute(2, 0, 1)  # restore (seq_len, batch_size, d_model)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)  # (seq_len, batch_size, d_model)
        src = src.permute(1, 2, 0)  # (batch_size, d_model, seq_len)
        src = self.norm2(src)
        src = src.permute(2, 0, 1)  # restore (seq_len, batch_size, d_model)
        return src

## Transformer with SimMTM

In [7]:
class DemoSimMTMTransformerEncoder(nn.Module):
    r"""SimMTM Transformer
    Args:
        max_len: max input sequence length.
        feat_dim: input feature dimensions.
        out_len: output sequence length.
        out_dim: output feature dimensions.
        d_model: representation dimensions.
        n_heads: number of transformer heads.
        num_layers: number of transformer layers.
        dim_feedforward: hidden layer dimensions.
        dropout: dropout rate.
        temporal_unit: default number of masked views.
    """

    def __init__(
        self,
        max_len: int,
        feat_dim: int,
        out_len: int,
        out_dim: int,
        d_model: int = 16,
        n_heads: int = 4,
        num_layers: int = 2,
        dim_feedforward: int = 32,
        dropout: float = 0.2,
        temporal_unit: int = 3,
    ) -> None:
        super(DemoSimMTMTransformerEncoder, self).__init__()

        self.max_len = max_len
        self.d_model = d_model
        self.n_heads = n_heads

        self.tau = 0.05
        self.mask_length = max_len // 2
        self.mask_rate = 0.5

        self.project_inp = nn.Linear(feat_dim, d_model)
        self.projector_layer = nn.Linear(max_len, 1)
        self.pos_enc1 = LearnablePositionalEncoding(d_model, dropout=dropout, max_len=max_len)
        self.pos_enc2 = LearnablePositionalEncoding(d_model, dropout=dropout, max_len=out_len)

        self.act = F.gelu

        # encoder_layer = nn.TransformerEncoderLayer(d_model, self.n_heads, dim_feedforward, dropout, activation='gelu')
        encoder_layer = TransformerBatchNormEncoderLayer(
            d_model, self.n_heads, dim_feedforward, dropout, activation="gelu"
        )

        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)

        self.output_layer = nn.Linear(d_model, feat_dim)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout1d(dropout)

        # self.predict_layer1 = nn.Conv1d(d_model, 512, 5, stride=1)
        self.predict_layer1 = nn.Linear(max_len, out_len)
        self.predict_layer2 = nn.Linear(d_model, out_dim)
        # self.bn = nn.BatchNorm1d(d_model)

        self.feat_dim = feat_dim

        self.temporal_unit = temporal_unit

        self.w1 = torch.nn.parameter.Parameter(data=torch.ones(1), requires_grad=True)
        self.w2 = torch.nn.parameter.Parameter(data=torch.ones(1), requires_grad=True)

    def forward(self, X: Tensor, N: Optional[int] = None) -> Tuple[Tensor, Tensor]:
        """
        Reconstruct the input and create the projected output of X.

        Args:
        ----
            X: (batch_size, seq_length, feat_dim) torch tensor of original input

        Returns:
        -------
            output: (batch_size, seq_length, feat_dim)
            s: (batch_size, d_model, 1)
        """
        _x = X

        # Create masked views of the input X
        if N is None:
            N = self.temporal_unit
        for _i in range(N):
            mask = geom_noise_mask_single(X.shape[0] * X.shape[1] * X.shape[2], self.mask_length, self.mask_rate)
            mask = mask.reshape(X.shape[0], X.shape[1], X.shape[2])
            mask = torch.from_numpy(mask).to(X.device)
            x_masked = mask * X
            _x = torch.cat([_x, x_masked], axis=-1)  # [batch_size, seq_length, feat_dim * temporal_unit]

        _x = _x.reshape(X.shape[0] * (N + 1), X.shape[1], X.shape[2])

        inp = _x.permute(1, 0, 2)
        inp = self.project_inp(inp) * np.sqrt(
            self.d_model
        )  # [seq_length, batch_size, d_model] project input vectors to d_model dimensional space
        inp = self.pos_enc1(inp)  # add positional encoding

        output = self.transformer_encoder(inp)  # (seq_length, batch_size, d_model)
        output = self.act(output)  # the output transformer encoder/decoder embeddings don't include non-linearity
        output = output.permute(1, 0, 2)  # (batch_size, seq_length, d_model)
        output = self.dropout1(output)

        z_hat, _s = self.project(output, self.tau, N)
        # Most probably defining a Linear(d_model,feat_dim) vectorizes the operation over (seq_length, batch_size).
        output = self.output_layer(z_hat)  # (batch_size, seq_length, feat_dim)

        return output, _s

    def project(self, z: Tensor, tau: float, N: int) -> Tuple[Tensor, Tensor]:
        """
        Output a weighted average of z.

        Args:
        ----
            X: (batch_size, seq_length, feat_dim) torch tensor of original input

        Returns:
        -------
            z_hat: (batch_size, seq_length, d_model)
            s: (batch_size, d_model, 1)
        """
        _z = z.transpose(1, 2)  # [batch_size, d_model, seq_length]
        _s = s = self.projector_layer(_z)  # [batch_size, d_model, 1]

        if self.training:
            mask = torch.ones(1, self.d_model, 1).to(z.device)
            mask = self.dropout3(mask)
            s = s * mask
            s = s + torch.randn(s.shape).to(z.device) * 1e-2

        s = s.squeeze(-1)
        B = s.shape[0]
        v = s.reshape(B, -1)

        norm_v = torch.norm(v, p=2, dim=-1).unsqueeze(-1)
        v = v / norm_v
        u = torch.transpose(v, 0, 1)

        R = torch.matmul(v, u)

        R = torch.exp(R / tau)  # (batch + mask size) x (batch + mask size)
        R = R * (
            torch.ones_like(R) - torch.eye(R.shape[0], device=R.device)
        )  # zero out the weight of no masked component
        R = R / R.sum(-1).unsqueeze(-1)
        M = N + 1
        R = R[::M]  # extract every no mask unit # (batch size) x (batch + mask size)

        z_hat = (R.unsqueeze(-1).unsqueeze(-1).detach() * z.unsqueeze(0)).sum(1)
        return z_hat, _s

    def predict(self, X: Tensor) -> Tensor:
        """
        Predict an output given X.

        Args:
        ----
            z: (batch_size, seq_length, d_model) torch tensor of representations of input
            tau: temperture of similarity matrix

        Returns:
        -------
            output: (batch_size, out_seq_len, out_dim)
        """
        # permute because pytorch convention for transformers is [seq_length, batch_size, feat_dim]. padding_masks [batch_size, feat_dim]
        inp = X.permute(1, 0, 2)
        inp = self.project_inp(inp) * np.sqrt(
            self.d_model
        )  # [seq_length, batch_size, d_model] project input vectors to d_model dimensional space
        inp = self.pos_enc1(inp)  # add positional encoding
        # NOTE: logic for padding masks is reversed to comply with definition in MultiHeadAttention, TransformerEncoderLayer

        output = self.transformer_encoder(inp)
        output = output.permute(1, 0, 2)  # (batch_size, seq_length, d_model)
        # output = self.dropout1(output)

        output = output.transpose(1, 2)  # (batch_size, d_model, seq_length)
        output = self.predict_layer1(output)
        # output = self.act(output)

        output = output.transpose(1, 2)  # (batch_size, seq_length, d_model)
        output = output.permute(1, 0, 2)
        output = self.pos_enc2(output)
        output = output.permute(1, 0, 2)
        output = self.dropout2(output)
        output = self.predict_layer2(output)
        return output

# Data Loading and Preparation

In this demo, we use a benchmask time series dataset called BeijingPM25Quality.
This dataset is part of the Monash, UEA & UCR time series regression repository. http://tseregression.org/

The goal of this dataset is to predict PM2.5 air quality in the city of Beijing. This dataset contains 17532 time series with 9 dimensions.  This includes hourly air pollutants measurments (SO2, NO2, CO and O3), temperature, pressure, dew point, rainfall and windspeed measurments from 12 nationally controlled air quality monitoring sites. The air-quality data are from the Beijing Municipal Environmental Monitoring Center. The meteorological data in each air-quality site are matched with the nearest weather station from the China Meteorological Administration. The time period is from March 1st, 2013 to February 28th, 2017. 

In [8]:
data = pd.read_csv("./datasets/BeijingPM25Quality/train_x.csv", index_col=0)
test_data = pd.read_csv("./datasets/BeijingPM25Quality/test_x.csv", index_col=0)
_data = data

In [9]:
data

Unnamed: 0,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8
0,4.0,7.0,300.0,77.0,-0.7,1023.0,-18.8,0.0,4.4
0,4.0,7.0,300.0,77.0,-1.1,1023.2,-18.2,0.0,4.7
0,5.0,10.0,300.0,73.0,-1.1,1023.5,-18.2,0.0,5.6
0,11.0,11.0,300.0,72.0,-1.4,1024.5,-19.4,0.0,3.1
0,12.0,12.0,300.0,72.0,-2.0,1025.2,-19.5,0.0,2.0
...,...,...,...,...,...,...,...,...,...
11917,27.0,96.0,3300.0,9.0,-1.4,1026.3,-8.6,0.0,1.0
11917,34.0,99.0,3700.0,9.0,-2.5,1026.2,-8.4,0.0,1.3
11917,31.0,95.0,3100.0,9.0,-2.7,1025.8,-8.0,0.0,0.9
11917,40.0,99.0,4200.0,13.0,-3.5,1025.5,-7.6,0.0,0.4


In [10]:
# Standard Normalization
normalizer = StandardScaler()
data[:] = normalizer.fit_transform(data)
test_data[:] = normalizer.transform(test_data)

In [11]:
data

Unnamed: 0,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8
0,-0.578723,-1.262339,-0.827454,0.331285,-1.354746,1.282597,-1.671897,-0.077887,2.197101
0,-0.578723,-1.262339,-0.827454,0.331285,-1.390913,1.302251,-1.626944,-0.077887,2.439034
0,-0.536462,-1.177416,-0.827454,0.261866,-1.390913,1.331731,-1.626944,-0.077887,3.164833
0,-0.282896,-1.149109,-0.827454,0.244512,-1.418038,1.430000,-1.716850,-0.077887,1.148725
0,-0.240636,-1.120801,-0.827454,0.244512,-1.472287,1.498788,-1.724342,-0.077887,0.261637
...,...,...,...,...,...,...,...,...,...
11917,0.393278,1.257031,1.802584,-0.848838,-1.418038,1.606883,-0.907691,-0.077887,-0.544806
11917,0.689105,1.341954,2.153256,-0.848838,-1.517495,1.597057,-0.892707,-0.077887,-0.302873
11917,0.562322,1.228724,1.627248,-0.848838,-1.535578,1.557749,-0.862738,-0.077887,-0.625450
11917,0.942670,1.341954,2.591596,-0.779419,-1.607911,1.528269,-0.832769,-0.077887,-1.028672


In [12]:
max_len = 24
out_size = 6
out_dim = 9

model = DemoSimMTMTransformerEncoder(
    max_len=18,
    feat_dim=data.shape[1],
    out_len=out_size,
    out_dim=out_dim,
    d_model=64,
    n_heads=4,
    num_layers=1,
    dim_feedforward=64,
)

device = "cuda"
model.to(device)
model.tau = 0.05

model.mask_length = max_len // 2
model.mask_ratio = 0.5
model.mask_views = 3

model.contrastive_views = 2
init_model = copy.deepcopy(model)

In [13]:
optimizer_finetune = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [14]:
from torch.utils.data import DataLoader


batch_size = 64


train_indices, val_indices = train_test_split(np.array(data.index.unique()), test_size=0.2)
test_indices = np.array(test_data.index.unique())

train_dataloader = DataLoader(train_indices, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_indices, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_indices, batch_size=batch_size, shuffle=True)

print(train_indices.shape, val_indices.shape, test_indices.shape, np.array(data.index.unique()).shape)

(9534,) (2384,) (5048,) (11918,)


# Self-Supervised Learning Training Loop

In [15]:
i: int = 0
max_epoch: int = 50
best_loss: float = 1.0e10
best_epoch: int = 0
device = "cuda"
loss_fn = nn.MSELoss()
best_model = copy.deepcopy(model)


while i < max_epoch:
    train_loss: Dict[str, List[float]] = {"loss": [], "loss_mse": [], "loss_con": []}
    progress_bar = tqdm(train_dataloader)

    for IDs in progress_bar:
        model.train()
        X = torch.tensor(data.loc[IDs].to_numpy()).to(device)
        X = X.float()

        X = X.reshape(-1, max_len, X.shape[-1])
        X = X[:, :18, :]

        #         _X = []
        #         idx = torch.randint(low=0, high=6, size=(X.shape[0],))
        #         for j in range(18):
        #             _X.append(X[torch.arange(X.shape[0]),idx+j,:].unsqueeze(1))

        #         X = torch.cat(_X,dim=1)

        # X = X[:, :, -1:]

        pred, _ = model(X, 3)  # (batch_size, padded_length, feat_dim)

        loss_mse = loss_fn(pred, X)

        _, s = model(X, 1)  # (batch_size, padded_length, feat_dim)
        loss_con = demo_contrastive_loss(s, X.shape[0])

        loss = 0.1 * loss_mse + loss_con

        optimizer_finetune.zero_grad()
        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=4.0)
        optimizer_finetune.step()
        # import ipdb; ipdb.set_trace()
        progress_bar.set_description(
            "Epoch {0} - Training loss: {1:.2f} - MSE loss: {2:.2f} - Contrastive loss: {3:.2f}".format(
                i,
                loss.cpu().detach().numpy().item(),
                loss_mse.cpu().detach().numpy().item(),
                loss_con.cpu().detach().numpy().item(),
            )
        )
        train_loss["loss"].append(loss.item())
        train_loss["loss_mse"].append(loss_mse.item())
        train_loss["loss_con"].append(loss_con.item())

    with torch.no_grad():
        val_loss: Dict[str, List[float]] = {"loss": [], "loss_mse": [], "loss_con": []}
        for IDs in val_dataloader:
            model.eval()
            X = torch.tensor(data.loc[IDs].to_numpy()).to(device)
            X = X.float()
            X = X.reshape(-1, max_len, X.shape[-1])
            X = X[:, :18, :]
            # X = X[:, :, -1:]

            #             _X = []
            #             idx = torch.randint(low=0, high=6, size=(X.shape[0],))
            #             for j in range(18):
            #                 _X.append(X[torch.arange(X.shape[0]),idx+j,:].unsqueeze(1))

            #             X = torch.cat(_X,dim=1)

            # X = X[:, :, -1:]

            pred, _ = model(X, 3)  # (batch_size, padded_length, feat_dim)

            loss_mse = loss_fn(pred, X)

            _, s = model(X, 1)  # (batch_size, padded_length, feat_dim)
            loss_con = demo_contrastive_loss(s, X.shape[0])

            loss = 0.1 * loss_mse + loss_con

            val_loss["loss"].append(loss.item())
            val_loss["loss_mse"].append(loss_mse.item())
            val_loss["loss_con"].append(loss_con.item())

        if torch.tensor(val_loss["loss"]).mean().item() < best_loss:
            best_loss = torch.tensor(val_loss["loss"]).mean()
            best_model = copy.deepcopy(model)
            best_epoch = i

        progress_bar.write(
            "Epoch {0} - Training loss: {1:.2f} {2:.2f} {3:.2f} - Validation loss: {4:.2f} {5:.2f} {6:.2f}".format(
                i,
                torch.tensor(train_loss["loss"]).mean().item(),
                torch.tensor(train_loss["loss_mse"]).mean().item(),
                torch.tensor(train_loss["loss_con"]).mean().item(),
                torch.tensor(val_loss["loss"]).mean().item(),
                torch.tensor(val_loss["loss_mse"]).mean().item(),
                torch.tensor(val_loss["loss_con"]).mean().item(),
            )
        )
    i += 1


tqdm.write("Best Epoch {} - Best Validation loss: {}".format(best_epoch, best_loss))

Epoch 0 - Training loss: 2.86 - MSE loss: 0.54 - Contrastive loss: 2.81: 100%|█████████████████████████████████████████████████| 149/149 [00:08<00:00, 18.13it/s]


Epoch 0 - Training loss: 5.66 0.66 5.59 - Validation loss: 2.51 0.59 2.45


Epoch 1 - Training loss: 2.61 - MSE loss: 0.47 - Contrastive loss: 2.56: 100%|█████████████████████████████████████████████████| 149/149 [00:07<00:00, 19.45it/s]


Epoch 1 - Training loss: 2.71 0.54 2.66 - Validation loss: 2.21 0.53 2.16


Epoch 2 - Training loss: 2.57 - MSE loss: 0.47 - Contrastive loss: 2.52: 100%|█████████████████████████████████████████████████| 149/149 [00:07<00:00, 18.78it/s]


Epoch 2 - Training loss: 2.48 0.51 2.43 - Validation loss: 2.04 0.51 1.99


Epoch 3 - Training loss: 2.16 - MSE loss: 0.41 - Contrastive loss: 2.12: 100%|█████████████████████████████████████████████████| 149/149 [00:06<00:00, 22.08it/s]


Epoch 3 - Training loss: 2.34 0.50 2.29 - Validation loss: 1.90 0.50 1.85


Epoch 4 - Training loss: 2.13 - MSE loss: 0.45 - Contrastive loss: 2.08: 100%|█████████████████████████████████████████████████| 149/149 [00:07<00:00, 21.10it/s]


Epoch 4 - Training loss: 2.24 0.49 2.19 - Validation loss: 1.81 0.51 1.76


Epoch 5 - Training loss: 2.07 - MSE loss: 0.43 - Contrastive loss: 2.02: 100%|█████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.20it/s]


Epoch 5 - Training loss: 2.16 0.49 2.11 - Validation loss: 1.71 0.48 1.66


Epoch 6 - Training loss: 2.01 - MSE loss: 0.36 - Contrastive loss: 1.98: 100%|█████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.19it/s]


Epoch 6 - Training loss: 2.10 0.49 2.05 - Validation loss: 1.65 0.48 1.60


Epoch 7 - Training loss: 1.81 - MSE loss: 0.35 - Contrastive loss: 1.78: 100%|█████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.84it/s]


Epoch 7 - Training loss: 2.04 0.49 1.99 - Validation loss: 1.61 0.48 1.57


Epoch 8 - Training loss: 2.10 - MSE loss: 0.42 - Contrastive loss: 2.05: 100%|█████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.71it/s]


Epoch 8 - Training loss: 2.01 0.49 1.96 - Validation loss: 1.57 0.48 1.53


Epoch 9 - Training loss: 1.85 - MSE loss: 0.57 - Contrastive loss: 1.80: 100%|█████████████████████████████████████████████████| 149/149 [00:06<00:00, 22.06it/s]


Epoch 9 - Training loss: 1.97 0.48 1.92 - Validation loss: 1.54 0.51 1.49


Epoch 10 - Training loss: 2.08 - MSE loss: 0.52 - Contrastive loss: 2.03: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.94it/s]


Epoch 10 - Training loss: 1.94 0.48 1.89 - Validation loss: 1.50 0.47 1.45


Epoch 11 - Training loss: 1.85 - MSE loss: 0.43 - Contrastive loss: 1.81: 100%|████████████████████████████████████████████████| 149/149 [00:06<00:00, 21.90it/s]


Epoch 11 - Training loss: 1.93 0.48 1.88 - Validation loss: 1.51 0.47 1.46


Epoch 12 - Training loss: 1.72 - MSE loss: 0.43 - Contrastive loss: 1.68: 100%|████████████████████████████████████████████████| 149/149 [00:06<00:00, 23.44it/s]


Epoch 12 - Training loss: 1.90 0.48 1.86 - Validation loss: 1.49 0.47 1.44


Epoch 13 - Training loss: 2.06 - MSE loss: 0.63 - Contrastive loss: 2.00: 100%|████████████████████████████████████████████████| 149/149 [00:06<00:00, 22.18it/s]


Epoch 13 - Training loss: 1.89 0.48 1.84 - Validation loss: 1.49 0.47 1.44


Epoch 14 - Training loss: 1.78 - MSE loss: 0.36 - Contrastive loss: 1.74: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.38it/s]


Epoch 14 - Training loss: 1.86 0.48 1.81 - Validation loss: 1.47 0.47 1.42


Epoch 15 - Training loss: 1.74 - MSE loss: 0.44 - Contrastive loss: 1.70: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 19.63it/s]


Epoch 15 - Training loss: 1.84 0.48 1.79 - Validation loss: 1.46 0.47 1.41


Epoch 16 - Training loss: 1.72 - MSE loss: 0.42 - Contrastive loss: 1.68: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 19.39it/s]


Epoch 16 - Training loss: 1.84 0.48 1.79 - Validation loss: 1.45 0.46 1.40


Epoch 17 - Training loss: 1.82 - MSE loss: 0.49 - Contrastive loss: 1.77: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.31it/s]


Epoch 17 - Training loss: 1.82 0.48 1.78 - Validation loss: 1.44 0.46 1.39


Epoch 18 - Training loss: 1.69 - MSE loss: 0.40 - Contrastive loss: 1.65: 100%|████████████████████████████████████████████████| 149/149 [00:06<00:00, 23.02it/s]


Epoch 18 - Training loss: 1.80 0.48 1.75 - Validation loss: 1.42 0.46 1.38


Epoch 19 - Training loss: 1.70 - MSE loss: 0.33 - Contrastive loss: 1.66: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.38it/s]


Epoch 19 - Training loss: 1.79 0.48 1.75 - Validation loss: 1.41 0.46 1.37


Epoch 20 - Training loss: 1.84 - MSE loss: 0.43 - Contrastive loss: 1.80: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.70it/s]


Epoch 20 - Training loss: 1.78 0.48 1.74 - Validation loss: 1.42 0.46 1.37


Epoch 21 - Training loss: 1.76 - MSE loss: 0.43 - Contrastive loss: 1.72: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.15it/s]


Epoch 21 - Training loss: 1.78 0.48 1.73 - Validation loss: 1.43 0.46 1.38


Epoch 22 - Training loss: 1.71 - MSE loss: 0.44 - Contrastive loss: 1.66: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 19.63it/s]


Epoch 22 - Training loss: 1.77 0.48 1.72 - Validation loss: 1.42 0.46 1.37


Epoch 23 - Training loss: 1.75 - MSE loss: 0.42 - Contrastive loss: 1.71: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.83it/s]


Epoch 23 - Training loss: 1.77 0.48 1.72 - Validation loss: 1.41 0.46 1.36


Epoch 24 - Training loss: 1.61 - MSE loss: 0.58 - Contrastive loss: 1.56: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.86it/s]


Epoch 24 - Training loss: 1.76 0.48 1.71 - Validation loss: 1.39 0.46 1.35


Epoch 25 - Training loss: 1.78 - MSE loss: 0.67 - Contrastive loss: 1.71: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.90it/s]


Epoch 25 - Training loss: 1.75 0.48 1.71 - Validation loss: 1.39 0.47 1.35


Epoch 26 - Training loss: 1.70 - MSE loss: 0.36 - Contrastive loss: 1.67: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 21.17it/s]


Epoch 26 - Training loss: 1.75 0.48 1.70 - Validation loss: 1.39 0.46 1.35


Epoch 27 - Training loss: 1.76 - MSE loss: 0.43 - Contrastive loss: 1.71: 100%|████████████████████████████████████████████████| 149/149 [00:06<00:00, 21.34it/s]


Epoch 27 - Training loss: 1.73 0.48 1.69 - Validation loss: 1.38 0.46 1.34


Epoch 28 - Training loss: 1.68 - MSE loss: 0.43 - Contrastive loss: 1.64: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.92it/s]


Epoch 28 - Training loss: 1.74 0.48 1.69 - Validation loss: 1.39 0.46 1.34


Epoch 29 - Training loss: 1.84 - MSE loss: 0.57 - Contrastive loss: 1.78: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.57it/s]


Epoch 29 - Training loss: 1.72 0.48 1.68 - Validation loss: 1.36 0.45 1.32


Epoch 30 - Training loss: 1.67 - MSE loss: 0.50 - Contrastive loss: 1.62: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.62it/s]


Epoch 30 - Training loss: 1.73 0.48 1.68 - Validation loss: 1.38 0.45 1.33


Epoch 31 - Training loss: 1.75 - MSE loss: 0.47 - Contrastive loss: 1.70: 100%|████████████████████████████████████████████████| 149/149 [00:06<00:00, 21.33it/s]


Epoch 31 - Training loss: 1.72 0.48 1.68 - Validation loss: 1.37 0.46 1.33


Epoch 32 - Training loss: 1.72 - MSE loss: 0.61 - Contrastive loss: 1.66: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.10it/s]


Epoch 32 - Training loss: 1.73 0.48 1.68 - Validation loss: 1.37 0.45 1.33


Epoch 33 - Training loss: 1.81 - MSE loss: 0.70 - Contrastive loss: 1.74: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 19.75it/s]


Epoch 33 - Training loss: 1.72 0.48 1.67 - Validation loss: 1.38 0.46 1.34


Epoch 34 - Training loss: 1.66 - MSE loss: 1.15 - Contrastive loss: 1.55: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 19.63it/s]


Epoch 34 - Training loss: 1.72 0.48 1.67 - Validation loss: 1.37 0.46 1.32


Epoch 35 - Training loss: 1.69 - MSE loss: 0.46 - Contrastive loss: 1.64: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 19.87it/s]


Epoch 35 - Training loss: 1.70 0.48 1.66 - Validation loss: 1.37 0.46 1.33


Epoch 36 - Training loss: 1.87 - MSE loss: 0.49 - Contrastive loss: 1.82: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 21.19it/s]


Epoch 36 - Training loss: 1.71 0.48 1.66 - Validation loss: 1.37 0.46 1.32


Epoch 37 - Training loss: 1.73 - MSE loss: 0.41 - Contrastive loss: 1.69: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.15it/s]


Epoch 37 - Training loss: 1.70 0.48 1.65 - Validation loss: 1.39 0.45 1.35


Epoch 38 - Training loss: 1.82 - MSE loss: 0.48 - Contrastive loss: 1.77: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 21.16it/s]


Epoch 38 - Training loss: 1.70 0.47 1.65 - Validation loss: 1.35 0.45 1.31


Epoch 39 - Training loss: 1.76 - MSE loss: 0.43 - Contrastive loss: 1.72: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.26it/s]


Epoch 39 - Training loss: 1.70 0.48 1.65 - Validation loss: 1.36 0.45 1.32


Epoch 40 - Training loss: 1.85 - MSE loss: 1.27 - Contrastive loss: 1.73: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.68it/s]


Epoch 40 - Training loss: 1.69 0.48 1.65 - Validation loss: 1.36 0.45 1.32


Epoch 41 - Training loss: 1.72 - MSE loss: 0.44 - Contrastive loss: 1.68: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 19.73it/s]


Epoch 41 - Training loss: 1.69 0.47 1.65 - Validation loss: 1.35 0.45 1.31


Epoch 42 - Training loss: 1.69 - MSE loss: 0.44 - Contrastive loss: 1.65: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.04it/s]


Epoch 42 - Training loss: 1.69 0.47 1.64 - Validation loss: 1.36 0.45 1.31


Epoch 43 - Training loss: 1.68 - MSE loss: 0.40 - Contrastive loss: 1.64: 100%|████████████████████████████████████████████████| 149/149 [00:06<00:00, 22.36it/s]


Epoch 43 - Training loss: 1.69 0.47 1.64 - Validation loss: 1.35 0.45 1.31


Epoch 44 - Training loss: 1.68 - MSE loss: 0.43 - Contrastive loss: 1.63: 100%|████████████████████████████████████████████████| 149/149 [00:06<00:00, 22.79it/s]


Epoch 44 - Training loss: 1.68 0.47 1.64 - Validation loss: 1.34 0.45 1.30


Epoch 45 - Training loss: 1.60 - MSE loss: 0.39 - Contrastive loss: 1.56: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.98it/s]


Epoch 45 - Training loss: 1.69 0.47 1.64 - Validation loss: 1.35 0.45 1.31


Epoch 46 - Training loss: 1.53 - MSE loss: 0.38 - Contrastive loss: 1.50: 100%|████████████████████████████████████████████████| 149/149 [00:06<00:00, 21.63it/s]


Epoch 46 - Training loss: 1.68 0.47 1.64 - Validation loss: 1.36 0.45 1.31


Epoch 47 - Training loss: 1.70 - MSE loss: 0.35 - Contrastive loss: 1.66: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 20.43it/s]


Epoch 47 - Training loss: 1.68 0.47 1.63 - Validation loss: 1.34 0.45 1.30


Epoch 48 - Training loss: 1.71 - MSE loss: 0.41 - Contrastive loss: 1.67: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 19.47it/s]


Epoch 48 - Training loss: 1.68 0.47 1.63 - Validation loss: 1.34 0.45 1.30


Epoch 49 - Training loss: 1.81 - MSE loss: 0.56 - Contrastive loss: 1.75: 100%|████████████████████████████████████████████████| 149/149 [00:07<00:00, 19.74it/s]


Epoch 49 - Training loss: 1.68 0.47 1.64 - Validation loss: 1.35 0.46 1.31
Best Epoch 48 - Best Validation loss: 1.3419822454452515


# Finetune Training Loop



In [16]:
finetune_model = copy.deepcopy(best_model)
optimizer = torch.optim.AdamW(finetune_model.parameters(), lr=1e-3)

batch_size = 64
train_dataloader = DataLoader(train_indices, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_indices, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_indices, batch_size=batch_size, shuffle=True)

In [17]:
i = 0
max_epoch = 50
best_loss = 1.0e10
best_finetune_model = copy.deepcopy(best_model)
best_epoch = 0
device = "cuda"
finetune_model.to(device)
while i < max_epoch:
    ft_train_loss = []
    progress_bar = tqdm(train_dataloader)

    for IDs in progress_bar:
        finetune_model.train()

        X = torch.tensor(data.loc[IDs].to_numpy()).to(device)
        X = X.float().reshape(-1, max_len, X.shape[-1])
        targets = X[:, 18:24, :]
        X = X[:, :18, :]

        pred = finetune_model.predict(X)
        pred = pred.reshape(X.shape[0], out_size, -1)
        loss = loss_fn(pred, targets)

        optimizer.zero_grad()
        loss.backward()

        nn.utils.clip_grad_norm_(finetune_model.parameters(), max_norm=4.0)
        optimizer.step()

        progress_bar.set_description("Epoch {} - Training loss: {:.2f}".format(i, loss))
        ft_train_loss.append(loss.item())

    with torch.no_grad():
        ft_val_loss = []
        for IDs in val_dataloader:
            finetune_model.eval()

            X = torch.tensor(data.loc[IDs].to_numpy()).to(device)
            X = X.float().reshape(-1, max_len, X.shape[-1])
            targets = X[:, 18:24, :]
            X = X[:, :18, :]

            pred = finetune_model.predict(X.float())
            pred = pred.reshape(X.shape[0], out_size, -1)
            loss = loss_fn(pred, targets)

            ft_val_loss.append(loss.item())

        if torch.tensor(ft_val_loss).mean().item() < best_loss:
            best_loss = torch.tensor(ft_val_loss).mean().item()
            best_finetune_model = copy.deepcopy(finetune_model)
            best_epoch = i

    progress_bar.write(
        "Epoch {} - Training loss: {:.2f} - Validation loss: {:.2f}".format(
            i, torch.tensor(ft_train_loss).mean().item(), torch.tensor(ft_val_loss).mean().item()
        )
    )
    i += 1


tqdm.write("Best Epoch {} - Best Validation loss: {}".format(best_epoch, best_loss))

Epoch 0 - Training loss: 0.31: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 39.65it/s]


Epoch 0 - Training loss: 0.70 - Validation loss: 0.48


Epoch 1 - Training loss: 0.32: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.76it/s]


Epoch 1 - Training loss: 0.51 - Validation loss: 0.45


Epoch 2 - Training loss: 0.61: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.01it/s]


Epoch 2 - Training loss: 0.48 - Validation loss: 0.44


Epoch 3 - Training loss: 0.26: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.19it/s]


Epoch 3 - Training loss: 0.46 - Validation loss: 0.43


Epoch 4 - Training loss: 0.40: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 41.00it/s]


Epoch 4 - Training loss: 0.45 - Validation loss: 0.42


Epoch 5 - Training loss: 0.43: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 43.06it/s]


Epoch 5 - Training loss: 0.45 - Validation loss: 0.41


Epoch 6 - Training loss: 0.24: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.67it/s]


Epoch 6 - Training loss: 0.44 - Validation loss: 0.42


Epoch 7 - Training loss: 0.37: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 44.47it/s]


Epoch 7 - Training loss: 0.44 - Validation loss: 0.42


Epoch 8 - Training loss: 0.79: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 46.83it/s]


Epoch 8 - Training loss: 0.43 - Validation loss: 0.41


Epoch 9 - Training loss: 0.24: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 45.41it/s]


Epoch 9 - Training loss: 0.43 - Validation loss: 0.41


Epoch 10 - Training loss: 0.30: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 39.86it/s]


Epoch 10 - Training loss: 0.43 - Validation loss: 0.41


Epoch 11 - Training loss: 0.26: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.81it/s]


Epoch 11 - Training loss: 0.43 - Validation loss: 0.41


Epoch 12 - Training loss: 0.84: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.21it/s]


Epoch 12 - Training loss: 0.43 - Validation loss: 0.40


Epoch 13 - Training loss: 0.23: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 41.52it/s]


Epoch 13 - Training loss: 0.43 - Validation loss: 0.40


Epoch 14 - Training loss: 0.26: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.45it/s]


Epoch 14 - Training loss: 0.43 - Validation loss: 0.41


Epoch 15 - Training loss: 0.35: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.65it/s]


Epoch 15 - Training loss: 0.42 - Validation loss: 0.40


Epoch 16 - Training loss: 0.26: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.50it/s]


Epoch 16 - Training loss: 0.42 - Validation loss: 0.40


Epoch 17 - Training loss: 0.52: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.07it/s]


Epoch 17 - Training loss: 0.42 - Validation loss: 0.40


Epoch 18 - Training loss: 0.22: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 44.41it/s]


Epoch 18 - Training loss: 0.42 - Validation loss: 0.41


Epoch 19 - Training loss: 0.34: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.84it/s]


Epoch 19 - Training loss: 0.42 - Validation loss: 0.40


Epoch 20 - Training loss: 0.29: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 39.78it/s]


Epoch 20 - Training loss: 0.42 - Validation loss: 0.40


Epoch 21 - Training loss: 0.29: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.55it/s]


Epoch 21 - Training loss: 0.42 - Validation loss: 0.40


Epoch 22 - Training loss: 0.34: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.79it/s]


Epoch 22 - Training loss: 0.42 - Validation loss: 0.41


Epoch 23 - Training loss: 0.33: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.34it/s]


Epoch 23 - Training loss: 0.42 - Validation loss: 0.40


Epoch 24 - Training loss: 0.21: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.89it/s]


Epoch 24 - Training loss: 0.42 - Validation loss: 0.40


Epoch 25 - Training loss: 0.32: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 43.05it/s]


Epoch 25 - Training loss: 0.42 - Validation loss: 0.41


Epoch 26 - Training loss: 0.35: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.88it/s]


Epoch 26 - Training loss: 0.42 - Validation loss: 0.40


Epoch 27 - Training loss: 0.39: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 43.00it/s]


Epoch 27 - Training loss: 0.42 - Validation loss: 0.40


Epoch 28 - Training loss: 0.41: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 43.04it/s]


Epoch 28 - Training loss: 0.42 - Validation loss: 0.39


Epoch 29 - Training loss: 0.31: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.83it/s]


Epoch 29 - Training loss: 0.42 - Validation loss: 0.40


Epoch 30 - Training loss: 0.51: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 43.70it/s]


Epoch 30 - Training loss: 0.42 - Validation loss: 0.42


Epoch 31 - Training loss: 0.29: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.02it/s]


Epoch 31 - Training loss: 0.42 - Validation loss: 0.40


Epoch 32 - Training loss: 0.25: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 41.12it/s]


Epoch 32 - Training loss: 0.42 - Validation loss: 0.40


Epoch 33 - Training loss: 0.43: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.37it/s]


Epoch 33 - Training loss: 0.42 - Validation loss: 0.44


Epoch 34 - Training loss: 0.30: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.57it/s]


Epoch 34 - Training loss: 0.42 - Validation loss: 0.45


Epoch 35 - Training loss: 0.29: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.95it/s]


Epoch 35 - Training loss: 0.42 - Validation loss: 0.39


Epoch 36 - Training loss: 0.72: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 43.69it/s]


Epoch 36 - Training loss: 0.41 - Validation loss: 0.39


Epoch 37 - Training loss: 0.30: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 43.37it/s]


Epoch 37 - Training loss: 0.41 - Validation loss: 0.40


Epoch 38 - Training loss: 0.28: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.79it/s]


Epoch 38 - Training loss: 0.41 - Validation loss: 0.40


Epoch 39 - Training loss: 0.27: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.69it/s]


Epoch 39 - Training loss: 0.41 - Validation loss: 0.40


Epoch 40 - Training loss: 0.65: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.74it/s]


Epoch 40 - Training loss: 0.41 - Validation loss: 0.40


Epoch 41 - Training loss: 0.27: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.47it/s]


Epoch 41 - Training loss: 0.41 - Validation loss: 0.40


Epoch 42 - Training loss: 0.55: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 41.93it/s]


Epoch 42 - Training loss: 0.42 - Validation loss: 0.40


Epoch 43 - Training loss: 0.28: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.94it/s]


Epoch 43 - Training loss: 0.42 - Validation loss: 0.39


Epoch 44 - Training loss: 0.22: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 43.30it/s]


Epoch 44 - Training loss: 0.41 - Validation loss: 0.39


Epoch 45 - Training loss: 0.74: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.96it/s]


Epoch 45 - Training loss: 0.41 - Validation loss: 0.41


Epoch 46 - Training loss: 0.25: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.63it/s]


Epoch 46 - Training loss: 0.42 - Validation loss: 0.42


Epoch 47 - Training loss: 0.22: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.25it/s]


Epoch 47 - Training loss: 0.42 - Validation loss: 0.39


Epoch 48 - Training loss: 1.54: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 41.25it/s]


Epoch 48 - Training loss: 0.41 - Validation loss: 0.39


Epoch 49 - Training loss: 0.83: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.05it/s]


Epoch 49 - Training loss: 0.41 - Validation loss: 0.41
Best Epoch 47 - Best Validation loss: 0.389247328042984


In [18]:
ft_test_loss = []
with torch.no_grad():
    for IDs in test_dataloader:
        best_finetune_model.eval()

        X = torch.tensor(data.loc[IDs].to_numpy()).to(device)
        X = X.float().reshape(-1, max_len, X.shape[-1])
        targets = X[:, 18:24, :]
        X = X[:, :18, :]

        pred = best_finetune_model.predict(X.float())
        pred = pred.reshape(X.shape[0], out_size, -1)
        loss = loss_fn(pred, targets)

        ft_test_loss.append(loss.item())


print("Test MSE loss: {}".format(np.mean(ft_test_loss)))
print("Test RMSE loss: {}".format(np.sqrt(np.mean(ft_test_loss))))

Test MSE loss: 0.39245657679400864
Test RMSE loss: 0.6264635478573424


In [19]:
ft_test_loss = []
with torch.no_grad():
    for IDs in test_dataloader:
        best_finetune_model.eval()

        X = torch.tensor(data.loc[IDs].to_numpy()).to(device)
        X = X.float().reshape(-1, max_len, X.shape[-1])
        targets = X[:, 18:24, :] * torch.tensor(normalizer.scale_).to(device) + torch.tensor(normalizer.mean_).to(
            device
        )
        X = X[:, :18, :]

        pred = best_finetune_model.predict(X.float())
        pred = pred.reshape(X.shape[0], out_size, -1) * torch.tensor(normalizer.scale_).to(device) + torch.tensor(
            normalizer.mean_
        ).to(device)
        loss = loss_fn(pred, targets)

        ft_test_loss.append(loss.item())


print("Test MSE loss: {}".format(np.mean(ft_test_loss)))
print("Test RMSE loss: {}".format(np.sqrt(np.mean(ft_test_loss))))
best_finetune_model_simmtm = copy.deepcopy(best_finetune_model)

Test MSE loss: 25944.05754264833
Test RMSE loss: 161.07159135815456


# Training Loop without SimMTM


In [20]:
finetune_model = copy.deepcopy(init_model)
optimizer = torch.optim.AdamW(finetune_model.parameters(), lr=1e-3)

In [21]:
i = 0
max_epoch = 50
best_loss = 1.0e10
best_finetune_model = copy.deepcopy(init_model)
best_epoch = 0
device = "cuda"
finetune_model.to(device)
while i < max_epoch:
    ft_train_loss = []
    progress_bar = tqdm(train_dataloader)

    for IDs in progress_bar:
        finetune_model.train()

        X = torch.tensor(data.loc[IDs].to_numpy()).to(device)
        X = X.float().reshape(-1, max_len, X.shape[-1])
        targets = X[:, 18:24, :]
        X = X[:, :18, :]

        pred = finetune_model.predict(X)
        pred = pred.reshape(X.shape[0], out_size, -1)
        loss = loss_fn(pred, targets)

        optimizer.zero_grad()
        loss.backward()

        nn.utils.clip_grad_norm_(finetune_model.parameters(), max_norm=4.0)
        optimizer.step()

        progress_bar.set_description("Epoch {} - Training loss: {:.2f}".format(i, loss))
        ft_train_loss.append(loss.item())

    with torch.no_grad():
        ft_val_loss = []
        for IDs in val_dataloader:
            finetune_model.eval()

            X = torch.tensor(data.loc[IDs].to_numpy()).to(device)
            X = X.float().reshape(-1, max_len, X.shape[-1])
            targets = X[:, 18:24, :]
            X = X[:, :18, :]

            pred = finetune_model.predict(X.float())
            pred = pred.reshape(X.shape[0], out_size, -1)
            loss = loss_fn(pred, targets)

            ft_val_loss.append(loss.item())

        if torch.tensor(ft_val_loss).mean().item() < best_loss:
            best_loss = torch.tensor(ft_val_loss).mean().item()
            best_finetune_model = copy.deepcopy(finetune_model)
            best_epoch = i

    progress_bar.write(
        "Epoch {} - Training loss: {:.2f} - Validation loss: {:.2f}".format(
            i, torch.tensor(ft_train_loss).mean().item(), torch.tensor(ft_val_loss).mean().item()
        )
    )
    i += 1


tqdm.write("Best Epoch {} - Best Validation loss: {}".format(best_epoch, best_loss))

Epoch 0 - Training loss: 0.47: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.24it/s]


Epoch 0 - Training loss: 0.71 - Validation loss: 0.48


Epoch 1 - Training loss: 0.35: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 41.81it/s]


Epoch 1 - Training loss: 0.51 - Validation loss: 0.46


Epoch 2 - Training loss: 0.31: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.48it/s]


Epoch 2 - Training loss: 0.48 - Validation loss: 0.46


Epoch 3 - Training loss: 0.50: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.62it/s]


Epoch 3 - Training loss: 0.47 - Validation loss: 0.44


Epoch 4 - Training loss: 0.30: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.17it/s]


Epoch 4 - Training loss: 0.46 - Validation loss: 0.43


Epoch 5 - Training loss: 0.25: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 41.05it/s]


Epoch 5 - Training loss: 0.45 - Validation loss: 0.48


Epoch 6 - Training loss: 0.34: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 39.86it/s]


Epoch 6 - Training loss: 0.45 - Validation loss: 0.42


Epoch 7 - Training loss: 0.68: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 38.97it/s]


Epoch 7 - Training loss: 0.44 - Validation loss: 0.42


Epoch 8 - Training loss: 0.56: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 38.93it/s]


Epoch 8 - Training loss: 0.44 - Validation loss: 0.42


Epoch 9 - Training loss: 0.32: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 39.60it/s]


Epoch 9 - Training loss: 0.44 - Validation loss: 0.41


Epoch 10 - Training loss: 0.58: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.65it/s]


Epoch 10 - Training loss: 0.44 - Validation loss: 0.41


Epoch 11 - Training loss: 1.10: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 47.17it/s]


Epoch 11 - Training loss: 0.43 - Validation loss: 0.41


Epoch 12 - Training loss: 0.55: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 47.27it/s]


Epoch 12 - Training loss: 0.43 - Validation loss: 0.41


Epoch 13 - Training loss: 0.44: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 45.21it/s]


Epoch 13 - Training loss: 0.43 - Validation loss: 0.41


Epoch 14 - Training loss: 0.22: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 44.80it/s]


Epoch 14 - Training loss: 0.43 - Validation loss: 0.41


Epoch 15 - Training loss: 0.20: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 44.70it/s]


Epoch 15 - Training loss: 0.43 - Validation loss: 0.40


Epoch 16 - Training loss: 0.22: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 44.78it/s]


Epoch 16 - Training loss: 0.43 - Validation loss: 0.42


Epoch 17 - Training loss: 0.70: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.82it/s]


Epoch 17 - Training loss: 0.43 - Validation loss: 0.41


Epoch 18 - Training loss: 0.45: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 39.78it/s]


Epoch 18 - Training loss: 0.43 - Validation loss: 0.44


Epoch 19 - Training loss: 0.20: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.04it/s]


Epoch 19 - Training loss: 0.43 - Validation loss: 0.40


Epoch 20 - Training loss: 0.22: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.88it/s]


Epoch 20 - Training loss: 0.43 - Validation loss: 0.41


Epoch 21 - Training loss: 0.22: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.81it/s]


Epoch 21 - Training loss: 0.43 - Validation loss: 0.40


Epoch 22 - Training loss: 0.60: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.71it/s]


Epoch 22 - Training loss: 0.43 - Validation loss: 0.40


Epoch 23 - Training loss: 0.27: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 41.33it/s]


Epoch 23 - Training loss: 0.43 - Validation loss: 0.40


Epoch 24 - Training loss: 0.81: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.36it/s]


Epoch 24 - Training loss: 0.42 - Validation loss: 0.40


Epoch 25 - Training loss: 0.24: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.60it/s]


Epoch 25 - Training loss: 0.43 - Validation loss: 0.40


Epoch 26 - Training loss: 0.32: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.29it/s]


Epoch 26 - Training loss: 0.43 - Validation loss: 0.40


Epoch 27 - Training loss: 0.54: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.42it/s]


Epoch 27 - Training loss: 0.42 - Validation loss: 0.40


Epoch 28 - Training loss: 0.31: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.50it/s]


Epoch 28 - Training loss: 0.43 - Validation loss: 0.40


Epoch 29 - Training loss: 0.26: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.98it/s]


Epoch 29 - Training loss: 0.42 - Validation loss: 0.40


Epoch 30 - Training loss: 0.29: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.43it/s]


Epoch 30 - Training loss: 0.42 - Validation loss: 0.40


Epoch 31 - Training loss: 1.44: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.36it/s]


Epoch 31 - Training loss: 0.42 - Validation loss: 0.40


Epoch 32 - Training loss: 0.30: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 41.81it/s]


Epoch 32 - Training loss: 0.42 - Validation loss: 0.40


Epoch 33 - Training loss: 0.30: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 43.05it/s]


Epoch 33 - Training loss: 0.42 - Validation loss: 0.40


Epoch 34 - Training loss: 0.33: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.94it/s]


Epoch 34 - Training loss: 0.42 - Validation loss: 0.40


Epoch 35 - Training loss: 0.25: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.90it/s]


Epoch 35 - Training loss: 0.42 - Validation loss: 0.40


Epoch 36 - Training loss: 0.21: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.23it/s]


Epoch 36 - Training loss: 0.42 - Validation loss: 0.40


Epoch 37 - Training loss: 0.25: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 39.90it/s]


Epoch 37 - Training loss: 0.42 - Validation loss: 0.40


Epoch 38 - Training loss: 0.27: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 39.96it/s]


Epoch 38 - Training loss: 0.42 - Validation loss: 0.41


Epoch 39 - Training loss: 0.37: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 43.77it/s]


Epoch 39 - Training loss: 0.42 - Validation loss: 0.40


Epoch 40 - Training loss: 0.59: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 41.31it/s]


Epoch 40 - Training loss: 0.42 - Validation loss: 0.40


Epoch 41 - Training loss: 0.18: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 42.06it/s]


Epoch 41 - Training loss: 0.42 - Validation loss: 0.40


Epoch 42 - Training loss: 0.26: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 41.50it/s]


Epoch 42 - Training loss: 0.42 - Validation loss: 0.39


Epoch 43 - Training loss: 0.76: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.21it/s]


Epoch 43 - Training loss: 0.42 - Validation loss: 0.40


Epoch 44 - Training loss: 0.18: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 41.29it/s]


Epoch 44 - Training loss: 0.42 - Validation loss: 0.39


Epoch 45 - Training loss: 0.25: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 41.54it/s]


Epoch 45 - Training loss: 0.42 - Validation loss: 0.49


Epoch 46 - Training loss: 0.27: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 41.40it/s]


Epoch 46 - Training loss: 0.42 - Validation loss: 0.40


Epoch 47 - Training loss: 0.65: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 40.78it/s]


Epoch 47 - Training loss: 0.42 - Validation loss: 0.40


Epoch 48 - Training loss: 0.30: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 38.98it/s]


Epoch 48 - Training loss: 0.42 - Validation loss: 0.40


Epoch 49 - Training loss: 0.28: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:03<00:00, 39.72it/s]


Epoch 49 - Training loss: 0.42 - Validation loss: 0.40
Best Epoch 44 - Best Validation loss: 0.39469093084335327


In [22]:
ft_test_loss = []
with torch.no_grad():
    for IDs in test_dataloader:
        best_finetune_model.eval()

        X = torch.tensor(data.loc[IDs].to_numpy()).to(device)
        X = X.float().reshape(-1, max_len, X.shape[-1])
        targets = X[:, 18:24, :]
        X = X[:, :18, :]

        pred = best_finetune_model.predict(X.float())
        pred = pred.reshape(X.shape[0], out_size, -1)
        loss = loss_fn(pred, targets)

        ft_test_loss.append(loss.item())


print("Test MSE loss: {}".format(np.mean(ft_test_loss)))
print("Test RMSE loss: {}".format(np.sqrt(np.mean(ft_test_loss))))

Test MSE loss: 0.4030115777933145
Test RMSE loss: 0.6348319287758882


In [23]:
ft_test_loss = []
with torch.no_grad():
    for IDs in test_dataloader:
        best_finetune_model.eval()

        X = torch.tensor(data.loc[IDs].to_numpy()).to(device)
        X = X.float().reshape(-1, max_len, X.shape[-1])
        targets = X[:, 18:24, :] * torch.tensor(normalizer.scale_).to(device) + torch.tensor(normalizer.mean_).to(
            device
        )
        X = X[:, :18, :]

        pred = best_finetune_model.predict(X.float())
        pred = pred.reshape(X.shape[0], out_size, -1) * torch.tensor(normalizer.scale_).to(device) + torch.tensor(
            normalizer.mean_
        ).to(device)
        loss = loss_fn(pred, targets)

        ft_test_loss.append(loss.item())


print("Test MSE loss: {}".format(np.mean(ft_test_loss)))
print("Test RMSE loss: {}".format(np.sqrt(np.mean(ft_test_loss))))

Test MSE loss: 28813.008208023184
Test RMSE loss: 169.74394895849213


In [24]:
# X = torch.tensor(data.loc[next(iter(test_dataloader))].to_numpy()).to(device)
# X = X.float().reshape(-1, max_len, X.shape[-1])
# target = X[:, 18:24, :]
# X = X[:, :18, :]
# pred_1 = best_finetune_model.predict(X.float())
# pred_2 = best_finetune_model_simmtm.predict(X.float())

In [25]:
# i = 4
# j = 1
# plt.plot(torch.cat((X[i, :, j], pred_1[i, :, j])).cpu().detach().numpy(), ls="--", label='no simmtm')
# plt.plot(torch.cat((X[i, :, j], pred_2[i, :, j])).cpu().detach().numpy(), ls="--", label='simmtm')
# plt.plot(torch.cat((X[i, :, j], target[i, :, j])).cpu().detach().numpy())
# plt.legend()

Reference:
1. https://arxiv.org/abs/2302.00861
2. https://github.com/gzerveas/mvts_transformer

Normalized Results

No Pretrain

Test MSE loss: 0.401656836271286

Test MSE loss: 0.3942021131515503

Test MSE loss: 0.4026849865913391

Test MSE loss: 0.3934531509876251

Test MSE loss: 0.4011250138282776

Pretrain

Test MSE loss: 0.39376363158226013

Test MSE loss: 0.3938564658164978

Test MSE loss: 0.39758536219596863

Test MSE loss: 0.3879236876964569

Test MSE loss: 0.3867727518081665

Unnormalized Results

No Pretrain

Test MSE loss: 28955.188077749768

Test MSE loss: 26232.348384247733

Test MSE loss: 29180.849135569904

Test MSE loss: 25597.15184640446

Test MSE loss: 28796.784927283206

Pretrain

Test MSE loss: 25913.2302627422

Test MSE loss: 26665.712613352596

Test MSE loss: 26484.948554457686

Test MSE loss: 26331.686181724996

Test MSE loss: 25527.132474246446