In [1]:
import torch
# Set seed
torch.manual_seed(23)

<torch._C.Generator at 0x7fdaf00f8a30>

In [2]:
from src.qwen import load_qwen
model_qwen, tokenizer = load_qwen()

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


# Part 3 (a) (Continued..)

Plan of action:

- Preprocess the data, using `load_and_preprocess`, and split the data into `train_texts`, `val_texts` and `val_text_70` $\equiv$ 900, 100 and 100 (systems). 
- The two validation sets `val_texts` and `val_text_70` have the same `shape` but:
    - In `val_texts` each system has the full 100 pairs of prey and predators
    - In `val_texts_70` each system has only the first 70 pairs of prey and predators
- We train the model on tokenised `train_texts`
- We validated the model by predicting the remaining 30 pair points in each of the 100 system in tokenised `val_texts_70`. 
- We then compare the predicted results from `val_texts_70` to the gruond truth data `val_texts` (or `true_val_values` obtained with `data_scale_split`)
- Just like for the untrained models we then want to compute MSE and RMSE 
- And report the loss/perplexity of each trained models

We are recomended to train our model up to 10,000 steps, but we have a budgeted number of flops overall for training $10^{17}$ and due to computational power required, we are going to proceed with fewer steps first, also to familiarise with the traing procedure, before increasing the number of steps and using HPC.

In synthesis:

“We trained on 900 systems, validated on 100 full sequences for loss monitoring, and evaluated forecasting performance by generating future predictions given the first 70 steps from each validation sequence.”

All the above description has been fully prepared in `set_up_lora.py`. For flops estimation we can use `total_transformer_training_flops` in `flops.py`. The reader is invited to explore and analyse every file in `src`.

In [3]:
from src.set_up_lora import*
from src.flops import*

After training the model we can determine the estimate number of flops based on training steps and other metrics.

For this part we do not want to exceed 5000 steps, otherwise we will be too close to the limited number of FLOPs allowed for training

To train the model we are going to implement the function `train_lora_model` from `set_up_lora.py`.

In [None]:
model_lora_5000, loss_lora_5000 = train_lora_model(model_qwen, tokenizer) # default steps and hyper parameters are set here

Steps 0:   1%|          | 9/1142 [00:49<1:44:32,  5.54s/it]


Saving LoRA Weights only (Efficitent checkpoints).

Save the LoRA adapter weights (not the full Qwen model).

In [None]:
# Extract LoRA-only weights
lora_state_dict = {
    name: param.cpu()
    for name, param in model_lora_5000.named_parameters()
    if param.requires_grad
}

#torch.save(lora_state_dict, "trained_lora_part_3a/lora_weight_matrices_5000.pt")

In [5]:
config = model_lora_5000.config

# Parameters
num_steps = 5000
batch_size = 4
seq_len = 512
d_model = config.hidden_size
num_heads = config.num_attention_heads
num_layers = config.num_hidden_layers
intermediate_dim = 2 * d_model  # SwiGLU
lora_rank = 4  # if using LoRA

total_flops_estimate = total_transformer_training_flops(num_steps, batch_size, seq_len, num_layers, d_model, num_heads, intermediate_dim, lora_rank)

print(f'Total number of estimated FLOPs for training LoRA with {num_steps} steps:',total_flops_estimate)

Total number of estimated FLOPs for training LoRA with 5000 steps: 5983174656000000


Evaluating loss and perplexity of both tarin and validation set, there is a designed function in `set_up_lora.py`, that evaluates the perplexity and loss of the validation set, to determine the loss and perplexity of the training set, we can directly extract it from `model_lora_5000`.

### Loss and Perplexity 

In [None]:
_,val_texts, val_texts_70 = load_and_preprocess("data/lotka_volterra_data.h5")

max_steps = 5000 # CHANGE IF REQUIRED
print(f"After training with {max_steps} steps")
print(f"Training loss: {loss_lora_5000:.4f}")
perplexity_train = np.exp(loss_lora_5000)
print(f"Training perplexity: {perplexity_train:.4f}")

loss_val, ppl_val = evaluate_loss_perplexity_val(model_lora_5000, tokenizer, val_texts, 4)
print('')
print(f'Validation loss: {loss_val:.4f}')
print(f'Validation loss: {ppl_val:.4f}')

After training with 10 steps
Training loss: 4.2634
Training perplexity: 71.0496


Validating: 100%|██████████| 75/75 [02:09<00:00,  1.73s/it, avg_loss=4.7363]


Validation loss: 4.7363
Validation loss: 114.0172





### Forecasting Missing Pair Values

After training the model, we can start using its predictive ability with the function `prediction_after_training` also defined in `set_up_lora.py`. 
Our goal is to predict the missing 30 pairpoints in `val_texts_70`, to then compare it to the full validation set, already pre-defined in the function `prediction_after_training`. Once we have both sets we can evaluate the following metrics, error difference within each system, MSE and RMSE.

In [None]:
predicted_encoded = prediction_after_training(model_lora_5000, tokenizer, val_texts_70)

Generating predictions:  16%|█▌        | 16/100 [09:35<50:21, 35.97s/it]


KeyboardInterrupt: 

: 

### Evaluating Metrics

To evaluate the metric mentioned above, we are going to use the designed function, `decoder_and_metrics_evaluator`, this function will return, the predicted outputs both as string-like and time=series (both outputs will be used in other functions), the true values in the validation set, and all the relevant metrics, i.e. MSE, RMSE and error in each idividual system.

In [None]:
predictions_decoded, predicted_output, true_values, MSE_values, RMSE_values, error_per_system = decoder_and_metrics_evaluator(predicted_encoded, tokenizer)

Saving results.

In [None]:
#np.savez("trained_lora_part_3a/predictions_decoded_trained_lora_3a.npz", *predictions_decoded)
#MSE_loaded = np.save("trained_lora_part_3a/MSE_values_3a.npy", np.array(MSE_values))
#np.save('trained_lora_part_3a/RMSE_values_3a', RMSE_values)
#np.savez("trained_lora_part_3a/error_per_system_5000.npz", *error_per_system)

### Visualisation of results

There is a designed function that wraps all the functions defined in `plotting.py` into a single function, `collective_plots`

In [None]:
collective_plots(predicted_encoded, tokenizer, system_id=0, bins=30)