# POS-EGNN ELoRA

In [1]:
import math
import torch
import torch.nn as nn

from ase import Atoms as ASEAtoms
from ase.io import read
from torch_geometric.data import Data, Batch

import sys
sys.path.append('../../')
from posegnn.model import PosEGNN
from posegnn.adapter import PosEGNNLoRAModel, LoRAConfig

## Create dataset

In [2]:
atoms = read("../../inputs/3BPA.xyz", index=0)

def build_data_from_ase(atoms: ASEAtoms) -> Data:
    z = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.long)
    box = torch.tensor(atoms.get_cell().tolist()).unsqueeze(0).float()
    pos = torch.tensor(atoms.get_positions().tolist()).float()
    batch = torch.zeros(len(z), dtype=torch.long)
    return Data(z=z, pos=pos, box=box, batch=batch, num_graphs=1)

data = build_data_from_ase(atoms)
batch = Batch.from_data_list([data])

## Get LoRA model

In [3]:
# Please download checkpoint from https://huggingface.co/ibm-research/materials.pos-egnn
checkpoint_dict = torch.load('../../pytorch_model.bin', weights_only=True, map_location='cpu')
backbone = PosEGNN(checkpoint_dict["config"])
backbone.load_state_dict(checkpoint_dict["state_dict"], strict=True)

cfg = LoRAConfig(
    rank=16,
    alpha=16,                 # uses alpha = rank by default
    dropout=0.0,
    merge_on_save=True,         # saves merged weights for compatibility
    freeze_base=True,           # train only adapters
    log_skipped=True
)
model = PosEGNNLoRAModel(backbone, cfg)
print(model.lora_report())

[lora] skipped post-activation linears:
  - encoder.neighbor_embedding.combine.dense_layers.0
  - encoder.edge_embedding.edge_up.dense_layers.0
  - encoder.gata.0.gamma_s.0
  - encoder.gata.0.gamma_v.0
  - encoder.gata.0.phik_w_ra
  - encoder.gata.0.edge_attr_up.dense_layers.0
  - encoder.gata.1.gamma_s.0
  - encoder.gata.1.gamma_v.0
  - encoder.gata.1.phik_w_ra
  - encoder.gata.1.edge_attr_up.dense_layers.0
  - encoder.gata.2.gamma_s.0
  - encoder.gata.2.gamma_v.0
  - encoder.gata.2.phik_w_ra
  - encoder.gata.2.edge_attr_up.dense_layers.0
  - encoder.gata.3.gamma_s.0
  - encoder.gata.3.gamma_v.0
  - encoder.gata.3.phik_w_ra
  - encoder.eqff.0.gamma_m.0
  - encoder.eqff.1.gamma_m.0
  - encoder.eqff.2.gamma_m.0
  - encoder.eqff.3.gamma_m.0
LoRA injected - scalar layers: 48


## Helpers

In [4]:
def random_SO3(dtype, device):
    A = torch.randn(3, 3, dtype=dtype, device=device)
    Q, _ = torch.linalg.qr(A)
    if torch.det(Q) < 0:
        Q[:, 0] = -Q[:, 0]
    return Q

def rotate_atoms_data(data: Data, R: torch.Tensor) -> Data:
    # Rotate positions and box. Keep everything else.
    pos = data.pos @ R.T
    if hasattr(data, "box") and data.box is not None:
        # box is shape [1, 3, 3] or [3, 3]
        box = data.box
        if box.dim() == 2:
            box_rot = box @ R.T
            box_rot = box_rot.unsqueeze(0)
        else:
            box_rot = box @ R.T
    else:
        box_rot = None
    new = Data(
        z=data.z.clone() if hasattr(data, "z") else None,
        pos=pos,
        box=box_rot,
        batch=data.batch.clone() if hasattr(data, "batch") else None,
        num_graphs=getattr(data, "num_graphs", None),
    )
    return new

def act_block_lastdim(vec, R, block=3):
    # Apply R to each 3-vector block along the last dim
    # vec [..., 3k], R [3, 3]
    c = vec.shape[-1]
    assert c % block == 0, "embedding last dim is not a multiple of 3"
    k = c // block
    v = vec.view(*vec.shape[:-1], k, block)
    vR = torch.einsum("...bi,ij->...bj", v, R)
    return vR.reshape(*vec.shape)

@torch.no_grad()
def cosine(a, b, eps=1e-12):
    num = (a * b).sum()
    den = a.norm() * b.norm() + eps
    return float((num / den).clamp(-1, 1))

# ---------- 0) prep ----------
device = torch.device("cpu")
model.eval()
model.to(device);

## Determinism

In [5]:
with torch.no_grad():
    o1 = model(batch)
    o2 = model(batch)
for k in o1.keys():
    if torch.is_tensor(o1[k]):
        diff = (o1[k] - o2[k]).abs().max().item()
        print(f"[determinism] {k:<20} max|Δ| = {diff:.3e}")

[determinism] embedding_0          max|Δ| = 0.000e+00


## Check LoRA parameters

In [6]:
# how many layers got LoRA
wrapped = []
for name, m in model.backbone.named_modules():
    if hasattr(m, "lora_A") and hasattr(m, "lora_B"):
        wrapped.append(name)
len_wrapped = len(wrapped)
print("wrapped layers:", len_wrapped)
print("\n".join(wrapped[:10]))

wrapped layers: 48
encoder.neighbor_embedding.distance_proj.dense_layers.0
encoder.neighbor_embedding.combine.dense_layers.1
encoder.edge_embedding.edge_up.dense_layers.1
encoder.gata.0.gamma_s.1
encoder.gata.0.q_w
encoder.gata.0.k_w
encoder.gata.0.gamma_v.1
encoder.gata.0.vecq_w
encoder.gata.0.veck_w.0
encoder.gata.0.veck_w.1


In [7]:
total_trainable = 0
for n, p in model.named_parameters():
    if p.requires_grad:
        total_trainable += p.numel()
print("total trainable params:", total_trainable)

# show a few LoRA shapes
for n, p in model.named_parameters():
    if "lora_A" in n or "lora_B" in n:
        print(n, tuple(p.shape))
        break

total trainable params: 504896
backbone.encoder.neighbor_embedding.distance_proj.dense_layers.0.lora_A.weight (16, 64)


## Rotation sanity on embeddings

In [8]:
R = random_SO3(dtype=batch.pos.dtype, device=batch.pos.device)
batch_R = Batch.from_data_list([rotate_atoms_data(data, R)])

with torch.no_grad():
    out = model(batch)
    out_R = model(batch_R)

if "embedding_0" in out and torch.is_tensor(out["embedding_0"]):
    e = out["embedding_0"]
    eR = out_R["embedding_0"]
    print(f"[embed] shape: {tuple(e.shape)}")

    # Invariance check
    inv_err = (e - eR).abs().max().item()
    print(f"[embed] invariance max|Δ| = {inv_err:.3e}")

[embed] shape: (27, 256, 1, 4)
[embed] invariance max|Δ| = 3.338e-06


## Energy invariance and force covariance (if energy available)

In [9]:
# energy_key = None
# for k in ["energy", "y_energy", "E", "total_energy"]:
#     if k in out:
#         energy_key = k
#         break

# if energy_key is not None:
#     # Build fresh tensors with grad for force test
#     d1 = Batch.from_data_list([build_data_from_ase(atoms)])
#     d1.pos.requires_grad_(True)
#     E1 = model(d1)[energy_key]   # scalar

#     d2 = Batch.from_data_list([rotate_atoms_data(build_data_from_ase(atoms), R)])
#     d2.pos.requires_grad_(True)
#     E2 = model(d2)[energy_key]   # scalar

#     # energy invariance
#     e_err = (E2.detach() - E1.detach()).abs().item()
#     print(f"[energy] |E(Rx) - E(x)| = {e_err:.3e}")

#     # forces = -dE/dx, covariance: F(Rx) = R F(x)
#     (F1,) = torch.autograd.grad(E1, d1.pos, retain_graph=False)
#     (F2,) = torch.autograd.grad(E2, d2.pos, retain_graph=False)
#     F1_equiv = F1 @ R.T
#     f_err = (F2 - F1_equiv).abs().max().item()
#     print(f"[forces] covariance max|Δ| = {f_err:.3e}")
# else:
#     print("[energy] no energy key found in outputs. Skipping energy/force checks.")

## LoRA merge correctness

In [10]:
import io
import torch

def _as_float_tensor(x):
    return x.detach().to(dtype=torch.float32)

def _compare_outputs(a, b):
    if isinstance(a, dict) and isinstance(b, dict):
        keys = sorted(set(a.keys()) & set(b.keys()))
        diffs = {}
        for k in keys:
            ta = _as_float_tensor(a[k])
            tb = _as_float_tensor(b[k])
            diffs[k] = (ta - tb).abs().max().item()
        return max(diffs.values()) if diffs else 0.0, diffs
    ta = _as_float_tensor(a)
    tb = _as_float_tensor(b)
    d = (ta - tb).abs().max().item()
    return d, {"_": d}

try:
    # 0) make sure dropout is off for determinism
    model.eval()

    # 1) reference with adapters ENABLED (this is what the merged backbone should match)
    model.enable_adapter()
    with torch.no_grad():
        o_ref = model(batch)

    # 2) export merged backbone (plain PosEGNN keys) into an in-memory buffer
    buf = io.BytesIO()
    merged_sd = model.state_dict_backbone(merged=True)
    torch.save(merged_sd, buf)
    buf.seek(0)

    # 3) load merged weights into a fresh plain backbone
    backbone2 = PosEGNN(checkpoint_dict["config"])
    state = torch.load(buf, map_location="cpu")
    missing, unexpected = backbone2.load_state_dict(state, strict=True)
    print(f"[merge] missing: {len(missing)}, unexpected: {len(unexpected)}")  # expect 0, 0

    # 4) numerics: merged backbone vs adapter-enabled model
    backbone2.eval()
    with torch.no_grad():
        o_merge = backbone2(batch)

    max_diff, per_key = _compare_outputs(o_ref, o_merge)
    print(f"[merge] max abs diff = {max_diff:.3e}")
    if isinstance(per_key, dict):
        for k, v in per_key.items():
            print(f"  {k}: {v:.3e}")

except Exception as e:
    print(f"[merge] test failed: {e}")

[merge] missing: 0, unexpected: 0
[merge] max abs diff = 0.000e+00
  embedding_0: 0.000e+00


  state = torch.load(buf, map_location="cpu")


In [11]:
def find_nonmergeable_wrapped(mod):
    bad = []
    for name, m in model.backbone.named_modules():
        if hasattr(m, "base") and hasattr(m, "lora_A") and hasattr(m, "lora_B"):
            act = getattr(m.base, "activation", None)
            if act is not None and not isinstance(act, torch.nn.Identity):
                bad.append(name)
    return bad

bad = find_nonmergeable_wrapped(model)
print("[wrapped + post-act]", len(bad))
for n in bad[:10]:
    print(" ", n)

[wrapped + post-act] 0


In [12]:
def audit_nonmergeable(model):
    bad = []
    for name, m in model.backbone.named_modules():
        if hasattr(m, "base") and hasattr(m.base, "weight"):
            act = getattr(m.base, "activation", None)
            has_post = act is not None and not isinstance(act, torch.nn.Identity)
            if has_post:
                bad.append((name, type(act).__name__))
    print("[audit] non-mergeable wrapped layers (have activation):")
    for n, a in bad:
        print(f"  {n}  (activation={a})")
    print(f"count: {len(bad)}")
    return bad

_ = audit_nonmergeable(model)

[audit] non-mergeable wrapped layers (have activation):
count: 0


## Check requires grad

In [13]:
model.train()
for n, p in model.named_parameters():
    if not p.requires_grad:
        continue
    p.grad = None
loss = 0.0
out_train = model(batch)

if "embedding_0" in out_train and torch.is_tensor(out_train["embedding_0"]):
    loss = out_train["embedding_0"].pow(2).mean()
else:
    # fallback: sum of any float tensor in outputs
    for v in out_train.values():
        if torch.is_tensor(v) and v.dtype.is_floating_point:
            loss = v.sum()
            break
loss.backward()

num_lora_grads = 0
num_base_grads = 0
for n, p in model.named_parameters():
    if p.grad is None:
        continue
    if "lora" in n.lower():
        num_lora_grads += 1
    else:
        num_base_grads += 1
print(f"[grads] LoRA params with grad: {num_lora_grads}, base params with grad: {num_base_grads}")
model.eval();

[grads] LoRA params with grad: 74, base params with grad: 0
