In [17]:
import polars as pl
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import StandardScaler
from typing import Tuple
from sklearn.metrics import root_mean_squared_error
import numpy as np

In [18]:
class LinearModel(nn.Module):
    def __init__(
            self,
            n_features: int,
    ) -> None:
        super().__init__()
        self.fc1 = nn.Linear(
            in_features=n_features,
            out_features=1,
        )  # Just 1 fully connected layer without activation, i.e. a linear regression.

    def forward(
        self,  
        X: torch.Tensor,  
    ) -> torch.Tensor:
        y = self.fc1(X)
        return y
        
class CustomDataset(Dataset):
    def __init__(
            self,
            X: torch.Tensor,
            y: torch.Tensor,
        ) -> None:
        self.X = X
        self.y = y

    def __len__(
            self
    ) -> int:
        return self.X.shape[0]
    
    def __getitem__(
            self,
            idx: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        X_item = self.X[idx,:]
        y_item = self.y[idx,:]
        return X_item, y_item
    
class CustomLossRelu(nn.Module):
    def __init__(
            self,
            threshold: float,
            weight_max_error: float = 1,
            weight_percentage_above_threshold:float = 1,
            weight_wrong_sign: float = 1,
    ) -> None:
        super().__init__()
        self.threshold = threshold
        # Normalize weights and assign them
        sum_weights = (
            weight_max_error
            + weight_percentage_above_threshold
            + weight_wrong_sign
        )
        self.weight_max_error = (
            weight_max_error / sum_weights
        )
        self.weight_percentage_above_threshold = (
            weight_percentage_above_threshold / sum_weights
        )
        self.weight_wrong_sign = (
            weight_wrong_sign / sum_weights
        )

    def forward(
            self, 
            inputs: torch.Tensor, 
            targets: torch.Tensor,
        ) -> torch.Tensor:

        residuals = targets - inputs
        # Maximum abs error
        max_error = residuals.abs().max()

        # Percentage of time above threshold value
        percentage_of_time_above_x = nn.functional.relu(
            residuals.abs()-self.threshold
        ).mean()

        # Percentage of time wrong sign
        loss_percentage_of_time_wrong_sign = nn.functional.relu(
            -targets*inputs
        ).mean()
        
        # Total loss
        total_loss = (
            self.weight_max_error * max_error
            + self.weight_percentage_above_threshold * percentage_of_time_above_x
            + self.weight_wrong_sign * loss_percentage_of_time_wrong_sign
        )
        return total_loss

In [19]:
X_train_raw = pl.read_parquet(
    "/home/thomas/repos/simplify_deployment/data/potential_features/s1/Xy_train_s1.parquet"
).select(pl.exclude("target"))

X_test_raw = pl.read_parquet(
    "/home/thomas/repos/simplify_deployment/data/potential_features/s1/Xy_test_s1.parquet"
).select(pl.exclude("target"))

y_train_raw = pl.read_parquet(
    "/home/thomas/repos/simplify_deployment/data/potential_features/s1/Xy_train_s1.parquet"
).select(pl.col(["datetime_utc", "target"]))

y_test_raw = pl.read_parquet(
    "/home/thomas/repos/simplify_deployment/data/potential_features/s1/Xy_test_s1.parquet"
).select(pl.col(["datetime_utc", "target"]))


X_scaler = StandardScaler()
y_scaler = StandardScaler()

X_train = X_scaler.fit_transform(X_train_raw.select(pl.exclude("datetime_utc")))
X_test = X_scaler.transform(X_test_raw.select(pl.exclude("datetime_utc")))

y_train = y_scaler.fit_transform(y_train_raw.select(pl.exclude("datetime_utc")))
y_test = y_scaler.transform(y_test_raw.select(pl.exclude("datetime_utc")))


In [20]:
epochs = 2000
lr = 1e-4
batch_size = 672
threshold = 117
converted_threshold = threshold / np.std(y_train_raw.select(pl.col("target")).to_numpy())

dataloader = DataLoader(
    CustomDataset(
        torch.Tensor(X_train).float(),
        torch.Tensor(y_train).float(),
    ),
    batch_size=batch_size,
    shuffle=True,
)
model_rmse = LinearModel(
    n_features=X_train.shape[1],
)
optimizer = torch.optim.Adam(
    model_rmse.parameters(),
    lr = lr,
)
criterion = CustomLossRelu(
    threshold=converted_threshold,
    weight_max_error=1,
    weight_percentage_above_threshold=1,
    weight_wrong_sign=1,   
)
for epoch in range(epochs):
    epoch_loss = 0
    model_rmse.train()
    for i,(X_batch, y_batch) in enumerate(dataloader):
        prediction = model_rmse(X_batch)
        optimizer.zero_grad()
        loss = criterion(prediction, y_batch)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    average_loss = epoch_loss/len(dataloader)
    print(f"Average epoch loss: {average_loss}")
    print(f"Epoch {epoch} done.")

model_rmse.eval()

Average epoch loss: 2.4954645312749424
Epoch 0 done.
Average epoch loss: 2.417843277637775
Epoch 1 done.
Average epoch loss: 2.348764568567276
Epoch 2 done.
Average epoch loss: 2.329694232115379
Epoch 3 done.
Average epoch loss: 2.277625485108449
Epoch 4 done.
Average epoch loss: 2.264802050132018
Epoch 5 done.
Average epoch loss: 2.2116605180960436
Epoch 6 done.
Average epoch loss: 2.1000250165279093
Epoch 7 done.
Average epoch loss: 2.093675679885424
Epoch 8 done.
Average epoch loss: 2.07962406369356
Epoch 9 done.
Average epoch loss: 2.0649191553776083
Epoch 10 done.
Average epoch loss: 1.9957969349164228
Epoch 11 done.
Average epoch loss: 1.9446846407193403
Epoch 12 done.
Average epoch loss: 1.9289141457814436
Epoch 13 done.
Average epoch loss: 1.8798181747014706
Epoch 14 done.
Average epoch loss: 1.8702705708833842
Epoch 15 done.
Average epoch loss: 1.853994140258202
Epoch 16 done.
Average epoch loss: 1.8613128501635332
Epoch 17 done.
Average epoch loss: 1.8187897503376007
Epoch 18

LinearModel(
  (fc1): Linear(in_features=25, out_features=1, bias=True)
)

In [21]:
test_prediction = y_scaler.inverse_transform(
    model_rmse(torch.Tensor(X_test)).detach().numpy(),
)
test_prediction

array([[  32.500042],
       [ -47.884945],
       [ -18.55689 ],
       ...,
       [-113.133   ],
       [-147.4495  ],
       [   4.388327]], dtype=float32)

In [22]:
test_prediction_df = y_test_raw.with_columns(pl.Series(name="prediction_custom_3_part_loss", values = test_prediction.squeeze()))
test_prediction_df

datetime_utc,target,prediction_custom_3_part_loss
"datetime[μs, UTC]",f32,f32
2024-01-03 00:14:00 UTC,-50.097,32.500042
2024-01-03 00:29:00 UTC,41.205002,-47.884945
2024-01-03 00:44:00 UTC,10.788,-18.55689
2024-01-03 00:59:00 UTC,28.538,24.080067
2024-01-03 01:14:00 UTC,11.015,-22.084848
…,…,…
2024-01-30 22:44:00 UTC,-140.723007,-146.902588
2024-01-30 22:59:00 UTC,-105.958,-182.505753
2024-01-30 23:14:00 UTC,-7.133,-113.133003
2024-01-30 23:29:00 UTC,2.132,-147.449493


In [23]:
test_prediction_df.write_parquet(f"/home/thomas/repos/simplify_deployment/data/potential_features/mi_variables/results/3_part_loss_batch_{batch_size}_epochs_{epochs}_s1_variables.parquet")

In [24]:
root_mean_squared_error(
    y_true = test_prediction_df.select(pl.col("target")),
    y_pred = test_prediction_df.select(pl.col("prediction_custom_3_part_loss"))
)

125.954605