In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt


# Positional Encoding with inhomogeneous spacing
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model

    def forward(self, x, times):
        device = x.device
        batch_size, seq_len, _ = x.shape
        pe = torch.zeros(batch_size, seq_len, self.d_model, device=device)

        div_term = torch.exp(
            torch.arange(0, self.d_model, 2, device=device).float() * (
                        -np.log(10000.0) / self.d_model))

        pe[:, :, 0::2] = torch.sin(times.unsqueeze(-1) * div_term)
        pe[:, :, 1::2] = torch.cos(times.unsqueeze(-1) * div_term)

        return x + pe


# Transformer Model
class TransformerModel(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_layers, dim_feedforward,
                 output_dim):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead,
                                                    dim_feedforward)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers,
                                                         num_layers)
        self.encoder = nn.Linear(input_dim, d_model)
        self.decoder = nn.Linear(d_model, output_dim)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, times):
        if self.src_mask is None or self.src_mask.size(0) != src.size(1):
            device = src.device
            mask = self._generate_square_subsequent_mask(src.size(1)).to(device)
            self.src_mask = mask

        src = self.encoder(src)
        src = self.pos_encoder(src, times)
        output = self.transformer_encoder(src.transpose(0, 1), self.src_mask)
        output = self.decoder(output.transpose(0, 1))
        return output

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
            mask == 1, float(0.0))
        return mask


# Training function
#The model is looking at the first 99 points (input_seq = val_data[i:i+1, :-1, :]) and 
#predicting the next 99 points (prediction = model(input_seq, input_times)).
def train(model, train_data, train_times, val_data, val_times, num_epochs,
          batch_size, lr):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for i in range(0, train_data.shape[0], batch_size):
            batch_data = train_data[i:i + batch_size]
            batch_times = train_times[i:i + batch_size]
            input_seq = batch_data[:, :-1, :]
            input_times = batch_times[:, :-1]
            target_seq = batch_data[:, 1:, :]

            optimizer.zero_grad()
            output = model(input_seq, input_times)
            loss = criterion(output, target_seq)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()

            total_loss += loss.item()

        train_loss = total_loss / (train_data.shape[0] // batch_size)
        train_losses.append(train_loss)

        model.eval()
        with torch.no_grad():
            val_output = model(val_data[:, :-1, :], val_times[:, :-1])
            val_loss = criterion(val_output, val_data[:, 1:, :])
        val_losses.append(val_loss.item())

        scheduler.step()

        print(
            f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

    return train_losses, val_losses

In [None]:
# Data Generation with inhomogeneous spacing
def generate_data(batch_size, seq_length, num_features, min_gap=0.01,
                  max_gap=0.5):
    data = np.zeros((batch_size, seq_length, num_features))
    times = np.zeros((batch_size, seq_length))

    for i in range(batch_size):
        t = 0
        for j in range(seq_length):
            times[i, j] = t
            gap = np.random.uniform(min_gap, max_gap)
            t += gap

        times[i] /= times[i, -1]  # Normalize to [0, 1] range

        for k in range(num_features):
            freq = np.random.rand() * 2 + 0.1
            phase = np.random.rand() * 2 * np.pi
            amplitude = np.random.rand() * 2
            if k % 2 == 0:
                data[i, :, k] = amplitude * np.sin(freq * times[i] + phase)
            else:
                data[i, :, k] = amplitude * np.cos(freq * times[i] + phase)

    noise = np.random.randn(*data.shape) * 0.1
    data += noise

    return torch.FloatTensor(data), torch.FloatTensor(times)

In [None]:
# function to plot example data
def plot_example_data(data, times, num_examples=3):
    num_features = data.shape[2]
    fig, axs = plt.subplots(num_examples, 1, figsize=(12, 4 * num_examples),
                            sharex=True)
    fig.suptitle('Example Time Series with Inhomogeneous Spacing', fontsize=16)

    for i in range(num_examples):
        example_data = data[i]
        example_times = times[i]

        for j in range(num_features):
            axs[i].plot(example_times, example_data[:, j],
                        label=f'Feature {j + 1}', alpha=0.7)
            axs[i].scatter(example_times, example_data[:, j],
                           label=f'Feature {j + 1}', alpha=0.7)

        axs[i].set_ylabel('Value')
        #axs[i].legend()
        axs[i].grid(True)

    axs[-1].set_xlabel('Time')
    plt.tight_layout()
    plt.show()


In [None]:
# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Generate data
batch_size = 1000
seq_length = 100
num_features = 1
data, times = generate_data(batch_size, seq_length, num_features)

# Plot example data
plot_example_data(data, times)

In [None]:
def multi_step_predict(model, input_seq, input_times, prediction_window):
    model.eval()
    with torch.no_grad():
        output = model(input_seq, input_times)
        return output[:, -prediction_window:, :]


# Modified function to plot multiple predictions
def plot_multiple_predictions(model, val_data, val_times, num_examples=3,
                              input_window=80, prediction_window=20):
    for i in range(num_examples):
        input_seq = val_data[i:i + 1, :input_window, :]
        input_times = val_times[i:i + 1, :input_window]
        target_seq = val_data[i:i + 1,
                     input_window:input_window + prediction_window, :]
        target_times = val_times[i:i + 1,
                       input_window:input_window + prediction_window]

        prediction = multi_step_predict(model, input_seq, input_times,
                                        prediction_window)

        plt.figure(figsize=(12, 4))
        for j in range(input_seq.shape[2]):
            plt.subplot(input_seq.shape[2], 1, j + 1)
            plt.plot(input_times[0].numpy(), input_seq[0, :, j].numpy(),
                     label='Input', alpha=0.7)
            plt.plot(target_times[0].numpy(), target_seq[0, :, j].numpy(),
                     label='True', alpha=0.7)
            plt.scatter(target_times[0].numpy(), prediction[0, :, j].numpy(),
                        label='Predicted', alpha=0.7, c='teal')
            plt.legend()
            plt.title(f'Example {i + 1}, Feature {j + 1}')
        plt.tight_layout()
        plt.show()

In [None]:
# Training parameters
num_epochs = 20
train_batch_size = 32
lr = 0.001
input_window = 80
prediction_window = 20

# Train the model
train_losses, val_losses = train(model, train_data, train_times, val_data,
                                 val_times, num_epochs, train_batch_size, lr,
                                 input_window, prediction_window)
