In [1]:
import jax
import numpy as np

import torch
from torch import nn, optim
from torchvision.transforms import ToTensor
import torch.nn.functional as F

import lightning as L

from molnet.torch_models import create_model
from molnet.data import input_pipeline
from configs.tests import torch_attention_test
from configs import root_dirs

In [2]:
config = torch_attention_test.get_config()
config.root_dir = root_dirs.get_root_dir()

rng = jax.random.PRNGKey(0)
datarng, rng = jax.random.split(rng)

ds = input_pipeline.get_pseudodatasets(datarng, config)
train_loader = ds['train']

In [3]:
batch = next(train_loader)
x, atom_map, xyz = batch['images'], batch['atom_map'], batch['xyz']

print(x.shape, atom_map.shape, xyz.shape)
print(x.dtype, atom_map.dtype, xyz.dtype)
print(type(x), type(atom_map), type(xyz))

(4, 1, 128, 128, 10) (4, 5, 128, 128, 21) (4, 54, 5)
float32 float32 float32
<class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'>


In [4]:
device = torch.device('cpu')

class LitModel(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        batch = jax.tree_util.tree_map(torch.from_numpy, batch)
        x, atom_map, xyz = batch['images'], batch['atom_map'], batch['xyz']
        x = x.to(device)
        atom_map = atom_map.to(device)
        pred = self.model(x)
        z_slices = x.shape[-1]
        loss = F.mse_loss(pred, atom_map[..., -z_slices:])
        return loss
    
    def validation_step(self, batch, batch_idx):
        batch = jax.tree_util.tree_map(torch.from_numpy, batch)
        x, atom_map, xyz = batch['images'], batch['atom_map'], batch['xyz']
        x = x.to(device)
        atom_map = atom_map.to(device)
        pred = self.model(x)
        z_slices = x.shape[-1]
        loss = F.mse_loss(pred, atom_map[..., -z_slices:])
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [5]:
unet = LitModel(
    create_model(config.model)
)

In [6]:
trainer = L.Trainer(limit_train_batches=10, max_epochs=1, accelerator='cpu')

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/kurkil1/.venvs/molnet/lib/python3.12/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


In [7]:
ds = input_pipeline.get_pseudodatasets(datarng, config)
train_loader = ds['train']

trainer.fit(model=unet, train_dataloaders=train_loader)


  | Name  | Type          | Params | Mode 
------------------------------------------------
0 | model | AttentionUNet | 17.6 K | train
------------------------------------------------
17.6 K    Trainable params
0         Non-trainable params
17.6 K    Total params
0.070     Total estimated model params size (MB)
59        Modules in train mode
0         Modules in eval mode
/Users/kurkil1/.venvs/molnet/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0:   0%|          | 0/10 [00:00<?, ?it/s] 

  return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))


Epoch 0: 100%|██████████| 10/10 [00:14<00:00,  0.68it/s, v_num=8]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 10/10 [00:14<00:00,  0.67it/s, v_num=8]
