In [1]:
import torch
import torch.nn as nn
from transformers import PatchTSTConfig, PatchTSTModel


  from .autonotebook import tqdm as notebook_tqdm


# Patch TST Huggingface

In [82]:
context_length=200
num_input_channels=3
patch_length=8
config = PatchTSTConfig(
            num_input_channels=num_input_channels,
            context_length=context_length,
            patch_length=patch_length,
            use_cls_token=True,  # Ensure CLS token usage is set correctly
            # Additional configurations can be added here if needed
        )
base_model = PatchTSTModel(config)


In [83]:
base_model = base_model.cuda()

# Example Run

In [84]:
past_values = torch.randn(32, context_length, num_input_channels).cuda()  # (batch_size, context_length, num_input_channels)
output = base_model(past_values)

In [85]:
output.keys()

odict_keys(['last_hidden_state', 'loc', 'scale', 'patch_input'])

In [78]:
output['last_hidden_state'].shape

torch.Size([32, 3, 10, 128])

In [None]:
model = CustomPatchTSTClassifier(
    num_input_channels=10,
    context_length=128,
    patch_length=16,
    num_classes_per_category=3,
    num_categories=4,
    use_cls_token=True
)

# Example input tensor
past_values = torch.randn(32, 128, 10)  # (batch_size, context_length, num_input_channels)

# Forward pass
logits = model(past_values) 

False

# Custom PatchTSTClassifier

In [129]:

class CustomPatchTSTClassifier(nn.Module):
    def __init__(self, num_input_channels, context_length, patch_length,
                 num_classes_per_category=3, num_categories=4, use_cls_token=True):
        super().__init__()

        # Configure PatchTSTModel
        self.config = PatchTSTConfig(
            num_input_channels=num_input_channels,
            context_length=context_length,
            patch_length=patch_length,
            use_patch_pe=True,        # Ensure positional encoding is used
            use_vars_per_channel=True # Depending on your data
        )
        self.base_model = PatchTSTModel(self.config)

        # Retrieve hidden size from the base model configuration
        hidden_size = self.config.hidden_size

        # Single classification head for all categories
        self.fc = nn.Linear(hidden_size, num_categories * num_classes_per_category)

        # Option to use CLS token or mean pooling
        self.use_cls_token = use_cls_token

        # Store other parameters
        self.num_classes_per_category = num_classes_per_category
        self.num_categories = num_categories

    def forward(self, past_values):
        # Get hidden states from PatchTSTModel
        outputs = self.base_model(past_values=past_values)
        hidden_states = outputs.last_hidden_state  # Shape: (batch_size, seq_len, hidden_size)
        print(f'hidden_state from patchtst base model is {hidden_states.shape}')

        # Use CLS token state or mean pooling based on the hyperparameter
        if self.use_cls_token:
            output_state = hidden_states[:,:,0, :]  # Assuming CLS token is at position 0
            output_state = output_state.mean(dim=1)
        else:
            output_state = hidden_states.mean(dim=2).mean(dim=1) 

        print(f'pre-head shape {output_state.shape}')
        # Generate logits for all categories
        logits = self.fc(output_state)  # Shape: (batch_size, num_categories * num_classes_per_category)
        print(f'post_head shape {output_state.shape}')
        # Reshape logits to (batch_size, num_categories, num_classes_per_category)
        logits = logits.view(-1, self.num_categories, self.num_classes_per_category)

        return logits  # Returns tensor of logits


In [130]:
context_length=200
num_input_channels=3
num_classes_per_category=3
num_categories=4
patch_length=8
custom_model= CustomPatchTSTClassifier(num_input_channels=num_input_channels,
                                 context_length=context_length,
                                 patch_length=patch_length,
                                 #hidden_size=hidden_size,
                                 num_classes_per_category=num_classes_per_category,
                                 num_categories=num_categories,
                                    use_cls_token=True)
custom_model = custom_model.cuda()
past_values = torch.randn(32, context_length, num_input_channels).cuda()  # (batch_size, context_length, num_input_channels)
output = custom_model(past_values)
probabilities = torch.softmax(output, dim=-1)

hidden_state from patchtst base model is torch.Size([32, 3, 193, 128])
pre-head shape torch.Size([32, 128])
post_head shape torch.Size([32, 128])


In [134]:
probabilities = torch.softmax(output, dim=-1)


In [136]:
probabilities.shape

torch.Size([32, 4, 3])