In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [15]:
class QuantumTransformer(nn.Module):
    def __init__(self, d_model=128, num_heads=8, num_layers=4, dim_feedforward=256, dropout=0.1, seq_length=100, device = torch.device("cuda" if torch.cuda.is_available() else "cpu")):
        """
        Transformer model for quantum wavefunction propagation.
        - d_model: Dimension of token embeddings.
        - num_heads: Number of attention heads.
        - num_layers: Number of Transformer layers.
        - dim_feedforward: Size of the feedforward network.
        - dropout: Dropout rate.
        - seq_length: Number of spatial grid points.
        """
        super(QuantumTransformer, self).__init__()
        
        self.d_model = d_model
        self.seq_length = seq_length

        # Input embedding (wavefunction + potential energy)
        self.embedding = nn.Linear(3, d_model)  # (Re(Ψ), Im(Ψ), V) → d_model
        
        # Positional encoding (adds spatial dependence)
        self.positional_encoding = self.create_positional_encoding(seq_length, d_model).to(device)

        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=num_heads, dim_feedforward=dim_feedforward, dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output layer: Predicts next wavefunction (Re and Im parts)
        self.output_layer = nn.Linear(d_model, 2)  # Output (Re(Ψ), Im(Ψ))

    def forward(self, psi_real, psi_imag, potential):
        """
        Forward pass of the Transformer.
        Input:
            - psi_real: Real part of wavefunction (batch, seq_length)
            - psi_imag: Imaginary part of wavefunction (batch, seq_length)
            - potential: Potential energy (batch, seq_length)
        Output:
            - Next-step wavefunction (batch, seq_length, 2) → (Re(Ψ), Im(Ψ))
        """
        # Stack inputs to create feature vector at each spatial position
        x = torch.stack([psi_real, psi_imag, potential], dim=-1)  # Shape: (batch, seq_length, 3)

        # Apply linear embedding
        x = self.embedding(x)  # Shape: (batch, seq_length, d_model)

        # Add positional encoding
        x = x + self.positional_encoding[:x.shape[1], :].to(x.device)

        # Transformer Encoder
        x = self.transformer_encoder(x)

        # Output layer: Predict (Re(Ψ), Im(Ψ)) at t + Δt
        x = self.output_layer(x)

        return x


In [27]:
import torch
import torch.nn as nn
import math

class QuantumTransformer(nn.Module):
    def __init__(self, d_model, num_heads, num_layers, dim_feedforward, dropout, seq_length, n_grid, device):
        super(QuantumTransformer, self).__init__()
        
        self.embedding = nn.Linear(3, d_model)  # (Re(Ψ), Im(Ψ), V) → d_model

        # Positional Encoding
        self.positional_encoding = self.create_positional_encoding(seq_length * n_grid, d_model).to(device)

        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=num_heads, dim_feedforward=dim_feedforward, dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output layer (maps d_model back to (Re(Ψ), Im(Ψ)))
        self.output_layer = nn.Linear(d_model, 2)

    def create_positional_encoding(self, full_seq_length, d_model):
        """
        Generate a positional encoding tensor of shape (1, full_seq_length, d_model).
        """
        pos_encoding = torch.zeros(full_seq_length, d_model)
        position = torch.arange(0, full_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pos_encoding[:, 0::2] = torch.sin(position * div_term)
        pos_encoding[:, 1::2] = torch.cos(position * div_term)
        
        return pos_encoding.unsqueeze(0)  # Shape: (1, full_seq_length, d_model)

    def forward(self, psi_real, psi_imag, potential):
        # Stack input features along last dimension
        x = torch.stack((psi_real, psi_imag, potential), dim=-1)  # Shape: (batch, seq_length, n_grid, 3)
        
        # Flatten the n_grid dimension into the sequence dimension
        batch_size, seq_length, n_grid, _ = x.shape
        x = x.view(batch_size, seq_length * n_grid, 3)

        # Apply embedding
        x = self.embedding(x)  # Shape: (batch, seq_length * n_grid, d_model)
        
        # Add positional encoding (matching seq_length)
        x = x + self.positional_encoding[:, :x.shape[1], :].to(x.device)

        # Transformer Encoder
        x = self.transformer_encoder(x)

        # Apply output layer
        x = self.output_layer(x)  # Shape: (batch, seq_length * n_grid, 2)

        # Reshape back to (batch, seq_length, n_grid, 2)
        x = x.view(batch_size, seq_length, n_grid, 2)
        
        return x


In [28]:
import h5py
import torch
from torch.utils.data import Dataset, DataLoader

class QuantumWaveDataset(Dataset):
    def __init__(self, h5_file):
        """
        Custom PyTorch Dataset for quantum wavefunction propagation.
        - h5_file: Path to the .h5 file containing dataset_X and dataset_y.
        """
        with h5py.File(h5_file, "r") as f:
            self.X = torch.tensor(f["dataset_X"][:], dtype=torch.float32)  # Shape (num_trajectories, sequence_length, n_grid * 3)
            self.y = torch.tensor(f["dataset_y"][:], dtype=torch.float32)  # Shape (num_trajectories, sequence_length, n_grid * 2)

        # Extract num_trajectories, sequence_length, and n_grid from the shape
        self.num_trajectories, self.sequence_length, n_grid_3 = self.X.shape
        self.n_grid = n_grid_3 // 3  # Since last dimension is n_grid * 3

        # Reshape to (num_trajectories, sequence_length, n_grid, features)
        self.X = self.X.view(self.num_trajectories, self.sequence_length, self.n_grid, 3)
        self.y = self.y.view(self.num_trajectories, self.sequence_length, self.n_grid, 2)

    def __len__(self):
        return self.num_trajectories

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]  # Each sample is (sequence_length, n_grid, features)

# Load dataset
h5_path = "./DataNew/ngrid64_19932025.h5"  # Change this to your actual path
dataset = QuantumWaveDataset(h5_path)

# Create DataLoader for training
batch_size = 10
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = QuantumTransformer(seq_length=dataset.sequence_length, d_model=128, num_heads=8, num_layers=6, dim_feedforward=512, n_grid=64, dropout=0.1, device=device).to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)




In [30]:
num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for X, y in train_loader:
        X, y = X.to(device), y.to(device)

        # Split inputs into components
        psi_real = X[..., 0]  # (batch_size, seq_length, n_grid)
        psi_imag = X[..., 1]  # (batch_size, seq_length, n_grid)
        potential = X[..., 2] # (batch_size, seq_length, n_grid)

        optimizer.zero_grad()
        output = model(psi_real, psi_imag, potential)  # Output shape: (batch_size, seq_length, n_grid, 2)

        loss = criterion(output, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader):.6f}")


Epoch 1/50, Loss: 0.513420
Epoch 2/50, Loss: 0.346202
Epoch 3/50, Loss: 0.281521
Epoch 4/50, Loss: 0.246034
Epoch 5/50, Loss: 0.212651
Epoch 6/50, Loss: 0.181405
Epoch 7/50, Loss: 0.158136
Epoch 8/50, Loss: 0.145092
Epoch 9/50, Loss: 0.142349
Epoch 10/50, Loss: 0.142825
Epoch 11/50, Loss: 0.142172
Epoch 12/50, Loss: 0.138241
Epoch 13/50, Loss: 0.132381
Epoch 14/50, Loss: 0.128892
Epoch 15/50, Loss: 0.127421
Epoch 16/50, Loss: 0.127325
Epoch 17/50, Loss: 0.128202
Epoch 18/50, Loss: 0.128331
Epoch 19/50, Loss: 0.126717
Epoch 20/50, Loss: 0.124991
Epoch 21/50, Loss: 0.123014
Epoch 22/50, Loss: 0.122126
Epoch 23/50, Loss: 0.122030
Epoch 24/50, Loss: 0.122354
Epoch 25/50, Loss: 0.122478
Epoch 26/50, Loss: 0.122432
Epoch 27/50, Loss: 0.121363
Epoch 28/50, Loss: 0.120367
Epoch 29/50, Loss: 0.119609
Epoch 30/50, Loss: 0.119496
Epoch 31/50, Loss: 0.119381
Epoch 32/50, Loss: 0.119749
Epoch 33/50, Loss: 0.119398
Epoch 34/50, Loss: 0.119012
Epoch 35/50, Loss: 0.118339
Epoch 36/50, Loss: 0.118146
E