This repository contains an optimized implementation of Retentive Networks Meet Vision Transformers (RMT) for the ImageNet classification task, augmented with advanced regularization techniques:
- SAM – Sharpness-Aware Minimization
- SADT– Sharpness-Aware Distilled Teachers
- CutMix – Patch-level data augmentation
- RMT model benchmarked on ImageNet with and without regularization
- Multi-GPU and distributed training support
- PyTorch implementation with modular code structure
- Performance logging and checkpointing (excluded from repo due to size)
| Path | Description |
|---|---|
classification_release/ |
Core model logic, dataloaders, training routines |
main_sam.py |
SAM training script |
main_sadt.py |
SADT training script |
train_multigpu.py |
Multi-GPU trainer |
checkpoints_*/ |
[Ignored] Checkpoints directory |
*.sh |
Launcher scripts |
- Python ≥ 3.8
- PyTorch ≥ 1.10
- CUDA-capable GPU
- (Optional) Multi-GPU setup via
torch.distributed.launch
Install dependencies:
pip install -r requirements.txt
Depending on the RMT Model you want to use
python main_sam.py
--data-path /2tb/Farid_image_classification/ILSVRC2017_CLS-LOC/ILSVRC/Data/CLS-LOC
--data-set IMNET
--model RMT_T
--epochs 300
--batch-size 64
--lr 1e-3
--output_dir checkpoints_RMT_T_SAM
--device cuda
--use-sam
--sam-rho 0.05
--sam-adaptive
chmod +x launch_sam_multigpu.sh
bash launch_sam_multigpu.sh
chmod +x launch_sadt_multigpu.sh
bash launch_sadt_multigpu.sh
##Train with Single GPU
python main_sadt.py
--data-path /2tb/Farid_image_classification/ILSVRC2017_CLS-LOC/ILSVRC/Data/CLS-LOC
--data-set IMNET
--model RMT_T
--epochs 300
--batch-size 64
--lr 1e-3
--output_dir checkpoints_RMT_T_SADT
--device cuda
--use-sadt
--sadt-noise-std 1e-4
--sadt-aux-weight 0.5
--sadt-temperature 4.0
• Inspired by Retentive Networks Meet Vision Transformers (RMT) (Fan et al, 2024)
• SAM from Sharpness-Aware Minimization (Foret et al., 2021)
• SADT from recent vision regularization works (Fahim & Boutellier, 2022)