In [1]:
# DATA


import numpy as np
import os
import json
from pathlib import Path
import re
from time import sleep
from tqdm import tqdm
import warnings

import torch
from torch_geometric.data import Data, DataLoader
from torch.utils.data import random_split



from pymatgen.io.cif import CifParser
from pymatgen.analysis.local_env import CrystalNN
from pymatgen.core import Structure, Lattice, Site

import torch
import pytorch_lightning as pl

from torch_geometric.nn import global_mean_pool
from torch.optim import Adam
from torch.nn.functional import relu
from torch.nn import Module, MultiheadAttention, Linear
from torch_geometric.nn import global_mean_pool, GATConv
from torch.optim.lr_scheduler import StepLR

from transformers import GPT2Config, GPT2Model



DATASETS = {
    "Mo": "./data/Mo"
}

def gvector (gvector):
    with open(gvector, "rb") as binary_file:
                bin_version = int.from_bytes(binary_file.read(4),
                                             byteorder='little',
                                             signed=False)
                if bin_version != 0:
                    print("Version not supported!")
                    exit(1)
                # converting to int to avoid handling little/big endian
                flags = int.from_bytes(binary_file.read(2),
                                       byteorder='little',
                                       signed=False)
                n_atoms = int.from_bytes(binary_file.read(4),
                                         byteorder='little',
                                         signed=False)
                g_size = int.from_bytes(binary_file.read(4),
                                        byteorder='little',
                                        signed=False)
                payload = binary_file.read()
                data = np.frombuffer(payload, dtype='<f4')
                en = data[0]
                gvect_size = n_atoms * g_size
                spec_tensor = np.reshape((data[1:1+n_atoms]).astype(np.int32),
                                     [1, n_atoms])
                gvect_tensor = np.reshape(data[1+n_atoms:1+n_atoms+gvect_size],
                                      [n_atoms, g_size])
    return (gvect_tensor)


def json_to_pmg_structure(db_name, json_file):
    """
    converts json files into cif format files
    """
    cif_path = os.path.join(DATASETS[db_name], 
                            "train_gv", "cifs")  
    
    json_path = os.path.join(DATASETS[db_name], 
                            "train_gv", "jsons", json_file) 
    
    Path(cif_path).mkdir(parents=True,
                          exist_ok=True)
    
    json_data = read_json(json_path)
    lattice_vectors = json_data["lattice_vectors"]
    lattice = Lattice(lattice_vectors)
    sites = [
        Site(species=atom[1], coords=atom[2], properties={"occupancy": 1.0})
        for atom in json_data["atoms"]
    ]
    cif_name = json_file.split(".")[0] + ".cif"
    structure = Structure(lattice=lattice, species=["Mo"] * len(sites), coords=[site.coords for site in sites])
    if os.path.isfile(cif_path + "/" + cif_name):
        pass
    else:
        structure.to(filename=cif_path + "/" + cif_name)
    return structure


def get_edge_indexes(structure):
    bonded_structure = CrystalNN(weighted_cn=True, distance_cutoffs=(10,  20.))
    bonded_structure = bonded_structure.get_bonded_structure(structure)
    bonded_structure = bonded_structure.as_dict()
    structure_graph = bonded_structure["graphs"]["adjacency"]

    # len(graph) = number of atoms
    edge_index_from = []
    edge_index_to = []
    edges = []
    for i in range (len(structure_graph)):
        #iterates over the connected atoms of each atom in the cell
        for j in range(len(structure_graph[i])):
            edge_index_from.append(i)
            edge_id = structure_graph[i][j]["id"]
            edge_index_to.append(edge_id)
            edge = torch.tensor(structure_graph[i][j]["to_jimage"])
            edges.append(edge)

    edge_index_from = torch.tensor(edge_index_from)
    edge_index_to = torch.tensor(edge_index_to)

    edge_indexes = np.array([edge_index_from, edge_index_to])
    edge_indexes = torch.from_numpy(edge_indexes)

    edges = np.array(edges)
    edges = torch.from_numpy(edges)
    return edge_indexes, edges


def read_json(filename):
    with open(filename, 'r') as file:
        data = json.load(file)
    return data


def get_db_keys(db_name):
    db_path = os.path.join(DATASETS[db_name], "train_gv", "gvectors")
    keys = [f.split(".")[0] for f in os.listdir(db_path) if os.path.isfile(os.path.join(db_path, f))]

    gvector_keys = []
    json_keys = []
    for item in keys:
        gvector_keys.append(item+".bin")
        json_keys.append(item+".example")
                  
    return gvector_keys, json_keys



def dataset(db_name):
    # Parinello vectors
    db_path =  os.path.join(DATASETS[db_name], "train_gv", "gvectors")
    gvect_keys, json_keys = get_db_keys(db_name)
    set = []
    for item in gvect_keys[0:50]:
        a = gvector (db_path + "/" + item)
        a = torch.tensor(a)
        set.append(a)
    parinello = set

    # edge indexes
    edge_indexes = []
    edges = []

    for item in tqdm(json_keys[0:50]):
        structure = json_to_pmg_structure(db_name="Mo", json_file=item)
        ei, e = get_edge_indexes(structure)
        edge_indexes.append(ei)
        edges.append(e)
         
    return parinello, edge_indexes, edges


def get_labels(db_name):
     """gets labels (energy, force, ...)"""
     
     label = []
     db_path =  os.path.join(DATASETS[db_name], "train_gv", "jsons")
     gvect_keys, json_keys = get_db_keys(db_name)
     
     for item in json_keys[0: 50]:
          example = os.path.join(db_path, item)
          data = read_json(example)
          num_atoms = len(data["atoms"])
          toten = data["energy"][0]
          en_per_atom = toten/num_atoms
          label.append(en_per_atom)
     
     label = torch.tensor(label, dtype=torch.float)
     
     return label

def create_sequence_tensor(feature, seq_len):
    count = 0
    sequence = []
    num_batches = len(feature) // seq_len

    for batch in range(num_batches):
        sub_sequence = [feature[count + i] for i in range(seq_len)]
        count += seq_len
        sequence.append(sub_sequence)

    return sequence

def in_context_data(data_loader, batch_size):
    in_context_db = []
    for batch in data_loader:
        in_context_example = {
            "parinello": batch.x,
            "edge_index": batch.edge_index,
            "to_j": batch.to_j,
            "in_context_label": batch.batch,
            "label": batch.y, 
        }

        data = Data(x=in_context_example["parinello"], edge_index=in_context_example["edge_index"],
            to_j=in_context_example["to_j"], config_label=in_context_example["in_context_label"],
            y=in_context_example["label"])
    
        in_context_db.append(data)

    context_loader = DataLoader(in_context_db, batch_size=batch_size, shuffle=False)

    return context_loader


def data(db_name, sequence_size, batch_size):
    """Create a PyTorch Geometric Data object"""
    warnings.filterwarnings("ignore")
    parinello, edge_indexes, edges = dataset(db_name=db_name)
    labels = get_labels(db_name)

    db = []
    for i in range (len(parinello)):
        data = Data(x=parinello[i], edge_index=edge_indexes[i], to_j=edges[i], y=labels[i])
        db.append(data)

    # Create a PyTorch Geometric DataLoader
    batch_size = batch_size
    dataset_size = len(db)
    train_size = int(0.8 * dataset_size)
    val_size = dataset_size - train_size
    train_dataset, val_dataset = random_split(db, [train_size, val_size])

    t_loader = DataLoader(train_dataset, batch_size=sequence_size, shuffle=False)
    v_loader = DataLoader(val_dataset, batch_size=sequence_size, shuffle=False)
    
    train_loader = in_context_data(t_loader, batch_size=batch_size)
    val_loader = in_context_data(v_loader, batch_size=batch_size)

    return train_loader, val_loader

In [8]:

#MODEL

class GPT2BasedModel(Module):
    def __init__(self, n_dims, n_positions, n_embd=128, n_layer=12, n_head=4):
        super(GPT2BasedModel, self).__init__()

        # GPT-2 Configuration
        configuration = GPT2Config(
            n_positions=2 * n_positions,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head,
            resid_pdrop=0.0,
            embd_pdrop=0.0,
            attn_pdrop=0.0,
            use_cache=False,
        )

        self.name = f"gpt2_embd={n_embd}_layer={n_layer}_head={n_head}"
        self._read_in = Linear(n_dims, n_embd)
        self._read_out = Linear(n_embd, 1)

        self._backbone = GPT2Model(configuration)

    @staticmethod
    def _combine(xs_b, ys_b):
        """Interleaves the x's and the y's into a single sequence."""
        bsize, points, dim = xs_b.shape
        ys_b_wide = torch.cat(
            (
                ys_b.view(bsize, points, 1),
                torch.zeros(bsize, points, dim - 1, device=ys_b.device),
            ),
            axis=2,
        )
        zs = torch.stack((xs_b, ys_b_wide), dim=2)
        zs = zs.view(bsize, 2 * points, dim)
        return zs

    def forward(self, input_tensor):
        x = input_tensor
        x = self._read_in(input_tensor)
        gpt2_output = self._backbone(inputs_embeds=x)
        output = self._read_out(gpt2_output.last_hidden_state[:, -1, :])

        return output


class InContextGNN(pl.LightningModule):
    def __init__(self):
        super(InContextGNN, self).__init__()
        self.graph1 = GATConv(in_channels=160, out_channels=16, heads=2)
        self.graph2 = GATConv(in_channels=32, out_channels=8, heads=8)
        self.att1 = GPT2BasedModel(64, 128)
        self.readout = Linear(8, 1)
        self.act = relu
        self.train_loader, self.val_loader = data("Mo", 10, 10) 
 
    def forward(self, batch):
        #encoder
        graphs_per_datapoint = torch.max(batch.config_label) + 1
        actual_batch_dot_batch = batch.batch * graphs_per_datapoint + batch.config_label

        graph_h1 = self.graph1(batch.x, batch.edge_index)
        graph_h1 = self.act(graph_h1)
        graph_h2 = self.graph2(graph_h1, batch.edge_index)
        graph_h2 = self.act(graph_h2)
        graph_h = global_mean_pool(graph_h2, actual_batch_dot_batch)
        batch.config_label
        h1 = self.att1(graph_h)
        h1 = self.act(h1[0])
        out = self.readout(h1[0:])
        return out

    def train_dataloader(self):
        train_loader = self.train_loader
        return train_loader

    def val_dataloader(self):
        val_loader = self.val_loader
        return val_loader

    def training_step(self, batch, batch_idx):
        output = self(batch)
        loss = torch.nn.functional.mse_loss(output, batch.y.view(-1, 1))
        self.log('train_loss', loss)
        self.log('learning_rate', self.trainer.optimizers[0].param_groups[0]['lr'])
        return loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=0.01)
        scheduler = StepLR(optimizer, step_size=10000, gamma=0.1)

        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch', 
                'monitor': 'val_loss',
            }
        }

    def validation_step(self, batch, batch_idx):
        output = self(batch)
        loss = torch.nn.functional.mse_loss(output, batch.y.view(-1, 1))
        self.log('val_loss', loss)
        return {'val_loss': loss}

In [7]:
graph1 = GATConv(in_channels=160, out_channels=16, heads=2)
graph2 = GATConv(in_channels=32, out_channels=8, heads=8)
readout = Linear(768, 1)
readin = Linear(64, 64)
act = relu
gpt2_model = GPT2BasedModel(64, 128)

In [4]:
train_loader, test_loader = data("Mo",4 , 4)

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [00:35<00:00,  1.43it/s]


In [5]:
d = next(iter(train_loader))
d.y


tensor([-10.6779, -10.8607, -10.8035, -10.8589, -10.9141, -10.8179, -10.5100,
        -10.9150, -10.9002, -10.9045, -10.8406, -10.8556, -10.9251, -10.8615,
        -10.8526,  -9.4712])

In [6]:
graphs_per_datapoint = torch.max(d.config_label) + 1
actual_batch_dot_product = d.batch * graphs_per_datapoint + d.config_label

graph_h1 = graph1(d.x, d.edge_index)
graph_h1 = act(graph_h1)
graph_h2 = graph2(graph_h1, d.edge_index)
graph_h2 = act(graph_h2)
graph_h = global_mean_pool(graph_h2, actual_batch_dot_product)


graph_h = act(graph_h)

#graph_h = graph_h.reshape(batch_size, examples_per_seq, -1)
graph_h = graph_h.reshape(4, 4, -1)
print (graph_h.shape)


gpt_o = gpt2_model(graph_h)

gpt_o


torch.Size([4, 4, 64])


tensor([[-0.3688],
        [-0.2766],
        [-0.2957],
        [-0.7259]], grad_fn=<AddmmBackward0>)