In [1]:
%env WANDB_NOTEBOOK_NAME Train 1 head mixer.ipynb
%env CUDA_DEVICE_ORDER PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES 0
%env PYTORCH_CUDA_ALLOC_CONF backend:cudaMallocAsync

env: WANDB_NOTEBOOK_NAME=Train 1 head mixer.ipynb
env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=0
env: PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync


In [2]:
import torch
# UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
torch.set_float32_matmul_precision('high')

In [3]:
from mp_20_utils import load_all_data
device = 'cuda'
dataset = 'mp_20_biternary'
datasets_pd, torch_datasets, site_to_ids, element_to_ids, spacegroup_to_ids, max_len, max_enumeration, enumeration_stop, enumeration_pad = load_all_data(
    dataset=dataset)
print(max_len, max_enumeration, enumeration_stop, enumeration_pad)

20 7 8 9


In [4]:
from cascade_transformer.model import CascadeTransformer
from wyckoff_transformer import WyckoffTrainer
from tokenization import PAD_TOKEN, MASK_TOKEN
n_space_groups = len(spacegroup_to_ids)
# Not all 230 space groups are present in the data
# Embedding doesn't support uint8. Sad!
dtype = torch.int64
cascade_order = ("elements", "symmetry_sites", "symmetry_sites_enumeration")
# (N_i, d_i, pad_i)
assert max_enumeration + 1 == enumeration_stop
assert max_enumeration + 2 == enumeration_pad
enumeration_mask = max_enumeration + 3
assert enumeration_mask < torch.iinfo(dtype).max

cascade = (
    (len(element_to_ids), 64, torch.tensor(element_to_ids[PAD_TOKEN], dtype=dtype, device=device)),
    (len(site_to_ids), 64 - 1, torch.tensor(site_to_ids[PAD_TOKEN], dtype=dtype, device=device)),
    (enumeration_mask + 1, None, torch.tensor(enumeration_pad, dtype=dtype, device=device))
)
model = CascadeTransformer(
    n_start=n_space_groups,
    cascade=cascade,
    n_head=1,
    d_hid=256,
    n_layers=4,
    dropout=0.1,
    use_mixer=True).to(device)
# Our dynamic discard of predicting PAD calls for frequent recompilation
# model = torch.compile(model)

In [5]:
import torch
pad_dict = {
    "elements": element_to_ids[PAD_TOKEN],
    "symmetry_sites": site_to_ids[PAD_TOKEN],
    "symmetry_sites_enumeration": enumeration_pad
}
mask_dict = {
    "elements": element_to_ids[MASK_TOKEN],
    "symmetry_sites": site_to_ids[MASK_TOKEN],
    "symmetry_sites_enumeration": enumeration_mask
}
trainer = WyckoffTrainer(
    model, torch_datasets, pad_dict, mask_dict, cascade_order, "spacegroup_number", max_len, device, dtype=dtype)
trainer.train(epochs=40000, val_period=20)

[34m[1mwandb[0m: Currently logged in as: [33mkazeev[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Tracking run with wandb version 0.17.0


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/home/kna/WyckoffTransformer/wandb/run-20240524_160358-nf7qfrff[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33mdeft-plasma-237[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/kazeev/WyckoffTransformer[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/kazeev/WyckoffTransformer/runs/nf7qfrff[0m


Epoch 20 val_loss_epoch 90.71502685546875 saved to checkpoints/2024-05-24_16-03-57/best_model_params.pt




Epoch 40 val_loss_epoch 51.163509368896484 saved to checkpoints/2024-05-24_16-03-57/best_model_params.pt


Epoch 240 val_loss_epoch 48.3653678894043 saved to checkpoints/2024-05-24_16-03-57/best_model_params.pt


Epoch 660 val_loss_epoch 47.044532775878906 saved to checkpoints/2024-05-24_16-03-57/best_model_params.pt


Epoch 840 val_loss_epoch 46.41947555541992 saved to checkpoints/2024-05-24_16-03-57/best_model_params.pt


Epoch 2680 val_loss_epoch 45.86494445800781 saved to checkpoints/2024-05-24_16-03-57/best_model_params.pt


[34m[1mwandb[0m: - 2.834 MB of 2.834 MB uploaded

[34m[1mwandb[0m: \ 0.495 MB of 2.839 MB uploaded

[34m[1mwandb[0m: | 2.777 MB of 2.839 MB uploaded

[34m[1mwandb[0m: / 2.777 MB of 2.839 MB uploaded

[34m[1mwandb[0m: - 2.839 MB of 2.839 MB uploaded

[34m[1mwandb[0m:                                                                                


[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:             epoch ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
[34m[1mwandb[0m: known_cascade_len ████▁▁▅▁▅███▁▁▁▅█▁▁▅▅▅█▅▅▁█▅██▅█▅█▅▅▁▁██
[34m[1mwandb[0m:     known_seq_len ▃▅▂▁▃▇▆▃▁▇▄▆▃█▂▁▁▂█▄▅▂▃▃▁▂▇▄▄▅▄▁▃▇█▅▃▆▆▅
[34m[1mwandb[0m:                lr ██▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:  train_loss_batch ▁▁▂▃▂▁▁▁█▁▁▁▁▁▃█▃▄▁▁▁▄▁▁█▄▁▁▁▁▁▃▁▁▁▁▂▁▁▁
[34m[1mwandb[0m:  train_loss_epoch ▂█▂▁▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
[34m[1mwandb[0m:    val_loss_epoch ▂█▂▁▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:             epoch 40000
[34m[1mwandb[0m: known_cascade_len 2
[34m[1mwandb[0m:     known_seq_len 17
[34m[1mwandb[0m:                lr 0.0
[34m[1mwandb[0m:  train_loss_batch 142.2659
[34m[1mwandb[0m:  train_loss_epoch 48.92275
[34m[1mwandb[0m:    val_loss_epoch 48.82223
[34m[1mwandb[0m: 


[34m[1mwandb[0m: 🚀 View run [33mdeft-plasma-237[0m at: [34m[4mhttps://wandb.ai/kazeev/WyckoffTransformer/runs/nf7qfrff[0m
[34m[1mwandb[0m: ⭐️ View project at: [34m[4mhttps://wandb.ai/kazeev/WyckoffTransformer[0m
[34m[1mwandb[0m: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)


[34m[1mwandb[0m: Find logs at: [35m[1m./wandb/run-20240524_160358-nf7qfrff/logs[0m
