In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from fairchem.core.datasets import AseDBDataset

dataset_path = "../datasets/val/rattled-500-subsampled"
config_kwargs = {}  # see tutorial on additional configuration

ase_dataset = AseDBDataset(config=dict(src=dataset_path, **config_kwargs))

# atoms objects can be retrieved by index
atoms = ase_dataset.get_atoms(0)

In [None]:
all_atoms = [ase_dataset.get_atoms(i) for i in range(len(ase_dataset.ids))]

In [None]:
labels = [atoms.get_forces() for atoms in all_atoms]
X_coords = [
    np.concatenate(
        [atoms.get_positions(wrap=True), atoms.get_scaled_positions(wrap=True)], axis=1
    )
    for atoms in all_atoms
]
X_numbers = [atoms.get_atomic_numbers() for atoms in all_atoms]

In [None]:
print(len(labels[0]), len(X_coords[0]), len(X_numbers[0]))

In [None]:
len(X_coords[0])

In [None]:
from torch.nn.utils.rnn import pad_sequence

In [None]:
class OMat24Dataset(torch.utils.data.Dataset):
    def __init__(self, X_coords, X_numbers, y):
        self.X_coords = X_coords
        self.X_numbers = X_numbers
        self.y = y

    def __len__(self):
        return len(self.X_numbers)

    def __getitem__(self, idx):
        return (
            self.X_coords[idx],
            self.X_numbers[idx],
            self.y[idx],
            torch.tensor([1] * len(self.y[idx]), dtype=torch.int64),
        )


class OMat24DataLoader(torch.utils.data.DataLoader):
    # Sequences are different
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


def collate_fn(batch):
    X_coords, X_numbers, labels, mask = zip(*batch)
    # print(len(X_numbers[0]), len(X_coords[0]), len(labels[0]))
    # print(len(X_numbers[1]), len(X_coords[1]), len(labels[1]))

    X_coords_t = pad_sequence(
        [torch.tensor(c, dtype=torch.float32) for c in X_coords], batch_first=True
    )
    X_numbers_t = pad_sequence(
        [torch.tensor(n, dtype=torch.int64) for n in X_numbers], batch_first=True
    )
    labels_t = pad_sequence(
        [torch.tensor(y, dtype=torch.float32) for y in labels], batch_first=True
    )
    mask_t = pad_sequence(
        [torch.tensor(m, dtype=torch.int64) for m in mask], batch_first=True
    ).to(torch.bool)

    return X_coords_t, X_numbers_t, labels_t, mask_t

In [None]:
class SimpleTransformer(nn.Module):
    def __init__(
        self,
        num_embeddings,
        embedding_dim,
        num_heads,
        num_layers,
        dim_feedforward,
        dropout,
    ):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)

        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                embedding_dim + 6, num_heads, dim_feedforward, dropout, batch_first=True
            ),
            num_layers,
        )
        self.fc = nn.Linear(embedding_dim + 6, 3)

    def forward(self, inputs, mask):
        # x = self.embedding(numbers+1)
        numbers, coords = inputs
        x = torch.cat([self.embedding(numbers), coords], dim=-1)
        x = self.transformer(x)
        x = self.fc(x)
        return x

In [None]:
model = SimpleTransformer(200, 32, 2, 3, 128, 0.0)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
losses = []
large_losses = []
small_losses = []

In [None]:
# get number of model parameters
sum(p.numel() for p in model.parameters())

In [None]:
from tqdm import tqdm

dataset = OMATDataset(X_coords, X_numbers, labels)
dataloader = OMATDataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)


for epoch in range(3):
    to_iter = tqdm(dataloader)
    for i, (coords, numbers, forces, mask) in enumerate(to_iter):
        to_iter.set_description(f"Loss: {np.mean(losses[-100:]):.4f}")
        optimizer.zero_grad()
        flat_mask = mask.flatten()
        label_forces = forces.reshape(-1, 3)[flat_mask]
        output_forces = model((numbers, coords), mask).reshape(-1, 3)[flat_mask]
        # loss = F.mse_loss(label_forces.flatten(), output_forces.flatten(), reduction="mean")
        # diff_vec = label_forces.flatten() - output_forces.flatten()
        # large_diffs = diff_vec[torch.abs(diff_vec) > 5]
        # rest_diffs = diff_vec[torch.abs(diff_vec) <= 5]
        # large_loss = torch.norm(large_diffs, p=0.75) / (len(large_diffs) ** (1/0.75))
        # rest_loss = torch.norm(rest_diffs, p=2) / len(rest_diffs)**(1/2)
        # loss = large_loss + rest_loss
        # loss = torch.norm(diff_vec, p=0.75) / len(diff_vec)**(1/0.75)
        loss = torch.mean(torch.abs(label_forces - output_forces))
        loss.backward()
        # clip gradients
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        losses.append(loss.item())
        # large_losses.append(large_loss.item())
        # small_losses.append(rest_loss.item())

In [None]:
plt.plot(losses)

In [None]:
absolute_errors = []
to_iter = tqdm(dataloader)
for i, (coords, numbers, forces, mask) in enumerate(to_iter):
    with torch.no_grad():
        to_iter.set_description(f"MAE: {np.mean(absolute_errors):.4f}")
        flat_mask = mask.flatten()
        label_forces = forces.reshape(-1, 3)[flat_mask]
        output_forces = model((numbers, coords), mask).reshape(-1, 3)[flat_mask]
        diff_vec = label_forces.flatten() - output_forces.flatten()
        absolute_errors.append(torch.mean(torch.abs(diff_vec)).item() * 3)
        # # to_iter.set_description(f"Loss: {np.mean(losses[-100:]):.4f}, Large Loss: {np.mean(large_losses[-100:]):.4f}, Small Loss: {np.mean(small_losses[-100:]):.4f}")
        # optimizer.zero_grad()
        # flat_mask = mask.flatten()
        # label_forces = forces.reshape(-1,3)[flat_mask]
        # output_forces = model((numbers,coords), mask).reshape(-1,3)[flat_mask]
        # # loss = F.mse_loss(label_forces.flatten(), output_forces.flatten(), reduction="mean")
        # diff_vec = label_forces.flatten() - output_forces.flatten()
        # large_diffs = diff_vec[torch.abs(diff_vec) > 5]
        # rest_diffs = diff_vec[torch.abs(diff_vec) <= 5]
        # large_loss = torch.norm(large_diffs, p=0.5) / (len(large_diffs) ** 2.5)
        # rest_loss = torch.mean(torch.abs(rest_diffs))
        # loss = large_loss + rest_loss
        # # loss = torch.norm(diff_vec, p=0.75) / len(diff_vec)**(1/0.75)
        # loss.backward()
        # # clip gradients
        # # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        # optimizer.step()
        # losses.append(loss.item())
        # large_losses.append(large_loss.item())
        # small_losses.append(rest_loss.item())

In [None]:
pred_0_errors = []
to_iter = tqdm(dataloader)
for i, (coords, numbers, forces, mask) in enumerate(to_iter):
    with torch.no_grad():
        to_iter.set_description(f"MAE: {np.mean(pred_0_errors):.4f}")
        flat_mask = mask.flatten()
        label_forces = forces.reshape(-1, 3)[flat_mask]
        output_forces = torch.zeros_like(label_forces)
        diff_vec = label_forces.flatten() - output_forces.flatten()
        pred_0_errors.append(torch.mean(torch.abs(diff_vec)).item() * 3)

In [None]:
print(np.mean(absolute_errors))

In [None]:
torch.max(output_forces)

In [None]:
torch.max(label_forces)

In [None]:
atoms.info.keys()

In [None]:
atoms.get_forces()

In [None]:
dataset = OMat24Dataset(X_coords, X_numbers, labels)

In [None]:
import pickle

with open("dataset.pkl", "wb") as f:
    pickle.dump(dataset, f)

In [None]:
dataset = pickle.load(open("dataset.pkl", "rb"))

In [None]:
train_split = int(len(dataset) * 0.8)
train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_split, len(dataset) - train_split]
)

In [None]:
with open("train_dataset.pkl", "wb") as f:
    pickle.dump(train_dataset, f)

with open("test_dataset.pkl", "wb") as f:
    pickle.dump(test_dataset, f)

In [None]:
len(dataset) * 0.9