## Digital Twin Framework with Spatiotemporal Vision Transformer for Heat Resilience

Reference: Gong, W., Ye, X., Wu, K., Jamonnak, S., Zhang, W., Yang, Y., & Huang, X. (2025). [Integrating Spatiotemporal Vision Transformer into Digital Twins for High-Resolution Heat Stress Forecasting in Campus Environments](https://arxiv.org/abs/2502.09657). arXiv preprint arXiv:2502.09657.

Note: GPU is required for training this model. This model was previously trained on two NVIDIA A800 Tensor Core GPUs (80GB memory each).

![Digital twin framework](Figure1.jpg)

### Define a custom PyTorch dataset class (CustomDataset) for spatiotemporal data processing

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from torchvision import transforms
import os
import torchvision.transforms.functional as TF
import random
from PIL import Image

In [None]:
class CustomDataset(Dataset):
    def __init__(self, spatial_data_dir, temporal_data_csv, output_data_dir, T_in, T_out):
        self.spatial_data = []
        self.output_data = []
        self.spatial_data_max = None
        self.spatial_data_min = None
        spatial_files = sorted(os.listdir(spatial_data_dir))[:4]
        for spatial_file in spatial_files:
            spatial_image = np.array(Image.open(os.path.join(spatial_data_dir, spatial_file))).astype(np.float32)
            spatial_image = torch.tensor(spatial_image).unsqueeze(0)  
            spatial_image, self.spatial_data_max, self.spatial_data_min = self.maxminscaler_3d(spatial_image)  
            spatial_image = TF.crop(spatial_image, top=0, left=0, height=3982, width=3739)
            self.spatial_data.append(spatial_image)
        self.spatial_data = torch.cat(self.spatial_data, dim=0)

        temporal_data_df = pd.read_csv(temporal_data_csv).iloc[:, 1:]  
        self.temporal_data = temporal_data_df.select_dtypes(include=[np.number]).fillna(0).astype(np.float32).values

        self.temporal_data_max = np.max(self.temporal_data, axis=0)  
        self.temporal_data_min = np.min(self.temporal_data, axis=0)  

        num_samples_possible = (self.temporal_data.shape[0] // 336)
        self.temporal_data = self.temporal_data[:num_samples_possible * 336].reshape(-1, 336, 7)  
        self.temporal_data = self.normalize_temporal_data(self.temporal_data)

        self.output_data_paths = [os.path.join(output_data_dir, f) for f in sorted(os.listdir(output_data_dir))]
        for output_path in self.output_data_paths:
            output_image = np.array(Image.open(output_path)).astype(np.float32)
            self.output_data.append(output_image)
        self.T_in = T_in
        self.T_out = T_out
        self.utci_max = None  
        self.utci_min = None  

        print("CustomDataset initialized successfully.")
        print(f"Number of spatial images: {len(spatial_files)}")
        print(f"Temporal data shape after reshape: {self.temporal_data.shape}")
        print(f"Number of output images (UTCI): {len(self.output_data_paths)}")
        print(f"T_in: {T_in}, T_out: {T_out}")

        self.compute_utci_global_max_min()
        self.num_samples = self.temporal_data.shape[0] * (336 - (self.T_in + self.T_out - 1))
        print(f"Calculated num_samples: {self.num_samples}")

        if self.num_samples <= 0:
            raise ValueError("Not enough time steps to generate input and output sequences")

    def normalize_temporal_data(self, temporal_data):
        normalized_temporal_data = (temporal_data - self.temporal_data_min) / (self.temporal_data_max - self.temporal_data_min)
        wind_direction_index = 4
        normalized_temporal_data[:, :, wind_direction_index] = temporal_data[:, :, wind_direction_index] / 360.0

        return normalized_temporal_data

    def compute_utci_global_max_min(self):
        for i, output_file in enumerate(self.output_data_paths):
            output_data = self.output_data[i]
            output_data_tensor = torch.tensor(output_data).unsqueeze(0)
            current_max = output_data_tensor.max().item()
            current_min = output_data_tensor.min().item()

            if self.utci_max is None or current_max > self.utci_max:
                self.utci_max = current_max
            if self.utci_min is None or current_min < self.utci_min:
                self.utci_min = current_min

        print(f"UTCI_Max: {self.utci_max}, UTCI_Min: {self.utci_min}")

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        h_idx = random.randint(0,3917)
        w_idx = random.randint(0,3674)
        sample_idx = idx // (336 - (self.T_in + self.T_out - 1))
        time_idx = idx % (336 - (self.T_in + self.T_out - 1))
        spatial_data_seq = TF.crop(self.spatial_data, top=h_idx, left=w_idx, height=64, width=64).cpu()
        spatial_data_seq = spatial_data_seq.unsqueeze(-1).repeat(1, 1, 1, self.T_in)  # [4, H, W, T_in]
        spatial_data_seq = spatial_data_seq.permute(1, 2, 0, 3)  # [H, W, 4, T_in]

        temporal_data = torch.tensor(self.temporal_data[sample_idx, time_idx:time_idx + self.T_in], dtype=torch.float32).cpu()

        output_data_list = []
        for t in range(self.T_out):
            output_data = self.output_data[time_idx + self.T_in + t]
            output_data = torch.tensor(output_data).unsqueeze(0)
            output_data, _, _ = self.maxminscaler_3d(output_data, self.utci_max, self.utci_min)
            output_data = TF.crop(output_data, top=h_idx, left=w_idx, height=64, width=64).cpu()
            output_data_list.append(output_data)
        output_data = torch.stack(output_data_list).permute(2, 3, 1, 0)
        utci_input_list = []
        for t in range(self.T_in):
            utci_data = self.output_data[time_idx + t]
            utci_data = torch.tensor(utci_data).unsqueeze(0)
            utci_data, _, _ = self.maxminscaler_3d(utci_data, self.utci_max, self.utci_min)
            utci_data = TF.crop(utci_data, top=h_idx, left=w_idx, height=64, width=64).cpu()
            utci_input_list.append(utci_data)

        utci_input = torch.stack(utci_input_list).permute(2, 3, 1, 0)  
        spatial_data_seq = torch.cat([spatial_data_seq, utci_input], dim=2) 

        return spatial_data_seq.cpu(), output_data.cpu(), temporal_data.cpu()

    def maxminscaler_3d(self, tensor_3d, scaler_max=None, scaler_min=None, range=(0, 1)):
        if scaler_max is None:
            scaler_max = tensor_3d.max()
        if scaler_min is None:
            scaler_min = tensor_3d.min()
        X_std = (tensor_3d - scaler_min) / (scaler_max - scaler_min)
        X_scaled = X_std * (range[1] - range[0]) + range[0]
        return X_scaled, scaler_max, scaler_min

### Construct a spatiotemporal Transformer-based decoder that leverages self-attention and cross-attention mechanisms to model spatial and temporal dependencies for structured sequence prediction tasks

In [None]:
import torch
import torch.nn as nn

In [None]:
class AttentionLayer(nn.Module):
    """Multi-head attention mechanism that handles inputs with additional dimensions."""

    def __init__(self, model_dim, num_heads=4, mask=False):
        super().__init__()

        self.model_dim = model_dim
        self.num_heads = num_heads
        self.mask = mask

        assert model_dim % num_heads == 0, "model_dim must be divisible by num_heads"
        self.head_dim = model_dim // num_heads

        self.FC_Q = nn.Linear(model_dim, model_dim)
        self.FC_K = nn.Linear(model_dim, model_dim)
        self.FC_V = nn.Linear(model_dim, model_dim)

        self.out_proj = nn.Linear(model_dim, model_dim)

    def forward(self, query, key, value):
        # query, key, value: (batch_size, ..., length, model_dim)
        batch_size = query.size(0)
        extra_dims = query.size()[1:-2]  # Tuple of additional dimensions, if any
        length = query.size(-2)

        # Flatten extra dimensions into batch dimension for processing
        if extra_dims:
            new_batch_size = batch_size * int(torch.prod(torch.tensor(extra_dims)))
            Q = self.FC_Q(query).view(new_batch_size, length, self.model_dim)
            K = self.FC_K(key).view(new_batch_size, length, self.model_dim)
            V = self.FC_V(value).view(new_batch_size, length, self.model_dim)
        else:
            Q = self.FC_Q(query).view(batch_size, length, self.model_dim)
            K = self.FC_K(key).view(batch_size, length, self.model_dim)
            V = self.FC_V(value).view(batch_size, length, self.model_dim)

        # Split into multiple heads
        Q = Q.view(-1, length, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size*, num_heads, length, head_dim)
        K = K.view(-1, length, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(-1, length, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.head_dim ** 0.5  # (batch_size*, num_heads, length, length)

        if self.mask:
            seq_length = attn_scores.size(-1)
            mask = torch.triu(torch.ones(seq_length, seq_length, device=query.device), diagonal=1).bool()
            attn_scores.masked_fill_(mask, float('-inf'))

        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)  # (batch_size*, num_heads, length, head_dim)

        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(-1, length, self.model_dim)

        # Restore original batch and extra dimensions
        if extra_dims:
            attn_output = attn_output.view(batch_size, *extra_dims, length, self.model_dim)
        else:
            attn_output = attn_output.view(batch_size, length, self.model_dim)

        output = self.out_proj(attn_output)  # (batch_size, ..., length, model_dim)

        return output


class SelfAttentionLayer(nn.Module):
    """Self-attention layer with residual connection and feed-forward network."""

    def __init__(self, model_dim, feed_forward_dim=2048, num_heads=4, dropout=0.1, mask=False):
        super().__init__()

        self.attn = AttentionLayer(model_dim, num_heads, mask)
        self.feed_forward = nn.Sequential(
            nn.Linear(model_dim, feed_forward_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feed_forward_dim, model_dim),
        )
        self.ln1 = nn.LayerNorm(model_dim)
        self.ln2 = nn.LayerNorm(model_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, dim=-2):
        # x: (batch_size, length, model_dim)
        x = x.transpose(dim, -2)  # Bring the attention dimension to the second position
        residual = x
        out = self.attn(x, x, x)
        out = self.dropout1(out)
        out = self.ln1(residual + out)

        residual = out
        out = self.feed_forward(out)
        out = self.dropout2(out)
        out = self.ln2(residual + out)

        out = out.transpose(dim, -2)  # Restore original dimensions
        return out


class CrossAttentionLayer(nn.Module):
    """Cross-attention layer for decoder to attend to encoder outputs."""

    def __init__(self, model_dim, num_heads=4, dropout=0.1):
        super(CrossAttentionLayer, self).__init__()
        self.model_dim = model_dim
        self.num_heads = num_heads
        self.head_dim = model_dim // num_heads

        assert model_dim % num_heads == 0, "model_dim must be divisible by num_heads"

        # Linear projections for query, key, and value
        self.query_proj = nn.Linear(model_dim, model_dim)
        self.key_proj = nn.Linear(model_dim, model_dim)
        self.value_proj = nn.Linear(model_dim, model_dim)

        self.out_proj = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value):
        # query: (batch_size * H_W, T_out, model_dim)
        # key, value: (batch_size * H_W, T_in, model_dim)
        batch_size = query.size(0)
        T_out = query.size(1)

        # Linear projections
        Q = self.query_proj(query)
        K = self.key_proj(key)
        V = self.value_proj(value)

        # Split into multiple heads
        Q = Q.view(batch_size, T_out, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.head_dim ** 0.5
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        attn_output = torch.matmul(attn_weights, V)

        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, T_out, self.model_dim)
        output = self.out_proj(attn_output)
        return output


class DecoderLayer(nn.Module):
    """Decoder layer with masked self-attention and cross-attention."""

    def __init__(self, model_dim, num_heads, feed_forward_dim, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = SelfAttentionLayer(
            model_dim, feed_forward_dim, num_heads, dropout, mask=True
        )
        self.cross_attn = CrossAttentionLayer(model_dim, num_heads, dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(model_dim, feed_forward_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feed_forward_dim, model_dim),
        )
        self.ln1 = nn.LayerNorm(model_dim)
        self.ln2 = nn.LayerNorm(model_dim)
        self.ln3 = nn.LayerNorm(model_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, memory):
        batch_size, T_out, H_W, model_dim = x.size()
        batch_size_mem, T_in, H_W_mem, model_dim_mem = memory.size()
        assert H_W == H_W_mem, "Spatial dimensions must match between x and memory"
        assert model_dim == model_dim_mem, "Model dimensions must match"

        # Temporal self-attention with masking
        residual = x
        x = self.self_attn(x, dim=1)
        x = self.ln1(x + residual)

        # Cross-attention with encoder output
        residual = x

        # Reshape x and memory to (batch_size * H_W, T_out, model_dim)
        x_reshaped = x.permute(0, 2, 1, 3).reshape(batch_size * H_W, T_out, model_dim)
        memory_reshaped = memory.permute(0, 2, 1, 3).reshape(batch_size * H_W, T_in, model_dim)

        x = self.cross_attn(x_reshaped, memory_reshaped, memory_reshaped)

        # Reshape back to original dimensions
        x = x.reshape(batch_size, H_W, T_out, model_dim).permute(0, 2, 1, 3)
        x = self.ln2(x + residual)

        # Feed-forward network
        residual = x
        x = self.feed_forward(x)
        x = self.ln3(x + residual)

        return x


class SpatialTemporalTransformer_Decoder(nn.Module):
    """Decoder-only model with spatial and temporal attention."""

    def __init__(self, H, W, C_in, C_temp, T_in, C_out,
                 hidden_dim=64, num_heads=4, num_layers=1, dropout=0.1):
        super(SpatialTemporalTransformer_Decoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.C_out = C_out
        self.T_in = T_in
        self.H = H
        self.W = W
        self.H_W = H * W
        self.C_temp = C_temp  # X_temp input features
        self.C_in = C_in      # X_in input features

        # Linear layer to project C_in to hidden_dim
        self.fc_in = nn.Linear(self.C_in, self.hidden_dim)

        # Linear layer to project C_temp to hidden_dim
        self.fc_temp = nn.Linear(self.C_temp, self.hidden_dim)

        # Spatial attention layers for X_in
        self.spatial_attn_layers = nn.ModuleList([
            SelfAttentionLayer(self.hidden_dim, feed_forward_dim=hidden_dim*4, num_heads=num_heads, dropout=dropout)
            for _ in range(num_layers)
        ])

        # Temporal attention layers for X_in
        self.temporal_attn_layers = nn.ModuleList([
            SelfAttentionLayer(self.hidden_dim, feed_forward_dim=hidden_dim*4, num_heads=num_heads, dropout=dropout)
            for _ in range(num_layers)
        ])

        # Temporal attention layers for X_temp
        self.temp_temporal_attn_layers = nn.ModuleList([
            SelfAttentionLayer(self.hidden_dim, feed_forward_dim=hidden_dim*4, num_heads=num_heads, dropout=dropout)
            for _ in range(num_layers)
        ])

        # Output projection
        self.output_proj = nn.Linear(self.hidden_dim, self.C_out)

    def forward(self, X_in, X_temp):
        batch_size = X_in.size(0)
        H = self.H
        W = self.W
        H_W = H * W
        T_in = self.T_in

        # Process X_in
        X_in = X_in.permute(0, 4, 1, 2, 3)  # (batch_size, T_in, H, W, C_in)
        x_in = X_in.reshape(batch_size, T_in, H_W, self.C_in)  # (batch_size, T_in, H*W, C_in)
        x_in = self.fc_in(x_in)  # (batch_size, T_in, H*W, hidden_dim)

        # Apply spatial attention
        for layer in self.spatial_attn_layers:
            x_in = layer(x_in, dim=2)

        # Apply temporal attention
        for layer in self.temporal_attn_layers:
            x_in = layer(x_in, dim=1)

        # Process X_temp
        x_temp = self.fc_temp(X_temp)  # (batch_size, T_in, hidden_dim)

        # Apply temporal attention
        for layer in self.temp_temporal_attn_layers:
            x_temp = layer(x_temp, dim=1)

        # Expand x_temp to match x_in's spatial dimensions
        x_temp_expanded = x_temp.unsqueeze(2).repeat(1, 1, H_W, 1)  # (batch_size, T_in, H*W, hidden_dim)

        # Combine x_in and x_temp
        x_combined = x_in + x_temp_expanded  # (batch_size, T_in, H*W, hidden_dim)

        # Output projection
        output = self.output_proj(x_combined)  # (batch_size, T_in, H*W, C_out)

        # Reshape to (batch_size, H, W, C_out, T_in)
        output = output.view(batch_size, T_in, H, W, self.C_out)
        output = output.permute(0, 2, 3, 4, 1)  # (batch_size, H, W, C_out, T_in)

        return output

### Train the Spatiotemporal Vision Transformer model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm  
import numpy as np
import gc

from dataloader import CustomDataset
from Transformer import SpatialTemporalTransformer_Decoder
from utils.utils import replace_w_sync_bn, CustomDataParallel

gc.collect()
torch.cuda.empty_cache()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def split_dataset_by_time(spatial_data_dir, temporal_data_csv, output_data_dir, train_ratio=0.7, val_ratio=0.15):
    dataset = CustomDataset(spatial_data_dir, temporal_data_csv, output_data_dir, T_in=24, T_out=24)
    total_samples = len(dataset) 
    train_size = int(train_ratio * total_samples) 
    val_size = int(val_ratio * total_samples)  
    test_size = total_samples - train_size - val_size 

    train_indices = list(range(0, train_size))
    val_indices = list(range(train_size, train_size + val_size))

    train_dataset = Subset(dataset, train_indices)
    val_dataset = Subset(dataset, val_indices)

    return train_dataset, val_dataset

class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.0005):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

def train_and_evaluate_model(model, train_loader, val_loader, num_epochs=50, learning_rate=0.0001, log_dir='logs/1010_model', patience=10, min_delta=0.0005):
    writer = SummaryWriter(log_dir)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    
    model.apply(replace_w_sync_bn)
    model = CustomDataParallel(model, 2)
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for batch_idx, (spatial_data, output_data, temporal_data) in enumerate(tqdm(train_loader)):
            spatial_data = spatial_data.to(device,non_blocking=True)
            output_data = output_data.to(device,non_blocking=True)
            temporal_data = temporal_data.to(device,non_blocking=True)  
            optimizer.zero_grad()
            outputs = model(spatial_data, temporal_data)
            loss = criterion(outputs, output_data)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        writer.add_scalar('Loss/train', avg_train_loss, epoch)
        if epoch % 10 == 0:
            torch.save(model.state_dict(), '../output/1010_model_trained_{}.pth'.format(epoch))
            model.eval()  
            val_loss = 0.0
            with torch.no_grad():  
                for spatial_data_val, output_data_val, temporal_data_val in val_loader:
                    spatial_data_val = spatial_data_val.to(device,non_blocking=True)
                    output_data_val = output_data_val.to(device,non_blocking=True)
                    temporal_data_val = temporal_data_val.to(device,non_blocking=True)  
                    outputs_val = model(spatial_data_val, temporal_data_val)
                    loss_val = criterion(outputs_val, output_data_val)
                    val_loss += loss_val.item()

            avg_val_loss = val_loss / len(val_loader)
            writer.add_scalar('Loss/val', avg_val_loss, epoch)
            print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

    writer.close()

    torch.save(model.module.state_dict(), '../output/1010_model_trained.pth')
    print("Training complete, model saved.")

   
    def calculate_metrics(loader, dataset_type):
        model.eval()  
        total_loss = 0.0
        mse_criterion = nn.MSELoss()

        with torch.no_grad():  
            for spatial_data, output_data, temporal_data in loader:
                spatial_data = spatial_data.to(device,non_blocking=True)
                output_data = output_data.to(device,non_blocking=True)
                temporal_data = temporal_data.to(device,non_blocking=True)  
                outputs = model(spatial_data, temporal_data)
                loss = mse_criterion(outputs, output_data)
                total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        rmse = np.sqrt(avg_loss)
        print(f"{dataset_type} - MSE: {avg_loss:.4f}, RMSE: {rmse:.4f}")

    calculate_metrics(train_loader, "Training Set")
    calculate_metrics(val_loader, "Validation Set")


spatial_data_dir = '../data/spatial_images'
temporal_data_csv = '../data/weather data.csv'
output_data_dir = '../data/output_images'

train_dataset, val_dataset = split_dataset_by_time(spatial_data_dir, temporal_data_csv, output_data_dir)

train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, pin_memory=True, num_workers=10, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False, pin_memory=True, num_workers=10, drop_last=True)

H = 64
W = 64
C_in = 5
T_in = 24
C_temp = 7
T_out = 24
C_out = 1

model = SpatialTemporalTransformer_Decoder(H=H, W=W, C_in=C_in, C_temp=C_temp, C_out=C_out,
                                   T_in=T_in, hidden_dim=12, num_heads=2,
                                   num_layers=1, dropout=0.1)

train_and_evaluate_model(model, train_loader, val_loader, num_epochs=200, learning_rate=0.0001)

### Evaluate the model performance

In [None]:
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Subset
import os
import tifffile
import csv
from dataloader import CustomDataset
from Transformer import SpatialTemporalTransformer_Decoder

In [None]:
def inverse_scaler(tensor, scaler_max, scaler_min):
    return tensor * (scaler_max - scaler_min) + scaler_min

def save_predictions_to_csv(predictions, ground_truths, output_file, num_nodes, T_out):
    with open(output_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        for node_id in range(num_nodes):  
            row = [f"{node_id+1}"]  
            for t in range(T_out):  
                pred = predictions[node_id, t]
                gt = ground_truths[node_id, t]
                row.append(f"{pred},{gt}")  
            writer.writerow(row)

def predict_and_save(model, dataloader, output_file, num_nodes, T_out, device='cuda'):
    model.eval()
    predictions_list = np.zeros((num_nodes, T_out))  
    ground_truth_list = np.zeros((num_nodes, T_out))  
    utci_max = dataloader.dataset.dataset.utci_max  
    utci_min = dataloader.dataset.dataset.utci_min  
    print(f"Max UTCI: {utci_max}, Min UTCI: {utci_min}")  
    node_counter = 0  

    with torch.no_grad():
        for spatial_data, output_data, temporal_data in dataloader:
            spatial_data = spatial_data.to(device)
            output_data = output_data.to(device)
            temporal_data = temporal_data.to(device)
            print(np.shape(spatial_data), np.shape(temporal_data))
            predictions = model(spatial_data, temporal_data)
            predictions_np = predictions.cpu().detach().numpy()
            output_data_np = output_data.cpu().detach().numpy()
            print(f"Predictions shape: {predictions_np.shape}")
            print(f"Output data shape before transpose: {output_data_np.shape}")
            predictions_np = np.transpose(predictions_np, (0, 4, 1, 2, 3))

            if len(output_data_np.shape) == 5 and output_data_np.shape[3] == 1:
                output_data_np = np.transpose(output_data_np, (0, 4, 1, 2, 3))  

            if predictions_np.shape[2] == 1:
                predictions_np = np.squeeze(predictions_np, axis=2)  

            if output_data_np.shape[4] == 1:
                output_data_np = np.squeeze(output_data_np, axis=4)  

            batch_size = predictions_np.shape[0] 

            for i in range(batch_size):  
                for h in range(64): 
                    for w in range(64):  
                        node_idx = node_counter + h * 64 + w  
                        if node_idx < num_nodes:
                            for t in range(T_out):  
                                scalar_prediction = float(predictions_np[i, t, h, w])
                                scalar_ground_truth = float(output_data_np[i, t, h, w])  

                                predictions_list[node_idx, t] = inverse_scaler(scalar_prediction, utci_max, utci_min)
                                ground_truth_list[node_idx, t] = inverse_scaler(scalar_ground_truth, utci_max, utci_min)

            node_counter += 64 * 64  
    print(np.shape(ground_truth_list))
    save_predictions_to_csv(predictions_list, ground_truth_list, output_file, num_nodes, T_out)

def split_dataset_by_time(spatial_data_dir, temporal_data_csv, output_data_dir, train_ratio=0.4, val_ratio=0.3):
    dataset = CustomDataset(spatial_data_dir, temporal_data_csv, output_data_dir, T_in=24, T_out=24)
    total_samples = len(dataset) 
    train_size = int(train_ratio * total_samples)  
    val_size = int(val_ratio * total_samples)  
    test_size = total_samples - train_size - val_size  
    val_indices = list(range(train_size, train_size + val_size))
    val_dataset = Subset(dataset, val_indices)
    test_indices = list(range(train_size + val_size, total_samples))
    test_dataset = Subset(dataset, test_indices)
    return test_dataset

spatial_data_dir = '../data/spatial_images'
temporal_data_csv = '../data/weather data.csv'
output_data_dir = '../data/output_images'

test_dataset = split_dataset_by_time(spatial_data_dir, temporal_data_csv, output_data_dir)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False)

H = 64
W = 64
C_in = 5
T_in = 24
C_temp = 7
T_out = 24
C_out = 1
model = SpatialTemporalTransformer_Decoder(H=H, W=W, C_in=C_in, C_temp=C_temp, C_out=C_out,
                                   T_in=T_in, hidden_dim=12, num_heads=2,
                                   num_layers=1, dropout=0.1)

checkpoint = torch.load('../output/1010_model_trained_40.pth')
# model = checkpoint.module
new_state_dict = {}
for k,v in checkpoint.items():
    new_state_dict[k[7:]] = v
model.load_state_dict(new_state_dict)
model.to('cuda')
num_nodes = 64 * 64 * len(test_loader)  
T_out = 24  
output_file = '../output/1010_test_predictions.csv'
predict_and_save(model, test_loader, output_file, num_nodes, T_out)

In [None]:
def calculate_accuracy(file_name):
    data = pd.read_csv(file_name, header=None)
    global_sum_mape, global_sum_mae, global_sum_rmse = 0.0, 0.0, 0.0
    global_num_mape, global_num_mae, global_num_rmse = 0, 0, 0

    for i, row in data.iterrows():
        node_id = row[0] 
        for t in range(1, len(row)):
            if row[t] != '':
                pred, gt = map(float, row[t].split(','))
                if gt != 0:  
                    mape = abs(pred - gt) / abs(gt)  
                    mae = abs(pred - gt)  
                    rmse = (pred - gt) ** 2  
                    global_sum_mape += mape
                    global_sum_mae += mae
                    global_sum_rmse += rmse
                    global_num_mape += 1
                    global_num_mae += 1
                    global_num_rmse += 1

    global_mape = global_sum_mape / global_num_mape if global_num_mape != 0 else np.nan
    global_mae = global_sum_mae / global_num_mae if global_num_mae != 0 else np.nan
    global_rmse = (global_sum_rmse / global_num_rmse) ** 0.5 if global_num_rmse != 0 else np.nan
    print(f"Global MAPE: {global_mape:.6f}, Global MAE: {global_mae:.6f}, Global RMSE: {global_rmse:.6f}")

if __name__ == "__main__":
    calculate_accuracy('../output/1010_test_predictions.csv')

Global MAPE: 0.051518, Global MAE: 1.638521, Global RMSE: 2.144549
