In [1]:
from torchdyn.core import NeuralODE

from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything

import torch.nn as nn
import torch

from transformers import BertModel, AutoTokenizer

from TrainDatasets.oscar import OSCARDataModule

In [2]:
model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)
model

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

Thanks documentation! https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/text-transformers.html

In [3]:
t_span = torch.linspace(0, 1, 2)
t_span

tensor([0., 1.])

In [4]:
model = NeuralODE(model, sensitivity='adjoint', solver='tsit5', interpolator=None)
model

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


Neural ODE:
	- order: 1        
	- solver: Tsitouras45()
	- adjoint solver: Tsitouras45()        
	- tolerances: relative 0.001 absolute 0.001        
	- adjoint tolerances: relative 0.0001 absolute 0.0001        
	- num_parameters: 109482240        
	- NFE: 0.0

In [5]:
class Learner(LightningModule):
    def __init__(self, t_span:torch.Tensor, model:nn.Module):
        super().__init__()
        self.model, self.t_span = model, t_span
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch      
        t_eval, y_hat = self.model(x, self.t_span)
        y_hat = y_hat[-1] # select last point of solution trajectory
        loss = nn.CrossEntropyLoss()(y_hat, y)
        return {'loss': loss}   
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.01)

In [6]:
dm = OSCARDataModule(AutoTokenizer.from_pretrained(model_name))
dm

<TrainDatasets.oscar.OSCARDataModule at 0x1c939fa0a30>

In [7]:
learn = Learner(t_span, model)
learn

Learner(
  (model): Neural ODE:
  	- order: 1        
  	- solver: Tsitouras45()
  	- adjoint solver: Tsitouras45()        
  	- tolerances: relative 0.001 absolute 0.001        
  	- adjoint tolerances: relative 0.0001 absolute 0.0001        
  	- num_parameters: 109482240        
  	- NFE: 0.0
)

In [8]:
trainer = Trainer(min_epochs=1, max_epochs=5, progress_bar_refresh_rate=1)
trainer.fit(learn, datamodule=dm)

  rank_zero_deprecation(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
Missing logger folder: C:\Users\weipy\OneDrive\Documents\GitHub\bert-ode\lightning_logs


Downloading builder script:   0%|          | 0.00/5.58k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/359k [00:00<?, ?B/s]


  | Name  | Type      | Params
------------------------------------
0 | model | NeuralODE | 109 M 
------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     Total params
437.929   Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
