In [1]:
from model import PatchTST

# Training

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Hyperparameters
N_BATCH = 32
N_CHANNELS = 7
SEQ_LEN = 512
PATCH_LEN = 12
D_MODEL = 128
N_HEADS = 16
N_LAYERS = 3
D_FF = 256
DROPOUT = 0.2
MASKING_RATIO = 0.4
LEARNING_RATE = 1e-4
EPOCHS = 1 # 100 ~ 300

# Model Initialization
model = PatchTST(
    n_channels=N_CHANNELS,
    seq_len=SEQ_LEN,
    patch_len=PATCH_LEN,
    d_model=D_MODEL,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    d_ff=D_FF,
    dropout=DROPOUT,
    masking_ratio=MASKING_RATIO
)

# Dataset and DataLoader
# input shape: (n_samples, n_channels, seq_length)
dummy_data = torch.randn(1000, N_CHANNELS, SEQ_LEN)
dataset = TensorDataset(dummy_data)
dataloader = DataLoader(dataset, batch_size=N_BATCH, shuffle=True)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()

model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    for i, (batch_x,) in enumerate(dataloader):
        # batch_x: (B, M, L)
        B, M, L = batch_x.shape

        # pred_patches: (B * M, N, P)
        # bool_mask: (B * M, N)
        pred_patches, bool_mask = model(batch_x)
        truncated_len = model.n_patches * model.patch_len
        

        gt_x = batch_x[:, :, :truncated_len] # gt_x: (B, M, L_trunc)
        gt_x = gt_x.reshape(B * M, 1, truncated_len) # gt_x: (B * M, 1, L_trunc)
        gt_x_norm = model.instance_norm(gt_x) # gt_x_norm: (B * M, 1, L_trunc)
        gt_x_norm = gt_x_norm.squeeze(1) # gt_x_norm: (B * M, L_trunc)
        gt_patches = gt_x_norm.unfold(dimension=-1, size=PATCH_LEN, step=PATCH_LEN) # gt_patches: (B * M, N, P)
        
        # loss backward
        loss = criterion(pred_patches[bool_mask], gt_patches[bool_mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {avg_loss:.6f}", end='\r')

torch.save(model.state_dict(), 'patchtst_pretrained_representation.pth')



Epoch [1/1], Loss: 1.106863

# Inference

In [5]:
model.eval()
with torch.no_grad():
    # sample_data: (B, M, L)
    sample_data = next(iter(dataloader))[0]
    
    # For representation, only output of encoder is wanted
    B, M, L = sample_data.shape
    x = sample_data[:, :, :model.truncated_len]
    x = x.reshape(B * M, 1, model.truncated_len)
    x_norm = model.instance_norm(x)
    x_norm = x_norm.squeeze(1)
    patches = x_norm.unfold(dimension=-1, size=model.patch_len, step=model.stride)
    patches_proj = model.projection(patches)
    x_encoded = model.dropout(patches_proj + model.pos_embedding)
    
    # representation 'z' is the output of Transformer Encoder
    # representation: (B * M, N, D)
    representation = model.transformer_encoder(x_encoded)  

    # After flatten this representation, it can be used for various tasks
    # pooled_representation: (B * M, D)
    pooled_representation = torch.mean(representation, dim=1) # Global Average Pooling
    
    
    # reshape to the original channel dimension
    # pooled_representation: (B, M, D)
    pooled_representation = pooled_representation.view(B, M, D_MODEL)