Skip to content

Latest commit

 

History

History
43 lines (29 loc) · 1.46 KB

File metadata and controls

43 lines (29 loc) · 1.46 KB

Meta-Learning - MAML

This is an example of a meta-learning algorithm called MAML, trained on the Omniglot dataset of handwritten characters from different alphabets.

The goal of meta-learning in this context is to learn a 'meta'-model trained on many different tasks, such that it can quickly adapt to a new task when trained with very few samples (few-shot learning). If you are new to meta-learning, have a look at this short introduction video.

We show two code versions: The first one is implemented in raw PyTorch, but it contains quite a bit of boilerplate code for distributed training. The second one is using Lightning Fabric to accelerate and scale the model.

Tip: You can easily inspect the difference between the two files with:

sdiff train_torch.py train_fabric.py

Requirements

pip install lightning learn2learn cherry-rl 'gym<=0.22'

Run

Raw PyTorch:

torchrun --nproc_per_node=2 --standalone train_torch.py

Accelerated using Lightning Fabric:

fabric run train_fabric.py --devices 2 --strategy ddp --accelerator cpu

References