In [1]:
import pandas as pd
from pathlib import Path
MP_20_train = pd.read_csv("cdvae/data/mp_20/train.csv", index_col=0)

In [2]:
from pymatgen.io.cif import CifParser
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

In [3]:
from multiprocessing import Pool
def structure_from_cif(cif: str):
    return CifParser.from_str(cif).get_structures()[0]

def MP_to_pickle(MP_csv: Path):
    """
    Convert MP csv to pickle file
    """
    MP_csv = Path(MP_csv)
    MP_df = pd.read_csv(MP_csv, index_col=0)
    with Pool() as pool:
        MP_df["structure"] = pool.map(structure_from_cif, MP_df["cif"])
    MP_df.drop(columns=["cif"], inplace=True)
    MP_df.to_pickle(MP_csv.parent / (MP_csv.stem + ".pkl"))
    return MP_df

In [4]:
MP_20_train = MP_to_pickle("cdvae/data/mp_20/train.csv")
MP_20_test = MP_to_pickle("cdvae/data/mp_20/test.csv")
MP_20_val = MP_to_pickle("cdvae/data/mp_20/val.csv")



In [310]:
from pymatgen.core import Structure
def structure_to_sites(structure: Structure):
    analyzer = SpacegroupAnalyzer(structure)
    symmetry_dataset = analyzer.get_symmetry_dataset()
    orbit_dict = dict(zip(symmetry_dataset['crystallographic_orbits'], symmetry_dataset['site_symmetry_symbols']))
    return list(orbit_dict.values())

In [314]:
analyzer = SpacegroupAnalyzer(MP_20_test["structure"].iloc[0])

In [316]:
analyzer.get_symmetry_dataset().keys()

dict_keys(['number', 'hall_number', 'international', 'hall', 'choice', 'transformation_matrix', 'origin_shift', 'rotations', 'translations', 'wyckoffs', 'site_symmetry_symbols', 'crystallographic_orbits', 'equivalent_atoms', 'primitive_lattice', 'mapping_to_primitive', 'std_lattice', 'std_types', 'std_positions', 'std_rotation_matrix', 'std_mapping_to_primitive', 'pointgroup'])

In [6]:
from multiprocessing import Pool
for dataset in (MP_20_train, MP_20_test, MP_20_val):
    with Pool() as p:
        dataset['orbit_list'] = p.map(structure_to_sites, dataset['structure'])

spglib: No point group was found (line 405, /project/src/pointgroup.c).
spglib: Attempt 0 tolerance = 1.000000e-02 failed(line 800, /project/src/spacegroup.c).
spglib: No point group was found (line 405, /project/src/pointgroup.c).
spglib: Attempt 1 tolerance = 9.500000e-03 failed(line 800, /project/src/spacegroup.c).
spglib: No point group was found (line 405, /project/src/pointgroup.c).
spglib: Attempt 2 tolerance = 9.025000e-03 failed(line 800, /project/src/spacegroup.c).
spglib: No point group was found (line 405, /project/src/pointgroup.c).
spglib: Attempt 3 tolerance = 8.573750e-03 failed(line 800, /project/src/spacegroup.c).
spglib: No point group was found (line 405, /project/src/pointgroup.c).
spglib: Attempt 0 tolerance = 1.000000e-02 failed(line 800, /project/src/spacegroup.c).
spglib: No point group was found (line 405, /project/src/pointgroup.c).
spglib: Attempt 1 tolerance = 9.500000e-03 failed(line 800, /project/src/spacegroup.c).
spglib: No point group was found (line 4

In [7]:
# pad the orbit list
max_len = max([len(x) for x in MP_20_train['orbit_list']])
for dataset in (MP_20_train, MP_20_test, MP_20_val):
    dataset['orbit_list'] = [x + ["PAD"] * (max_len - len(x)) for x in dataset['orbit_list']]
print(max_len)

20


In [8]:
from itertools import chain
all_sites = set(chain(*MP_20_train['orbit_list'].values,*MP_20_test['orbit_list'].values, *MP_20_val['orbit_list'].values))

In [9]:
import math
import os
from tempfile import TemporaryDirectory
from typing import Tuple

import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

In [12]:
word_to_idx = {word: idx for idx, word in enumerate(all_sites)}
def sites_to_tensor(sites):
    return torch.tensor([word_to_idx[site] for site in sites])
for dataset in (MP_20_train, MP_20_test, MP_20_val):
    dataset['orbit_tensor'] = dataset['orbit_list'].map(sites_to_tensor)


In [13]:
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
from torch.nn import TransformerEncoder, TransformerEncoderLayer
class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.embedding = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        """
        Arguments:
            src: Tensor, shape ``[seq_len, batch_size]``
            src_mask: Tensor, shape ``[seq_len, seq_len]``

        Returns:
            output Tensor of shape ``[seq_len, batch_size, ntoken]``
        """
        src = self.embedding(src) * math.sqrt(self.d_model)
        output = self.transformer_encoder(src, src_mask)
        output = self.linear(output)
        return output

In [100]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Split the data into batches
train_data = torch.stack(tuple(MP_20_train['orbit_tensor'].values)).T.to(device)
val_data = torch.stack(tuple(MP_20_val['orbit_tensor'].values)).T.to(device)
test_data = torch.stack(tuple(MP_20_test['orbit_tensor'].values)).T.to(device)

In [128]:
batch_size = train_data.size(1)
val_batch_size = val_data.size(1)
test_batch_size = test_data.size(1)

In [191]:
get_batch(train_data[:, :4],11)[0].shape

torch.Size([8, 4])

In [157]:
train_data[:,0]

tensor([45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45,  6,  6,  6,  6,  6,  6,
         6,  6], device='cuda:0')

In [158]:
train_data[:,1]

tensor([51, 68, 68, 31,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
         6,  6], device='cuda:0')

In [160]:
bptt = 20
def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    Args:
        source: Tensor, shape ``[full_seq_len, batch_size]``
        i: int

    Returns:
        tuple (data, target), where data has shape ``[seq_len, batch_size]`` and
        target has shape ``[seq_len * batch_size]``
    """
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

In [162]:
ntokens = len(word_to_idx)  # size of vocabulary
emsize = 200  # embedding dimension
d_hid = 200  # dimension of the feedforward network model in ``nn.TransformerEncoder``
nlayers = 2  # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerEncoder``
nhead = 2  # number of heads in ``nn.MultiheadAttention``
dropout = 0.2  # dropout probability
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)



In [163]:
 len(train_data)

20

In [164]:
import time
bptt = max_len

criterion = nn.CrossEntropyLoss()
lr = 5.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

def train(model: nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 200
    start_time = time.time()

    num_batches = len(train_data) // bptt
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        output = model(data)
        output_flat = output.view(-1, ntokens)
        loss = criterion(output_flat, targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            total_loss = 0
            start_time = time.time()

def evaluate(model: nn.Module, eval_data: Tensor) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    with torch.no_grad():
        for i in range(0, eval_data.size(0) - 1, bptt):
            data, targets = get_batch(eval_data, i)
            seq_len = data.size(0)
            output = model(data)
            output_flat = output.view(-1, ntokens)
            total_loss += seq_len * criterion(output_flat, targets).item()
    return total_loss / (len(eval_data) - 1)

In [417]:
best_val_loss = float('inf')
epochs = 200

with TemporaryDirectory() as tempdir:
    best_model_params_path = os.path.join(tempdir, "best_model_params.pt")

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train(model)
        val_loss = evaluate(model, val_data)
        val_ppl = math.exp(val_loss)
        elapsed = time.time() - epoch_start_time
        print('-' * 89)
        print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
            f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
        print('-' * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), best_model_params_path)

        scheduler.step()
    model.load_state_dict(torch.load(best_model_params_path))

-----------------------------------------------------------------------------------------
| end of epoch   1 | time:  0.27s | valid loss  0.28 | valid ppl     1.32
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   2 | time:  0.27s | valid loss  0.28 | valid ppl     1.32
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   3 | time:  0.27s | valid loss  0.28 | valid ppl     1.32
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   4 | time:  0.27s | valid loss  0.28 | valid ppl     1.32
--------------------------------------------------------------------------

In [431]:
from torch.nn.functional import softmax
model.eval()
generation_size = 10000

sequences = torch.randint(0, len(word_to_idx) - 1, (1, generation_size)).to(device)
for i in range(max_len - 1):
    probas = softmax(model(sequences), -1)
    next_tokens = torch.multinomial(probas[-1], 1).T
    sequences = torch.cat((sequences, next_tokens), 0)

In [432]:
MP_20_train['orbit_list']

37228    [m, m, m, m, m, m, m, m, m, m, m, m, PAD, PAD,...
19480    [4/mmm, m2m., m2m., ..2/m, PAD, PAD, PAD, PAD,...
29624    [m-3m, m-3m, -43m, PAD, PAD, PAD, PAD, PAD, PA...
38633    [.m., .m., .m., .m., PAD, PAD, PAD, PAD, PAD, ...
10889    [.2., -4.., 1, PAD, PAD, PAD, PAD, PAD, PAD, P...
                               ...                        
37856    [3., .2, 1, PAD, PAD, PAD, PAD, PAD, PAD, PAD,...
11955    [-43m, m-3m, m-3m, PAD, PAD, PAD, PAD, PAD, PA...
26119    [4mm, -4m2, -4m2, 4mm, PAD, PAD, PAD, PAD, PAD...
30556    [6/mmm, 2mm, -6m2, -6m2, 6mm, PAD, PAD, PAD, P...
32933    [m2m, m.., 2/m.., m.., m2m, PAD, PAD, PAD, PAD...
Name: orbit_list, Length: 27136, dtype: object

In [421]:
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

In [422]:
sequences.shape

torch.Size([20, 13])

In [434]:
# sequences to word
generated_words = []
for sequence in sequences.T:
    this_word = filter(lambda x: x != "PAD", [idx_to_word[idx.item()] for idx in sequence])
    generated_words.append(frozenset(this_word))

In [435]:
generated_words

[frozenset({'-3m', '.-3.', '.m'}),
 frozenset({'-42m', '-62m'}),
 frozenset({'.-3m'}),
 frozenset({'.2/m', '23.'}),
 frozenset({'-3m', '-4m2', '-6..', '23.', '4/mmm'}),
 frozenset({'-3..', '-4m2', '2/m..', '3.m', '4mm', 'mm2'}),
 frozenset({'-1', '1'}),
 frozenset({'..2'}),
 frozenset({'-3..', '-3m', '-4m.2', '..2', '..2/m', '.2/m', '2.mm', '6mm'}),
 frozenset({'-4m.2'}),
 frozenset({'-4..',
            '-6..',
            '.-3.',
            '.3m',
            '2/m..',
            '222',
            '2mm',
            '3m.',
            '4..',
            '4/m..',
            '6mm',
            'm.m2'}),
 frozenset({'6/mmm'}),
 frozenset({'-4m2', '4/mmm', '422', 'm2m.'}),
 frozenset({'.m.', 'mmm..'}),
 frozenset({'222', '4m.m', 'm.mm'}),
 frozenset({'-3m', '2/m', 'mm2'}),
 frozenset({'m'}),
 frozenset({'2..', 'm2m', 'm2m.'}),
 frozenset({'-3..', '-42m', '..m', '4/m..'}),
 frozenset({'.-3.'}),
 frozenset({'-62m', '.m', '4mm'}),
 frozenset({'1'}),
 frozenset({'-4..', '-42m', '-62m'}),
 

In [436]:
from pyxtal.symmetry import Group

In [463]:
index = dict()
wp_index = defaultdict(dict)
for group_number in range(1, 231):
    group = Group(group_number)
    site_symmetry_symbols = []
    for wp in group.Wyckoff_positions:
        wp.get_site_symmetry()  
        site_symmetry_symbols.append(wp.site_symm)
        wp_index[group_number][wp.site_symm] = (wp.multiplicity, wp.letter)
    index[frozenset(site_symmetry_symbols)] = group_number

In [464]:
wp_index

defaultdict(dict,
            {1: {'1': (1, 'a')},
             2: {'1': (2, 'i'), '-1': (1, 'a')},
             3: {'1': (2, 'e'), '.2.': (1, 'a')},
             4: {'1': (2, 'a')},
             5: {'1': (4, 'c'), '.2.': (2, 'a')},
             6: {'1': (2, 'c'), '.m.': (1, 'a')},
             7: {'1': (2, 'a')},
             8: {'1': (4, 'b'), '.m.': (2, 'a')},
             9: {'1': (4, 'a')},
             10: {'1': (4, 'o'),
              '.m.': (2, 'm'),
              '.2.': (2, 'i'),
              '.2/m.': (1, 'a')},
             11: {'1': (4, 'f'), '.m.': (2, 'e'), '-1': (2, 'a')},
             12: {'1': (8, 'j'),
              '.m.': (4, 'i'),
              '.2.': (4, 'g'),
              '-1': (4, 'e'),
              '.2/m.': (2, 'a')},
             13: {'1': (4, 'g'), '.2.': (2, 'e'), '-1': (2, 'a')},
             14: {'1': (4, 'e'), '-1': (2, 'a')},
             15: {'1': (8, 'f'), '.2.': (4, 'e'), '-1': (4, 'a')},
             16: {'1': (4, 'u'),
              '..2': (2, 'q')

In [445]:
valid_generated = []
for word in generated_words:
    try:
        index[word]
        valid_generated.append(word)
    except KeyError:
        pass

In [475]:
example = valid_generated[9]

In [486]:
generator = pyxtal.pyxtal()

In [490]:
sites_list = []
mutliplicity = 0
for site in example:
    mutliplicity += wp_index[group_number][site][0]
    sites_list.append(str(wp_index[group_number][site][0])+wp_index[group_number][site][1])

In [494]:
mutliplicity

16

In [500]:
sites_list

['4a', '4b', '8c']

In [501]:
group_number = index[example]
generator.from_random(3, group_number, species=["C"], numIons=[mutliplicity], sites=[sites_list])

In [502]:
ase_struct = generator.to_ase()

In [507]:
import ase.visualize
ase.visualize.view(ase_struct, viewer="ngl")



HBox(children=(NGLWidget(), VBox(children=(Dropdown(description='Show', options=('All', 'C'), value='All'), Dr…

In [456]:
group = Group(61)
for wp in group.Wyckoff_positions:
    wp.get_site_symmetry()

In [386]:
group.Wyckoff_positions[0].get_site_symmetry()

In [387]:
group.Wyckoff_positions[0].site_symm

'1'

In [364]:
a is None

False

In [366]:
a.site_symm

'1'

In [330]:
        # index  is hall number
        #indices[0] = position_wyckoff[index];
        #indices[1] = position_wyckoff[index + 1] - position_wyckoff[index];

In [332]:
#                 ssmdb_get_site_symmetry_symbol(site_sym_symbol,
                                               #indices_wyc[0] + i);

In [342]:
from collections import defaultdict
sites = defaultdict(list)
for hall_number in range(len(position_wyckoff) - 1):
    for i in range(spg_db.position_wyckoff[hall_number], spg_db.position_wyckoff[hall_number + 1]):
        sites[hall_number].append(spg_db.site_symmetry[i].rstrip())

In [343]:
sites

defaultdict(list,
            {0: [''],
             1: ['1'],
             2: ['1', '-1', '-1', '-1', '-1', '-1', '-1', '-1', '-1'],
             3: ['1', '2', '2', '2', '2'],
             4: ['1', '2', '2', '2', '2'],
             5: ['1', '2', '2', '2', '2'],
             6: ['1'],
             7: ['1'],
             8: ['1'],
             9: ['1', '2', '2'],
             10: ['1', '2', '2'],
             11: ['1', '2', '2'],
             12: ['1', '2', '2'],
             13: ['1', '2', '2'],
             14: ['1', '2', '2'],
             15: ['1', '2', '2'],
             16: ['1', '2', '2'],
             17: ['1', '2', '2'],
             18: ['1', 'm', 'm'],
             19: ['1', 'm', 'm'],
             20: ['1', 'm', 'm'],
             21: ['1'],
             22: ['1'],
             23: ['1'],
             24: ['1'],
             25: ['1'],
             26: ['1'],
             27: ['1'],
             28: ['1'],
             29: ['1'],
             30: ['1', 'm'],
             31:

In [348]:
from pyxtal.symmetry import Group, Wyckoff_position

In [354]:
Group(2)

-- Spacegroup --# 2 (P-1)--
2i	site symm: 1
1h	site symm: -1
1g	site symm: -1
1f	site symm: -1
1e	site symm: -1
1d	site symm: -1
1c	site symm: -1
1b	site symm: -1
1a	site symm: -1

In [347]:
s.from_random()

TypeError: unsupported format string passed to NoneType.__format__

In [327]:
len(spg_db.site_symmetry)

4096

In [None]:
pyxtal.pyxtal.from_random()

In [308]:
spglib.get_symmetry_from_database(hall_number=2)

{'rotations': array([[[ 1,  0,  0],
         [ 0,  1,  0],
         [ 0,  0,  1]],
 
        [[-1,  0,  0],
         [ 0, -1,  0],
         [ 0,  0, -1]]], dtype=int32),
 'translations': array([[0., 0., 0.],
        [0., 0., 0.]])}

In [298]:
spglib.__file__

'/home/kna/.cache/pypoetry/virtualenvs/wyckofftransformer-FeCwefly-py3.10/lib/python3.10/site-packages/spglib/__init__.py'

In [265]:
next_tokens.shape

torch.Size([13, 1])

In [284]:
sequences.shape

torch.Size([3, 13])

In [278]:
probas[-1].shape

torch.Size([13, 73])

In [221]:
start

tensor([66, 24, 32, 40, 64, 32, 63, 44, 16, 65, 18, 60, 30], device='cuda:0')

In [259]:
get_batch(train_data[:, :4],18)[0].shape

torch.Size([1, 4])

In [226]:
output

tensor([[-1.1501e-01, -1.0813e+00, -2.8277e-01, -8.8878e-01,  9.6309e-01,
         -1.4374e+00,  4.8899e+00,  5.3303e-01, -1.6858e-02, -1.2647e+00,
         -1.8436e+00, -2.0023e-01,  4.5401e-01, -1.9268e+00,  1.0780e+00,
         -7.8715e-01, -1.3936e+00,  7.1426e-01,  1.3526e+00, -8.7893e-01,
         -8.2050e-01, -3.6213e-01, -9.9722e-01,  6.1405e-01, -6.3549e-01,
         -5.5356e-01, -5.1126e-03, -9.6416e-01, -5.6719e-01,  2.0959e+00,
          3.0051e+00,  1.7029e+00,  2.6258e+00, -1.0740e+00,  7.6692e-01,
          1.1560e+00, -9.3513e-01, -2.0610e+00, -8.9891e-01,  1.2618e+00,
         -4.2852e-01, -2.1445e-01, -9.1498e-01, -8.1692e-01, -9.3644e-01,
          4.3205e-01,  1.7107e-01,  4.1742e-01, -1.6240e-01, -2.1882e-01,
          1.7070e+00,  4.0800e-01, -2.7730e-02,  1.4231e+00, -8.8708e-01,
          5.4351e-01, -4.3499e-01, -1.9454e-01,  2.6634e+00, -7.8157e-02,
          2.8401e+00,  2.6433e-01,  3.2736e-01, -1.5897e+00,  3.1882e-01,
         -1.4460e+00,  8.0376e-01, -2.