#### Import libraries

In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import warnings
import sys
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader

sys.path.append("../")
from pipeline import data
from pipeline.config import CONF
from pipeline.data import plots
from pipeline.data import io
from pipeline.data import inspection
from pipeline.data import preprocess

# To suppress all warnings
warnings.filterwarnings("ignore")

# black is a code formatter (see https://github.com/psf/black).
# It will automatically format the code you write in the cells imposing consistent Python style.
%load_ext jupyter_black
# matplotlib style file
# Template for style file: https://matplotlib.org/stable/tutorials/introductory/customizing.html#customizing-with-style-sheets
plt.style.use("../matplotlib_style.txt")
pd.set_option("display.max_columns", None)  # Show all columns
pd.set_option("display.expand_frame_repr", False)  # Prevent wrapping

### Load raw data

This takes about 1 minute for the first time.

In [None]:
if CONF.data.process_raw_data and not CONF.data.loaded_raw_data:
    # Load raw data
    (
        Installed_Capacity_Germany_Raw,
        Prices_Europe_Raw,
        Realised_Supply_Germany_Raw,
        Realised_Demand_Germany_Raw,
        Weather_Data_Germany_Raw,
        Weather_Data_Germany_2022_Raw,
    ) = data.load_data(CONF=CONF)
    CONF.data.loaded_raw_data = True

In [None]:
if CONF.data.process_raw_data:
    Installed_Capacity_Germany = Installed_Capacity_Germany_Raw.copy()
    Prices_Europe = Prices_Europe_Raw.copy()
    Realised_Supply_Germany = Realised_Supply_Germany_Raw.copy()
    Realised_Demand_Germany = Realised_Demand_Germany_Raw.copy()
    Weather_Data_Germany = Weather_Data_Germany_Raw.copy()
    Weather_Data_Germany_2022 = Weather_Data_Germany_2022_Raw.copy()

### Inspect raw data

##### Inspect missingness

In [None]:
if CONF.data.process_raw_data:
    # Generate profile reports
    if CONF.data.inspect:
        data.save_data_inspection(
            Installed_Capacity_Germany=Installed_Capacity_Germany,
            Prices_Europe=Prices_Europe,
            Realised_Supply_Germany=Realised_Supply_Germany,
            Realised_Demand_Germany=Realised_Demand_Germany,
            Weather_Data_Germany=Weather_Data_Germany,
            Weather_Data_Germany_2022=Weather_Data_Germany_2022,
            CONF=CONF,
            data_type="raw",
        )

##### Inspect resolution

In [None]:
if CONF.data.process_raw_data:
    inspection.date_range_and_resolution_dfs(
        Installed_Capacity_Germany=Installed_Capacity_Germany,
        Prices_Europe=Prices_Europe,
        Realised_Supply_Germany=Realised_Supply_Germany,
        Realised_Demand_Germany=Realised_Demand_Germany,
        Weather_Data_Germany=Weather_Data_Germany,
    )

#### Plot raw data

In [None]:
if CONF.data.plot:
    plots.plot_df(
        Installed_Capacity_Germany,
        "Installed_Capacity_Germany",
        CONF,
        processed_data=False,
    )
    plots.plot_df(Prices_Europe, "Prices_Europe", CONF, processed_data=False)
    plots.plot_df(
        Realised_Supply_Germany, "Realised_Supply_Germany", CONF, processed_data=False
    )
    plots.plot_df(
        Realised_Demand_Germany, "Realised_Demand_Germany", CONF, processed_data=False
    )
    plots.plot_df(
        Weather_Data_Germany,
        "Weather_Data_Germany",
        CONF,
        date_col=io.DATE_COLUMNS_WEATHER[-1],
        drop_date_cols=io.DATE_COLUMNS_WEATHER,
        processed_data=False,
    )

### Raw data pipeline

##### Merging weather data frames

In [None]:
if CONF.data.process_raw_data:
    # Remove the data for 2022 from the original dataframe
    Weather_Data_Germany = Weather_Data_Germany[
        Weather_Data_Germany["time"].dt.year != 2022
    ]

    # Concatenate the filtered original dataframe with the 2022 data
    Weather_Data_Germany = pd.concat(
        [Weather_Data_Germany, Weather_Data_Germany_2022], ignore_index=True
    )

##### Fill NaN

In [None]:
if CONF.data.process_raw_data:
    Processed_Installed_Capacity_Germany = data.process_na_values(
        Installed_Capacity_Germany, CONF
    )
    Processed_Prices_Europe = data.process_na_values(Prices_Europe, CONF)
    Processed_Realised_Supply_Germany = data.process_na_values(
        Realised_Supply_Germany, CONF
    )
    Processed_Realised_Demand_Germany = data.process_na_values(
        Realised_Demand_Germany, CONF
    )
    Processed_Weather_Data_Germany = data.process_na_values(Weather_Data_Germany, CONF)

#### Aggregate weather data

In [None]:
if CONF.data.process_raw_data:
    Processed_Weather_Data_Germany = preprocess.aggregate_weather_data(
        Processed_Weather_Data_Germany, ["forecast_origin", "time"]
    )

#### Decrease demand and supply's time resolution

In [None]:
if CONF.data.process_raw_data:
    Processed_Realised_Demand_Germany = Processed_Realised_Demand_Germany[
        Processed_Realised_Demand_Germany["Date to"].dt.minute == 0
    ]
    Processed_Realised_Demand_Germany["Date from"] = Processed_Realised_Demand_Germany[
        "Date to"
    ] - pd.Timedelta(hours=1)
    Processed_Realised_Supply_Germany = Processed_Realised_Supply_Germany[
        Processed_Realised_Supply_Germany["Date to"].dt.minute == 0
    ]
    Processed_Realised_Supply_Germany["Date from"] = Processed_Realised_Supply_Germany[
        "Date to"
    ] - pd.Timedelta(hours=1)

#### Increase time resolution of capacity

In [None]:
new_row = {
    "Date from": pd.Timestamp("2022-12-31 23:00:00"),
    "Date to": pd.Timestamp("2023-01-01 00:00:00"),
}
new_row_df = pd.DataFrame([new_row])
Processed_Installed_Capacity_Germany_hourly = pd.concat(
    [Processed_Installed_Capacity_Germany, new_row_df], ignore_index=True
)

Processed_Installed_Capacity_Germany_hourly = (
    Processed_Installed_Capacity_Germany_hourly.set_index("Date from")
)
Processed_Installed_Capacity_Germany_hourly = (
    Processed_Installed_Capacity_Germany_hourly.resample("H").mean()
)
Processed_Installed_Capacity_Germany_hourly.reset_index(inplace=True)
Processed_Installed_Capacity_Germany_hourly["Date to"] = (
    Processed_Installed_Capacity_Germany_hourly["Date from"] + pd.Timedelta(hours=1)
)
Processed_Installed_Capacity_Germany = Processed_Installed_Capacity_Germany_hourly
Processed_Installed_Capacity_Germany = Processed_Installed_Capacity_Germany.fillna(
    method="ffill"
)
inspection.date_range_and_resolution(
    Processed_Installed_Capacity_Germany, io.DATE_COLUMNS
)

#### Trim rows of every df to have same range

In [None]:
# trim first row of Processed_Weather_Data_Germany
Processed_Weather_Data_Germany = Processed_Weather_Data_Germany[
    Processed_Weather_Data_Germany["time"]
    != Processed_Weather_Data_Germany["time"].min()
]

# trim last row of every other df
Processed_Installed_Capacity_Germany = Processed_Installed_Capacity_Germany[
    Processed_Installed_Capacity_Germany["Date to"]
    != Processed_Installed_Capacity_Germany["Date to"].max()
]
Processed_Prices_Europe = Processed_Prices_Europe[
    Processed_Prices_Europe["Date to"] != Processed_Prices_Europe["Date to"].max()
]
Processed_Realised_Supply_Germany = Processed_Realised_Supply_Germany[
    Processed_Realised_Supply_Germany["Date to"]
    != Processed_Realised_Supply_Germany["Date to"].max()
]
Processed_Realised_Demand_Germany = Processed_Realised_Demand_Germany[
    Processed_Realised_Demand_Germany["Date to"]
    != Processed_Realised_Demand_Germany["Date to"].max()
]

#### Patch time saving

In [None]:
Processed_Prices_Europe = preprocess.patch_time_saving(Processed_Prices_Europe)
Processed_Realised_Demand_Germany = preprocess.patch_time_saving(
    Processed_Realised_Demand_Germany
)
Processed_Realised_Supply_Germany = preprocess.patch_time_saving(
    Processed_Realised_Supply_Germany
)

#### Normalize data

In [None]:
if CONF.data.process_raw_data and CONF.data.normalize_data:
    print("Split data in train, val and test")
    Processed_Installed_Capacity_Germany = preprocess.split_data(
        df=Processed_Installed_Capacity_Germany, column_name=io.DATE_COLUMNS[-1]
    )
    Processed_Prices_Europe = preprocess.split_data(
        df=Processed_Prices_Europe, column_name=io.DATE_COLUMNS[-1]
    )
    Processed_Realised_Supply_Germany = preprocess.split_data(
        df=Processed_Realised_Supply_Germany, column_name=io.DATE_COLUMNS[-1]
    )
    Processed_Realised_Demand_Germany = preprocess.split_data(
        df=Processed_Realised_Demand_Germany, column_name=io.DATE_COLUMNS[-1]
    )
    Processed_Weather_Data_Germany = preprocess.split_data(
        df=Processed_Weather_Data_Germany, column_name=io.DATE_COLUMNS_WEATHER[0]
    )

    print("Normalizing data")
    (
        Processed_Installed_Capacity_Germany,
        Processed_Installed_Capacity_Germany_Scalers,
    ) = preprocess.normalize_data(
        df=Processed_Installed_Capacity_Germany,
        ignore_features=io.DATE_COLUMNS,
        constant=CONF.data.price_normalization_constant,
    )

    Processed_Prices_Europe, Processed_Prices_Europe_Scalers = (
        preprocess.normalize_data(
            df=Processed_Prices_Europe,
            ignore_features=io.DATE_COLUMNS,
            constant=CONF.data.price_normalization_constant,
        )
    )
    Processed_Realised_Supply_Germany, Processed_Realised_Supply_Germany_Scalers = (
        preprocess.normalize_data(
            df=Processed_Realised_Supply_Germany, ignore_features=io.DATE_COLUMNS
        )
    )
    Processed_Realised_Demand_Germany, Processed_Realised_Demand_Germany_Scalers = (
        preprocess.normalize_data(
            df=Processed_Realised_Demand_Germany, ignore_features=io.DATE_COLUMNS
        )
    )
    Processed_Weather_Data_Germany, Processed_Weather_Data_Germany_Scalers = (
        preprocess.normalize_data(
            df=Processed_Weather_Data_Germany,
            ignore_features=io.DATE_COLUMNS_WEATHER + ["longitude", "latitude"],
        )
    )

    print("Remove train, test, val columns, again")
    Processed_Installed_Capacity_Germany = Processed_Installed_Capacity_Germany.drop(
        ["train", "val", "test"], axis=1
    )
    Processed_Prices_Europe = Processed_Prices_Europe.drop(
        ["train", "val", "test"], axis=1
    )
    Processed_Realised_Supply_Germany = Processed_Realised_Supply_Germany.drop(
        ["train", "val", "test"], axis=1
    )
    Processed_Realised_Demand_Germany = Processed_Realised_Demand_Germany.drop(
        ["train", "val", "test"], axis=1
    )
    Processed_Weather_Data_Germany = Processed_Weather_Data_Germany.drop(
        ["train", "val", "test"], axis=1
    )

## Inspect processed data

### Inspect missingness

In [None]:
# Generate profile reports

if CONF.data.inspect:
    data.save_data_inspection(
        Installed_Capacity_Germany=Processed_Installed_Capacity_Germany,
        Prices_Europe=Processed_Prices_Europe,
        Realised_Supply_Germany=Processed_Realised_Supply_Germany,
        Realised_Demand_Germany=Processed_Realised_Demand_Germany,
        Weather_Data_Germany=Processed_Weather_Data_Germany,
        CONF=CONF,
        data_type="preprocessed",
    )

### Plot processed data

In [None]:
if CONF.data.plot:
    plots.plot_df(
        Processed_Installed_Capacity_Germany, "Installed_Capacity_Germany", CONF
    )
    plots.plot_df(Processed_Prices_Europe, "Prices_Europe", CONF)
    plots.plot_df(Processed_Realised_Supply_Germany, "Realised_Supply_Germany", CONF)
    plots.plot_df(Processed_Realised_Demand_Germany, "Realised_Demand_Germany", CONF)
    plots.plot_df(
        Processed_Weather_Data_Germany,
        "Weather_Data_Germany",
        CONF,
        date_col=io.DATE_COLUMNS_WEATHER[-1],
        drop_date_cols=io.DATE_COLUMNS_WEATHER,
    )

### Inspect time's resolution

In [None]:
inspection.date_range_and_resolution_dfs(
    Installed_Capacity_Germany=Processed_Installed_Capacity_Germany,
    Prices_Europe=Processed_Prices_Europe,
    Realised_Supply_Germany=Processed_Realised_Supply_Germany,
    Realised_Demand_Germany=Processed_Realised_Demand_Germany,
    Weather_Data_Germany=Processed_Weather_Data_Germany,
    processed=True,
)

### Data Unit Tests

In [None]:
dfs = [
    Processed_Installed_Capacity_Germany,
    Processed_Prices_Europe,
    Processed_Realised_Supply_Germany,
    Processed_Realised_Demand_Germany,
]

raw_dfs = [
    Installed_Capacity_Germany_Raw,
    Prices_Europe_Raw,
    Realised_Supply_Germany_Raw,
    Realised_Demand_Germany_Raw,
]


# assert all data frames have same length.
def test_dataframe_lengths(dfs):
    expected_length = len(dfs[0])
    for df in dfs[1:]:
        assert len(df) == expected_length, "DataFrames have different lengths"


test_dataframe_lengths(dfs)


# assert all data frames have have same time resolution
def test_dataframe_resolutions(dfs):
    date_column = "Date from"
    expected_resolution = dfs[0][date_column].diff().dropna().mode()[0]
    for df in dfs[1:]:
        current_resolution = df[date_column].diff().dropna().mode()[0]
        assert (
            current_resolution == expected_resolution
        ), "DataFrames have different time resolutions"
    current_resolution = (
        Processed_Weather_Data_Germany["time"].diff().dropna().mode()[0]
    )
    assert (
        current_resolution == expected_resolution
    ), "DataFrames have different time resolutions"


test_dataframe_resolutions(dfs)


# Assert every row of every df has the same "Date to"
def test_date_to_consistency(dfs):
    date_to_values = dfs[0]["Date to"].values
    for i, df in enumerate(dfs[1:]):
        assert all(
            df["Date to"].values == date_to_values
        ), f"Mismatch in 'Date to' values across DataFrames {i + 1}"
    assert np.all(
        Processed_Weather_Data_Germany["time"].values == date_to_values
    ), "DataFrames have different time resolutions"


test_date_to_consistency(dfs)


# assert that not cell is missed
def test_no_missing_cells(dfs):
    for i, df in enumerate(dfs + [Processed_Weather_Data_Germany]):
        for column in df.columns:
            assert (
                df[column].isnull().sum() == 0
            ), f"Missing values found in column '{column}' of DataFrame at index {i}"


test_no_missing_cells(dfs)


# assert that raw dfs and normal dfs have same set of columns
def test_same_columns(raw_dfs, processed_dfs, ignore_columns=None):
    if ignore_columns is None:
        ignore_columns = set()

    for raw_df, processed_df in zip(raw_dfs, processed_dfs):
        raw_columns = set(raw_df.columns) - ignore_columns
        processed_columns = set(processed_df.columns) - ignore_columns

        assert (
            raw_columns == processed_columns
        ), f"Column mismatch between raw and processed DataFrames. Missing {raw_columns - processed_columns} or {processed_columns - raw_columns}"


# Example usage of the function, ignoring 'train', 'val', and 'test'
test_same_columns(
    raw_dfs,
    dfs,
    ignore_columns={
        "∅ Neighbouring DE/LU [€/MWh]",
        "Hungary [€/MWh]",
        "Poland [€/MWh]",
        "DE/AT/LU [€/MWh]",
    },
)

## Data loader for supply forecasting 

##### Merge data together to a single frame

In [None]:
def add_prefix(df, prefix):
    """
    Add a prefix to all column names except the merge key.
    """
    return df.rename(
        columns={
            col: prefix + col if col not in io.DATE_COLUMNS else col
            for col in df.columns
        }
    )


Suffix_Processed_Prices_Europe = add_prefix(Processed_Prices_Europe, "prices_")
Suffix_Processed_Installed_Capacity_Germany = add_prefix(
    Processed_Installed_Capacity_Germany, "capacity_"
)
Suffix_Processed_Realised_Supply_Germany = add_prefix(
    Processed_Realised_Supply_Germany, "supply_"
)
Suffix_Processed_Realised_Demand_Germany = add_prefix(
    Processed_Realised_Demand_Germany, "demand_"
)
Suffix_Processed_Weather_Data_Germany = add_prefix(
    Processed_Weather_Data_Germany, "weather_"
)
Suffix_Processed_Weather_Data_Germany = Suffix_Processed_Weather_Data_Germany.rename(
    columns={"weather_time": "Date to"}
)

# Now perform the merge
df = pd.merge(
    Suffix_Processed_Prices_Europe,
    Suffix_Processed_Installed_Capacity_Germany,
    on=io.DATE_COLUMNS,
    how="inner",
)
df = pd.merge(
    df, Suffix_Processed_Realised_Supply_Germany, on=io.DATE_COLUMNS, how="inner"
)
df = pd.merge(
    df, Suffix_Processed_Realised_Demand_Germany, on=io.DATE_COLUMNS, how="inner"
)
df = pd.merge(
    df, Suffix_Processed_Weather_Data_Germany, on=io.DATE_COLUMNS[-1], how="inner"
)
df.head()

In [None]:
if CONF.data.plot:
    plots.plot_df(df, "Final Dataframe", CONF, figsize=(150, 30))

##### Split train, val, test

In [None]:
df = preprocess.split_data(df=df, column_name=io.DATE_COLUMNS[-1])

##### Torch's datasets

In [None]:
import torch
import pandas as pd
from torch.utils.data import Dataset
from pipeline.config import CONF


class TimeSeriesDataset(Dataset):
    def __init__(self, df: pd.DataFrame, CONF):
        self.dataframe = df
        self.features = CONF.model.features
        self.targets = CONF.model.targets
        self.lag = CONF.model.lag
        self.horizons = CONF.model.horizons

        # Determine the maximum horizon to ensure all targets can be accessed
        self.max_horizon = max(self.horizons)
        # Calculate total samples considering the lag and the maximum forecast horizon
        self.total_samples = len(df) - (self.lag + self.max_horizon)

    def __len__(self):
        return self.total_samples

    def __getitem__(self, idx):
        # Adjust start index to accommodate the lag
        idx += self.lag

        # Collect inputs using the lag
        inputs = torch.tensor(
            self.dataframe.loc[
                [idx - lag for lag in range(self.lag)], self.features
            ].values.astype(float)
        )  # Shape: [n_batches, n_lag, n_features]
        # Collect targets for each horizon and each feature
        targets = torch.tensor(
            self.dataframe.loc[
                [idx + horizon for horizon in self.horizons], self.targets
            ].values.astype(float)
        )  # Shape will be [n_batches, n_horizons, n_outputs]
        return inputs.to(torch.float32), targets.to(torch.float32)

#### Torch's dataloaders

In [None]:
from pipeline.config import CONF

train_df = df[df["train"]].drop(CONF.model.ignore_columns, axis=1).reset_index()
val_df = df[df["val"]].drop(CONF.model.ignore_columns, axis=1).reset_index()
test_df = df[df["test"]].drop(CONF.model.ignore_columns, axis=1).reset_index()

train_dataset = TimeSeriesDataset(train_df, CONF=CONF)
val_dataset = TimeSeriesDataset(val_df, CONF=CONF)
test_dataset = TimeSeriesDataset(test_df, CONF=CONF)


train_loader = DataLoader(train_dataset, batch_size=CONF.train.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CONF.train.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=CONF.train.batch_size, shuffle=False)

## Training

#### Model's architecture

In [None]:
import torch
import torch.nn as nn
import math
from pipeline.config import CONF


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0)]
        return self.dropout(x)


class TimeSeriesTransformer(nn.Module):
    def __init__(self, CONF=CONF):
        super(TimeSeriesTransformer, self).__init__()
        num_layers = CONF.model.num_layers
        num_heads = CONF.model.num_heads
        forward_expansion = CONF.model.forward_expansion
        dropout = CONF.model.dropout
        output_horizons = CONF.model.horizons
        d_model = CONF.model.num_features * forward_expansion
        self.output_horizons = output_horizons

        # Linear transformation to project input features to a higher dimensional space
        self.feature_to_embedding = nn.Linear(CONF.model.num_features, d_model)

        self.positional_encoder = PositionalEncoding(d_model, dropout)

        # Transformer Encoder Layer
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            self.encoder_layer, num_layers=num_layers
        )

        # Output layer for each horizon and feature
        self.fc_out = nn.ModuleList(
            [nn.Linear(d_model, CONF.model.num_targets) for _ in output_horizons]
        )

    def forward(self, src):
        src = src.permute(
            1, 0, 2
        )  # Permute to (sequence_length, batch_size, num_features)
        src = self.feature_to_embedding(
            src
        )  # Map features to the higher dimensional space
        src = self.positional_encoder(src)
        transformed = self.transformer_encoder(src)

        last_output = transformed[-1]  # Use only the last output for forecasting

        outputs = [fc(last_output) for fc in self.fc_out]
        return torch.stack(outputs, dim=1)

#### Trainining loop

In [None]:
import torch
from torch.optim import Adam
import torch.nn as nn
from pipeline.config import CONF
from tqdm import tqdm

if CONF.train.do_train:
    # Assuming the model and dataset classes are already imported and configured
    model = TimeSeriesTransformer(CONF=CONF)
    optimizer = Adam(model.parameters(), lr=CONF.train.lr)
    criterion = (
        nn.MSELoss()
    )  # You can change the loss function based on your specific needs

    # Training loop
    def train(model, train_loader, optimizer, criterion, device):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc="Training", leave=False)
        for inputs, targets in progress_bar:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            progress_bar.set_postfix({"batch_loss": f"{loss.item():.4f}"})
        progress_bar.close()
        return total_loss / len(train_loader)

    # Validation loop
    def validate(model, val_loader, criterion, device):
        model.eval()
        total_loss = 0
        progress_bar = tqdm(val_loader, desc='Validation', leave=False)
        with torch.no_grad():
            for inputs, targets in progress_bar:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                total_loss += loss.item()
                progress_bar.set_postfix({'batch_loss': f'{loss.item():.4f}'})
        progress_bar.close()
        return total_loss / len(val_loader)

    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Main training loop
    num_epochs = CONF.train.epochs
    for epoch in range(num_epochs):
        train_loss = train(model, train_loader, optimizer, criterion, device)
        val_loss = validate(model, val_loader, criterion, device)
        print(
            f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}"
        )

    # Save the model
    torch.save(model.state_dict(), "model_weights.pth")

## Evaluation