In [3]:
import sys
sys.path.append('/home/aistudio/external-libraries')

from colossalai.legacy.amp import AMP_TYPE

BATCH_SIZE = 256
DROP_RANK = 0.1
NUM_EPOCHS = 300

fp16 = dict(
    mode=AMP_TYPE.TORCH,
)

gradient_accumulation = 16
clip_grad_norm = 1.0
dail = dict(
    gpu_aug=True,
    mixup_alpha=0.2
)

In [8]:
import colossalai
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.lr_scheduler import LinearWarmupLR
from colossalai.legacy.nn.metric import Accuracy
from colossalai.legacy.trainer import Trainer, hooks

In [9]:
import os 
import torch
from timm.models import vit_base_patch16_224
from torchvision import tansforms
from trochvision.datasets import CIFAR10

ModuleNotFoundError: No module named 'timm'

In [None]:
parser = colossalai.get_default_parser()
args = parser.parse_args()
colossalai.launch_from_torch(config=arg.config)
disable_existing_loggers()
logger = get_dist_logger()

In [None]:
print(gpc.config.BATCH_SIZE)

In [None]:
model = vit_base_patch16_224(drop_rate=0.1, num_classes=gpc.config.NUM_CLASSES)

In [None]:
def build_cifar(batch_size):
    transform_train = transforms.Compose([
        tansforms.RandomCrop(224, pad_if_needed=True),
        transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
        transforms.Normalize((0,4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    transform_test = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0,4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train)
    test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test)
    train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True)
    test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True)
    return train_dataloader, test_dataloader

train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE)

In [None]:
optimizer = colossalai.nn.Lamb(model.parameters(), lr=1.8e-2, weight_decay=0.1)
criterion = torch.nn.CrossEntropyLoss()
lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS)

In [None]:
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader, test_dataloader)

trainer = Trainer(engine=engine, logger=logger)
hook_list = [
    hooks.LossHook(),
    hooks.AccuracyHook(accuracy_func=MixupAccuracy()),
    hooks.LogMemoryByEpochHook(logger),
    hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
    hooks.SaveCheckpointHook(interval=1, checkpoint_dir='./ckpt'),
    hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
]

In [None]:
trainer.fit(
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    epochs=gpc.config.NUM_EPOCHS,
    hooks=hook_list,
    display_progress=True,
    test_interval=1
) 

In [None]:
export DATA=/home/aistudio/data
torchrun --standalone --nproc_per_node=1 train_dp.py --config ./configs/config_data_parallel.py