# Chapter 01: IsingGPT – Transformer Learns Phase Transitions


Let's train a Transformer on equilibrium samples from the 1D Ising model spontaneously discovers the Boltzmann distribution, nearest-neighbor spin correlations, and phase-transition behavior — without ever seeing the Hamiltonian.


In [None]:
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import seaborn as sns
import matplotlib.pyplot as plt
import physai

plt.rcParams['mathtext.fontset'] = 'stix'
plt.rcParams['font.family'] = 'STIXGeneral'

sns.set_style("white")
sns.set_context("notebook", font_scale=1.3)

### 1. Generate the data

We generate $N$ samples, each sample is a spin chain of length $L$. Then the spin direction {-1, +1} can be represented by a token (0/1).


In [None]:
N = 15_000

low_T  = physai.generate_ising_samples(temp=0.5, n_samples=N, equilibration_steps=900)
mid_T  = physai.generate_ising_samples(temp=1.0, n_samples=N, equilibration_steps=900)
high_T = physai.generate_ising_samples(temp=3.0, n_samples=N, equilibration_steps=900)

In [None]:
def plot_chains(samples: torch.Tensor, title: str):
    # Move to CPU + convert to numpy (fixes the MPS error)
    data = samples[:10].cpu().numpy()          # show 10 chains for better visual
    
    plt.figure(figsize=(12, 5))
    
    # Beautiful physics-style colormap: red = +1 (up), blue = -1 (down)
    im = plt.imshow(
        data,
        cmap='RdBu_r',
        aspect='auto',
        vmin=-1, vmax=1,
        interpolation='nearest'
    )
    
    plt.title(title, fontsize=18, pad=20)
    plt.xlabel('Site index (position along chain)', fontsize=14)
    plt.ylabel('Sample index', fontsize=14)
    
    # Colorbar with physical meaning
    cbar = plt.colorbar(im, ticks=[-1, 0, 1], shrink=0.8)
    cbar.ax.set_yticklabels(['↓ (-1)', '0', '↑ (+1)'])
    cbar.set_label('Spin', rotation=270, labelpad=20, fontsize=14)
    
    # Clean ticks
    plt.yticks(range(10))
    plt.xticks(fontsize=12)
    
    # Remove spines for cleaner look
    sns.despine(left=True, bottom=True)
    
    plt.tight_layout()
    plt.show()


In [None]:
# Plot
plot_chains(low_T,  "1D Ising Model – Low Temperature (T = 0.5)\nLarge ferromagnetic domains")
plot_chains(high_T, "1D Ising Model – High Temperature (T = 3.0)\nComplete paramagnetic disorder")

### 2. Transfer Matrix Method

Exact Analytical Solution of the 1D Ising Model.

In [None]:
def exact_spin_correlation_1d(L: int = 32, temp: float = 1.0, J: float = 1.0) -> torch.Tensor:
    beta = 1.0 / temp
    corr = torch.tanh(torch.tensor(beta * J)) ** torch.arange(L)
    corr = torch.roll(corr, -L//2)
    return corr.cpu().numpy()

plt.figure(figsize=(11, 6.5))
temps = [0.4, 0.6, 0.8, 1.0, 1.5, 3.0]
colors = plt.cm.viridis(torch.linspace(0.9, 0.1, len(temps)))

for T in temps:
    corr = exact_spin_correlation_1d(temp=T)
    plt.plot(range(-16, 16), corr, 'o-', label=rf'$T = {T}$', markersize=7, lw=2.5)

plt.axhline(0, color='k', lw=0.8, alpha=0.4)
plt.title(r"Exact Spin-Spin Correlation $\langle \sigma_0 \sigma_r \rangle$"
          r" in the 1D Ising Model ($h=0$)", fontsize=20, pad=25)
plt.xlabel(r"Distance $r$", fontsize=16)
plt.ylabel(r"Correlation $\langle \sigma_0 \sigma_r \rangle$", fontsize=16)
plt.legend(fontsize=13, frameon=False)
plt.grid(True, alpha=0.3)
sns.despine()
plt.tight_layout()
plt.show()

### 3. Dataset Preparation – From Spins to Tokens

Now we use the data we have generated to prepare the dataset.

In [None]:
class IsingDataset(Dataset):
    def __init__(self, spins: torch.Tensor):
        """
        Convert ±1 spins → {0, 1} tokens for vocabulary size 2.
        Each sample becomes an autoregressive sequence: predict next spin from previous ones.
        """
        self.tokens = ((spins + 1) // 2).long()   # ±1 → 0/1
    
    def __len__(self):
        return len(self.tokens)
    
    def __getitem__(self, idx):
        seq = self.tokens[idx]
        return seq[:-1], seq[1:]

In [None]:
# Create a temperature-mixed dataset
data = torch.cat([low_T, mid_T, high_T], dim=0)
dataset = IsingDataset(data)
loader  = DataLoader(dataset, batch_size=512, shuffle=True, drop_last=True)

print(f"Dataset size: {len(dataset):,}")
print(f"Vocabulary size: {dataset.tokens.max().item() + 1}")
print(f"Example chain (tokens): {dataset[0][0]}")

In [None]:
data = torch.cat([low_T, mid_T, high_T], dim=0)

In [None]:
# Data Validation
temps = [0.5, 1.0, 3.0]
for i, T in enumerate(temps):
    print(f"\n{'='*60}")
    print(f"Validating T = {T}")
    print(f"{'='*60}")
    
    samples = data[i*N:(i+1)*N]
    results = physai.validate_ising_samples(samples, temp=T, plot=True)
    physai.print_validation_report(results)

### 3. Build the Transformer Block

In [None]:
class TransformerBlock(nn.Module):
    """Single transformer block with causal self-attention"""
    def __init__(self, d_model: int, n_head: int):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(
            d_model, 
            n_head, 
            dropout=0.0, 
            batch_first=True
        )
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
        )
    
    def forward(self, x):
        # Pre-norm + attention with causal mask
        b, t, d = x.shape
        # Lower triangular = True (visible), upper = False (masked)
        causal_mask = torch.tril(
            torch.ones(t, t, device=x.device, dtype=torch.bool), 
            diagonal=0
        )
        
        x_norm = self.ln1(x)
        attn_out, _ = self.attn(
            x_norm, x_norm, x_norm,
            attn_mask=causal_mask,
            need_weights=False
        )
        x = x + attn_out
        
        # Pre-norm + MLP
        x = x + self.mlp(self.ln2(x))
        
        return x


In [None]:
class IsingTransformer(nn.Module):
    def __init__(self, d_model: int = 128, n_head: int = 8, n_layer: int = 6, block_size: int = 31):
        super().__init__()
        self.block_size = block_size
        self.d_model = d_model
        self.n_head = n_head
        
        self.tok_emb = nn.Embedding(2, d_model)
        self.pos_emb = nn.Embedding(block_size, d_model)
        
        # Stack of transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_head) for _ in range(n_layer)
        ])
        
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, 2, bias=False)
        
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, idx, targets=None):
        b, t = idx.shape
        
        # Token + position embeddings
        pos = torch.arange(0, t, dtype=torch.long, device=idx.device).unsqueeze(0)
        x = self.tok_emb(idx) + self.pos_emb(pos)
        
        # Apply transformer blocks
        for block in self.blocks:
            x = block(x)
        
        x = self.ln_f(x)
        logits = self.head(x)
        
        if targets is None:
            return logits
        
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss



### 4. Training

In [None]:
model = IsingTransformer(d_model=128, n_head=8, n_layer=6).to('mps')

history = physai.train_ising_transformer(
    model=model,
    loader=loader,
    num_steps=1_000,
    lr=1e-3,
    weight_decay=0.01,
    device='mps',
    plot=True
)

### 5. Check the Model has really learnt it

In [None]:
model.eval()
with torch.no_grad():
    # Test 1: Low Temperature
    test_ordered = torch.tensor([[0]*20 + [0]*11], device='mps')  # 全 0
    logits_ordered = model(test_ordered)
    probs_ordered = F.softmax(logits_ordered, dim=-1)
    
    # Test 2：High Temperature
    test_random = torch.tensor([[0, 1, 0, 1, 1, 0, 1, 0, 0, 1] * 3 + [0]], device='mps')
    logits_random = model(test_random)
    probs_random = F.softmax(logits_random, dim=-1)
    
    print("="*60)
    print("TEST 1: Ordered sequence (all 0s, low-T like)")
    print("="*60)
    print(f"Last position prediction:")
    print(f"  P(next=0) = {probs_ordered[0, -1, 0].item():.4f}")
    print(f"  P(next=1) = {probs_ordered[0, -1, 1].item():.4f}")
    print(f"Expected: P(0) >> P(1) due to ferromagnetic correlation")
    
    print("\n" + "="*60)
    print("TEST 2: Random sequence (alternating, high-T like)")
    print("="*60)
    print(f"Last position prediction:")
    print(f"  P(next=0) = {probs_random[0, -1, 0].item():.4f}")
    print(f"  P(next=1) = {probs_random[0, -1, 1].item():.4f}")
    print(f"Expected: closer to 0.5/0.5 (weak correlation)")
    
    # Test 3：Learning nearest-neighbor bias
    print("\n" + "="*60)
    print("TEST 3: Nearest-neighbor correlation")
    print("="*60)
    
    # Previous 0, predict next
    context_0 = torch.tensor([[0]*15], device='mps')
    p_next_0 = F.softmax(model(context_0), dim=-1)[0, -1, 0].item()
    
    # Previous 1, predict next
    context_1 = torch.tensor([[1]*15], device='mps')
    p_next_1 = F.softmax(model(context_1), dim=-1)[0, -1, 1].item()
    
    print(f"After seeing 0000...:")
    print(f"  P(next=0) = {p_next_0:.4f}")
    print(f"\nAfter seeing 1111...:")
    print(f"  P(next=1) = {p_next_1:.4f}")
    
    if p_next_0 > 0.7 and p_next_1 > 0.7:
        print("\n✓✓✓ Model learned ferromagnetic correlation!")
        print("    (spins prefer to align with neighbors)")
    elif p_next_0 > 0.6 or p_next_1 > 0.6:
        print("\n✓ Model learned some correlation structure")
    else:
        print("\n⚠️ Model hasn't learned clear correlations")