# U-Cast Tutorial
Set-up instructions: this notebook give a tutorial on the high-dimensional time series forecasting task supported by U-Cast.

1. Install Python 3.10. For convenience, execute the following command.

In [None]:
pip install -r requirements.txt

or

In [None]:
conda env create -f environment.yaml

2. Package Import

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

3. U-Cast Construction

U-Cast is an efficient model that captures channel correlations via learning latent hierarchical structures. U-Cast also introduces a full-rank regularization
term to encourage disentanglement and improve the learning of structured representations.

In the following section, we will have a detailed view on U-Cast. To make it clearer, please see the figures below.

![U-Cast](../pic/U-Cast.png)

U-Cast consist of three main component (1) downsampling (2) full-rank regularization (3) upsampling

Downsampling (HierarchicalLatentQueryNetwork)

In [None]:
class HierarchicalLatentQueryNetwork(nn.Module):
    def __init__(self, orig_channels, time_dim, num_layers, head_dim, reduction_ratio=16, num_heads=1, dropout=0.1):
        super(HierarchicalLatentQueryNetwork, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        latent_dims = []
        current_channels = orig_channels
        for _ in range(num_layers):
            new_channels = max(1, current_channels // reduction_ratio)
            latent_dims.append(new_channels)
            current_channels = new_channels

        self.latent_dims = latent_dims
        current_in_dim = time_dim
        for latent_dim in latent_dims:
            self.layers.append(LatentQueryAttention(current_in_dim, latent_dim, head_dim, num_heads, dropout))
            current_in_dim = head_dim * num_heads
        self.norm_layers = nn.ModuleList([nn.LayerNorm(head_dim * num_heads) for _ in latent_dims])

    def forward(self, x, return_attn=False):
        B, T, C = x.shape
        x_base = x.transpose(1, 2)  # [B, C, T]
        skip_list = [x_base]
        x_down = x_base
        attn_maps = [] if return_attn else None
        for layer, norm in zip(self.layers, self.norm_layers):
            if return_attn:
                x_down, attn = layer(x_down, return_attn=True)
                attn_maps.append(attn.detach().cpu())
            else:
                x_down = layer(x_down)
            x_down = norm(x_down)
            skip_list.append(x_down)
        if return_attn:
            return skip_list[-1], skip_list, attn_maps
        else:
            return skip_list[-1], skip_list

In [None]:
class LatentQueryAttention(nn.Module):
    def __init__(self, in_dim, latent_dim, head_dim, num_heads=1, dropout=0.1):
        super(LatentQueryAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.latent_dim = latent_dim

        self.latent_queries = nn.Parameter(torch.randn(latent_dim, head_dim * num_heads))
        self.q_proj = nn.Linear(head_dim * num_heads, head_dim * num_heads)
        self.k_proj = nn.Linear(in_dim, head_dim * num_heads)
        self.v_proj = nn.Linear(in_dim, head_dim * num_heads)
        self.out_proj = nn.Linear(head_dim * num_heads, head_dim * num_heads)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, return_attn=False):
        B, L, _ = x.shape
        queries = self.latent_queries.unsqueeze(0).expand(B, -1, -1)
        queries = self.q_proj(queries)
        keys = self.k_proj(x)
        values = self.v_proj(x)
        queries = queries.view(B, self.latent_dim, self.num_heads, self.head_dim).transpose(1, 2)
        keys = keys.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        values = values.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        out = torch.matmul(attn, values)
        out = out.transpose(1, 2).contiguous().view(B, self.latent_dim, self.num_heads * self.head_dim)
        out = self.out_proj(out)
        if return_attn:
            return out, attn
        else:
            return out

Full-rank regularization

In [None]:
def covariance_loss(skip_list, lambda_cov=0.1, eps=1e-5):
    """
    skip_list: list of tensors, each tensor is of shape [B, C, D]
    Computes a normalized sum of negative log-determinants of covariance matrices.
    """
    total_loss = 0.0
    num_layers = len(skip_list) - 1  # exclude input
    for x in skip_list[1:]:
        B, C, D = x.shape
        x_reshaped = x.reshape(B * C, D)
        x_centered = (x_reshaped - x_reshaped.mean(dim=0, keepdim=True)) / (
            x_reshaped.std(dim=0, keepdim=True) + eps
        )
        cov = (x_centered.T @ x_centered) / (B * C - 1)
        cov = cov + eps * torch.eye(D, device=x.device, dtype=x.dtype)

        # normalize by dimension (D) to reduce scale variance
        loss = -torch.logdet(cov) / D
        total_loss += loss

    return lambda_cov * (total_loss / num_layers if num_layers > 0 else 0.0)

Upsampling

In [None]:
class HierarchicalUpsamplingNetwork(nn.Module):
    def __init__(self, num_layers, q_in_dim, head_dim, num_heads=1, dropout=0.1):
        super(HierarchicalUpsamplingNetwork, self).__init__()
        self.layers = nn.ModuleList([
            UpLatentQueryAttention(q_in_dim, head_dim, num_heads, dropout) for _ in range(num_layers)
        ])
        self.norms = nn.ModuleList([
            nn.LayerNorm(q_in_dim) for _ in range(num_layers)
        ])

    def forward(self, x_bottom, skip_list):
        rev = list(reversed(skip_list))
        queries = rev[1:]
        x = x_bottom
        for layer, norm, query in zip(self.layers, self.norms, queries):
            x = norm(layer(query, x) + query)
            # x = layer(query, x)
        return x

In [None]:
For more details, please read the our paper (link: https://arxiv.org/pdf/2507.15119)

4. U-Cast

In [None]:
@register_model("UCast", paper="U-Cast: Learning Latent Hierarchical Channel Structure for High-Dimensional Time Series Forecasting", year=2024)
class Model(nn.Module):
    def __init__(self, configs):

    def forecast(self, x_enc):

    def forward(self, x_enc):

First of all, let us focus on __init__(self, configs)

In [None]:
def __init__(self, configs):
    super(Model, self).__init__()
    self.task_name = configs.task_name
    self.seq_len = configs.seq_len
    self.pred_len = configs.pred_len
    self.enc_in = configs.enc_in
    self.d_model = configs.d_model
    self.alpha = configs.alpha

    self.input_proj = nn.Linear(self.seq_len, self.d_model)
    self.output_proj = nn.Linear(self.d_model, self.pred_len)

    self.channel_reduction_net = HierarchicalLatentQueryNetwork(
        orig_channels=self.enc_in,
        time_dim=self.d_model,
        num_layers=configs.e_layers,
        head_dim=self.d_model,
        reduction_ratio=configs.channel_reduction_ratio,
        num_heads=1,
        dropout=configs.dropout
    )

    self.upsample_net = HierarchicalUpsamplingNetwork(
        num_layers=configs.e_layers,
        q_in_dim=self.d_model,
        head_dim=self.d_model,
        num_heads=1,
        dropout=configs.dropout
    )

    self.predict_layer = nn.Linear(self.d_model, self.d_model)

In [None]:
Then, let's focus on forecast(self, x_enc)

In [None]:
def forecast(self, x_enc):
    means = x_enc.mean(1, keepdim=True).detach()
    x_enc = x_enc - means
    stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
    x_enc = x_enc / stdev

    x_enc = x_enc.transpose(1, 2)  # [B, C, T]
    x_enc = self.input_proj(x_enc)  # [B, C, d_model]
    x_enc = x_enc.transpose(1, 2)  # [B, d_model, C]
    x_bottom, skip_list = self.channel_reduction_net(x_enc)
    cov_loss = covariance_loss(skip_list, self.alpha)
    x_bottom = self.predict_layer(x_bottom)
    x_up = self.upsample_net(x_bottom, skip_list)
    dec_out = self.output_proj(x_up + x_enc.transpose(1, 2))  # [B, enc_in, pred_len]
    dec_out = dec_out.transpose(1, 2)  # [B, pred_len, enc_in]

    dec_out = dec_out * stdev[:, 0, :].unsqueeze(1)
    dec_out = dec_out + means[:, 0, :].unsqueeze(1)
    return dec_out, cov_loss

In [None]:
def forward(self, x_enc):
    return self.forecast(x_enc)

5. Training and Settings

5.1 Training for Hign-dimensional Forecasting

In [None]:
# class LongTermForecastingExperiment(BaseExperiment)
def train(self, setting: str) -> Tuple[nn.Module, list, dict, str]:
    """
    Execute the complete training procedure.
    
    Performs model training with validation monitoring, early stopping,
    and comprehensive metric tracking across all epochs.
    
    Args:
        setting: Unique experiment setting string for checkpoint management
        
    Returns:
        Tuple containing:
            - all_epoch_metrics: List of metrics for each training epoch
            - best_metrics: Dictionary of best validation metrics achieved
            - best_model_path: Path to the saved best model checkpoint
    """
    # Load data splits
    train_data, train_loader = self._get_data(flag='train')
    vali_data, vali_loader = self._get_data(flag='val')
    test_data, test_loader = self._get_data(flag='test')
    
    # Setup checkpoint directory
    path = os.path.join(self.config.checkpoints, setting)
    
    # Initialize performance tracking variables
    all_epoch_metrics = []
    best_metrics = {
        "epoch": 0,
        "train_loss": float('inf'),
        "vali_loss": float('inf'),
        "vali_mae_loss": float('inf'),
        "test_loss": float('inf'),
        "test_mae_loss": float('inf')
    }
    best_model_path = ""
    
    time_now = time.time()
    
    train_steps = len(train_loader)
    early_stopping = EarlyStopping(patience=self.config.patience, verbose=True, accelerator=self.accelerator)
    
    # Prepare components for distributed training with accelerator
    self.model, self.optimizer, train_loader, vali_loader, test_loader = self.accelerator.prepare(
        self.model, self.optimizer, train_loader, vali_loader, test_loader
    )
    
    # Main training loop
    for epoch in range(self.config.train_epochs):
        iter_count = 0
        train_loss = []
        
        self.model.train()
        epoch_time = time.time()
        batch_times = []  # Track training time per batch
        
        # Batch training loop
        for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
            batch_start_time = time.time()  # Record batch start time
            
            iter_count += 1
            self.optimizer.zero_grad()
            
            # Move data to device
            batch_x = batch_x.float().to(self.device)
            batch_y = batch_y.float().to(self.device)
            batch_x_mark = batch_x_mark.float().to(self.device)
            batch_y_mark = batch_y_mark.float().to(self.device)
            
            # Prepare decoder input (for sequence-to-sequence models)
            dec_inp = torch.zeros_like(batch_y[:, -self.config.pred_len:, :]).float()
            dec_inp = torch.cat([batch_y[:, :self.config.label_len, :], dec_inp], dim=1).float().to(self.device)
            
            # Forward pass with automatic mixed precision
            with self.accelerator.autocast():
                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
            
            # Handle models that return additional loss components
            if isinstance(outputs, tuple):
                outputs, additional_loss = outputs
            else:
                additional_loss = 0
            
            # Calculate loss (only on prediction horizon)
            batch_y = batch_y[:, -self.config.pred_len:, :].to(self.device)
            loss = self.criterion(outputs, batch_y) + additional_loss
            
            train_loss.append(loss.item())
            
            # Log progress every 100 iterations
            if (i + 1) % 100 == 0:
                self.accelerator.print(f"\titers: {i+1}, epoch: {epoch+1} | loss: {loss.item():.7f}")
                speed = (time.time() - time_now) / iter_count
                left_time = speed * ((self.config.train_epochs - epoch) * train_steps - i)
                self.accelerator.print(f'\tspeed: {speed:.4f}s/iter; left time: {left_time:.4f}s')
                iter_count = 0
                time_now = time.time()
            
            # Backward pass and optimization
            self.accelerator.backward(loss)
            self.optimizer.step()
            
            batch_end_time = time.time()
            batch_times.append(batch_end_time - batch_start_time)  # Record batch training time
        
        # Calculate epoch timing statistics
        epoch_cost_time = time.time() - epoch_time
        avg_batch_time = np.mean(batch_times)
        self.accelerator.print(f"Epoch: {epoch+1} cost time: {epoch_cost_time:.2f}s")
        self.accelerator.print(f"Average batch training time: {avg_batch_time:.4f}s")
        
        # Evaluate model performance
        train_loss = np.average(train_loss)
        val_time = time.time()
        vali_loss, vali_mae_loss = self.validate(vali_loader)
        self.accelerator.print(f"Val cost time: {time.time() - val_time:.2f}s")
        test_time = time.time()
        test_loss, test_mae_loss = self.validate(test_loader)
        self.accelerator.print(f"Test cost time: {time.time() - test_time:.2f}s")
        
        # Record comprehensive epoch metrics
        epoch_metrics = {
            "epoch": epoch + 1,
            "train_loss": float(train_loss),
            "vali_loss": float(vali_loss),
            "vali_mae_loss": float(vali_mae_loss),
            "test_loss": float(test_loss),
            "test_mae_loss": float(test_mae_loss)
        }
        all_epoch_metrics.append(epoch_metrics)
        
        self.accelerator.print(f'Epoch: {epoch+1}, Steps: {train_steps} | Train Loss: {train_loss:.7f} Vali Loss: {vali_loss:.7f} Test Loss: {test_loss:.7f}')
        
        # Early stopping check (includes saving checkpoint)
        early_stopping(vali_loss, self.model, path, metrics=epoch_metrics)

        # Update best metrics if current model is better
        if vali_loss < best_metrics["vali_loss"]:
            best_metrics.update(epoch_metrics)
            best_model_path = early_stopping.get_checkpoint_path()

        # Stop if needed
        if early_stopping.early_stop:
            self.accelerator.print("Early stopping")
            break

        # Adjust learning rate according to schedule
        adjust_learning_rate(self.optimizer, epoch + 1, self.config, self.accelerator)
    
    return all_epoch_metrics, best_metrics, best_model_path

If you want to learn more, please see it at core/experiments/long_term_forecasting.py

5.2 Distributed Training

In [None]:
# Prepare components for distributed training with accelerator
self.model, self.optimizer, train_loader, vali_loader, test_loader = self.accelerator.prepare(
    self.model, self.optimizer, train_loader, vali_loader, test_loader
)

5.3 Early Stop

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, accelerator=None):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.delta = delta
        self.accelerator = accelerator
        self.best_metrics = None  # Store best validation metrics

    def __call__(self, val_loss, model, path, metrics=None):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path, metrics)
            if metrics is not None:
                self.best_metrics = metrics
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.accelerator:
                self.accelerator.print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            else:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path, metrics)
            if metrics is not None:
                self.best_metrics = metrics
            self.counter = 0

5.4 Learning Rate Scheduler

In [None]:
def adjust_learning_rate(optimizer, epoch, args, accelerator=None):
    # lr = args.learning_rate * (0.2 ** (epoch // 2))
    if args.lradj == 'type1':
        lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))}
    elif args.lradj == 'type2':
        lr_adjust = {
            2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
            10: 5e-7, 15: 1e-7, 20: 5e-8
        }
    elif args.lradj == 'type3':
        lr_adjust = {epoch: args.learning_rate if epoch < 3 else args.learning_rate * (0.9 ** ((epoch - 3) // 1))}
    elif args.lradj == 'constant':
        lr_adjust = {epoch: args.learning_rate}
    elif args.lradj == 'TSLR':
        lr_adjust = {epoch: args.learning_rate * ((0.5 ** 0.1) ** (epoch // 20))}
    elif args.lradj == '3':
        lr_adjust = {epoch: args.learning_rate if epoch < 10 else args.learning_rate * 0.1}
    elif args.lradj == '4':
        lr_adjust = {epoch: args.learning_rate if epoch < 15 else args.learning_rate * 0.1}
    elif args.lradj == '5':
        lr_adjust = {epoch: args.learning_rate if epoch < 25 else args.learning_rate * 0.1}
    elif args.lradj == '6':
        lr_adjust = {epoch: args.learning_rate if epoch < 5 else args.learning_rate * 0.1}
    elif args.lradj == 'TST':
        lr_adjust = {epoch: args.learning_rate * (1.0 + 0.1 * epoch / args.train_epochs)}
    if epoch in lr_adjust.keys():
        lr = lr_adjust[epoch]
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        if accelerator is not None:
            accelerator.print('Updating learning rate to {}'.format(lr))
        else:
            print('Updating learning rate to {}'.format(lr))

6. Validation and Testing

In [None]:
def validate(self, vali_loader=None) -> Tuple[float, float]:
    """
    Validate the model using distributed metric aggregation to avoid GPU OOM.

    Args:
        vali_loader: Optional DataLoader for validation data. If None, it will be created internally.

    Returns:
        Tuple of (MSE, MAE)
    """
    if vali_loader is None:
        _, vali_loader = self._get_data(flag='val')

    # Initialize distributed accumulators
    sum_sq_error = torch.tensor(0.0, device=self.device)
    sum_abs_error = torch.tensor(0.0, device=self.device)
    total_count = torch.tensor(0.0, device=self.device)

    self.model.eval()
    with torch.no_grad():
        for batch_x, batch_y, batch_x_mark, batch_y_mark in vali_loader:
            # Move data to device
            batch_x = batch_x.float().to(self.device)
            batch_y = batch_y.float().to(self.device)
            batch_x_mark = batch_x_mark.float().to(self.device)
            batch_y_mark = batch_y_mark.float().to(self.device)

            # Prepare decoder input
            dec_inp = torch.zeros_like(batch_y[:, -self.config.pred_len:, :])
            dec_inp = torch.cat(
                [batch_y[:, :self.config.label_len, :], dec_inp], dim=1
            ).to(self.device)

            # Forward pass
            with self.accelerator.autocast():
                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
            if isinstance(outputs, tuple):
                outputs = outputs[0]

            # Slice true values
            true_slice = batch_y[:, -self.config.pred_len:, :]

            # Compute batch errors
            error = outputs - true_slice
            sum_sq_error += error.pow(2).sum()
            sum_abs_error += error.abs().sum()
            total_count += torch.tensor(error.numel(), device=self.device)

    # Reduce metrics across all devices once
    sum_sq_error = self.accelerator.reduce(sum_sq_error, reduction="sum")
    sum_abs_error = self.accelerator.reduce(sum_abs_error, reduction="sum")
    total_count = self.accelerator.reduce(total_count, reduction="sum")

    # Compute final metrics
    mse = sum_sq_error / total_count
    mae = sum_abs_error / total_count

    self.model.train()
    return mse.item(), mae.item()

In [None]:
def test(self, setting: str, best_model_path: Optional[str] = None) -> Tuple[float, float]:
    """
    Evaluate the trained model on test data using the same aggregation logic as in validate().

    Args:
        setting: Experiment identifier string, used for result saving and fallback checkpoint loading.
        best_model_path: Path to the model checkpoint. If None, defaults to ./checkpoints/{setting}.pth

    Returns:
        Tuple of (MSE, MAE) on the test set.
    """
    # Determine checkpoint path
    if best_model_path is None:
        best_model_path = os.path.join(self.config.checkpoints, f"{setting}.pth")

    self.accelerator.print(f'Loading trained model {best_model_path} for testing')
    
    # Load model weights
    self.model = self.accelerator.unwrap_model(self.model)
    self.model.load_state_dict(torch.load(best_model_path, map_location='cpu'))

    # Get test data loader and prepare for distributed eval
    _, test_loader = self._get_data(flag='test')
    self.model, test_loader = self.accelerator.prepare(self.model, test_loader)

    # Use validate() to compute MSE and MAE
    mse, mae = self.validate(test_loader)

    self.accelerator.print(f'Test MSE: {mse:.6f}, Test MAE: {mae:.6f}')
    return mse, mae