# Temporal Understanding Through fMRI Analysis

Testing the hypothesis that temporal understanding requires physiological state mediation (A->B->C).

## Hypothesis:
World models cannot gain comprehensive temporal understanding solely from external observations.
Biometric data provides insight into human temporal understanding through physiological state transitions.

In [1]:
# Imports
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import logging
from pathlib import Path
from typing import Dict, Optional, Tuple

from learnedSpectrum.config import Config, DataConfig
from learnedSpectrum.models import VisionTransformerModel
from learnedSpectrum.train import train_loop, evaluate
from learnedSpectrum.visualization import TemporalUnderstandingVisualizer
from learnedSpectrum.physiological import PhysiologicalStateTracker
from learnedSpectrum.causal import CausalAnalysisModule
from learnedSpectrum.rl import TemporalStateEncoder
from learnedSpectrum.data import DatasetManager, create_dataloaders

In [2]:
# Configuration
config = Config()
data_config = DataConfig()

In [3]:
# Initialize wandb
wandb.init(
    project="temporal-understanding",
    config={
        "architecture": "vit-temporal",
        "dataset": "fmri-learning-stages",
        "epochs": config.NUM_EPOCHS,
        "batch_size": config.BATCH_SIZE,
        "learning_rate": config.LEARNING_RATE,
        "temporal_analysis": config.TEMPORAL_ANALYSIS,
        "causal_inference": config.CAUSAL_INFERENCE
    }
)

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: tawarner (tawarner-usc). Use `wandb login --relogin` to force relogin


In [4]:
# Device setup and model initialization
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [5]:
# Initialize models
model = VisionTransformerModel(config).to(device)
physiological_tracker = PhysiologicalStateTracker(config.EMBED_DIM).to(device)
causal_analyzer = CausalAnalysisModule(config.EMBED_DIM, config.TEMPORAL_DIM).to(device)

In [6]:
# Initialize optimizer with gradient clipping
optimizer = torch.optim.AdamW(model.parameters(), lr=config.LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.NUM_EPOCHS)
scaler = torch.amp.GradScaler('cuda') 

In [7]:
# Initialize visualization
vis_manager = TemporalUnderstandingVisualizer(config.VIS_DIR)

In [8]:
# Load and preprocess data
data_manager = DatasetManager(config, data_config)
train_ds, val_ds, test_ds = data_manager.prepare_datasets()
train_loader, val_loader, test_loader = create_dataloaders(train_ds, val_ds, test_ds, config)

Validating samples:   0%|          | 0/264 [00:00<?, ?file/s]

Analyzing timepoints:   0%|          | 0/184 [00:00<?, ?file/s]

Analyzing timepoints:   0%|          | 0/40 [00:00<?, ?file/s]

Analyzing timepoints:   0%|          | 0/40 [00:00<?, ?file/s]

In [9]:
# Training loop
model = train_loop(
    model=model,
    train_dl=train_loader,
    val_dl=val_loader, 
    optimizer=optimizer,
    config=config
)


Epoch 1/50


  return fn(*args, **kwargs)


Epoch 1/50
Train Loss: 2.4927
Val Loss: 1.7392
Val Accuracy: 0.2000
Val Balanced Accuracy: 0.2500

Prediction Distribution:
class_2: 0.975
class_3: 0.025

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 1/50 - Train Loss: 2.4927, Val Loss: 1.7392, Val Acc: 0.2000

Epoch 2/50


  return fn(*args, **kwargs)


Epoch 2/50
Train Loss: 2.4413
Val Loss: 1.6383
Val Accuracy: 0.2000
Val Balanced Accuracy: 0.2500

Prediction Distribution:
class_2: 0.975
class_3: 0.025

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 2/50 - Train Loss: 2.4413, Val Loss: 1.6383, Val Acc: 0.2000

Epoch 3/50


  return fn(*args, **kwargs)


Epoch 3/50
Train Loss: 2.2974
Val Loss: 1.4924
Val Accuracy: 0.2000
Val Balanced Accuracy: 0.2500

Prediction Distribution:
class_2: 0.975
class_3: 0.025

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 3/50 - Train Loss: 2.2974, Val Loss: 1.4924, Val Acc: 0.2000

Epoch 4/50


  return fn(*args, **kwargs)


Epoch 4/50
Train Loss: 2.1248
Val Loss: 1.3602
Val Accuracy: 0.3500
Val Balanced Accuracy: 0.2500

Prediction Distribution:
class_1: 1.000

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 4/50 - Train Loss: 2.1248, Val Loss: 1.3602, Val Acc: 0.3500

Epoch 5/50


  return fn(*args, **kwargs)


Epoch 5/50
Train Loss: 1.9631
Val Loss: 1.2930
Val Accuracy: 0.3500
Val Balanced Accuracy: 0.2500

Prediction Distribution:
class_1: 1.000

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 5/50 - Train Loss: 1.9631, Val Loss: 1.2930, Val Acc: 0.3500

Epoch 6/50


  return fn(*args, **kwargs)


Epoch 6/50
Train Loss: 1.8364
Val Loss: 1.2839
Val Accuracy: 0.3000
Val Balanced Accuracy: 0.2143

Prediction Distribution:
class_0: 0.375
class_1: 0.625

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 6/50 - Train Loss: 1.8364, Val Loss: 1.2839, Val Acc: 0.3000

Epoch 7/50


  return fn(*args, **kwargs)


Epoch 7/50
Train Loss: 1.7384
Val Loss: 1.2892
Val Accuracy: 0.3500
Val Balanced Accuracy: 0.2500

Prediction Distribution:
class_1: 1.000

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 7/50 - Train Loss: 1.7384, Val Loss: 1.2892, Val Acc: 0.3500

Epoch 8/50


  return fn(*args, **kwargs)


Epoch 8/50
Train Loss: 1.6446
Val Loss: 1.2976
Val Accuracy: 0.3750
Val Balanced Accuracy: 0.2812

Prediction Distribution:
class_0: 0.075
class_1: 0.900
class_2: 0.025

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 8/50 - Train Loss: 1.6446, Val Loss: 1.2976, Val Acc: 0.3750

Epoch 9/50


  return fn(*args, **kwargs)


Epoch 9/50
Train Loss: 1.5755
Val Loss: 1.2899
Val Accuracy: 0.4000
Val Balanced Accuracy: 0.2991

Prediction Distribution:
class_0: 0.925
class_1: 0.050
class_2: 0.025

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 9/50 - Train Loss: 1.5755, Val Loss: 1.2899, Val Acc: 0.4000

Epoch 10/50


  return fn(*args, **kwargs)


Epoch 10/50
Train Loss: 1.5346
Val Loss: 1.2759
Val Accuracy: 0.3500
Val Balanced Accuracy: 0.2634

Prediction Distribution:
class_0: 0.850
class_1: 0.125
class_2: 0.025

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 10/50 - Train Loss: 1.5346, Val Loss: 1.2759, Val Acc: 0.3500

Epoch 11/50


  return fn(*args, **kwargs)


Epoch 11/50
Train Loss: 1.4777
Val Loss: 1.2714
Val Accuracy: 0.3750
Val Balanced Accuracy: 0.2812

Prediction Distribution:
class_0: 0.050
class_1: 0.925
class_2: 0.025

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 11/50 - Train Loss: 1.4777, Val Loss: 1.2714, Val Acc: 0.3750

Epoch 12/50


  return fn(*args, **kwargs)


Epoch 12/50
Train Loss: 1.4496
Val Loss: 1.2513
Val Accuracy: 0.3500
Val Balanced Accuracy: 0.2634

Prediction Distribution:
class_0: 0.750
class_1: 0.225
class_2: 0.025

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 12/50 - Train Loss: 1.4496, Val Loss: 1.2513, Val Acc: 0.3500

Epoch 13/50


  return fn(*args, **kwargs)


Epoch 13/50
Train Loss: 1.4142
Val Loss: 1.2363
Val Accuracy: 0.4000
Val Balanced Accuracy: 0.2991

Prediction Distribution:
class_0: 0.950
class_1: 0.025
class_2: 0.025

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 13/50 - Train Loss: 1.4142, Val Loss: 1.2363, Val Acc: 0.4000

Epoch 14/50


  return fn(*args, **kwargs)


Epoch 14/50
Train Loss: 1.3731
Val Loss: 1.2350
Val Accuracy: 0.3750
Val Balanced Accuracy: 0.2812

Prediction Distribution:
class_1: 0.975
class_2: 0.025

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 14/50 - Train Loss: 1.3731, Val Loss: 1.2350, Val Acc: 0.3750

Epoch 15/50


  return fn(*args, **kwargs)


Epoch 15/50
Train Loss: 1.3783
Val Loss: 1.2449
Val Accuracy: 0.3750
Val Balanced Accuracy: 0.2812

Prediction Distribution:
class_0: 0.975
class_2: 0.025

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 15/50 - Train Loss: 1.3783, Val Loss: 1.2449, Val Acc: 0.3750

Epoch 16/50


  return fn(*args, **kwargs)


Epoch 16/50
Train Loss: 1.3510
Val Loss: 1.2296
Val Accuracy: 0.3750
Val Balanced Accuracy: 0.2812

Prediction Distribution:
class_1: 0.975
class_2: 0.025

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 16/50 - Train Loss: 1.3510, Val Loss: 1.2296, Val Acc: 0.3750

Epoch 17/50


  return fn(*args, **kwargs)


Epoch 17/50
Train Loss: 1.3413
Val Loss: 1.2078
Val Accuracy: 0.3750
Val Balanced Accuracy: 0.2812

Prediction Distribution:
class_0: 0.975
class_2: 0.025

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 17/50 - Train Loss: 1.3413, Val Loss: 1.2078, Val Acc: 0.3750

Epoch 18/50


  return fn(*args, **kwargs)


Epoch 18/50
Train Loss: 1.3290
Val Loss: 1.2325
Val Accuracy: 0.3500
Val Balanced Accuracy: 0.2500

Prediction Distribution:
class_1: 1.000

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 18/50 - Train Loss: 1.3290, Val Loss: 1.2325, Val Acc: 0.3500

Epoch 19/50


  return fn(*args, **kwargs)


Epoch 19/50
Train Loss: 1.2733
Val Loss: 1.2098
Val Accuracy: 0.4250
Val Balanced Accuracy: 0.4062

Prediction Distribution:
class_0: 0.100
class_1: 0.775
class_2: 0.025
class_3: 0.100

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 19/50 - Train Loss: 1.2733, Val Loss: 1.2098, Val Acc: 0.4250

Epoch 20/50


  return fn(*args, **kwargs)


Epoch 20/50
Train Loss: 1.2306
Val Loss: 1.1990
Val Accuracy: 0.3250
Val Balanced Accuracy: 0.2455

Prediction Distribution:
class_0: 0.925
class_2: 0.075

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 20/50 - Train Loss: 1.2306, Val Loss: 1.1990, Val Acc: 0.3250

Epoch 21/50


  return fn(*args, **kwargs)


Epoch 21/50
Train Loss: 1.3023
Val Loss: 1.2003
Val Accuracy: 0.4000
Val Balanced Accuracy: 0.3884

Prediction Distribution:
class_0: 0.550
class_1: 0.325
class_2: 0.025
class_3: 0.100

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 21/50 - Train Loss: 1.3023, Val Loss: 1.2003, Val Acc: 0.4000

Epoch 22/50


  return fn(*args, **kwargs)


Epoch 22/50
Train Loss: 1.2598
Val Loss: 1.1743
Val Accuracy: 0.3750
Val Balanced Accuracy: 0.3705

Prediction Distribution:
class_0: 0.050
class_1: 0.800
class_2: 0.025
class_3: 0.125

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 22/50 - Train Loss: 1.2598, Val Loss: 1.1743, Val Acc: 0.3750

Epoch 23/50


  return fn(*args, **kwargs)


Epoch 23/50
Train Loss: 1.2275
Val Loss: 1.1522
Val Accuracy: 0.3000
Val Balanced Accuracy: 0.3170

Prediction Distribution:
class_0: 0.600
class_1: 0.250
class_2: 0.050
class_3: 0.100

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 23/50 - Train Loss: 1.2275, Val Loss: 1.1522, Val Acc: 0.3000

Epoch 24/50


  return fn(*args, **kwargs)


Epoch 24/50
Train Loss: 1.1969
Val Loss: 1.1745
Val Accuracy: 0.2500
Val Balanced Accuracy: 0.2366

Prediction Distribution:
class_0: 0.775
class_1: 0.100
class_2: 0.050
class_3: 0.075

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 24/50 - Train Loss: 1.1969, Val Loss: 1.1745, Val Acc: 0.2500

Epoch 25/50


  return fn(*args, **kwargs)


Epoch 25/50
Train Loss: 1.1861
Val Loss: 1.2133
Val Accuracy: 0.4000
Val Balanced Accuracy: 0.4777

Prediction Distribution:
class_1: 0.775
class_2: 0.025
class_3: 0.200

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 25/50 - Train Loss: 1.1861, Val Loss: 1.2133, Val Acc: 0.4000

Epoch 26/50


  return fn(*args, **kwargs)


Epoch 26/50
Train Loss: 1.2182
Val Loss: 1.2047
Val Accuracy: 0.2750
Val Balanced Accuracy: 0.2098

Prediction Distribution:
class_0: 0.475
class_1: 0.400
class_2: 0.125

Label Distribution:
class_0: 0.350
class_1: 0.350
class_2: 0.200
class_3: 0.100
Epoch 26/50 - Train Loss: 1.2182, Val Loss: 1.2047, Val Acc: 0.2750

Epoch 27/50


  return fn(*args, **kwargs)


KeyboardInterrupt: 

In [None]:
# Evaluate and visualize results
test_loss, test_metrics = evaluate(
    model=model,
    test_dl=test_loader,
    config=config,
    physiological_tracker=physiological_tracker,
    causal_analyzer=causal_analyzer
)

In [None]:
# Plot comprehensive results
vis_manager.plot_state_transitions(
    test_metrics['state_analysis'],
    save_name='final_state_transitions'
)

vis_manager.plot_temporal_understanding_analysis(
    test_metrics['temporal_results'],
    save_name='final_temporal_analysis'
)

In [None]:
# Log final results to wandb
wandb.log({
    "test_accuracy": test_metrics['accuracy'],
    "test_loss": test_metrics['loss'],
    "test_temporal_consistency": test_metrics['temporal_consistency'],
    "causal_strength": test_metrics['causal_analysis']['mean_strength'],
    "ltc_stability": test_metrics['ltc_analysis']['stability_metric'],
    "temporal_understanding_score": test_metrics['temporal_understanding_score']
})

In [None]:
# Save visualizations to wandb
wandb.log({
    "state_transitions": wandb.Image(str(Path(config.VIS_DIR) / 'final_state_transitions.png')),
    "temporal_analysis": wandb.Image(str(Path(config.VIS_DIR) / 'final_temporal_analysis.png'))
})