In [1]:
import os
import math
import time
import copy
import numpy as np
import gc 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import nonlinear_dagmpn as mpn 

class SequentialMNIST(torch.utils.data.Dataset):
    """
    """
    def __init__(self, root, train=True, download=True, normalize=True):
        tfms = [transforms.ToTensor()]
        self.ds = datasets.MNIST(root=root, train=train, download=download, transform=transforms.Compose(tfms))
        self.normalize = normalize

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        x, y = self.ds[idx]             
        x = x.view(-1)                  
        if self.normalize:
            x = (x - 0.1307) / 0.3081
        x_seq = x.unsqueeze(-1)        
        return x_seq, y


def collate_seq(batch):
    """
    """
    xs, ys = zip(*batch)
    x = torch.stack(xs, dim=0)
    y = torch.tensor(ys, dtype=torch.long)
    return x, y


In [2]:
def print_cuda_tensor_shapes(limit=None, sort_by_numel=True, include_nonleaf=True):
    """
    Prints shapes (and a bit more) for all live torch tensors on CUDA.

    Notes:
    - This lists tensors that are still referenced by Python (reachable by GC).
    - It may include duplicates (views). We de-duplicate by storage data_ptr.
    """
    cuda_tensors = []
    seen = set()

    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj):
                t = obj
            elif hasattr(obj, "data") and torch.is_tensor(obj.data):
                # Parameters and some wrappers
                t = obj.data
            else:
                continue

            if t.is_cuda:
                # de-dup by underlying storage pointer (works for views)
                try:
                    key = (t.untyped_storage().data_ptr(), t.storage_offset(), tuple(t.size()), str(t.dtype))
                except Exception:
                    key = (t.data_ptr(), tuple(t.size()), str(t.dtype))

                if key in seen:
                    continue
                seen.add(key)

                if (not include_nonleaf) and (t.grad_fn is not None):
                    continue

                cuda_tensors.append(t)
        except Exception:
            pass

    if sort_by_numel:
        cuda_tensors.sort(key=lambda x: x.numel(), reverse=True)

    if limit is not None:
        cuda_tensors = cuda_tensors[:limit]

    total_bytes = 0
    for i, t in enumerate(cuda_tensors, 1):
        nbytes = t.numel() * t.element_size()
        total_bytes += nbytes
        print(
            f"[{i:04d}] shape={tuple(t.shape)} dtype={t.dtype} "
            f"device={t.device} requires_grad={t.requires_grad} "
            f"bytes={nbytes/1024**2:.2f}MB"
        )

    print(f"\nCount: {len(cuda_tensors)} tensors")
    print(f"Estimated total (sum of listed tensor sizes): {total_bytes/1024**2:.2f}MB")


In [3]:
@torch.no_grad()
def evaluate(net, loader, device, chunk_size=64):
    net.eval()
    correct = 0
    total = 0
    for x, y in loader:
        x = x.to(device, non_blocking=True)  # (B,T,D)
        y = y.to(device, non_blocking=True)

        # Convert to (T,B,D)
        x_TBD = x.transpose(0, 1).contiguous()

        logits_last = net.forward_sequence_checkpointed(
            x_TBD,
            chunk_size=chunk_size,
            Ms0=None,
            run_mode="minimal",
        )

        pred = logits_last.argmax(dim=-1)
        correct += (pred == y).sum().item()
        total += y.numel()

    return correct / max(total, 1)

def count_parameter(net):
    """
    """
    trainable = [(n, p) for n, p in net.named_parameters() if p.requires_grad]
    
    # Print a readable summary
    total = 0
    for n, p in trainable:
        num = p.numel()
        total += num
        print(f"{n:50s}  shape={tuple(p.shape)}  numel={num}")
    
    print(f"\nTotal trainable parameters: {total}")

# Example usage:
# print_cuda_params(net)
from torch.cuda.amp import autocast, GradScaler

def train_sequential_mnist(
    device="cuda",
    data_root="./data",
    hidden_dim=256,
    batch_size=64,
    lr=1.0,
    eta=1e-3,
    epochs=5,
    chunk_size=28,
    use_amp=False,
):
    device = torch.device(device if torch.cuda.is_available() else "cpu")
    print(f"device: {device}")

    # Data
    train_ds = SequentialMNIST(root=data_root, train=True, download=True, normalize=True)
    test_ds  = SequentialMNIST(root=data_root, train=False, download=True, normalize=True)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2,
                              pin_memory=True, collate_fn=collate_seq)
    test_loader  = DataLoader(test_ds, batch_size=batch_size*10, shuffle=False, num_workers=2,
                              pin_memory=True, collate_fn=collate_seq)

    net_params = {
        "n_neurons": [1] + [hidden_dim] + [10],
        "linear_embed": 100,
        "dt": 1.0,
        "activation": "tanh",
        "output_bias": True,
        "W_output_init": "xavier",
        "input_layer_add": True,
        'input_layer_add_trainable': True,
        'input_layer_bias': False,
        "output_matrix": "",

        'ml_params': {
            'bias': True,
            'mp_type': 'mult',
            'm_update_type': 'hebb_assoc',
            'eta_type': 'scalar',
            'eta_train': False,
            'lam_type': 'scalar',
            'm_time_scale': 1000,
            'lam_train': False,
            'W_freeze': False,
        },
    }
    net_params.update({
        "dag_W_gain": 0.9,
        "dag_W_diag_decay": 0.01,
        "dag_row_l1_target": 0,
    })
    net_params["ml_params"].update({
        "m_time_scale": 1000,  # can override / ensure present
        "eta": eta,
        # optionally: "eta": 1e-3,  # if you want to hard-set eta instead of eta0
    })
    
    net = mpn.DeepMultiPlasticNet(net_params, verbose=True, forzihan=False).to(device)

    count_parameter(net)

    opt = torch.optim.Adam([p for p in net.parameters() if p.requires_grad], lr=lr)
    criterion = nn.CrossEntropyLoss()

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt, mode="max", factor=0.5, patience=2, threshold=1e-3, min_lr=1e-6, verbose=True
    )

    scaler = GradScaler(enabled=(use_amp and device.type == "cuda"))

    stats = {"test_acc": []}

    for ep in range(1, epochs + 1):
        net.train()
        t0 = time.time()
        running_loss = 0.0
        n_batches = 0

        for x, y in train_loader:
            x = x.to(device, non_blocking=True)   # (B,T,1)
            y = y.to(device, non_blocking=True)

            # Convert to (T,B,D) for forward_sequence_checkpointed
            x_TBD = x.transpose(0, 1).contiguous()

            opt.zero_grad(set_to_none=True)

            # Forward with time checkpointing
            with autocast(enabled=(use_amp and device.type == "cuda")):#autocast(device_type="cuda", enabled=(use_amp and device.type == "cuda")):
                logits_last = net.forward_sequence_checkpointed(
                    x_TBD,
                    chunk_size=chunk_size,
                    Ms0=None,
                    run_mode="minimal",
                )
                loss = criterion(logits_last, y)

            # Backward
            if scaler.is_enabled():
                scaler.scale(loss).backward()
                scaler.step(opt)
                scaler.update()
            else:
                loss.backward()
                print("grad W_out:", None if net.W_out.grad is None else net.W_out.grad.norm().item())
                print("grad W:",     None if net.W.grad is None else net.W.grad.norm().item())
                print("grad W_in:", None if net.W_in.grad is None else net.W_in.grad.norm().item())
                opt.step()

            net.param_clamp()

            # Logging: 
            if (n_batches % 2) == 0:
                with torch.no_grad():
                    W0 = net.mp_layers[0].W.detach()# add .detach()? 
                    W_abs_mean = W0.abs().mean().item()

                print(f"[ep {ep} | batch {n_batches}] loss={loss.item():.4f} "
                      f"|M| |W| mean={W_abs_mean:.3e}")
                torch.cuda.synchronize()
                alloc = torch.cuda.memory_allocated() / 1024**3
                reserv = torch.cuda.memory_reserved() / 1024**3
                print(f"ep {ep} batch {n_batches} | loss={loss.item():.4f} | alloc={alloc:.2f}GB reserv={reserv:.2f}GB", flush=True)

            running_loss += float(loss.item())
            n_batches += 1
            del logits_last, loss

        train_loss = running_loss / max(n_batches, 1)
        test_acc = evaluate(net, test_loader, device=device, chunk_size=max(chunk_size, 64))
        scheduler.step(test_acc)

        dt = time.time() - t0
        current_lr = opt.param_groups[0]["lr"]
        print(f"Epoch {ep:02d}/{epochs} | lr={current_lr:.2e} | loss={train_loss:.4f} | test_acc={test_acc*100:.2f}% | dt={dt:.1f}s")

        stats["test_acc"].append(test_acc)

    return net, net_params, stats


In [4]:
trained_net, params = train_sequential_mnist(
                device="cuda",
                hidden_dim=300,
                batch_size=640,
                eta=0.9,
                lr=1e-3,
                epochs=20)

device: cuda
[Nonlinear DAG-MPN per-sample M] d_in=1, d_h=300, d_out=10, lam=0.999000, eta=0.9, h_mode=blockwise, h_block=64
W_in                                                shape=(300, 1)  numel=300
W                                                   shape=(300, 300)  numel=90000
W_out                                               shape=(10, 300)  numel=3000
b_out                                               shape=(10,)  numel=10

Total trainable parameters: 93310
grad W_out: 0.030221162363886833
grad W: 0.03092757798731327
grad W_in: 0.03363049775362015
[ep 1 | batch 0] loss=2.3043 |M| |W| mean=1.358e-02
ep 1 batch 0 | loss=2.3043 | alloc=0.01GB reserv=12.23GB
grad W_out: 0.014759446494281292
grad W: 0.014993693679571152
grad W_in: 0.016082575544714928
grad W_out: 0.031168051064014435
grad W: 0.03255341202020645
grad W_in: 0.03271140530705452
[ep 1 | batch 2] loss=2.3057 |M| |W| mean=1.362e-02
ep 1 batch 2 | loss=2.3057 | alloc=0.01GB reserv=12.24GB
grad W_out: 0.0235968045890331

ValueError: too many values to unpack (expected 2)