# Training the SWEFormer Model

In [None]:
# Standard library
import os
import re
import time
import math
import pickle
import importlib
from typing import Any, Optional, Tuple
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.nn import MultiheadAttention, Linear, Dropout, BatchNorm1d, TransformerEncoderLayer
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from scipy.io import loadmat
from scipy.stats import pearsonr
from sklearn.model_selection import train_test_split

# Local application imports
import model_builder
import optimizer_code
import dataloader
import train
import utils
importlib.reload(utils)
importlib.reload(train)
importlib.reload(dataloader)
importlib.reload(model_builder)

In [None]:
!pip install gdown

In [None]:
import gdown

# File ID from the Google Drive link
file_id = '1XgHYFKpmK0y7xfXNNdOc87h52_X2qaRS'

# Create the download URL
url = f'https://drive.google.com/uc?export=download&id={file_id}'

# Download the file and save it as 'Training_data_ERA5_datasets.pkl'
gdown.download(url, 'Training_data_ERA5_datasets.pkl', quiet=False)

In [None]:
with open('Datasets/Training_data_ERA5_datasets.pkl', 'rb') as f:
    datasets = pickle.load(f)

# Access the datasets and variables like this:
X_tbh_tbv_emiss = datasets["X_tbh_tbv_emiss"]
X_tbh_tbv = datasets["X_tbh_tbv"]
X_air_temp = datasets["X_air_temp"]
X_ground_temp = datasets["X_ground_temp"]
land_cover = datasets["land_cover"]
water_frac = datasets["water_frac"]
peak_swe = datasets["peak_swe"]
swe = datasets["swe"]
elev = datasets["elev"]

## Creating Dataloader For Training

In [None]:
train_loader, test_loader = dataloader.prepare_training_dataloader(X_tbh_tbv_emiss, X_air_temp, X_ground_temp, land_cover, 
                                water_frac, peak_swe, train_ratio=0.8, batch_size=64, elev=elev)

# Get the first batch using iter
data, cat, water, elev, targets = next(iter(test_loader))

print("Data shape:", data.shape)         # Shape of the input data
print("Categorical input shape:", cat.shape)   # Shape of the categorical input
print("elev input shape:", elev.shape)   # Shape of the categorical input
print("Targets shape:", targets.shape)   # Shape of the targets
print("Water_frac shape:", water.shape)   # Shape of the targets

print(len(test_loader))

## Model Building: SWEFormer

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model_builder.SWEFormer(feat_dim=4,max_len = 218, d_model = 16, n_heads=4, num_layers=3, 
                                                            num_categories=13, cat_embed_dim=8, water_frac_dim = 1, 
                                                            elev_dim = 1, num_classes=1, output_concat_switch = False).to(device)
criterion, optimizer, scheduler = optimizer_code.setup_training(model, loss_type='mse', lr=0.00005, weight_decay=0.01, 
                                                                betas=(0.9, 0.999), total_steps=100, gamma=0.95)

<div style="background-color: #fff3cd; border: 1px solid #ffeeba; padding: 10px; color: #856404; font-weight: bold;">
⚠️ <strong>WARNING:</strong> Training the model may take several hours on a GPU. Therefore, please utilize the pre-trained model and skip running the following cell in order to directly view and analyze the content.
</div>


In [None]:
train.train_model_elev(model=model,
            train_dataloader=train_loader,
            val_dataloader=test_loader,
            optimizer=optimizer,
            scheduler=scheduler,
            criterion=criterion,
            epochs=100,
            device=device,
            early_stopping_patience=10,  # Stop if no improvement in 10 epochs
            min_delta=0.0001,  # Minimum improvement needed to reset early stopping
            model_save_path='Trained Model/ERA5_trained_model.pth')  # Path to save the best model