In [1]:
%env WANDB_NOTEBOOK_NAME Train 4 head no mixer.ipynb
%env CUDA_DEVICE_ORDER PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES 1
%env PYTORCH_CUDA_ALLOC_CONF backend:cudaMallocAsync

env: WANDB_NOTEBOOK_NAME=Train 4 head no mixer.ipynb
env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1
env: PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync


In [2]:
import torch
# UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
torch.set_float32_matmul_precision('high')

In [3]:
from mp_20_utils import load_all_data
device = 'cuda'
dataset = 'mp_20_biternary'
datasets_pd, torch_datasets, site_to_ids, element_to_ids, spacegroup_to_ids, max_len, max_enumeration, enumeration_stop, enumeration_pad = load_all_data(
    dataset=dataset)
print(max_len, max_enumeration, enumeration_stop, enumeration_pad)

20 7 8 9


In [4]:
from cascade_transformer.model import CascadeTransformer
from wyckoff_transformer import WyckoffTrainer
from tokenization import PAD_TOKEN, MASK_TOKEN
n_space_groups = len(spacegroup_to_ids)
# Not all 230 space groups are present in the data
# Embedding doesn't support uint8. Sad!
dtype = torch.int64
cascade_order = ("elements", "symmetry_sites", "symmetry_sites_enumeration")
# (N_i, d_i, pad_i)
assert max_enumeration + 1 == enumeration_stop
assert max_enumeration + 2 == enumeration_pad
enumeration_mask = max_enumeration + 3
assert enumeration_mask < torch.iinfo(dtype).max

cascade = (
    (len(element_to_ids), 64, torch.tensor(element_to_ids[PAD_TOKEN], dtype=dtype, device=device)),
    (len(site_to_ids), 64 - 1, torch.tensor(site_to_ids[PAD_TOKEN], dtype=dtype, device=device)),
    (enumeration_mask + 1, None, torch.tensor(enumeration_pad, dtype=dtype, device=device))
)
model = CascadeTransformer(
    n_start=n_space_groups,
    cascade=cascade,
    n_head=4,
    d_hid=256,
    n_layers=4,
    dropout=0.1,
    use_mixer=False).to(device)
# Our dynamic discard of predicting PAD calls for frequent recompilation
# model = torch.compile(model)

In [5]:
import torch
pad_dict = {
    "elements": element_to_ids[PAD_TOKEN],
    "symmetry_sites": site_to_ids[PAD_TOKEN],
    "symmetry_sites_enumeration": enumeration_pad
}
mask_dict = {
    "elements": element_to_ids[MASK_TOKEN],
    "symmetry_sites": site_to_ids[MASK_TOKEN],
    "symmetry_sites_enumeration": enumeration_mask
}
trainer = WyckoffTrainer(
    model, torch_datasets, pad_dict, mask_dict, cascade_order, "spacegroup_number", max_len, device, dtype=dtype)
trainer.train(epochs=40000, val_period=20)

[34m[1mwandb[0m: Currently logged in as: [33mkazeev[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Tracking run with wandb version 0.17.0


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/home/kna/WyckoffTransformer/wandb/run-20240524_160400-u1gfkpji[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33mfrosty-river-238[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/kazeev/WyckoffTransformer[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/kazeev/WyckoffTransformer/runs/u1gfkpji[0m


Epoch 20 val_loss_epoch 86.28123474121094 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt




Epoch 40 val_loss_epoch 54.89169692993164 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 140 val_loss_epoch 49.71947479248047 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 300 val_loss_epoch 45.63631057739258 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 1140 val_loss_epoch 44.860931396484375 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 1160 val_loss_epoch 43.69648361206055 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 1220 val_loss_epoch 38.825927734375 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 1460 val_loss_epoch 38.034454345703125 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 1560 val_loss_epoch 38.033111572265625 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 1780 val_loss_epoch 36.998779296875 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 1880 val_loss_epoch 36.805824279785156 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 2200 val_loss_epoch 35.67499542236328 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 2580 val_loss_epoch 35.660030364990234 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 2720 val_loss_epoch 34.49345397949219 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 4060 val_loss_epoch 34.33827590942383 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 4080 val_loss_epoch 34.13359451293945 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 4480 val_loss_epoch 33.72593307495117 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 5460 val_loss_epoch 33.699668884277344 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 5520 val_loss_epoch 33.052310943603516 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 5560 val_loss_epoch 32.82109451293945 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 6440 val_loss_epoch 32.78449630737305 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 6960 val_loss_epoch 32.652732849121094 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 7420 val_loss_epoch 32.47675323486328 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 7440 val_loss_epoch 32.1143913269043 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 7560 val_loss_epoch 32.059532165527344 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 7840 val_loss_epoch 31.922780990600586 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 8680 val_loss_epoch 31.88934898376465 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 10360 val_loss_epoch 31.872167587280273 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 10380 val_loss_epoch 31.811370849609375 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 10420 val_loss_epoch 31.722091674804688 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 10540 val_loss_epoch 31.645763397216797 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 10560 val_loss_epoch 31.60563850402832 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 10600 val_loss_epoch 31.516510009765625 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 12760 val_loss_epoch 31.504451751708984 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 12840 val_loss_epoch 31.49676513671875 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 12860 val_loss_epoch 31.455251693725586 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 12980 val_loss_epoch 31.439817428588867 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 13080 val_loss_epoch 31.398475646972656 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


Epoch 13100 val_loss_epoch 31.374202728271484 saved to checkpoints/2024-05-24_16-03-59/best_model_params.pt


[34m[1mwandb[0m: - 2.771 MB of 2.771 MB uploaded

[34m[1mwandb[0m: \ 2.771 MB of 2.779 MB uploaded

[34m[1mwandb[0m: | 0.483 MB of 2.779 MB uploaded

[34m[1mwandb[0m: / 2.779 MB of 2.779 MB uploaded

[34m[1mwandb[0m: - 2.779 MB of 2.779 MB uploaded

[34m[1mwandb[0m: \ 2.779 MB of 2.779 MB uploaded

[34m[1mwandb[0m:                                                                                


[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:             epoch ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
[34m[1mwandb[0m: known_cascade_len ██▁▅▅▅█▁▅▅█▅▁▁▅█▁█▁▅▁▅▅▁█▅▅█▁▅█▅▁▁▁█▅█▁▁
[34m[1mwandb[0m:     known_seq_len █▁▆▄▁▄▆▆▅▃▇█▄▂█▅█▆█▅█▃▆▇▃▅▅▇▇▂▃▅▂▃▁█▆▅▁▆
[34m[1mwandb[0m:                lr █▆▅▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:  train_loss_batch ▁▄▁▁▆▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▄▁█▁▁▁█▁
[34m[1mwandb[0m:  train_loss_epoch █▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:    val_loss_epoch █▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:             epoch 40000
[34m[1mwandb[0m: known_cascade_len 1
[34m[1mwandb[0m:     known_seq_len 12
[34m[1mwandb[0m:                lr 0.0
[34m[1mwandb[0m:  train_loss_batch 37.45678
[34m[1mwandb[0m:  train_loss_epoch 31.40739
[34m[1mwandb[0m:    val_loss_epoch 31.56444
[34m[1mwandb[0m: 


[34m[1mwandb[0m: 🚀 View run [33mfrosty-river-238[0m at: [34m[4mhttps://wandb.ai/kazeev/WyckoffTransformer/runs/u1gfkpji[0m
[34m[1mwandb[0m: ⭐️ View project at: [34m[4mhttps://wandb.ai/kazeev/WyckoffTransformer[0m
[34m[1mwandb[0m: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)


[34m[1mwandb[0m: Find logs at: [35m[1m./wandb/run-20240524_160400-u1gfkpji/logs[0m
