In [11]:
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
from accelerate import Accelerator
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset

from m2_utilities.load_data import load_trajectories
from m2_utilities.preprocessor import load_and_preprocess
from m2_utilities.qwen import load_qwen
from m2_utilities.lora import apply_lora, process_sequences
from m2_utilities.flops import compute_flops

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
model, tokenizer = load_qwen()

config.json:   0%|          | 0.00/659 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

In [3]:
lora_rank = 4
apply_lora(model, lora_rank)

### Loading Trajectories and Stringifying

In [4]:
ALPHA = 1.5
DECIMALS = 2
texts = load_and_preprocess("data/lotka_volterra_data.h5", ALPHA, DECIMALS)

n_train = 700
train_texts = texts[:n_train]
val_texts = texts[n_train:]

In [5]:
train_input_ids = process_sequences(train_texts, tokenizer)
val_input_ids = process_sequences(val_texts, tokenizer)

In [6]:
print(train_input_ids.shape)

torch.Size([2800, 512])


In [12]:
batch_size = 4
learning_rate = 1e-5

optimizer = torch.optim.Adam(
    (p for p in model.parameters() if p.requires_grad), lr=learning_rate
)
train_dataset = TensorDataset(train_input_ids)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

accelerator = Accelerator()
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


### Training

In [17]:
total_flops = 0.0
model.train()
steps = 0
while steps < 10000:
    progress_bar = tqdm(train_loader, desc=f"Steps {steps}")
    for (batch,) in train_loader:
        optimizer.zero_grad()
        outputs = model(batch, labels=batch)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        steps += 1
        
#         progress_bar.set_postfix(loss=loss.item())
        if steps > 10000:
            break
        total_flops += batch_size * compute_flops(512, backpropagate=True)
        print(f"{total_flops:.2e}")

model.eval()


Steps 0:   0%|          | 0/700 [00:11<?, ?it/s, loss=0.587]


6.78e+12
1.36e+13
2.03e+13
2.71e+13
3.39e+13
4.07e+13
4.74e+13
5.42e+13
6.10e+13
6.78e+13
7.46e+13
8.13e+13
8.81e+13
9.49e+13
1.02e+14
1.08e+14
1.15e+14
1.22e+14
1.29e+14
1.36e+14
1.42e+14
1.49e+14
1.56e+14
1.63e+14
1.69e+14
1.76e+14
1.83e+14
1.90e+14
1.97e+14
2.03e+14
2.10e+14
2.17e+14
2.24e+14
2.30e+14
2.37e+14
2.44e+14
2.51e+14
2.58e+14
2.64e+14
2.71e+14
2.78e+14
2.85e+14
2.91e+14
2.98e+14
3.05e+14
3.12e+14
3.19e+14
3.25e+14
3.32e+14
3.39e+14
3.46e+14
3.52e+14
3.59e+14
3.66e+14
3.73e+14
3.80e+14
3.86e+14
3.93e+14
4.00e+14
4.07e+14
4.13e+14
4.20e+14
4.27e+14
4.34e+14
4.41e+14
4.47e+14
4.54e+14
4.61e+14
4.68e+14
4.74e+14
4.81e+14
4.88e+14
4.95e+14
5.02e+14
5.08e+14
5.15e+14
5.22e+14
5.29e+14
5.35e+14
5.42e+14
5.49e+14
5.56e+14
5.63e+14
5.69e+14
5.76e+14
5.83e+14
5.90e+14
5.96e+14
6.03e+14
6.10e+14
6.17e+14
6.24e+14
6.30e+14
6.37e+14
6.44e+14
6.51e+14
6.57e+14
6.64e+14
6.71e+14
6.78e+14
6.85e+14
6.91e+14
6.98e+14
7.05e+14
7.12e+14
7.18e+14
7.25e+14
7.32e+14
7.39e+14
7.46e+14
7.52e+14
7

Steps 0:   0%|          | 0/700 [03:58<?, ?it/s]s]


4.75e+15
4.76e+15
4.76e+15
4.77e+15
4.78e+15
4.79e+15
4.79e+15
4.80e+15
4.81e+15
4.81e+15
4.82e+15
4.83e+15
4.83e+15
4.84e+15
4.85e+15
4.85e+15
4.86e+15
4.87e+15
4.87e+15
4.88e+15
4.89e+15
4.89e+15
4.90e+15
4.91e+15
4.91e+15
4.92e+15
4.93e+15
4.93e+15
4.94e+15
4.95e+15
4.95e+15
4.96e+15
4.97e+15
4.98e+15
4.98e+15
4.99e+15
5.00e+15
5.00e+15
5.01e+15
5.02e+15
5.02e+15
5.03e+15
5.04e+15
5.04e+15
5.05e+15
5.06e+15
5.06e+15
5.07e+15
5.08e+15
5.08e+15
5.09e+15
5.10e+15
5.10e+15
5.11e+15
5.12e+15
5.12e+15
5.13e+15
5.14e+15
5.14e+15
5.15e+15
5.16e+15
5.16e+15
5.17e+15
5.18e+15
5.19e+15
5.19e+15
5.20e+15
5.21e+15
5.21e+15
5.22e+15
5.23e+15
5.23e+15
5.24e+15
5.25e+15
5.25e+15
5.26e+15
5.27e+15
5.27e+15
5.28e+15
5.29e+15
5.29e+15
5.30e+15
5.31e+15
5.31e+15
5.32e+15
5.33e+15
5.33e+15
5.34e+15
5.35e+15
5.35e+15
5.36e+15
5.37e+15
5.37e+15
5.38e+15
5.39e+15
5.40e+15
5.40e+15
5.41e+15
5.42e+15
5.42e+15
5.43e+15
5.44e+15
5.44e+15
5.45e+15
5.46e+15
5.46e+15
5.47e+15
5.48e+15
5.48e+15
5.49e+15
5.50e+15
5


Steps 700:   0%|          | 0/700 [03:58<?, ?it/s]][A


9.50e+15
9.50e+15
9.51e+15
9.52e+15
9.52e+15
9.53e+15
9.54e+15
9.54e+15
9.55e+15
9.56e+15
9.56e+15
9.57e+15
9.58e+15
9.58e+15
9.59e+15
9.60e+15
9.60e+15
9.61e+15
9.62e+15
9.62e+15
9.63e+15
9.64e+15
9.65e+15
9.65e+15
9.66e+15
9.67e+15
9.67e+15
9.68e+15
9.69e+15
9.69e+15
9.70e+15
9.71e+15
9.71e+15
9.72e+15
9.73e+15
9.73e+15
9.74e+15
9.75e+15
9.75e+15
9.76e+15
9.77e+15
9.77e+15
9.78e+15
9.79e+15
9.79e+15
9.80e+15
9.81e+15
9.81e+15
9.82e+15
9.83e+15
9.83e+15
9.84e+15
9.85e+15
9.86e+15
9.86e+15
9.87e+15
9.88e+15
9.88e+15
9.89e+15
9.90e+15
9.90e+15
9.91e+15
9.92e+15
9.92e+15
9.93e+15
9.94e+15
9.94e+15
9.95e+15
9.96e+15
9.96e+15
9.97e+15
9.98e+15
9.98e+15
9.99e+15
1.00e+16
1.00e+16
1.00e+16
1.00e+16
1.00e+16
1.00e+16
1.00e+16
1.00e+16
1.01e+16
1.01e+16
1.01e+16
1.01e+16
1.01e+16
1.01e+16
1.01e+16
1.01e+16
1.01e+16
1.01e+16
1.01e+16
1.01e+16
1.01e+16
1.01e+16
1.01e+16
1.02e+16
1.02e+16
1.02e+16
1.02e+16
1.02e+16
1.02e+16
1.02e+16
1.02e+16
1.02e+16
1.02e+16
1.02e+16
1.02e+16
1.02e+16
1.02e+16
1

Steps 1400:   0%|          | 0/700 [03:58<?, ?it/s]


1.42e+16
1.42e+16
1.43e+16
1.43e+16
1.43e+16
1.43e+16
1.43e+16
1.43e+16
1.43e+16
1.43e+16
1.43e+16
1.43e+16
1.43e+16
1.43e+16
1.43e+16
1.43e+16
1.43e+16
1.44e+16
1.44e+16
1.44e+16
1.44e+16
1.44e+16
1.44e+16
1.44e+16
1.44e+16
1.44e+16
1.44e+16
1.44e+16
1.44e+16
1.44e+16
1.44e+16
1.45e+16
1.45e+16
1.45e+16
1.45e+16
1.45e+16
1.45e+16
1.45e+16
1.45e+16
1.45e+16
1.45e+16
1.45e+16
1.45e+16
1.45e+16
1.45e+16
1.45e+16
1.46e+16
1.46e+16
1.46e+16
1.46e+16
1.46e+16
1.46e+16
1.46e+16
1.46e+16
1.46e+16
1.46e+16
1.46e+16
1.46e+16
1.46e+16
1.46e+16
1.46e+16
1.47e+16
1.47e+16
1.47e+16
1.47e+16
1.47e+16
1.47e+16
1.47e+16
1.47e+16
1.47e+16
1.47e+16
1.47e+16
1.47e+16
1.47e+16
1.47e+16
1.47e+16
1.48e+16
1.48e+16
1.48e+16
1.48e+16
1.48e+16
1.48e+16
1.48e+16
1.48e+16
1.48e+16
1.48e+16
1.48e+16
1.48e+16
1.48e+16
1.48e+16
1.49e+16
1.49e+16
1.49e+16
1.49e+16
1.49e+16
1.49e+16
1.49e+16
1.49e+16
1.49e+16
1.49e+16
1.49e+16
1.49e+16
1.49e+16
1.49e+16
1.49e+16
1.50e+16
1.50e+16
1.50e+16
1.50e+16
1.50e+16
1.50e+16
1


Steps 2100:   0%|          | 0/700 [03:58<?, ?it/s][A


1.90e+16
1.90e+16
1.90e+16
1.90e+16
1.90e+16
1.90e+16
1.90e+16
1.90e+16
1.90e+16
1.90e+16
1.91e+16
1.91e+16
1.91e+16
1.91e+16
1.91e+16
1.91e+16
1.91e+16
1.91e+16
1.91e+16
1.91e+16
1.91e+16
1.91e+16
1.91e+16
1.91e+16
1.91e+16
1.92e+16
1.92e+16
1.92e+16
1.92e+16
1.92e+16
1.92e+16
1.92e+16
1.92e+16
1.92e+16
1.92e+16
1.92e+16
1.92e+16
1.92e+16
1.92e+16
1.92e+16
1.93e+16
1.93e+16
1.93e+16
1.93e+16
1.93e+16
1.93e+16
1.93e+16
1.93e+16
1.93e+16
1.93e+16
1.93e+16
1.93e+16
1.93e+16
1.93e+16
1.94e+16
1.94e+16
1.94e+16
1.94e+16
1.94e+16
1.94e+16
1.94e+16
1.94e+16
1.94e+16
1.94e+16
1.94e+16
1.94e+16
1.94e+16
1.94e+16
1.94e+16
1.95e+16
1.95e+16
1.95e+16
1.95e+16
1.95e+16
1.95e+16
1.95e+16
1.95e+16
1.95e+16
1.95e+16
1.95e+16
1.95e+16
1.95e+16
1.95e+16
1.95e+16
1.96e+16
1.96e+16
1.96e+16
1.96e+16
1.96e+16
1.96e+16
1.96e+16
1.96e+16
1.96e+16
1.96e+16
1.96e+16
1.96e+16
1.96e+16
1.96e+16
1.96e+16
1.97e+16
1.97e+16
1.97e+16
1.97e+16
1.97e+16
1.97e+16
1.97e+16
1.97e+16
1.97e+16
1.97e+16
1.97e+16
1.97e+16
1

Steps 2800:   0%|          | 0/700 [03:58<?, ?it/s]


2.37e+16
2.37e+16
2.37e+16
2.38e+16
2.38e+16
2.38e+16
2.38e+16
2.38e+16
2.38e+16
2.38e+16
2.38e+16
2.38e+16
2.38e+16
2.38e+16
2.38e+16
2.38e+16
2.38e+16
2.38e+16
2.39e+16
2.39e+16
2.39e+16
2.39e+16
2.39e+16
2.39e+16
2.39e+16
2.39e+16
2.39e+16
2.39e+16
2.39e+16
2.39e+16
2.39e+16
2.39e+16
2.39e+16
2.40e+16
2.40e+16
2.40e+16
2.40e+16
2.40e+16
2.40e+16
2.40e+16
2.40e+16
2.40e+16
2.40e+16
2.40e+16
2.40e+16
2.40e+16
2.40e+16
2.40e+16
2.41e+16
2.41e+16
2.41e+16
2.41e+16
2.41e+16
2.41e+16
2.41e+16
2.41e+16
2.41e+16
2.41e+16
2.41e+16
2.41e+16
2.41e+16
2.41e+16
2.41e+16
2.42e+16
2.42e+16
2.42e+16
2.42e+16
2.42e+16
2.42e+16
2.42e+16
2.42e+16
2.42e+16
2.42e+16
2.42e+16
2.42e+16
2.42e+16
2.42e+16
2.43e+16
2.43e+16
2.43e+16
2.43e+16
2.43e+16
2.43e+16
2.43e+16
2.43e+16
2.43e+16
2.43e+16
2.43e+16
2.43e+16
2.43e+16
2.43e+16
2.43e+16
2.44e+16
2.44e+16
2.44e+16
2.44e+16
2.44e+16
2.44e+16
2.44e+16
2.44e+16
2.44e+16
2.44e+16
2.44e+16
2.44e+16
2.44e+16
2.44e+16
2.44e+16
2.45e+16
2.45e+16
2.45e+16
2.45e+16
2


Steps 3500:   0%|          | 0/700 [03:58<?, ?it/s][A


2.85e+16
2.85e+16
2.85e+16
2.85e+16
2.85e+16
2.85e+16
2.85e+16
2.85e+16
2.85e+16
2.85e+16
2.85e+16
2.85e+16
2.86e+16
2.86e+16
2.86e+16
2.86e+16
2.86e+16
2.86e+16
2.86e+16
2.86e+16
2.86e+16
2.86e+16
2.86e+16
2.86e+16
2.86e+16
2.86e+16
2.87e+16
2.87e+16
2.87e+16
2.87e+16
2.87e+16
2.87e+16
2.87e+16
2.87e+16
2.87e+16
2.87e+16
2.87e+16


KeyboardInterrupt: 