# Transform data


Transforms the datasets we have to .xyz format, so they are compatible with the CryinGAN repository.

In [None]:
import os
import ase
from ase.io import read, write
import numpy as np
import shutil
from pathlib import Path
from tqdm import tqdm

%cd ..
from src.load_data import get_descriptors
from src.utils import load_raw_data, read_raw_sample

from CCGAN.tools import BatchDistance

%cd -

In [None]:
def makedir_if_not_exists(path):
    try:
        if not os.path.isdir(path):
            print("Creating directory {}".format(path))
            os.mkdir(path)
    except OSError:
        print("Creation of the directory %s failed" % path)
        makedir_if_not_exists(path.parent) # Recursive call to create parent directory
    return

path = Path("../data/raw/crystal/Sq")
path = Path("../data/raw/samples")


phis = [
    0.70,
    0.80,
    0.84,
    0.86,
    ]

# Low packing fraction 0.70
#    2. Mid 0.78
#    3. High packing fraction 0.84
#    4. Very high 0.86

files, dataframe, metadata = load_raw_data(path, phi=phis, subpath="")

In [None]:
import pandas as pd
from ase.units import Bohr

max_files = np.inf # NOTE: Limit the amount of data to speed up training

input_paths = [file for file in files]
output_paths = {input_paths[i]: Path(str(input_paths[i]).replace("raw", "processed")).parent for i in range(len(input_paths))}

for path in output_paths.values():
    # Remove the folder if it exists
    if path.is_dir():
        print("Removing folder {}".format(path))
        shutil.rmtree(path)


coords_all = []

for i, file in tqdm(enumerate(files), total=len(files)):
    dataframe, metadata = read_raw_sample(file)

    output_folder = output_paths[file]

    output_folder.mkdir(parents=True, exist_ok=True)

    output_samples = output_folder / "samples.extxyz"
    output_metadata = output_folder / "metadata.csv"
    radius_file = output_folder / "radius.csv"

    xyz = dataframe[["class", "x", "y"]].reset_index(drop=True)
    r = dataframe[["r"]].reset_index(drop=True)

    N = metadata.iloc[0,0] # N particles

    phi, sample = dataframe.index.unique()[0]
    phi_value = float(phi.split("-")[-1])
    xyz["z"] = 0

    # Create an ASE Atoms object

    L = metadata["L"].iloc[0]

    # NOTE: Radius is not right
    atoms = ase.Atoms(
        numbers=xyz["class"].values,
        positions=xyz[["x", "y", "z"]].values+[L/2, L/2, 0], # NOTE: Displace the system
        cell=[L, L, 0], # NOTE: 2D system
        pbc=[True, True, False], # NOTE: 2D system
        info={"phi": phi_value, "sample": sample, "N": N, "L":L},#, "r": r["r"].values},
    )
    atoms.new_array('rmt', r.values)

    coords_all.append(
        atoms.get_scaled_positions()
    )
    # Save to xyz file
    # NOTE: This is the format used by ASE

    with open(output_samples, "a+") as f:
        write(f, atoms, format="extxyz", append=True)

    atoms_flipped = ase.Atoms(
        numbers=xyz["class"].values,
        positions=xyz[["y", "x", "z"]].values+[L/2, L/2, 0], # NOTE: Displace the system
        cell=[L, L, 0], # NOTE: 2D system
        pbc=[True, True, False], # NOTE: 2D system
        info={"phi": phi_value, "sample": sample, "N": N, "L":L},#, "r": r["r"].values},
    )
    atoms_flipped.new_array('rmt', r.values)

    coords_all.append(
        atoms_flipped.get_scaled_positions()
    )
    # Save to xyz file
    # NOTE: This is the format used by ASE

    with open(output_samples, "a+") as f:
        write(f, atoms_flipped, format="extxyz", append=True)


    pd.DataFrame(
        {
            "phi": [phi_value],
            "sample": [sample],
            "N": [N],
            "L": [L],
        }
    ).to_csv(
        output_metadata,
        header=False,
        index=False,
        sep="\t",
        mode="a+",
    )
    
    if i > max_files:
        break


display(dataframe.head(10))
metadata

In [None]:
slab_from_file = read("../data/processed/samples/phi-0.70/samples.extxyz", index=":", format="extxyz")

from ase.visualize import view
print(slab_from_file[0].info)
view(slab_from_file[0], viewer="x3d")

Works

## Add batch distances

In [None]:
import torch

prep_dataloader = torch.utils.data.DataLoader(coords_all, batch_size = 256, shuffle = False)

In [None]:
atoms.get_cell()

In [None]:
n_neighbors = 3
train_data = []
lattice = atoms.get_cell()[:] # NOTE: This is the lattice matrix

for i, batch_coords in tqdm(enumerate(prep_dataloader), total=len(prep_dataloader)):
    
    n_atoms_total = len(atoms)


    batch_coords = batch_coords.view(batch_coords.shape[0], 1, n_atoms_total, 3).float()
    if torch.cuda.is_available():
        batch_coords = batch_coords.cuda()
    elif torch.backends.mps.is_available():
        batch_coords = batch_coords.to("mps")
    else:
        batch_coords = batch_coords.cpu()

    batch_dataset = BatchDistance(batch_coords, n_neighbors=n_neighbors, lat_matrix=lattice)
    batch_coords_with_dist = batch_dataset.append_dist()
    train_data.append(batch_coords_with_dist.cpu())

# train_data = torch.cat(train_data)