In [1]:
import importlib
import numpy as np
import torch
import torch.nn as nn
import Model_Structure_pkg.CNN_Transformer_Model.model.CNN_transformer as CNN_transformer
import Model_Structure_pkg.CNN_Transformer_Model.model.decoder as decoder
import Model_Structure_pkg.CNN_Transformer_Model.model.encoder as encoder

# Reload all submodules in correct dependency order
importlib.reload(decoder)
importlib.reload(encoder)
importlib.reload(CNN_transformer)


<module 'Model_Structure_pkg.CNN_Transformer_Model.model.CNN_transformer' from '/Volumes/rvmartin2/Active/s.siyuan/Projects/Daily_PM25_DL_2024/code/Training_Validation_Estimation/PM25/v0.3.0/Model_Structure_pkg/CNN_Transformer_Model/model/CNN_transformer.py'>

In [2]:
import importlib
import Model_Structure_pkg.Transformer_Model
import Model_Structure_pkg.CNN_Transformer_Model.model.decoder as decoder
import Model_Structure_pkg.CNN_Transformer_Model.model.encoder as encoder
import Model_Structure_pkg.CNN_Transformer_Model.model.CNN_transformer as CNN_Transformer

# Reload all submodules in correct dependency order
importlib.reload(decoder)
importlib.reload(encoder)
importlib.reload(CNN_Transformer)

## Prepare some datasets as targets and inputs for the Transformer model
CNN_input_dim = 4  # Number of input features (e.g., channels)
input_dim = 5  # Number of input features (e.g., channels)
trg_dim = 1    # Number of target features (e.g., PM2.5)
d_model = 64   # Dimension of the model (hidden size)
n_head = 8     # Number of attention heads
ffn_hidden = 256  # Dimension of the feed-forward network hidden layer
num_layers = 2  # Number of encoder/decoder layers
max_len = 40  # Maximum length of the input sequence
drop_prob = 0.1  # Dropout probability

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer_model = CNN_Transformer.CNN_Transformer(CNN_input_dim, input_dim,trg_dim, d_model, n_head, ffn_hidden, num_layers, max_len, drop_prob,device)

#prepare some input daata and target data batchs
batch_size = 16  # Number of samples in a batch
seq_length = 32  # Length of each input sequence
CNN_input_data = torch.randn(batch_size, seq_length, CNN_input_dim,11,11).to(device)  # Input data shape: (batch_size, seq_length, input_dim)
input_data = torch.randn(batch_size, seq_length, input_dim).to(device)  # Input data shape: (batch_size, seq_length, input_dim)
input_data = input_data + torch.Tensor(np.arange(seq_length)[:, np.newaxis]).to(device)  # Add a positional encoding
input_sum = torch.sum(input_data, dim=-1, keepdim=True)  # Sum over the last dimension

#Create a target_data tensor for learning, suitable for transformer
target_data = torch.Tensor(torch.square(input_sum)+0.32*input_sum+0.02).to(device)  # Target data shape: (batch_size, seq_length, trg_dim)

print('size of target_data:', target_data.size())
## select some random elements in target_data to nan values
nan_indices = np.random.choice(seq_length, size=int(seq_length * 0.1), replace=False)  # 10% of the target data will be NaN
target_data[:, nan_indices, :] = float('nan')  # Set selected indices to NaN

# Forward pass through the Transformer model
output = transformer_model(CNN_input_data, input_data, target_data)
# Output shape: (batch_size, seq_length, trg_dim)
print("Output shape:", output.shape)
# Check if the output shape matches the expected target shape
if output.shape == target_data.shape:
    print("Output shape matches target shape.")
else:
    print("Output shape does not match target shape.")






Transformer initialized with parameters:
CNN Input Dimension: 4, Transformer Input Dimension: 517, Target Dimension: 1, d_model: 64, n_head: 8, ffn_hidden: 256, num_layers: 2, max_len: 40, drop_prob: 0.1
size of target_data: torch.Size([16, 32, 1])
Output shape: torch.Size([16, 32, 1])
Output shape matches target shape.


In [3]:


# Try to train a transformer model with the input and target data
criterion = nn.MSELoss()  # Mean Squared Error Loss
optimizer = torch.optim.Adam(transformer_model.parameters(), lr=0.001)  # Adam
# Training loop
num_epochs = 17  # Number of epochs for training
def r2_score(y_true, y_pred, mask):
    """
    Computes the R^2 score.
    """
    y_true = y_true[mask]
    y_pred = y_pred[mask]
    ss_total = ((y_true - y_true.mean()) ** 2).sum()
    ss_residual = ((y_true - y_pred) ** 2).sum()
    return 1 - (ss_residual / ss_total)

def masked_mse_loss(predictions, targets, mask):
    """
    Computes the mean squared error loss with a mask.
    
    predictions: (B, T, D)
    targets:     (B, T, D)
    mask:        (B, T) or (B, T, 1) with 1 for valid, 0 for invalid
    """
    # Ensure mask is broadcastable
    if mask.dim() == 2:
        mask = mask.unsqueeze(-1)
    mask = mask.expand_as(targets).float()  # (B, T, D)
    squared_error = (predictions - targets) ** 2
    #print('squared_error:', squared_error, 'mask:', mask)  # Debugging shapes
    #print('predictions:', predictions)
    masked_loss = squared_error * mask
    #print('squared_error:', squared_error[0,:,:],'mask:', mask[0,:,:], 'masked_loss:', masked_loss[0,:,:])
    loss = masked_loss.sum() / mask.sum().clamp(min=1e-8)  # avoid divide by zero
    return loss

for epoch in range(num_epochs):
    transformer_model.train()  # Set the model to training mode
    optimizer.zero_grad()  # Zero the gradients
    output = transformer_model(CNN_input_data,input_data, target_data)  # Forward pass
    #print('output:', output)
    
    mask = ~torch.isnan(target_data)  # Create a mask for valid target data (not NaN)
    filled_target_data = torch.nan_to_num(target_data, nan=0.0)  # Fill NaN values in target data with 0.0
    loss = masked_mse_loss(output, filled_target_data, mask)  # Compute the loss
    
    print('Epoch:', epoch, 'Loss:', loss.item())
    loss.backward()  # Backward pass
    optimizer.step()  # Update the model parameters
    # Calculate the R2 of output and target
    r2 = r2_score(filled_target_data.detach().cpu().numpy(), output.detach().cpu().numpy(), mask.detach().cpu().numpy())

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, R2: {r2:.4f}")  # Print the loss for each epoch

# Save the trained model
#torch.save(transformer_model.state_dict(), 'transformer_model.pth') 
# Load the trained model
#transformer_model.load_state_dict(torch.load('transformer_model.pth'))
transformer_model.eval()  # Set the model to evaluation mode
# Test the model with a new input

cnn_test_input = torch.randn(batch_size, seq_length, CNN_input_dim,11,11).to(device)  # New CNN input data for testing
test_input = torch.randn(batch_size, seq_length, input_dim).to(device)  # New input data for testing
test_target = target_data = torch.sum(torch.square(test_input)+0.32*test_input+0.02, dim=-1).to(device).unsqueeze(-1)  # Target data shape: (batch_size, seq_length, trg_dim)
test_output = transformer_model(cnn_test_input, test_input)  # Forward pass with the test input
mask = ~torch.isnan(test_output)  # Create a mask for valid test output (not NaN)
r2 = r2_score(test_target.detach().cpu().numpy(), test_output.detach().cpu().numpy(), mask.detach().cpu().numpy())
print(f"Test R2 Score: {r2:.4f}")  # Print the R2 score for the test output

print("Test output shape:", test_output.shape)  # Output shape should be (batch_size, seq_length, trg_dim)
# Check if the test output shape matches the expected target shape
if test_output.shape == target_data.shape:
    print("Test output shape matches target shape.")
else:
    print("Test output shape does not match target shape.") 

Epoch: 0 Loss: 126517384.0
Epoch [1/17], Loss: 126517384.0000, R2: -1.2093
Epoch: 1 Loss: 126488896.0
Epoch [2/17], Loss: 126488896.0000, R2: -1.2088
Epoch: 2 Loss: 126471688.0
Epoch [3/17], Loss: 126471688.0000, R2: -1.2085
Epoch: 3 Loss: 126461840.0
Epoch [4/17], Loss: 126461840.0000, R2: -1.2083
Epoch: 4 Loss: 126455648.0
Epoch [5/17], Loss: 126455648.0000, R2: -1.2082
Epoch: 5 Loss: 126451936.0
Epoch [6/17], Loss: 126451936.0000, R2: -1.2081
Epoch: 6 Loss: 126449968.0
Epoch [7/17], Loss: 126449968.0000, R2: -1.2081
Epoch: 7 Loss: 126448696.0
Epoch [8/17], Loss: 126448696.0000, R2: -1.2081
Epoch: 8 Loss: 126447392.0
Epoch [9/17], Loss: 126447392.0000, R2: -1.2081
Epoch: 9 Loss: 126446352.0
Epoch [10/17], Loss: 126446352.0000, R2: -1.2080
Epoch: 10 Loss: 126445288.0
Epoch [11/17], Loss: 126445288.0000, R2: -1.2080
Epoch: 11 Loss: 126444144.0
Epoch [12/17], Loss: 126444144.0000, R2: -1.2080
Epoch: 12 Loss: 126442936.0
Epoch [13/17], Loss: 126442936.0000, R2: -1.2080
Epoch: 13 Loss: 12