In [1]:
%env WANDB_NOTEBOOK_NAME Train dev.ipynb
%env CUDA_DEVICE_ORDER PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES 0
%env PYTORCH_CUDA_ALLOC_CONF backend:cudaMallocAsync

env: WANDB_NOTEBOOK_NAME=Train dev.ipynb
env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=0
env: PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync


In [2]:
import torch
# UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
torch.set_float32_matmul_precision('high')

In [3]:
from mp_20_utils import load_all_data
device = 'cuda'
dataset = 'mp_20_biternary'
datasets_pd, torch_datasets, site_to_ids, element_to_ids, spacegroup_to_ids, max_len, max_enumeration, enumeration_stop, enumeration_pad = load_all_data(
    dataset=dataset)
print(max_len, max_enumeration, enumeration_stop, enumeration_pad)

20 7 8 9


In [4]:
from cascade_transformer.model import CascadeTransformer
from wyckoff_transformer import WyckoffTrainer
from tokenization import PAD_TOKEN, MASK_TOKEN
n_space_groups = len(spacegroup_to_ids)
# Not all 230 space groups are present in the data
# Embedding doesn't support uint8. Sad!
dtype = torch.int64
cascade_order = ("elements", "symmetry_sites", "symmetry_sites_enumeration")
# (N_i, d_i, pad_i)
assert max_enumeration + 1 == enumeration_stop
assert max_enumeration + 2 == enumeration_pad
enumeration_mask = max_enumeration + 3
assert enumeration_mask < torch.iinfo(dtype).max

cascade = (
    (len(element_to_ids), 64, torch.tensor(element_to_ids[PAD_TOKEN], dtype=dtype, device=device)),
    (len(site_to_ids), 64 - 1, torch.tensor(site_to_ids[PAD_TOKEN], dtype=dtype, device=device)),
    (enumeration_mask + 1, None, torch.tensor(enumeration_pad, dtype=dtype, device=device))
)
model = CascadeTransformer(
    n_start=n_space_groups,
    cascade=cascade,
    n_head=4,
    d_hid=256,
    n_layers=4,
    dropout=0.1,
    use_mixer=True).to(device)
# Our dynamic discard of predicting PAD calls for frequent recompilation
# model = torch.compile(model)

In [5]:
import torch
pad_dict = {
    "elements": element_to_ids[PAD_TOKEN],
    "symmetry_sites": site_to_ids[PAD_TOKEN],
    "symmetry_sites_enumeration": enumeration_pad
}
mask_dict = {
    "elements": element_to_ids[MASK_TOKEN],
    "symmetry_sites": site_to_ids[MASK_TOKEN],
    "symmetry_sites_enumeration": enumeration_mask
}
trainer = WyckoffTrainer(
    model, torch_datasets, pad_dict, mask_dict, cascade_order, "spacegroup_number", max_len, device, dtype=dtype)
trainer.train(epochs=40000, val_period=20)

[34m[1mwandb[0m: Currently logged in as: [33mkazeev[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Tracking run with wandb version 0.17.0


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/home/kna/WyckoffTransformer/wandb/run-20240524_160417-eya1llbi[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33mgraceful-wave-239[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/kazeev/WyckoffTransformer[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/kazeev/WyckoffTransformer/runs/eya1llbi[0m


Epoch 20 val_loss_epoch 110.37109375 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt




Epoch 60 val_loss_epoch 51.28900909423828 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 120 val_loss_epoch 49.35118865966797 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 320 val_loss_epoch 46.471858978271484 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 760 val_loss_epoch 45.8333740234375 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 1080 val_loss_epoch 44.10159683227539 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 1200 val_loss_epoch 42.67142868041992 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 1220 val_loss_epoch 41.50479507446289 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 1320 val_loss_epoch 40.442752838134766 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 2580 val_loss_epoch 40.442298889160156 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 6600 val_loss_epoch 40.22306823730469 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 6640 val_loss_epoch 40.21381759643555 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 6660 val_loss_epoch 39.729244232177734 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 6920 val_loss_epoch 39.422401428222656 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 6980 val_loss_epoch 39.25120544433594 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 7000 val_loss_epoch 39.17072677612305 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 7080 val_loss_epoch 39.04907989501953 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 7100 val_loss_epoch 38.69624710083008 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 7180 val_loss_epoch 38.376766204833984 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 8400 val_loss_epoch 38.33898162841797 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 8440 val_loss_epoch 37.9315185546875 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 8580 val_loss_epoch 37.91562271118164 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 8620 val_loss_epoch 37.639583587646484 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 8780 val_loss_epoch 37.156009674072266 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 9500 val_loss_epoch 36.99886703491211 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 11340 val_loss_epoch 36.71635437011719 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 11580 val_loss_epoch 36.5906982421875 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 11660 val_loss_epoch 36.52017593383789 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 11680 val_loss_epoch 36.41949462890625 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 11740 val_loss_epoch 36.40169143676758 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 11780 val_loss_epoch 35.936771392822266 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 13620 val_loss_epoch 35.8924674987793 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 14180 val_loss_epoch 35.81734848022461 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 14600 val_loss_epoch 35.8140983581543 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 15120 val_loss_epoch 35.790870666503906 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 15300 val_loss_epoch 35.759098052978516 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 15400 val_loss_epoch 35.68293380737305 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 15800 val_loss_epoch 35.66883087158203 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 16280 val_loss_epoch 35.59356689453125 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


Epoch 16340 val_loss_epoch 35.57943344116211 saved to checkpoints/2024-05-24_16-04-14/best_model_params.pt


[34m[1mwandb[0m: - 1.831 MB of 1.831 MB uploaded

[34m[1mwandb[0m: \ 1.831 MB of 1.839 MB uploaded

[34m[1mwandb[0m: | 1.202 MB of 2.842 MB uploaded

[34m[1mwandb[0m: / 2.842 MB of 2.842 MB uploaded

[34m[1mwandb[0m: - 2.842 MB of 2.842 MB uploaded

[34m[1mwandb[0m: \ 2.842 MB of 2.842 MB uploaded

[34m[1mwandb[0m:                                                                                


[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:             epoch ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
[34m[1mwandb[0m: known_cascade_len ██▅▅▁█▁█▅▁████▅▁█▅▅▅▅▁█▅██▁▁▅█▅████▁██▁▁
[34m[1mwandb[0m:     known_seq_len █▃▂█▅▇▇▄▃▆▂▄▆▂▂▇▆▇▃▆▄▆▁▂▂▅▆▁▇▇▃▆▃▁▂▇█▆▅▄
[34m[1mwandb[0m:                lr █▇▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:  train_loss_batch ▁▁▂▁▁▁▁▁▁▁▂▁▁▁▅▁▁▁▁▁▁▁▂▅▁▁▁█▁▁▁▁▁▂▂▁▁▁▁▁
[34m[1mwandb[0m:  train_loss_epoch ▂█▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:    val_loss_epoch ▂█▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:             epoch 40000
[34m[1mwandb[0m: known_cascade_len 2
[34m[1mwandb[0m:     known_seq_len 6
[34m[1mwandb[0m:                lr 0.0
[34m[1mwandb[0m:  train_loss_batch 1099.88013
[34m[1mwandb[0m:  train_loss_epoch 35.75783
[34m[1mwandb[0m:    val_loss_epoch 35.65189
[34m[1mwandb[0m: 


[34m[1mwandb[0m: 🚀 View run [33mgraceful-wave-239[0m at: [34m[4mhttps://wandb.ai/kazeev/WyckoffTransformer/runs/eya1llbi[0m
[34m[1mwandb[0m: ⭐️ View project at: [34m[4mhttps://wandb.ai/kazeev/WyckoffTransformer[0m
[34m[1mwandb[0m: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)


[34m[1mwandb[0m: Find logs at: [35m[1m./wandb/run-20240524_160417-eya1llbi/logs[0m
