# Imports and Data

In [67]:
# Third-party

import xarray as xr

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch import Trainer


# Local imports

from weather_data_class import WeatherData
from torch_weather_model import TorchWeatherModel

ds = xr.open_dataset('data_850/2022_850_SA_coarsen.nc')
ds.load()

# Data setup

In [68]:
# Third-party imports
import xarray as xr
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from datetime import datetime

import cartopy
import cartopy.crs as ccrs

from torch.utils.data import DataLoader, Dataset

from IPython.display import HTML

import torch

from typing import Tuple

class WeatherData_Test(Dataset):

    """
    A dataset class for preparing wind speed data for machine learning models.

    Attributes:
        dataset (xr.Dataset): The xarray dataset containing wind speed data.
        window_size (int): The size of the window for creating features.
        steps (int): The number of forecasting steps.
        use_forcings (bool): Flag to indicate whether to use forcings.
        features (np.ndarray): Array of feature data.
        targets (np.ndarray): Array of target data.
        forcings (np.ndarray): Array of forcing data.
        time_values (np.ndarray): Array of time values corresponding to features.
        min_value (float): Minimum wind speed value for normalization.
        max_value (float): Maximum wind speed value for normalization.
        mean_value (float): Mean wind speed value for normalization.
        std_value (float): Standard deviation of wind speed for normalization.
        X_train (np.ndarray): Training features.
        X_test (np.ndarray): Testing features.
        y_train (np.ndarray): Training targets.
        y_test (np.ndarray): Testing targets.
        F_train (np.ndarray): Training forcings.
        F_test (np.ndarray): Testing forcings.
        X_train_t (torch.Tensor): Normalized training features as tensors.
        y_train_t (torch.Tensor): Normalized training targets as tensors.
        X_test_t (torch.Tensor): Normalized testing features as tensors.
        y_test_t (torch.Tensor): Normalized testing targets as tensors.
        F_train_t (torch.Tensor): Training forcings as tensors.
        F_test_t (torch.Tensor): Testing forcings as tensors.
    """

    def __init__(self, dataset: xr.Dataset, window_size: int = 24, steps: int = 3, auto: bool = True, use_forcings: bool = True, intervals: int = 1, data_split: str = 'train'):

        """
        Initializes the WeatherData object.

        Args:
            dataset (xr.Dataset): The xarray dataset containing wind speed data.
            window_size (int): The size of the window for creating features. Default is 24.
            steps (int): The number of forecasting steps. Default is 3.
            auto (bool): Flag to automatically window and normalize data. Default is False.
            use_forcings (bool): Flag to indicate whether to use forcings. Default is False.
        """
        
        self.dataset = dataset
        self.window_size = window_size
        self.steps = steps
        self.calculate_wind_speed()
        self.dataset = self.dataset.sortby('latitude')

        self.min_value = self.dataset.wspd.min().item()
        self.max_value = self.dataset.wspd.max().item()

        self.mean_value = self.dataset.wspd.mean().item()
        self.std_value = self.dataset.wspd.std().item()

        self.use_forcings = use_forcings

        self.model = None

        # MLP input size
        self.input_size = self.window_size * self.dataset.latitude.size * self.dataset.longitude.size
        self.forcing_size = 2  
        self.output_size = 1 * self.dataset.latitude.size * self.dataset.longitude.size 

        self.data_split = data_split

        if auto:
            if intervals > 1:
                self.time_intervals(intervals)
            self.split_data()    
            self.normalize_data()    

    def __len__(self) -> int:
        """
        Returns the number of samples based on how many windows of size `window_size + steps`
        can fit into the dataset for the specified split.

        Returns:
            int: The number of valid windows that can fit into the specified dataset split.
        """
        if self.data_split == 'train':
            dataset_length = len(self.X_train)
        elif self.data_split == 'val':
            dataset_length = len(self.X_val)
        elif self.data_split == 'test':
            dataset_length = len(self.X_test)
        else:
            raise ValueError("data_split must be 'train', 'val', or 'test'")
        
        # Calculate how many windows of size `window_size + steps` fit into the dataset
        total_window_size = self.window_size + self.steps
        num_windows = dataset_length - total_window_size + 1  # Ensure correct fit
        
        return max(0, num_windows)  # Ensure the result is not negative

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Retrieves a sample from the specified dataset split.

        Args:
            idx (int): The index of the sample to retrieve.
            data_split (str): The dataset split ('train', 'val', 'test'). Default is 'train'.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing features, forcings, and target.
        """
        if self.data_split == 'train':
            return self.X_train_t[idx:idx+self.window_size], self.F_train_t[idx + self.window_size], self.X_train_t[idx + self.window_size:idx + self.window_size + self.steps]
        elif self.data_split == 'val':
            return self.X_val_t[idx:idx+self.window_size], self.F_val_t[idx + self.window_size], self.X_val_t[idx + self.window_size:idx + self.window_size + self.steps]
        elif self.data_split == 'test':
            return self.X_test_t[idx:idx+self.window_size], self.F_test_t[idx + self.window_size], self.X_test_t[idx + self.window_size:idx + self.window_size + self.steps]
        else:
            raise ValueError("data_split must be 'train', 'val', or 'test'")

    def time_intervals(self, intervals: int = 3) -> None:
    
        """
        Subsets the dataset based on the specified time intervals. Only happens once and then dataset is saved as a .nc file.

        Args:
            intervals (int): The time intervals for subsetting. Default is 3.

        Returns:
            None: Updates the dataset in place.
        """

        self.dataset = self.dataset.sel(time=slice(None, None, intervals))

    def subset_data(self, coarsen: int = 1) -> None:

        """
        Subsets the dataset based on the specified coarsening factor. Only happens once and then dataset is saved as a .nc file.

        Args:
            coarsen (int): The coarsening factor for subsetting. Default is 1.

        Returns:
            None: Updates the dataset in place.
        """

        if coarsen > 1:
            lat_slice = slice(1, 33, coarsen)
            lon_slice = slice(3, 67, coarsen)
        else:
            lat_slice = slice(1, 33)  
            lon_slice = slice(3, 67)

        self.dataset = self.dataset.isel(latitude=lat_slice, longitude=lon_slice)

    def calculate_wind_speed(self) -> None:
        """
        Calculates wind speed from u and v components and adds it to the dataset.

        Returns:
            None: Updates the dataset in place with the wind speed variable.
        """

        self.dataset['wspd'] = np.sqrt(self.dataset.u**2 + self.dataset.v**2).astype(np.float32)
        self.dataset.attrs['wspd_units'] = 'm/s'
        # self.dataset['wdir'] = np.arctan2(self.dataset.v, self.dataset.u) * 180 / np.pi
        # self.dataset.attrs['wdir_units'] = 'degrees'

    def split_data(self, test_size: float = 0.1, val_size: float = 0.2, random_state: int = 42) -> None:
        """
        Splits the data into training, validation, and testing sets.

        Args:
            test_size (float): Proportion of the dataset to include in the test split. Default is 0.2.
            val_size (float): Proportion of the dataset to include in the validation split from the training set. Default is 0.1.
            random_state (int): Random seed for reproducibility. Default is 42.

        Returns:
            None: Updates the instance attributes with training, validation, and testing sets.
        """
        
        
        # Create a numpy array of the dataset
        data = self.dataset.wspd.values
        forcings = np.stack([self.dataset.time.dt.hour.values, self.dataset.time.dt.month.values], axis=-1)
        time_values = self.dataset.time.values

        # Split the data into train, validation, and test sets

        self.X_train, self.X_test, self.F_train, self.F_test, self.T_train, self.T_test = train_test_split(data, forcings, time_values, test_size=test_size, shuffle=False)

        self.X_train, self.X_val, self.F_train, self.F_val, self.T_train, self.T_val = train_test_split(self.X_train, self.F_train, self.T_train, test_size=val_size, shuffle=False)

    def normalize_data(self) -> None:
        """
        Normalizes the training, validation, and testing data using mean and standard deviation.

        Returns:
            None: Updates the instance attributes with normalized data as tensors.
        """

        self.X_train_t = (self.X_train - self.mean_value) / self.std_value

        self.X_val_t = (self.X_val - self.mean_value) / self.std_value

        self.X_test_t = (self.X_test - self.mean_value) / self.std_value

        # Convert to tensors
        self.X_train_t = torch.tensor(self.X_train_t).float()

        self.X_val_t = torch.tensor(self.X_val_t).float()

        self.X_test_t = torch.tensor(self.X_test_t).float()

        # Convert forcings to tensors
        self.F_train_t = torch.tensor(self.F_train).float()
        self.F_val_t = torch.tensor(self.F_val).float()
        self.F_test_t = torch.tensor(self.F_test).float()

    def plot_from_data(self, seed: int = 0, frame_rate: int = 16, levels: int = 10) -> HTML:
        """
        Plots features and targets from the windowed arrays for visualization.

        Args:
            seed (int): Seed for reproducibility in selecting samples. Default is 0.
            frame_rate (int): The frame rate for the animation. Default is 16.
            levels (int): Number of contour levels for the plot. Default is 10.

        Returns:
            HTML: An HTML object representing the animation.
        """
        bounds = [self.dataset.longitude.min().item(), self.dataset.longitude.max().item(), self.dataset.latitude.min().item(), self.dataset.latitude.max().item()]
        features = self.X_test[seed:seed + self.window_size]
        targets = self.X_test[seed + self.window_size:seed + self.window_size + self.steps]
        time_features = self.T_test[seed:seed + self.window_size]
        time_targets = self.T_test[seed + self.window_size:seed + self.window_size + self.steps]

        time_features = pd.to_datetime(time_features)
        time_targets = pd.to_datetime(time_targets)

        fig, axs = plt.subplots(1, 2, figsize=(21, 7), subplot_kw={'projection': ccrs.PlateCarree()})

        vmin = min(features.min().item(), targets.min().item())
        vmax = max(features.max().item(), targets.max().item())

        fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.2)

        for ax in axs:
            ax.set_extent(bounds, crs=ccrs.PlateCarree())
            ax.coastlines()


        feat = axs[0].contourf(self.dataset.longitude, self.dataset.latitude, features[0], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
        tar = axs[1].contourf(self.dataset.longitude, self.dataset.latitude, targets[0], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
        axs[1].set_title('Target')

        fig.colorbar(feat, ax=axs[0], orientation='vertical', label='Wind Speed (m/s)')
        fig.colorbar(tar, ax=axs[1], orientation='vertical', label='Wind Speed (m/s)')

        def animate(i):
            axs[0].clear()
            axs[0].coastlines()

            axs[0].contourf(self.dataset.longitude, self.dataset.latitude, features[i], levels=levels, vmin=vmin, vmax = vmax)

            axs[0].set_title(f'Window {i} - {time_features[i].strftime("%Y-%m-%d %H:%M:%S")}')
            if self.steps > 1:
                axs[1].contourf(self.dataset.longitude, self.dataset.latitude, targets[i % self.steps], levels=levels, vmin=vmin, vmax = vmax)
                axs[1].set_title(f'Target - {time_targets[i % self.steps].strftime("%Y-%m-%d %H:%M:%S")}')
            # return pcm

            
        frames = features.shape[0]

        interval = 1000 / frame_rate

        ani = FuncAnimation(fig, animate, frames=frames, interval=interval)

        plt.close(fig)

        return HTML(ani.to_jshtml())
    
    def assign_model(self, model: nn.Module) -> None:
        """
        Assign a model to the class instance.

        Args:
            model (nn.Module): A PyTorch model to assign for training and prediction.
        """
        self.model = model

    def load_model(self, file_path: str) -> None:
        """
        Load a model from a file.

        Args:
            file_path (str): Path to load the model from.
        """

        self.model.load_state_dict(torch.load(file_path, map_location=self.device, weights_only=True))
        self.model.to(self.device)
        self.model.eval()

    def predict(self, X: torch.Tensor, F: torch.Tensor) -> np.ndarray:
        """
        Predict output based on input data.

        Args:
            X (torch.Tensor): Input data for prediction.
            F (torch.Tensor): Forcings data, such as hour and month (if used).

        Returns:
            np.ndarray: Model predictions.
        """

        self.model.eval()
        with torch.no_grad():
            X = torch.tensor(X).float()
            F = torch.tensor(F).float()
            if self.use_forcings:
                return self.model(X, F).numpy()
            else:
                return self.model(X).numpy()

    def autoregressive_predict(self, X: torch.Tensor, F: torch.Tensor, rollout_steps: int, unnormalize: bool = True, verbose: bool = False) -> np.ndarray:
        """
        Perform autoregressive predictions for multiple time steps.

        Args:
            X (torch.Tensor): Input data for prediction.
            F (torch.Tensor): Forcings data, such as hour and month.
            rollout_steps (int): Number of future steps to predict.
            unnormalize (bool): Whether to unnormalize the predictions. Default is True.
            verbose (bool): Whether to print intermediate shapes for debugging. Default is False.

        Returns:
            np.ndarray: Predictions for each time step.
        """

        self.model.eval()
        with torch.no_grad():
            
            # X = torch.tensor(X).float()
            F = torch.tensor(F).float()
            
            predictions = []

            current_input = X#.to(self.device)
            current_F = F#.to(self.device)
            
            for step in range(rollout_steps):
                
                if self.use_forcings:
                    next_pred = self.model(current_input, current_F).cpu().numpy()
                else:
                    try:
                        next_pred = self.model(current_input).cpu().numpy()
                    except:
                        next_pred = self.model(current_input).numpy()
                
                predictions.append(next_pred)
                
                next_pred_tensor = torch.tensor(next_pred).float()#.to(self.device) 

                if verbose:
                    print(current_input.shape, next_pred_tensor.shape)

                current_input = torch.cat((current_input[:, 1:], next_pred_tensor), dim=1)#.to(self.device)

                hour = current_F[0, 0].item()  # Extract the hour
                month = current_F[0, 1].item()  # Extract the month
                
                hour += 1
                if hour == 24:
                    hour = 0
                
                current_F = torch.tensor([[hour, month]]).float()#.to(self.device)

            predictions = np.array(predictions).reshape(rollout_steps, self.dataset.sizes['latitude'], self.dataset.sizes['longitude'])

            # Unnromalize the predictions
            if unnormalize:
                predictions = predictions * self.std_value + self.mean_value
            
            return predictions
        
    def plot_pred_target(self, seed: int = 0, frame_rate: int = 16, levels: int = 10) -> HTML:
        """
        Plot the predictions and targets with animations.

        Args:
            seed (int): Seed to select the test data for plotting. Default is 0.
            frame_rate (int): Frame rate for animation. Default is 16.
            levels (int): Number of contour levels for plots. Default is 10.

        Returns:
            HTML: An HTML object containing the animation of predictions and targets.
        """

        bounds = [self.dataset.longitude.min().item(), self.dataset.longitude.max().item(), self.dataset.latitude.min().item(), self.dataset.latitude.max().item()]
        targets = self.X_test[seed + self.window_size:seed + self.window_size + self.steps]
        time_values = self.T_test[seed + self.window_size:seed + self.window_size + self.steps]

        time_values = pd.to_datetime(time_values)

        predictions = self.autoregressive_predict(self.X_test_t[seed:seed + self.window_size].unsqueeze(0), self.F_test_t[seed + self.window_size].unsqueeze(0), self.steps)

        fig, axs = plt.subplots(2, 3, figsize=(21, 7), subplot_kw={'projection': ccrs.PlateCarree()})

        vmin = min(predictions.min().item(), targets.min().item())
        vmax = max(predictions.max().item(), targets.max().item())

        fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.2)

        for ax in axs.flatten()[:-1]:
            ax.set_extent(bounds, crs=ccrs.PlateCarree())
            ax.coastlines()

        ax_last = fig.add_subplot(2, 3, 6)

        pred = axs[0, 0].contourf(self.dataset.longitude, self.dataset.latitude, predictions[0], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
        tar = axs[0, 1].contourf(self.dataset.longitude, self.dataset.latitude, targets[0], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())

        error = (predictions[0] - targets[0,0].squeeze()) 

        err = axs[0, 2].contourf(self.dataset.longitude, self.dataset.latitude, error.squeeze(), levels=levels, transform=ccrs.PlateCarree(), cmap='coolwarm')

        perc_error = error / targets[0,0].squeeze() * 100
        perc_error = np.clip(perc_error, -100, 100)
        rmse = np.sqrt(error**2)

        perr = axs[1, 0].contourf(self.dataset.longitude, self.dataset.latitude, perc_error, levels=levels, transform=ccrs.PlateCarree(), cmap='coolwarm')
        rms = axs[1, 1].contourf(self.dataset.longitude, self.dataset.latitude, rmse, levels=levels, transform=ccrs.PlateCarree(), cmap='coolwarm')
        ax_last.scatter(targets[0].flatten(), predictions[0].flatten(), c=error, cmap='coolwarm')

        fig.colorbar(pred, ax=axs[0, 0], orientation='vertical', label='Wind Speed (m/s)')
        fig.colorbar(tar, ax=axs[0, 1], orientation='vertical', label='Wind Speed (m/s)')
        fig.colorbar(err, ax=axs[0, 2], orientation='vertical', label='Percentage Error (%)')
        fig.colorbar(perr, ax=axs[1, 0], orientation='vertical', label='Percentage Error (%)')
        fig.colorbar(rms, ax=axs[1, 1], orientation='vertical', label='Root Mean Squared Error (m/s)')

        ax_last.set_xlabel("Observed Wind Speed (m/s)")
        ax_last.set_ylabel("Forecasted Wind Speed (m/s)")

        def animate(i):
            for ax in axs.flatten()[:-1]:
                ax.clear()
                ax.coastlines()
            
            ax_last.clear()
            ax_last.set_xlabel("Observed Wind Speed (m/s)")
            ax_last.set_ylabel("Forecasted Wind Speed (m/s)")

            axs[0, 0].contourf(self.dataset.longitude, self.dataset.latitude, predictions[i], levels=levels, vmin=vmin, vmax = vmax)
            axs[0, 1].contourf(self.dataset.longitude, self.dataset.latitude, targets[i], levels=levels, vmin=vmin, vmax = vmax)
            
            error =  (predictions[i] - targets[i].squeeze())
            axs[0, 2].contourf(self.dataset.longitude, self.dataset.latitude, error, levels=levels, transform=ccrs.PlateCarree(), cmap='coolwarm')
            
            perc_error = error / targets[i % self.steps].squeeze() * 100
            perc_error = np.clip(perc_error, -100, 100)
            rmse = np.sqrt(error**2)

            axs[1, 0].contourf(self.dataset.longitude, self.dataset.latitude, perc_error, levels=levels, transform=ccrs.PlateCarree(), cmap='coolwarm')
            axs[1, 1].contourf(self.dataset.longitude, self.dataset.latitude, rmse, levels=levels, transform=ccrs.PlateCarree(), cmap='coolwarm')
            ax_last.scatter(targets[i].flatten(), predictions[i].flatten(), c=error, cmap='coolwarm')

            axs[0, 0].set_title(f'Prediction {i} - {time_values[i].strftime("%Y-%m-%d %H:%M:%S")}')  
            axs[0, 1].set_title(f'Target - {time_values[i].strftime("%Y-%m-%d %H:%M:%S")}')
            axs[0, 2].set_title(f'Error - {time_values[i].strftime("%Y-%m-%d %H:%M:%S")}')
            axs[1, 0].set_title(f'Percentage Error - {time_values[i].strftime("%Y-%m-%d %H:%M:%S")}')
            axs[1, 1].set_title(f'Root Mean Squared Error - {time_values[i].strftime("%Y-%m-%d %H:%M:%S")}')
            ax_last.set_title(f'Error Scatter Plot - {time_values[i].strftime("%Y-%m-%d %H:%M:%S")}')

        frames = predictions.shape[0]

        interval = 1000 / frame_rate

        ani = FuncAnimation(fig, animate, frames=frames, interval=interval)

        plt.close(fig)

        return HTML(ani.to_jshtml())


# Model Setup

In [75]:
class SimpleMLP(L.LightningModule):
    def __init__(self, input_size, forcing_size, output_size, steps = 1, lat = 32, lon = 64):
        super(SimpleMLP, self).__init__()
        self.save_hyperparameters()

        self.fc1 = nn.Linear(input_size + forcing_size, 1024)  
        self.fc2 = nn.Linear(1024, 512) 
        self.fc3 = nn.Linear(512, output_size) 

        self.loss_fn = nn.MSELoss()

        self.steps = steps

        self.lat = lat
        self.lon = lon

    def forward(self, X, F):

        batch_size = X.size(0)
        X = X.view(batch_size, -1)  

        inputs = torch.cat((X, F), dim=1)  

        x = torch.relu(self.fc1(inputs)) 
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)  

        return x.view(-1, 1, 16, 32)
    
    def training_step(self, batch, batch_idx):
        x, F, y = batch  

        loss = self.auto_rollout(x, F, y)

        self.log("train_loss", loss, on_step=True, on_epoch=True) 
        return loss

    def validation_step(self, batch, batch_idx):
        x, F, y = batch  
 
        loss = self.auto_rollout(x, F, y)

        self.log("val_loss", loss, on_step=False, on_epoch=True) 
        return loss
    
    def auto_rollout(self, x, F, y):
        
        cumulative_loss = 0.0
        current_input = x.clone()
        current_F = F.clone()


        for step in range(self.steps): 
            print('Step: ', step)
            
            y_hat = self(current_input, current_F)

            loss = self.loss_fn(y_hat, y[:, step].reshape(-1, 1, self.lat, self.lon))
            cumulative_loss += loss  
            
            current_input = torch.cat((current_input[:, 1:], y_hat), dim=1)

            hour = current_F[:, 0]
            month = current_F[:, 1]
            
            hour = (hour + 1) % 24

            current_F = torch.stack((hour, month), dim=1).float()

        return cumulative_loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=0.002)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True)
        
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}}

# Training

In [76]:
window_size = 5
steps = 5
batch_size = 128
intervals = 3

model_class = WeatherData_Test(ds, 
                               window_size=window_size, 
                               steps=steps, 
                               intervals=intervals)

In [77]:
ar_steps = 1

model = SimpleMLP(model_class.input_size, 
                  model_class.forcing_size, 
                  model_class.output_size,
                  steps=ar_steps)

if ar_steps > 1:
    model = SimpleMLP.load_from_checkpoint('checkpoints/best_model_1_steps.ckpt')

train_loader = DataLoader(WeatherData_Test(ds, 
                                            window_size=window_size, 
                                            steps=ar_steps, 
                                            intervals=intervals, 
                                            data_split='train'), 
                                            batch_size=batch_size, 
                                            shuffle=True)

val_loader = DataLoader(WeatherData_Test(ds,
                                            window_size=window_size,
                                            steps=ar_steps,
                                            intervals=intervals,
                                            data_split='val'), 
                                            batch_size=batch_size, 
                                            shuffle=False)

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",      
    dirpath="checkpoints/",  
    filename=f"best_model_{ar_steps}_steps",
    save_top_k=1,            
    mode="min"              
)

early_stopping_callback = EarlyStopping(
    monitor="val_loss",      
    patience=10,             
    mode="min",              
    verbose=True
)


trainer = Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback, early_stopping_callback]
)

trainer.fit(model, train_loader, val_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
c:\Users\23603526\AppData\Local\Programs\Python\Python311\Lib\site-packages\lightning\pytorch\callbacks\model_checkpoint.py:653: Checkpoint directory C:\Users\23603526\Documents\GitHub\1_WindSpeedForecasting\checkpoints exists and is not empty.

  | Name    | Type    | Params
------------------------------------
0 | fc1     | Linear  | 2.6 M 
1 | fc2     | Linear  | 524 K 
2 | fc3     | Linear  | 262 K 
3 | loss_fn | MSELoss | 0     
------------------------------------
3.4 M     Trainable params
0         Non-trainable params
3.4 M     Total params
13.648    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\23603526\AppData\Local\Programs\Python\Python311\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
  return F.mse_loss(input, target, reduction=self.reduction)


Step:  0


RuntimeError: The size of tensor a (32) must match the size of tensor b (64) at non-singleton dimension 3

In [72]:
model_class.model = model

seed = 74

model_class.plot_pred_target(seed=seed, frame_rate=4, levels=10)

  F = torch.tensor(F).float()
