In [None]:
import os
import random
import time
import sys
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split, WeightedRandomSampler
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error
from torch_geometric.nn import GINEConv, GlobalAttention, Set2Set
from torch_geometric.nn.norm import BatchNorm
from torch_geometric.nn import TransformerConv

# --------------------------------------------------------------------------------
# Configuration
# --------------------------------------------------------------------------------
CONFIG = {
    'processed_dir': './processed',
    'batch_size': 32768,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4,
    'hidden_channels': 256,
    'num_epochs': 500,           # shorter for demo
    'patience': 50,
    'random_seed': 42,
    'best_model_path': 'best_gnn_model.pth',
    'dropout_p': 0.4,
    'scheduler_factor': 0.5,
    'scheduler_patience': 10,
    'grad_clip': 1.0,
    'curriculum_steps': 5  # Number of curriculum stages
}

# --------------------------------------------------------------------------------
# Logging Setup
# --------------------------------------------------------------------------------
def setup_logging():
    """
    Configure logging so that each message is immediately flushed.
    This ensures real-time updates on cluster logs.
    """
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[logging.StreamHandler(sys.stdout)]
    )

# --------------------------------------------------------------------------------
# Dataset
# --------------------------------------------------------------------------------
class SpinSystemDataset(InMemoryDataset):
    def __init__(self, root='.', transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return ['data.pt']

# --------------------------------------------------------------------------------
# Utilities
# --------------------------------------------------------------------------------
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

# --------------------------------------------------------------------------------
# Custom Layers
# --------------------------------------------------------------------------------
class SizeAwareNorm(nn.Module):
    """
    Node-level normalization layer that adapts to system size.
    Assumes system_size has a shape matching the node dimension (num_nodes_in_batch).
    """
    def __init__(self, hidden_channels):
        super().__init__()
        self.norm = nn.LayerNorm(hidden_channels)
        self.size_scale = nn.Sequential(
            nn.Linear(1, hidden_channels),
            nn.SiLU(),
            nn.Linear(hidden_channels, hidden_channels)
        )
        
    def forward(self, x, size):
        # size should be [num_nodes_in_batch] or [num_nodes_in_batch, 1]
        if size.dim() == 1:
            size = size.unsqueeze(-1)
        size_factor = self.size_scale(size)  # [num_nodes_in_batch, hidden_channels]
        return self.norm(x) * torch.sigmoid(size_factor)

class RelativeFeatureTransform(nn.Module):
    """
    Transforms set2set (or other global) embeddings so they become size-invariant.
    Expects 'size' to have shape [batch_size] or [batch_size,1].
    """
    def __init__(self, hidden_channels):
        super().__init__()
        self.transform = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.LayerNorm(hidden_channels),
            nn.SiLU(),
            nn.Linear(hidden_channels, hidden_channels)
        )
        
    def forward(self, x, size):
        if size.dim() == 2 and size.size(1) == 1:
            size = size.squeeze(-1)
        x_scaled = x / (size.unsqueeze(-1) + 1e-6)
        return self.transform(x_scaled)

# --------------------------------------------------------------------------------
# Model
# --------------------------------------------------------------------------------
class EnhancedPhysicsGNN(nn.Module):
    """
    A more advanced GNN model:
      - Replaces or augments some GINEConv layers with TransformerConv.
      - Uses a deeper stacking of message-passing blocks (with skip).
      - Combines Set2Set and a global attention readout for extra expressiveness.
      - Maintains size-aware normalization + relative transform logic.
    """
    def __init__(self, 
                 num_node_features,
                 edge_attr_dim,
                 hidden_channels=256,
                 num_layers=8,
                 dropout_p=0.4):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.dropout_p = dropout_p

        # 1) Initial Node Embedding
        self.init_transform = nn.Sequential(
            nn.Linear(num_node_features, hidden_channels),
            nn.LayerNorm(hidden_channels),
            nn.SiLU(),
        )

        # 2) Build Multiple Message Passing Blocks (alternate GINE and Transformer)
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        for i in range(num_layers):
            if i % 2 == 0:
                # GINE-based MLP
                mp_mlp = nn.Sequential(
                    nn.Linear(hidden_channels, hidden_channels),
                    nn.LayerNorm(hidden_channels),
                    nn.SiLU(),
                    nn.Dropout(dropout_p),
                    nn.Linear(hidden_channels, hidden_channels)
                )
                conv = GINEConv(mp_mlp, edge_dim=edge_attr_dim)
            else:
                # Transformer-based conv
                conv = TransformerConv(
                    hidden_channels, hidden_channels // 4,
                    heads=4,
                    edge_dim=edge_attr_dim,
                    dropout=dropout_p,
                    beta=True
                )
            self.convs.append(conv)
            self.norms.append(BatchNorm(hidden_channels))

        # 3) Global Readouts
        self.set2set_readout = Set2Set(hidden_channels, processing_steps=4)
        self.gate_nn = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels // 2),
            nn.SiLU(),
            nn.Linear(hidden_channels // 2, 1)
        )
        self.global_attention = GlobalAttention(gate_nn=self.gate_nn)

        # 4) Global Feature Transform (8 -> hidden_channels)
        self.global_transform = nn.Sequential(
            nn.Linear(8, hidden_channels),
            nn.LayerNorm(hidden_channels),
            nn.SiLU(),
            nn.Dropout(dropout_p),
            nn.Linear(hidden_channels, hidden_channels),
        )

        # 5) Final MLP
        combined_in_dim = 2*hidden_channels + hidden_channels + hidden_channels  # 4*hidden_channels
        self.final_mlp = nn.Sequential(
            nn.Linear(combined_in_dim, hidden_channels),
            nn.LayerNorm(hidden_channels),
            nn.SiLU(),
            nn.Dropout(dropout_p),
            nn.Linear(hidden_channels, hidden_channels // 2),
            nn.LayerNorm(hidden_channels // 2),
            nn.SiLU(),
            nn.Dropout(dropout_p),
            nn.Linear(hidden_channels // 2, 1),
        )

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        # Initial transform
        h = self.init_transform(x)

        # Message passing
        for i in range(self.num_layers):
            h_new = self.convs[i](h, edge_index, edge_attr)
            h_new = self.norms[i](h_new)
            h = h + h_new  # Residual

        # 1) Set2Set
        s2s = self.set2set_readout(h, batch)  # [batch_size, 2*hidden_channels]

        # 2) Global Attention
        ga = self.global_attention(h, batch)  # [batch_size, hidden_channels]

        # 3) Global features
        Omega = data.Omega.squeeze(-1)
        Delta = data.Delta.squeeze(-1)
        Energy = data.Energy.squeeze(-1)
        total_rydberg = data.total_rydberg
        system_size = data.system_size.squeeze(-1)
        config_entropy = data.config_entropy.squeeze(-1)
        rydberg_density = data.rydberg_density
        relative_entropy = config_entropy / torch.log(system_size + 1e-6)

        global_feats = torch.stack([
            Omega, Delta,
            Energy / system_size,
            (total_rydberg / system_size),
            rydberg_density,
            system_size,
            config_entropy,
            relative_entropy,
        ], dim=1)  # [batch_size, 8]

        gf_out = self.global_transform(global_feats)  # [batch_size, hidden_channels]

        # Combine
        combined = torch.cat([s2s, ga, gf_out], dim=-1)  # => [batch_size, 4*hidden_channels]
        out = self.final_mlp(combined)                   # => [batch_size, 1]

        return out.squeeze(-1)

# --------------------------------------------------------------------------------
# Custom Loss
# --------------------------------------------------------------------------------
class PhysicalScaleAwareLoss(nn.Module):
    def __init__(self, base_size=4, scaling_power=1.5, physics_weight=1.0):
        super().__init__()
        self.base_size = base_size
        self.scaling_power = scaling_power
        self.physics_weight = physics_weight
        
    def get_entropy_bounds(self, system_size, subsystem_size):
        """
        Calculate physical bounds for von Neumann entropy:
        0 ≤ S(ρₐ) ≤ min(nA, nB) * ln(2)
        """
        # Lower bound is 0
        lower_bound = torch.zeros_like(system_size, dtype=torch.float)
        
        # Upper bound from dimensionality
        remaining_size = system_size - subsystem_size
        min_size = torch.minimum(subsystem_size.float(), remaining_size.float())
        upper_bound = min_size * torch.log(torch.tensor(2.0, device=system_size.device))
        
        return lower_bound, upper_bound
        
    def forward(self, pred, target, system_size, subsystem_size):
        if system_size.dim() == 2:
            system_size = system_size.squeeze(-1)
        if subsystem_size.dim() == 2:
            subsystem_size = subsystem_size.squeeze(-1)
            
        # predicted log-entropy -> exponentiate
        pred_entropy = torch.exp(pred)
        
        # bounds
        lower_bound, upper_bound = self.get_entropy_bounds(system_size, subsystem_size)
        
        # constraint: 0 <= S(A) <= min(nA, nB)*ln(2)
        lower_violation = F.smooth_l1_loss(
            torch.maximum(lower_bound, pred_entropy),
            pred_entropy,
            reduction='none'
        )
        upper_violation = F.smooth_l1_loss(
            torch.minimum(upper_bound, pred_entropy),
            pred_entropy,
            reduction='none'
        )
        physics_loss = lower_violation + upper_violation
        
        # base loss on log-entropy
        log_target = torch.log(target + 1e-10)
        base_loss = F.mse_loss(pred, log_target, reduction='none')
        
        # scale weighting
        size_weight = (system_size.float() / self.base_size) ** self.scaling_power
        weighted_loss = base_loss * size_weight
        
        # combine
        total_loss = weighted_loss + self.physics_weight * physics_loss
        return total_loss.mean()

# --------------------------------------------------------------------------------
# Curriculum Sampler
# --------------------------------------------------------------------------------
def get_curriculum_sampler(dataset, epoch, max_epochs):
    system_sizes = [data.system_size.item() for data in dataset]
    max_size = max(system_sizes)
    progress = min(1.0, epoch / (max_epochs * 0.7))
    
    weights = []
    for size in system_sizes:
        if size < 8:
            weight = 1.0
        else:
            weight = progress * (1.0 - (size - 8) / (max_size - 8 + 1e-6))
            weight = max(0.1, weight)
        weights.append(weight)
    
    return WeightedRandomSampler(weights=weights, num_samples=len(dataset), replacement=True)

# --------------------------------------------------------------------------------
# Training Loop
# --------------------------------------------------------------------------------
def train_epoch(model, loader, optimizer, criterion, device, clip_grad=None):
    model.train()
    total_loss = 0.0

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        out = model(data)
        
        # subsystem size from mask
        subsystem_size = torch.tensor([
            torch.sum(data.x[data.batch == i, 3]) 
            for i in range(data.num_graphs)
        ], device=device)
        
        loss = criterion(out, data.y.squeeze(), data.system_size, subsystem_size)
        loss.backward()
        if clip_grad is not None:
            nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
        optimizer.step()
        
        total_loss += loss.item() * data.num_graphs

    # Return average loss over dataset
    return total_loss / len(loader.dataset)

@torch.no_grad()
def evaluate(model, loader, criterion, device, name='Eval'):
    model.eval()
    total_loss = 0.0
    predictions, targets, sizes = [], [], []
    
    for data in loader:
        data = data.to(device)
        out = model(data)
        
        subsystem_size = torch.tensor([
            torch.sum(data.x[data.batch == i, 3])
            for i in range(data.num_graphs)
        ], device=device)
        
        loss = criterion(out, data.y.squeeze(), data.system_size, subsystem_size)
        total_loss += loss.item() * data.num_graphs
        
        predictions.append(torch.exp(out).cpu())
        targets.append(data.y.squeeze().cpu())
        sizes.append(data.system_size.squeeze().cpu())
    
    predictions = torch.cat(predictions, dim=0).numpy()
    targets = torch.cat(targets, dim=0).numpy()
    sizes = torch.cat(sizes, dim=0).numpy()
    
    size_metrics = {}
    for size_val in np.unique(sizes):
        mask = (sizes == size_val)
        size_preds = predictions[mask]
        size_targets = targets[mask]
        size_metrics[int(size_val)] = {
            'mse': mean_squared_error(size_targets, size_preds),
            'mae': mean_absolute_error(size_targets, size_preds),
            'mape': np.mean(np.abs(size_preds - size_targets) / (size_targets + 1e-10)) * 100
        }
    
    mean_loss = total_loss / len(loader.dataset) if len(loader.dataset) > 0 else 0.0

    # Log metrics for this evaluation
    logging.info(f"[{name}] Loss: {mean_loss:.6f}")
    for sz, met in size_metrics.items():
        logging.info(
            f"  Size={sz:2d} : "
            f"MSE={met['mse']:.6f} "
            f"MAE={met['mae']:.6f} "
            f"MAPE={met['mape']:.2f}%"
        )
    
    return {
        'loss': mean_loss,
        'predictions': predictions,
        'targets': targets,
        'sizes': sizes,
        'size_metrics': size_metrics
    }

# --------------------------------------------------------------------------------
# Main
# --------------------------------------------------------------------------------
def main():
    setup_logging()  # ensure logging is configured
    set_seed(CONFIG['random_seed'])
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 1) Load the dataset
    dataset = SpinSystemDataset(root=CONFIG['processed_dir'])
    
    # 2) Separate data by system size
    dataset_sizes = [dataset[i].system_size.item() for i in range(len(dataset))]
    train_val_subset = [dataset[i] for i in range(len(dataset)) if dataset_sizes[i] <= 12]
    test_size14_subset = [dataset[i] for i in range(len(dataset)) if abs(dataset_sizes[i] - 14) < 1e-6]
    
    # Convert subsets to list-based datasets
    class ListDataset(InMemoryDataset):
        def __init__(self, data_list):
            self._data_list = data_list
        def __len__(self):
            return len(self._data_list)
        def __getitem__(self, idx):
            return self._data_list[idx]

    train_val_dataset = ListDataset(train_val_subset)
    size14_dataset = ListDataset(test_size14_subset)
    
    # 3) Train/val split
    train_size = int(0.8 * len(train_val_dataset))
    val_size = len(train_val_dataset) - train_size
    
    if train_size == 0:
        logging.info("No data found with system_size <= 12. Exiting.")
        return
    
    train_dataset, val_dataset = random_split(
        train_val_dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(CONFIG['random_seed'])
    )
    
    # 4) Create model
    sample_data = next(iter(DataLoader(train_dataset, batch_size=1)))
    model = EnhancedPhysicsGNN(
        num_node_features=sample_data.x.size(1),
        edge_attr_dim=sample_data.edge_attr.size(1),
        hidden_channels=CONFIG['hidden_channels'],
        dropout_p=CONFIG['dropout_p']
    ).to(device)
    
    # 5) Training components
    criterion = PhysicalScaleAwareLoss(physics_weight=0.5) 
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=CONFIG['learning_rate'],
        weight_decay=CONFIG['weight_decay']
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=CONFIG['scheduler_factor'],
        patience=CONFIG['scheduler_patience']
    )
    
    # 6) Train/Val loaders
    def create_dataloader(dataset_, epoch_, shuffle_=False):
        sampler = None
        if shuffle_:
            sampler = get_curriculum_sampler(dataset_, epoch_, CONFIG['num_epochs'])
        return DataLoader(dataset_, batch_size=CONFIG['batch_size'], sampler=sampler)

    best_val_loss = float('inf')
    
    # 7) Main loop
    for epoch in range(CONFIG['num_epochs']):
        logging.info(f"Epoch {epoch+1}/{CONFIG['num_epochs']}")
        
        # Create loaders each epoch to incorporate curriculum (weighted) sampler
        train_loader = create_dataloader(train_dataset, epoch, shuffle_=True)
        val_loader = create_dataloader(val_dataset, epoch, shuffle_=False)
        
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device, clip_grad=CONFIG['grad_clip'])
        logging.info(f"  Training Loss: {train_loss:.6f}")

        val_metrics = evaluate(model, val_loader, criterion, device, name='Validation')
        val_loss = val_metrics['loss']

        # Scheduler step
        scheduler.step(val_loss)

        # Check if best
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), CONFIG['best_model_path'])
            logging.info(f"  [Info] New best model saved (val_loss={best_val_loss:.6f})")

    # 8) Final evaluation on the size-14 subset
    if len(size14_dataset) > 0:
        size14_loader = DataLoader(size14_dataset, batch_size=CONFIG['batch_size'])
        model.load_state_dict(torch.load(CONFIG['best_model_path'], map_location=device))
        test_metrics_14 = evaluate(model, size14_loader, criterion, device, name='Size14-Test')
    else:
        logging.warning("No data found with system_size == 14. Skipping test.")

if __name__ == "__main__":
    main()


  self.data, self.slices = torch.load(self.processed_paths[0])



Epoch 1/50
  [Batch 000] loss=3.366929, time=7.108s
  [Batch 001] loss=1.187062, time=0.483s
  [Batch 002] loss=1.148598, time=0.459s
  [Batch 003] loss=0.986141, time=0.442s
  [Batch 004] loss=0.870196, time=0.335s
  [Batch 005] loss=0.977821, time=0.415s
  [Batch 006] loss=0.890420, time=0.465s
  [Batch 007] loss=1.017218, time=0.433s
  [Batch 008] loss=0.839870, time=0.431s
  [Batch 009] loss=0.861150, time=0.431s
  [Batch 010] loss=0.771906, time=0.303s
  [Batch 011] loss=0.908237, time=0.376s
  [Batch 012] loss=0.773724, time=0.427s
  [Batch 013] loss=0.765590, time=0.473s
  [Batch 014] loss=0.597785, time=0.478s
  [Batch 015] loss=0.772159, time=0.358s
  [Batch 016] loss=0.671009, time=0.262s
  [Batch 017] loss=0.772884, time=0.337s
  [Batch 018] loss=0.654636, time=0.538s
  [Batch 019] loss=0.814314, time=0.392s
  [Batch 020] loss=0.571000, time=0.449s
  [Batch 021] loss=0.543551, time=0.317s
  [Batch 022] loss=0.848314, time=0.293s
  [Batch 023] loss=0.838554, time=0.492s
  [B



  [Batch 000] loss=0.609838, time=0.663s
  [Batch 001] loss=0.647852, time=0.519s
  [Batch 002] loss=0.543496, time=0.468s
  [Batch 003] loss=0.388132, time=0.480s
  [Batch 004] loss=0.558100, time=0.485s
  [Batch 005] loss=0.380284, time=0.279s
  [Batch 006] loss=0.536658, time=0.326s
  [Batch 007] loss=0.535147, time=0.454s
  [Batch 008] loss=0.563443, time=0.429s
  [Batch 009] loss=0.530141, time=0.297s
  [Batch 010] loss=0.621670, time=0.355s
  [Batch 011] loss=0.390298, time=0.369s
  [Batch 012] loss=0.451335, time=0.280s
  [Batch 013] loss=0.431002, time=0.242s
  [Batch 014] loss=0.349761, time=0.277s
  [Batch 015] loss=0.466944, time=0.359s
  [Batch 016] loss=0.411623, time=0.309s
  [Batch 017] loss=0.415435, time=0.286s
  [Batch 018] loss=0.525186, time=0.361s
  [Batch 019] loss=0.395995, time=0.480s
  [Batch 020] loss=0.438428, time=0.387s
  [Batch 021] loss=0.624413, time=0.260s
  [Batch 022] loss=0.486005, time=0.509s
  [Batch 023] loss=0.614464, time=0.391s
  [Batch 024] lo



  [Batch 000] loss=0.267145, time=0.677s
  [Batch 001] loss=0.376926, time=0.388s
  [Batch 002] loss=0.320740, time=0.372s
  [Batch 003] loss=0.306412, time=0.310s
  [Batch 004] loss=0.478899, time=0.276s
  [Batch 005] loss=0.348388, time=0.398s
  [Batch 006] loss=0.345529, time=0.320s
  [Batch 007] loss=0.370452, time=0.476s
  [Batch 008] loss=0.406927, time=0.392s
  [Batch 009] loss=0.299567, time=0.267s
  [Batch 010] loss=0.348388, time=0.289s
  [Batch 011] loss=0.421946, time=0.301s
  [Batch 012] loss=0.396721, time=0.426s
  [Batch 013] loss=0.308888, time=0.252s
  [Batch 014] loss=0.388705, time=0.381s
  [Batch 015] loss=0.399593, time=0.330s
  [Batch 016] loss=0.372262, time=0.368s
  [Batch 017] loss=0.357016, time=0.385s
  [Batch 018] loss=0.334178, time=0.388s
  [Batch 019] loss=0.284477, time=0.328s
  [Batch 020] loss=0.383374, time=0.358s
  [Batch 021] loss=0.321554, time=0.461s
  [Batch 022] loss=0.322876, time=0.434s
  [Batch 023] loss=0.286461, time=0.445s
  [Batch 024] lo



  [Batch 000] loss=0.205947, time=0.685s
  [Batch 001] loss=0.276370, time=0.414s
  [Batch 002] loss=0.221187, time=0.412s
  [Batch 003] loss=0.215707, time=0.477s
  [Batch 004] loss=0.333962, time=0.244s
  [Batch 005] loss=0.304534, time=0.329s
  [Batch 006] loss=0.224144, time=0.489s
  [Batch 007] loss=0.230042, time=0.265s
  [Batch 008] loss=0.185692, time=0.367s
  [Batch 009] loss=0.168142, time=0.384s
  [Batch 010] loss=0.222292, time=0.462s
  [Batch 011] loss=0.258654, time=0.232s
  [Batch 012] loss=0.285875, time=0.447s
  [Batch 013] loss=0.149791, time=0.329s
  [Batch 014] loss=0.222346, time=0.441s
  [Batch 015] loss=0.152922, time=0.459s
  [Batch 016] loss=0.245393, time=0.297s
  [Batch 017] loss=0.170472, time=0.236s
  [Batch 018] loss=0.165756, time=0.339s
  [Batch 019] loss=0.206734, time=0.417s
  [Batch 020] loss=0.162794, time=0.201s
  [Batch 021] loss=0.220008, time=0.237s
  [Batch 022] loss=0.170315, time=0.318s
  [Batch 023] loss=0.196068, time=0.279s
  [Batch 024] lo



  [Batch 000] loss=0.190119, time=0.765s
  [Batch 001] loss=0.185925, time=0.615s
  [Batch 002] loss=0.140049, time=0.263s
  [Batch 003] loss=0.255653, time=0.321s
  [Batch 004] loss=0.213246, time=0.366s
  [Batch 005] loss=0.208204, time=0.343s
  [Batch 006] loss=0.167255, time=0.361s
  [Batch 007] loss=0.207878, time=0.450s
  [Batch 008] loss=0.189681, time=0.421s
  [Batch 009] loss=0.278139, time=0.534s
  [Batch 010] loss=0.168731, time=0.249s
  [Batch 011] loss=0.163367, time=0.459s
  [Batch 012] loss=0.164868, time=0.472s
  [Batch 013] loss=0.199090, time=0.438s
  [Batch 014] loss=0.208514, time=0.487s
  [Batch 015] loss=0.169479, time=0.357s
  [Batch 016] loss=0.144530, time=0.377s
  [Batch 017] loss=0.241863, time=0.253s
  [Batch 018] loss=0.190202, time=0.430s
  [Batch 019] loss=0.166105, time=0.260s
  [Batch 020] loss=0.168403, time=0.394s
  [Batch 021] loss=0.143363, time=0.362s
  [Batch 022] loss=0.170210, time=0.321s
  [Batch 023] loss=0.161148, time=0.348s
  [Batch 024] lo



  [Batch 000] loss=0.147665, time=0.709s
  [Batch 001] loss=0.151410, time=0.276s
  [Batch 002] loss=0.150261, time=0.258s
  [Batch 003] loss=0.175071, time=0.504s
  [Batch 004] loss=0.137552, time=0.357s
  [Batch 005] loss=0.172080, time=0.473s
  [Batch 006] loss=0.172553, time=0.473s
  [Batch 007] loss=0.144244, time=0.444s
  [Batch 008] loss=0.191412, time=0.420s
  [Batch 009] loss=0.136857, time=0.323s
  [Batch 010] loss=0.192510, time=0.337s
  [Batch 011] loss=0.137436, time=0.485s
  [Batch 012] loss=0.139768, time=0.365s
  [Batch 013] loss=0.154926, time=0.363s
  [Batch 014] loss=0.137817, time=0.444s
  [Batch 015] loss=0.141568, time=0.274s
  [Batch 016] loss=0.169491, time=0.404s
  [Batch 017] loss=0.174049, time=0.319s
  [Batch 018] loss=0.139160, time=0.289s
  [Batch 019] loss=0.122464, time=0.227s
  [Batch 020] loss=0.204966, time=0.364s
  [Batch 021] loss=0.143853, time=0.313s
  [Batch 022] loss=0.209343, time=0.427s
  [Batch 023] loss=0.168545, time=0.385s
  [Batch 024] lo



  [Batch 000] loss=0.134367, time=0.769s
  [Batch 001] loss=0.144960, time=0.419s
  [Batch 002] loss=0.133900, time=0.401s
  [Batch 003] loss=0.209837, time=0.438s
  [Batch 004] loss=0.172093, time=0.447s
  [Batch 005] loss=0.166316, time=0.366s
  [Batch 006] loss=0.251494, time=0.379s
  [Batch 007] loss=0.127275, time=0.444s
  [Batch 008] loss=0.137079, time=0.328s
  [Batch 009] loss=0.167733, time=0.359s
  [Batch 010] loss=0.178246, time=0.405s
  [Batch 011] loss=0.121714, time=0.255s
  [Batch 012] loss=0.131296, time=0.456s
  [Batch 013] loss=0.131064, time=0.254s
  [Batch 014] loss=0.148200, time=0.449s
  [Batch 015] loss=0.133569, time=0.206s
  [Batch 016] loss=0.141378, time=0.231s
  [Batch 017] loss=0.149433, time=0.307s
  [Batch 018] loss=0.215824, time=0.331s
  [Batch 019] loss=0.197950, time=0.438s
  [Batch 020] loss=0.134449, time=0.384s
  [Batch 021] loss=0.178260, time=0.381s
  [Batch 022] loss=0.199816, time=0.241s
  [Batch 023] loss=0.168748, time=0.463s
  [Batch 024] lo