In [15]:
pip install -r requirements.txt



In [16]:
import argparse
import os
from pathlib import Path
import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.optim import Optimizer
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from tqdm import tqdm
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam

In [17]:
NUM_EPOCHS = 50
LEARNING_RATE = 1e-3

def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase):
    transform_train = transforms.Compose(
        [transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()]
    )
    transform_test = transforms.ToTensor()

    data_path = os.environ.get("DATA", "./data")
    with coordinator.priority_execution():
        train_dataset = torchvision.datasets.CIFAR10(
            root=data_path, train=True, transform=transform_train, download=True
        )
        test_dataset = torchvision.datasets.CIFAR10(
            root=data_path, train=False, transform=transform_test, download=True
        )

    train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
    return train_dataloader, test_dataloader

@torch.no_grad()
def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float:
    model.eval()
    correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
    total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
    for images, labels in test_dataloader:
        images = images.cuda()
        labels = labels.cuda()
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    dist.all_reduce(correct)
    dist.all_reduce(total)
    accuracy = correct.item() / total.item()
    if coordinator.is_master():
        print(f"Accuracy of the model on the test images: {accuracy * 100:.2f} %")
    return accuracy

def train_epoch(
    epoch: int,
    model: nn.Module,
    optimizer: Optimizer,
    criterion: nn.Module,
    train_dataloader: DataLoader,
    booster: Booster,
    coordinator: DistCoordinator,
):
    model.train()
    with tqdm(train_dataloader, desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not coordinator.is_master()) as pbar:
        for images, labels in pbar:
            images = images.cuda()
            labels = labels.cuda()
            outputs = model(images)
            loss = criterion(outputs, labels)

            booster.backward(loss, optimizer)
            optimizer.step()
            optimizer.zero_grad()

            pbar.set_postfix({"loss": loss.item()})

In [18]:
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['LOCAL_RANK'] = '0'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'

if not dist.is_initialized():
    colossalai.launch_from_torch(config={})
coordinator = DistCoordinator()
global LEARNING_RATE
LEARNING_RATE *= coordinator.world_size

booster_kwargs = {}
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin, **booster_kwargs)

train_dataloader, test_dataloader = build_dataloader(100, coordinator, plugin)

model = torchvision.models.resnet18(num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3)
model, optimizer, criterion, _, lr_scheduler = booster.boost(
    model, optimizer, criterion=criterion, lr_scheduler=lr_scheduler
)

start_epoch = 0
for epoch in range(start_epoch, NUM_EPOCHS):
    train_epoch(epoch, model, optimizer, criterion, train_dataloader, booster, coordinator)
    lr_scheduler.step()

accuracy = evaluate(model, test_dataloader, coordinator)
print(f"Final Accuracy: {accuracy * 100:.2f} %")

Files already downloaded and verified
Files already downloaded and verified


Epoch [1/50]: 100%|██████████| 500/500 [00:44<00:00, 11.31it/s, loss=1.33]
Epoch [2/50]: 100%|██████████| 500/500 [00:34<00:00, 14.54it/s, loss=1.16]
Epoch [3/50]: 100%|██████████| 500/500 [00:35<00:00, 14.07it/s, loss=0.976]
Epoch [4/50]: 100%|██████████| 500/500 [00:34<00:00, 14.65it/s, loss=0.945]
Epoch [5/50]: 100%|██████████| 500/500 [00:34<00:00, 14.67it/s, loss=0.866]
Epoch [6/50]: 100%|██████████| 500/500 [00:34<00:00, 14.65it/s, loss=0.886]
Epoch [7/50]: 100%|██████████| 500/500 [00:34<00:00, 14.54it/s, loss=0.816]
Epoch [8/50]: 100%|██████████| 500/500 [00:34<00:00, 14.39it/s, loss=0.832]
Epoch [9/50]: 100%|██████████| 500/500 [00:34<00:00, 14.46it/s, loss=0.724]
Epoch [10/50]: 100%|██████████| 500/500 [00:34<00:00, 14.48it/s, loss=0.731]
Epoch [11/50]: 100%|██████████| 500/500 [00:33<00:00, 14.81it/s, loss=0.63]
Epoch [12/50]: 100%|██████████| 500/500 [00:34<00:00, 14.40it/s, loss=0.682]
Epoch [13/50]: 100%|██████████| 500/500 [00:34<00:00, 14.62it/s, loss=0.535]
Epoch [14/5

Accuracy of the model on the test images: 84.53 %
Final Accuracy: 84.53 %
