In [1]:
import torch
import pytorch_lightning as pl
from torchvision.transforms import ToTensor
from torchvision.datasets import CIFAR10
from src.models.softmoe_lightning import LightningVitSoftMoE

In [2]:
train_dataset = CIFAR10(root='./data', train=True, download=True,transform=ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataset = CIFAR10(root='./data', train=False, download=True,transform=ToTensor())
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
len(train_loader)

1563

In [4]:
image = train_dataset[0][0]
image_size = image.shape[1]
patch_size = image.shape[1] // 4
channels = image.shape[0]
dim = 64
num_experts= 16
num_slots = 8
num_tokens = 16
max_epochs = 30
model = LightningVitSoftMoE(image_size=image_size, patch_size=patch_size, num_classes=10, 
                        dim=64, depth=6, heads=8, 
                        num_experts=num_experts, num_slots=num_slots, num_tokens=num_tokens, 
                        channels=3, dim_head=64,learning_rate=1e-5,warmup_steps=len(train_loader)* int(max_epochs * 0.1))

In [5]:
tensorboard_logger = pl.loggers.TensorBoardLogger("logs", name="ViTSoftMoE")
trainer = pl.Trainer(logger=tensorboard_logger, devices=torch.cuda.device_count(),accelerator="gpu",max_epochs=max_epochs,precision="16-mixed")

trainer.fit(model, train_loader, val_loader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | model | ViTSoftMoE       | 2.4 M  | train
1 | loss  | CrossEntropyLoss | 0      | train
---------------------------------------------------
2.4 M     Trainable params
0         Non-trainable params
2.4 M     Total params
9.618     Total estimated model params size (MB)
642       Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

C:\Users\fede_\.conda\envs\MoE\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


                                                                           

C:\Users\fede_\.conda\envs\MoE\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 1563/1563 [04:12<00:00,  6.18it/s, v_num=4]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/313 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/313 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 1/313 [00:00<00:14, 21.74it/s][A
Validation DataLoader 0:   1%|          | 2/313 [00:00<00:16, 18.51it/s][A
Validation DataLoader 0:   1%|          | 3/313 [00:00<00:16, 18.40it/s][A
Validation DataLoader 0:   1%|▏         | 4/313 [00:00<00:16, 19.00it/s][A
Validation DataLoader 0:   2%|▏         | 5/313 [00:00<00:15, 19.41it/s][A
Validation DataLoader 0:   2%|▏         | 6/313 [00:00<00:15, 19.23it/s][A
Validation DataLoader 0:   2%|▏         | 7/313 [00:00<00:16, 18.86it/s][A
Validation DataLoader 0:   3%|▎         | 8/313 [00:00<00:15, 19.16it/s][A
Validation DataLoader 0:   3%|▎         | 9/313 [00:00<00:16, 18.88it/s][A
Validation DataLoader 0:   3%|▎         | 10/313 [00:00<00:15, 19.08it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 1563/1563 [04:07<00:00,  6.32it/s, v_num=4]
