In [1]:
%cd drive/MyDrive/CGformer

/content/drive/MyDrive/CGformer


In [2]:
!pip install torch torchvision torchaudio



In [3]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m66.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.7.0


In [4]:
!pip install pymatgen

Collecting pymatgen
  Downloading pymatgen-2025.10.7-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)
Collecting bibtexparser>=1.4.0 (from pymatgen)
  Downloading bibtexparser-1.4.3.tar.gz (55 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/55.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.6/55.6 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting monty>=2025.1.9 (from pymatgen)
  Downloading monty-2025.3.3-py3-none-any.whl.metadata (3.6 kB)
Collecting palettable>=3.3.3 (from pymatgen)
  Downloading palettable-3.3.3-py2.py3-none-any.whl.metadata (3.3 kB)
Collecting ruamel.yaml>=0.17.0 (from pymatgen)
  Downloading ruamel_yaml-0.19.1-py3-none-any.whl.metadata (16 kB)
Collecting spglib>=2.5 (from pymatgen)
  Downloading spglib-2.7.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

In [5]:
import torch
import torch_geometric
from pymatgen.core.structure import Structure
import numpy as np
import sklearn

print(f"PyTorch: {torch.__version__}")
print(f"PyTorch Geometric: {torch_geometric.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Current device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

PyTorch: 2.9.0+cu126
PyTorch Geometric: 2.7.0
CUDA available: True
Current device: NVIDIA A100-SXM4-80GB


In [11]:
def parse_poscar_string(poscar_str: str) -> dict:
    """Parse POSCAR string into structured dict with sublattice info"""
    lines = [l.strip() for l in poscar_str.strip().split('\n')]

    # Parse header
    cell_name = lines[0]
    latt_const = float(lines[1])

    # Lattice vectors
    base = []
    for i in range(2, 5):
        base.append([float(x) for x in lines[i].split()])

    # Element names and counts
    ele_names = lines[5].split()
    atom_nums = [int(x) for x in lines[6].split()]
    atom_sum = sum(atom_nums)

    # Coordinate type
    coord_type = lines[7]

    # Positions
    positions = []
    for i in range(8, 8 + atom_sum):
        parts = lines[i].split()
        pos = [float(parts[0]), float(parts[1]), float(parts[2])]
        positions.append(pos)

    return {
        'CellName': cell_name,
        'LattConst': latt_const,
        'Base': base,
        'EleName': ele_names,
        'AtomNum': atom_nums,
        'AtomSum': atom_sum,
        'LatType': coord_type,
        'LattPnt': positions
    }

# Before it goest to tensor, we handle all the sublattice masks here in CPU since GPU is not adequate for string processing
def poscar_to_tensors(poscar: dict, device='cpu') -> dict:

    """Convert POSCAR to tensors with sublattice masks"""
    # Positions [N, 3]
    positions = torch.tensor(poscar['LattPnt'], dtype=torch.float32, device=device)
    # Atom types [N]
    atom_types = []

    for type_idx, count in enumerate(poscar['AtomNum']):
        atom_types.extend([type_idx] * count)
    atom_types = torch.tensor(atom_types, dtype=torch.long, device=device)
    # Lattice [3, 3]
    lattice = torch.tensor(poscar['Base'], dtype=torch.float32, device=device)
    # Type mapping: Sr=0, Ti=1, Fe=2, O=3, VO=4
    ele_names = poscar['EleName']
    type_map = {name: idx for idx, name in enumerate(ele_names)}
    # Sublattice masks
    b_site_mask = (atom_types == type_map['Ti']) | (atom_types == type_map['Fe'])
    o_site_mask = (atom_types == type_map['O']) | (atom_types == type_map['VO'])

    return {
        'positions': positions,
        'atom_types': atom_types,
        'lattice': lattice,
        'b_site_mask': b_site_mask,
        'o_site_mask': o_site_mask,
        'element_names': ele_names,
        'type_map': type_map,
        'atom_counts': poscar['AtomNum']
    }


@torch.no_grad()
def swap_by_idx(x: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
    """
    Swap elements at idx position (fully parallelized)

    Args:
        x: [batch, N] atom types
        idx: [batch, 2] indices to swap

    Returns:
        x_swapped: [batch, N] atom types after swap
    """

    first = x.gather(-1, idx[..., 0:1])   # [batch, 1] # Get the elements at the first indices
    second = x.gather(-1, idx[..., 1:2])  # [batch, 1]

    x_swapped = x.clone()
    x_swapped.scatter_(-1, idx[..., 0:1], second)
    x_swapped.scatter_(-1, idx[..., 1:2], first)

    return x_swapped


@torch.no_grad()
def sample_sublattice_swap(
    atom_types: torch.Tensor, # [batch, N]
    sublattice_mask: torch.Tensor, # [N]
    type_a: int, # The type of atoms to swap -> "Ti, Fe" or "O Vo"
    type_b: int, # The type of atoms to swap -> "Ti, Fe" or "O Vo" if a is Ti then b is Fe
    scores: torch.Tensor = None # [batch, N] or None
):

    device = atom_types.device
    batch_size, N = atom_types.shape

    sub_idx = torch.where(sublattice_mask)[0]  # Return the index where True (Mask True )
    M = len(sub_idx) # M is the length of mask (the length of sublattice that we choose)

    sub_types = atom_types[:, sub_idx]  # [batch, M] Extract only the sublattice atoms

    is_a = (sub_types == type_a)  # [batch, M]
    is_b = (sub_types == type_b)  # [batch, M]

    if scores is None:
        sub_scores = torch.zeros(batch_size, M, device = device)
    else:
        sub_scores = scores[:, sub_idx]

    noise = torch.rand(batch_size, M, device = device).clamp(min=1e-10)
    gumbel = -torch.log(-torch.log(noise))

    score_a = sub_scores + gumbel
    score_b = sub_scores + gumbel
    score_a = score_a.masked_fill(~is_a, float('-inf'))
    score_b = score_b.masked_fill(~is_b, float('-inf'))

    local_a = torch.argmax(score_a, dim=-1) # [batch]
    local_b = torch.argmax(score_b, dim=-1) # [batch]

    global_a = sub_idx[local_a]  # [batch]
    global_b = sub_idx[local_b]  # [batch]
    indices = torch.stack([global_a, global_b], dim=-1)  # [batch, 2]
    swapped = swap_by_idx(atom_types, indices)
    return swapped, indices

@torch.no_grad()
def apply_n_swaps(
    atom_types: torch.Tensor,       # [batch, N]
    b_site_mask: torch.Tensor,      # [N]
    o_site_mask: torch.Tensor,      # [N]
    type_map: dict,
    n_swaps: int,
    swap_mode: str = 'both'
) -> tuple:
    """
    Apply n swap steps (batch parallel at each step)

    Args:
        atom_types: [batch, N]
        b_site_mask, o_site_mask: sublattice masks
        type_map: element name -> type index
        n_swaps: number of swap steps
        swap_mode: 'B-site', 'O-site', or 'both'
    Returns:
        final: [batch, N] after all swaps
        history: list of (sublattice, indices) tuples
    """
    current = atom_types.clone()
    history = []

    ti, fe = type_map['Ti'], type_map['Fe']
    o, vo = type_map['O'], type_map['VO']

    for step in range(n_swaps):
        # Choose sublattice
        if swap_mode == 'B-site':
            do_b = True
        elif swap_mode == 'O-site':
            do_b = False
        else:
            do_b = torch.rand(1).item() < 0.5

        if do_b:
            current, idx = sample_sublattice_swap(current, b_site_mask, ti, fe)
            history.append(('B', idx.clone()))
        else:
            current, idx = sample_sublattice_swap(current, o_site_mask, o, vo)
            history.append(('O', idx.clone()))

    return current, history

In [7]:
# CPU Sequential
def cpu_sequential_swap(atom_types, b_site_mask, o_site_mask, type_map, n_swaps):
    atom_types = atom_types.cpu().numpy()
    b_indices = np.where(b_site_mask.cpu().numpy())[0]
    o_indices = np.where(o_site_mask.cpu().numpy())[0]

    ti_idx = type_map['Ti']
    fe_idx = type_map['Fe']
    o_idx = type_map['O']
    vo_idx = type_map['VO']

    batch_size = atom_types.shape[0]

    for b in range(batch_size):
        for step in range(n_swaps):
            if np.random.random() < 0.5:
                sub_types = atom_types[b, b_indices]
                ti_pos = np.where(sub_types == ti_idx)[0]
                fe_pos = np.where(sub_types == fe_idx)[0]

                if len(ti_pos) > 0 and len(fe_pos) > 0:
                    i = b_indices[np.random.choice(ti_pos)]
                    j = b_indices[np.random.choice(fe_pos)]
                    atom_types[b, i], atom_types[b, j] = atom_types[b, j], atom_types[b, i]
            else:
                sub_types = atom_types[b, o_indices]
                o_pos = np.where(sub_types == o_idx)[0]
                vo_pos = np.where(sub_types == vo_idx)[0]

                if len(o_pos) > 0 and len(vo_pos) > 0:
                    i = o_indices[np.random.choice(o_pos)]
                    j = o_indices[np.random.choice(vo_pos)]
                    atom_types[b, i], atom_types[b, j] = atom_types[b, j], atom_types[b, i]

    return torch.from_numpy(atom_types)

In [8]:
poscar_str = """SrTiFeO
1.000000
11.199000 0.000000 0.000000
0.000000 11.199000 0.000000
0.000000 0.000000 15.983000
Sr Ti Fe O VO
32 16 16 88 8
Direct
0.000000 0.250000 0.125000
0.000000 0.250000 0.625000
0.000000 0.750000 0.125000
0.000000 0.750000 0.625000
0.500000 0.250000 0.125000
0.500000 0.250000 0.625000
0.500000 0.750000 0.125000
0.500000 0.750000 0.625000
0.000000 0.250000 0.375000
0.000000 0.250000 0.875000
0.000000 0.750000 0.375000
0.000000 0.750000 0.875000
0.500000 0.250000 0.375000
0.500000 0.250000 0.875000
0.500000 0.750000 0.375000
0.500000 0.750000 0.875000
0.250000 0.000000 0.125000
0.250000 0.000000 0.625000
0.250000 0.500000 0.125000
0.250000 0.500000 0.625000
0.750000 0.000000 0.125000
0.750000 0.000000 0.625000
0.750000 0.500000 0.125000
0.750000 0.500000 0.625000
0.250000 0.000000 0.375000
0.250000 0.000000 0.875000
0.250000 0.500000 0.375000
0.250000 0.500000 0.875000
0.750000 0.000000 0.375000
0.750000 0.000000 0.875000
0.750000 0.500000 0.375000
0.750000 0.500000 0.875000
0.000000 0.000000 0.000000
0.250000 0.750000 0.500000
0.750000 0.250000 0.500000
0.500000 0.000000 0.000000
0.000000 0.500000 0.000000
0.500000 0.500000 0.500000
0.750000 0.250000 0.750000
0.000000 0.500000 0.500000
0.750000 0.750000 0.500000
0.250000 0.750000 0.750000
0.250000 0.250000 0.750000
0.750000 0.750000 0.750000
0.500000 0.500000 0.000000
0.500000 0.000000 0.500000
0.000000 0.000000 0.500000
0.250000 0.250000 0.500000
0.750000 0.250000 0.000000
0.250000 0.750000 0.250000
0.500000 0.500000 0.750000
0.750000 0.750000 0.000000
0.000000 0.500000 0.250000
0.500000 0.000000 0.750000
0.250000 0.250000 0.250000
0.000000 0.500000 0.750000
0.750000 0.250000 0.250000
0.250000 0.750000 0.000000
0.000000 0.000000 0.250000
0.000000 0.000000 0.750000
0.750000 0.750000 0.250000
0.250000 0.250000 0.000000
0.500000 0.000000 0.250000
0.500000 0.500000 0.250000
0.000000 0.000000 0.125000
0.000000 0.000000 0.625000
0.000000 0.500000 0.125000
0.000000 0.500000 0.625000
0.500000 0.000000 0.125000
0.500000 0.000000 0.625000
0.500000 0.500000 0.125000
0.500000 0.500000 0.625000
0.000000 0.000000 0.375000
0.000000 0.000000 0.875000
0.000000 0.500000 0.375000
0.000000 0.500000 0.875000
0.500000 0.000000 0.375000
0.500000 0.000000 0.875000
0.500000 0.500000 0.375000
0.500000 0.500000 0.875000
0.250000 0.250000 0.375000
0.250000 0.250000 0.875000
0.250000 0.750000 0.375000
0.250000 0.750000 0.875000
0.750000 0.250000 0.375000
0.750000 0.250000 0.875000
0.750000 0.750000 0.375000
0.750000 0.750000 0.875000
0.250000 0.250000 0.125000
0.250000 0.250000 0.625000
0.250000 0.750000 0.125000
0.250000 0.750000 0.625000
0.750000 0.250000 0.125000
0.750000 0.250000 0.625000
0.750000 0.750000 0.125000
0.750000 0.750000 0.625000
0.141000 0.391000 0.250000
0.391000 0.359000 0.250000
0.641000 0.391000 0.250000
0.891000 0.359000 0.250000
0.141000 0.891000 0.250000
0.391000 0.859000 0.250000
0.641000 0.891000 0.250000
0.891000 0.859000 0.250000
0.141000 0.109000 0.500000
0.391000 0.141000 0.500000
0.641000 0.109000 0.500000
0.891000 0.141000 0.500000
0.109000 0.359000 0.500000
0.359000 0.391000 0.500000
0.609000 0.359000 0.500000
0.859000 0.391000 0.500000
0.141000 0.609000 0.500000
0.391000 0.641000 0.500000
0.641000 0.609000 0.500000
0.891000 0.641000 0.500000
0.109000 0.859000 0.500000
0.359000 0.891000 0.500000
0.609000 0.859000 0.500000
0.859000 0.891000 0.500000
0.109000 0.141000 0.750000
0.359000 0.109000 0.750000
0.609000 0.141000 0.750000
0.859000 0.109000 0.750000
0.141000 0.391000 0.750000
0.391000 0.359000 0.750000
0.641000 0.391000 0.750000
0.891000 0.359000 0.750000
0.109000 0.641000 0.750000
0.359000 0.609000 0.750000
0.609000 0.641000 0.750000
0.859000 0.609000 0.750000
0.141000 0.891000 0.750000
0.391000 0.859000 0.750000
0.641000 0.891000 0.750000
0.891000 0.859000 0.750000
0.141000 0.109000 0.000000
0.391000 0.141000 0.000000
0.641000 0.109000 0.000000
0.891000 0.141000 0.000000
0.109000 0.359000 0.000000
0.359000 0.391000 0.000000
0.609000 0.359000 0.000000
0.859000 0.391000 0.000000
0.141000 0.609000 0.000000
0.391000 0.641000 0.000000
0.641000 0.609000 0.000000
0.891000 0.641000 0.000000
0.109000 0.859000 0.000000
0.359000 0.891000 0.000000
0.609000 0.859000 0.000000
0.859000 0.891000 0.000000
0.109000 0.141000 0.250000
0.359000 0.109000 0.250000
0.609000 0.141000 0.250000
0.859000 0.109000 0.250000
0.109000 0.641000 0.250000
0.359000 0.609000 0.250000
0.609000 0.641000 0.250000
0.859000 0.609000 0.250000"""

In [9]:
def create_random_batch(poscar_str: str, batch_size: int = 1024):
    """Create batch with different random Ti/Fe and O/VO arrangements"""

    # Parse base structure
    poscar = parse_poscar_string(poscar_str)
    tensors = poscar_to_tensors(poscar, device='cuda')

    base_types = tensors['atom_types']  # [160]
    b_mask = tensors['b_site_mask']     # [160]
    o_mask = tensors['o_site_mask']     # [160]

    # Replicate to batch
    batch_types = base_types.unsqueeze(0).repeat(batch_size, 1)  # [B, 160]

    # Randomize B-site (Ti ↔ Fe)
    b_indices = torch.where(b_mask)[0]  # 32 B-site positions
    for b in range(batch_size):
        perm = torch.randperm(len(b_indices))
        batch_types[b, b_indices] = base_types[b_indices[perm]]

    # Randomize O-site (O ↔ VO)
    o_indices = torch.where(o_mask)[0]  # 96 O-site positions
    for b in range(batch_size):
        perm = torch.randperm(len(o_indices))
        batch_types[b, o_indices] = base_types[o_indices[perm]]

    return batch_types, tensors

# Usage
X, tensors = create_random_batch(poscar_str, batch_size=10000)
print(X.shape)

torch.Size([10000, 160])


In [10]:
import time

In [13]:
n_swaps = 10000
b_indices = torch.where(tensors['b_site_mask'])[0][:16]
config_before = ''.join(['T' if x == tensors['type_map']['Ti'] else 'F'
                        for x in X[0, b_indices]])
formatted_before = ' '.join([config_before[i:i+4] for i in range(0, 16, 4)])

print(f"Before: [{formatted_before}]")

# GPU swap
torch.cuda.synchronize()
start = time.time()
X_swapped, _ = apply_n_swaps(X, tensors['b_site_mask'],
                             tensors['o_site_mask'], tensors['type_map'],
                             n_swaps=n_swaps, swap_mode='both')
torch.cuda.synchronize()
elapsed = time.time() - start

# After
config_after = ''.join(['T' if x == tensors['type_map']['Ti'] else 'F'
                       for x in X_swapped[0, b_indices]])
formatted_after = ' '.join([config_after[i:i+4] for i in range(0, 16, 4)])

print(f"After:  [{formatted_after}]")
print(f"\n{10000} samples × {n_swaps} swaps = {10000*n_swaps:,} total swaps in {elapsed:.3f}s")
print(f"Throughput: {(10000*n_swaps)/elapsed:,.0f} swaps/sec")

[1;30;43m스트리밍 출력 내용이 길어서 마지막 5000줄이 삭제되었습니다.[0m
        [0, 0, 0,  ..., 3, 3, 3],
        ...,
        [0, 0, 0,  ..., 3, 3, 3],
        [0, 0, 0,  ..., 3, 3, 4],
        [0, 0, 0,  ..., 3, 3, 4]], device='cuda:0')
tensor([[0, 0, 0,  ..., 3, 3, 4],
        [0, 0, 0,  ..., 3, 3, 3],
        [0, 0, 0,  ..., 3, 3, 3],
        ...,
        [0, 0, 0,  ..., 3, 3, 3],
        [0, 0, 0,  ..., 3, 3, 4],
        [0, 0, 0,  ..., 3, 3, 4]], device='cuda:0')
tensor([[0, 0, 0,  ..., 3, 3, 4],
        [0, 0, 0,  ..., 3, 3, 4],
        [0, 0, 0,  ..., 3, 3, 3],
        ...,
        [0, 0, 0,  ..., 3, 3, 3],
        [0, 0, 0,  ..., 3, 3, 4],
        [0, 0, 0,  ..., 3, 3, 4]], device='cuda:0')
tensor([[0, 0, 0,  ..., 3, 3, 4],
        [0, 0, 0,  ..., 3, 3, 4],
        [0, 0, 0,  ..., 3, 3, 3],
        ...,
        [0, 0, 0,  ..., 3, 3, 3],
        [0, 0, 0,  ..., 3, 3, 4],
        [0, 0, 0,  ..., 3, 3, 4]], device='cuda:0')
tensor([[0, 0, 0,  ..., 3, 3, 4],
        [0, 0, 0,  ..., 3, 3, 4],
        [0