In [7]:
import argparse
import os
import sys
import time

import numpy as np
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

import geoopt

# If your mps/ package is local, ensure it’s on sys.path or installed in editable mode:
sys.path.append("../../")

from mps.simple_mps import SimpleMPS
from mps.tpcp_mps import MPSTPCP, ManifoldType
from mps.StiefelOptimizers import StiefelAdam, StiefelSGD
from mps.radam import RiemannianAdam

# Enable inline plotting in Jupyter
%matplotlib inline


In [8]:
class SyntheticDataset(torch.utils.data.Dataset):
    """
    Each sample is an n-dimensional vector where the first element is either 0 or 1,
    and the remaining n-1 entries are 0. The label is the first element.
    We embed each scalar x to a 2-d vector [x, 1-x] and then L2-normalize.
    """
    def __init__(self, n: int, num_samples: int = 10000, seed: int | None = None):
        """
        Args:
            n: Dimension of the vector.
            num_samples: Number of samples in the dataset.
            seed: Random seed (optional).
        """
        self.n = n
        self.num_samples = num_samples
        if seed is not None:
            np.random.seed(seed)
        # Randomly assign labels (0 or 1) for each sample
        self.labels = np.random.randint(0, 2, size=num_samples).astype(np.int64)
    
    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        # The label is either 0 or 1
        l = self.labels[index]
        # Create a vector of length n with first element l and rest zeros
        x = torch.zeros(self.n, dtype=torch.float64)
        x[0] = float(l)
        # Embed each scalar: map x -> [x, 1-x]
        x_embedded = torch.stack([x, 1 - x], dim=-1)  # shape: [n, 2]
        # Normalize each site so that the two entries sum to 1
        x_embedded = x_embedded / (x_embedded.sum(dim=-1, keepdim=True).clamp(min=1e-8))
        # The target label is l
        return x_embedded, torch.tensor(l, dtype=torch.int64)


In [9]:
def loss_batch(outputs, labels):
    """
    Binary cross-entropy style loss for outputs in [0, 1].
    For label=0 => use outputs[i], for label=1 => use 1 - outputs[i].
    """
    device = outputs.device
    loss_val = torch.zeros(1, device=device, dtype=torch.float64)
    for i in range(len(outputs)):
        prob = outputs[i] if labels[i] == 0 else (1 - outputs[i])
        loss_val -= torch.log(prob + 1e-8)
        if torch.isnan(loss_val):
            print(f"NaN in loss at index {i}, prob={prob}, output={outputs[i]}, label={labels[i]}")
    return loss_val

def calculate_accuracy(outputs, labels):
    """
    Threshold outputs at 0.5 to assign label 0 or 1 and compare to true labels.
    """
    predictions = (outputs < 0.5).float()
    correct = (predictions == labels).float().sum()
    return correct / labels.numel()


In [10]:
# Configuration for debugging
n = 100          # Use a smaller input dimension for faster runs
num_data = 1000 # Smaller dataset for quick debugging
batch_size = 64

synthetic_dataset = SyntheticDataset(n=n, num_samples=num_data, seed=42)
dataloader = torch.utils.data.DataLoader(synthetic_dataset, batch_size=batch_size, shuffle=True)

print("Dataset prepared:")
print("Number of samples:", len(synthetic_dataset))
print("Batch size:", batch_size)


Dataset prepared:
Number of samples: 1000
Batch size: 64


In [11]:
# Training a SimpleMPS model (demo)
N = n
d = l = 2  
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
smps = SimpleMPS(
    N,
    2,
    d,
    l,
    layers=2,
    device=device,
    dtype=torch.float64,
    optimize="greedy",
)
logsoftmax = torch.nn.LogSoftmax(dim=-1)
nnloss = torch.nn.NLLLoss(reduction="mean")
opt_smps = torch.optim.Adam(smps.parameters(), lr=0.001)

smps_losses = []
smps.train()
print("\n=== Training SimpleMPS for 3 epochs (debug mode)... ===")
for epoch in range(20):  # Use fewer epochs for debugging
    total_loss_smps = 0.0
    total_samples_smps = 0
    total_correct_smps = 0
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        data = data.permute(1, 0, 2)
        opt_smps.zero_grad()
        outputs = smps(data)  # shape: [batch_size, 2]
        outputs = logsoftmax(outputs)
        loss = nnloss(outputs, target)
        loss.backward()
        opt_smps.step()

        bs = target.size(0)
        total_loss_smps += loss.item() * bs
        total_samples_smps += bs

        preds = outputs.argmax(dim=-1)
        total_correct_smps += (preds == target).float().sum().item()

    epoch_loss = total_loss_smps / total_samples_smps
    epoch_accuracy = total_correct_smps / total_samples_smps
    smps_losses.append(epoch_loss)
    print(f"[SimpleMPS] Epoch {epoch+1} | Loss: {epoch_loss:.6f} | Accuracy: {epoch_accuracy:.2%}")

print("SimpleMPS training complete.")


Path is not set, setting...
Found the path
Initialized MPS with random matrices

=== Training SimpleMPS for 3 epochs (debug mode)... ===
[SimpleMPS] Epoch 1 | Loss: 0.693003 | Accuracy: 51.00%
[SimpleMPS] Epoch 2 | Loss: 0.688807 | Accuracy: 51.00%
[SimpleMPS] Epoch 3 | Loss: 0.622389 | Accuracy: 54.60%
[SimpleMPS] Epoch 4 | Loss: 0.444156 | Accuracy: 51.00%
[SimpleMPS] Epoch 5 | Loss: 0.421245 | Accuracy: 54.50%
[SimpleMPS] Epoch 6 | Loss: 0.389004 | Accuracy: 54.10%
[SimpleMPS] Epoch 7 | Loss: 0.355695 | Accuracy: 72.50%
[SimpleMPS] Epoch 8 | Loss: 0.507542 | Accuracy: 51.00%
[SimpleMPS] Epoch 9 | Loss: 0.416707 | Accuracy: 57.50%
[SimpleMPS] Epoch 10 | Loss: 0.484801 | Accuracy: 51.00%
[SimpleMPS] Epoch 11 | Loss: 0.400196 | Accuracy: 67.90%
[SimpleMPS] Epoch 12 | Loss: 0.352454 | Accuracy: 77.20%
[SimpleMPS] Epoch 13 | Loss: 0.394807 | Accuracy: 69.70%
[SimpleMPS] Epoch 14 | Loss: 0.426074 | Accuracy: 58.90%
[SimpleMPS] Epoch 15 | Loss: 0.330534 | Accuracy: 100.00%
[SimpleMPS] Epoc

In [52]:
from mps import tpcp_mps
from importlib import reload

reload(tpcp_mps)

# Build TPCP model with original manifold
tpcp = tpcp_mps.MPSTPCP(N, K=1, d=2, with_pros=False, with_identity=True, manifold=tpcp_mps.ManifoldType.EXACT)
tpcp.to(device)
tpcp.train()
tpcp.set_canonical_mps(smps, set_r=True)
print("TPCP model built and canonical form set from SimpleMPS.")

# Choose an optimizer for TPCP
opt_tpcp = RiemannianAdam(tpcp.kraus_ops.parameters(), lr=0.01)
print("TPCP optimizer set up.")


TPCP model built and canonical form set from SimpleMPS.
TPCP optimizer set up.


In [62]:
W = torch.zeros(tpcp.L, 2, dtype=torch.float64)
W[:, 0] = 1
W[:, 1] = 0
tpcp.initialize_W(W)

def accuracy(outputs, target):
    correct = (outputs < 0).float() == target.float()
    return correct.float().sum() / target.numel()

def to_probs(outputs):
    outputs = outputs / outputs.sum(dim=-1).unsqueeze(-1)
    return outputs

data, target = next(iter(dataloader))

out = tpcp(data)
probs = to_probs(out)

calculate_accuracy(probs[:, 0], target)

tensor(1.)

In [None]:
max_epochs = 5
for i in range(max_epochs):
  out = tpcp(data)
  probs = to_probs(out)
  loss = loss_batch(probs[:, 0], target)


In [None]:
# Configuration for TPCP training (using fewer epochs and a subset of weight values for debugging)
max_epochs = 5
min_epochs = 2
conv_threshold = 1e-4
log_steps = 5

w_values = [0, 0.5, 1]  # Use a few weight rates for debugging

for w_ in w_values:
    # Re-initialize weight parameter W for current w_
    W = torch.zeros(tpcp.L, 2, dtype=torch.float64, device=device)
    W[:, 0] = 1
    W[:, 1] = w_
    tpcp.initialize_W(W)
    
    print(f"\n=== Training TPCP with initial weight rate w={w_:.1f} (max_epochs={max_epochs}) ===")
    epoch = 0
    prev_epoch_loss = None

    while epoch < max_epochs:
        epoch_loss_sum = 0.0
        total_samples = 0
        epoch_acc_sum = 0.0
        total_acc_samples = 0
        t0 = time.time()
        for step, (data, target) in enumerate(dataloader):
            data = data.to(device)
            target = target.to(device)
            bs = target.size(0)
            
            opt_tpcp.zero_grad()
            outputs = tpcp(data)
            if torch.isnan(outputs).any():
                print(f"NaN detected in outputs at step {step}, skipping batch.")
                continue
            
            loss_val = loss_batch(outputs, target)
            if torch.isnan(loss_val):
                print("NaN detected in batch loss, skipping batch.")
                continue
            
            loss_val.backward()
            opt_tpcp.step()
            tpcp.proj_stiefel(check_on_manifold=True, print_log=False, rtol=1e-3)
            
            epoch_loss_sum += loss_val.item() * bs
            total_samples += bs
            batch_acc = calculate_accuracy(outputs.detach(), target)
            epoch_acc_sum += batch_acc.item() * bs
            total_acc_samples += bs
            
            if (step + 1) % log_steps == 0:
                print(f"[TPCP::w={w_:.1f}] Epoch {epoch+1}, Step {step+1}/{len(dataloader)} | Batch Loss: {loss_val.item():.6f} | Batch Accuracy: {batch_acc.item():.2%}")

        if total_samples == 0:
            print("No samples processed in this epoch, skipping.")
            continue
        
        avg_loss = epoch_loss_sum / total_samples
        avg_acc = epoch_acc_sum / total_acc_samples
        current_weight_rate = (tpcp.W[:, 1] / torch.sum(tpcp.W, dim=1)).mean().item() if hasattr(tpcp, "W") else w_
        
        elapsed = time.time() - t0
        print(f"[TPCP::w={w_:.1f}] Epoch {epoch+1} | Avg Loss: {avg_loss:.6f}, Avg Acc: {avg_acc:.2%}, Weight Rate: {current_weight_rate:.4f} | Time: {elapsed:.2f}s")
        
        if epoch >= min_epochs and prev_epoch_loss is not None:
            if abs(avg_loss - prev_epoch_loss) < conv_threshold:
                print(f"Convergence reached: Loss change below threshold for epoch {epoch+1}.")
                break

        prev_epoch_loss = avg_loss
        epoch += 1

print("\nTPCP training complete.")
