In [None]:
import math
import torch
from torch import nn
from torchvision import datasets, transforms

### Setting

In [None]:
# Optimization Setting
AUTOCAST_FLAG = False
COMPILE_FLAG = False

# Data && Model Setting
## Data Resize Shape
size = 16 #16, '32'
## RandAugment
num_ops = 2
magnitude = 5 #5, '10'
## Model
n = 3
num_classes = 100
p = 0.1 #0.1, '0.2'
## DataLoader
batch_size = 512
eval_batch_size = 1000
num_workers = 2
pin_memory = True
drop_last = True
persistent_workers = False

# Training Setting
epochs = 160 #160
## loss function
lr = 1e-1
momentum = 0.9
weight_decay = 1e-4
## MultiStepLR scheduler
milestones = [80, 120] #[80, 120]
gamma = 0.1
## stage scheduler
### modify image resize, randaugment magnitude, dropout probability
stage_schedule = {
    'milestones': [40, 80, 100, 120, 140],
    'counter': 0,
    'len': 2,
    'size': [16, 32],
    'magnitude': [5, 10],
    'p': [0.1, 0.2]
}

In [None]:
root = '~/.pytorch/datasets/'
device = torch.device(
    f'cuda:{torch.cuda.device_count() - 1}' if torch.cuda.is_available() else 'cpu'
)
if device.type == 'cuda':
    torch.cuda.set_device(device)
    torch.backends.cudnn.benchmark = True
print(f'Device: {device}, Type: {device.type}')

### Module

In [None]:
def load_data(
    root: str,
    batch_size: int,
    eval_batch_size: int,
    num_workers: int = 2,
    size: int = 32,
    num_ops: int = 2,
    magnitude: int = 9,
    pin_memory: bool = False,
    drop_last: bool = False,
    persistent_workers: bool = False
):
    '''
    # Conventional Transform
    ## For cifar-10
    mean = torch.tensor([125.3, 123.0, 113.9]) / 255
    std = torch.tensor([63.0, 62.1, 66.7]) / 255
    #normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.244, 0.262])
    ## For cifar-100
    mean = torch.tensor([129.3, 124.1, 112.4]) / 255
    std = torch.tensor([68.2, 65.4, 70.4]) / 255
    #normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
    ## Transform Step
    transforms.Compose([
        transforms.RandomCrop(32, padding=4, padding_mode='constant'),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    '''
    normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
    transform = {
        'train': transforms.Compose([
            transforms.RandomCrop(32, padding=4, padding_mode='constant'),
            transforms.Resize(size=size),
            transforms.RandomHorizontalFlip(),
            #transforms.RandAugment(num_ops=num_ops, magnitude=magnitude),
            transforms.ToTensor(),
            normalize
        ]),
        'eval': transforms.Compose([
            transforms.Resize(size=size),
            transforms.ToTensor(),
            normalize
        ])
    }
    dataset = {
        'train': datasets.CIFAR100(
            root=root, train=True, download=True, transform=transform['train']
        ),
        'test': datasets.CIFAR100(
            root=root, train=False, download=True, transform=transform['eval']
        )
    }
    dataloader = {
        'train': torch.utils.data.DataLoader(
            dataset['train'],
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=pin_memory,
            drop_last=drop_last,
            persistent_workers=persistent_workers,
        ),
        'test': torch.utils.data.DataLoader(
            dataset['test'],
            batch_size=eval_batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=persistent_workers,
        )
    }
    return dataset, dataloader

In [None]:
# Models
## Original Model
class BasicBlock(nn.Module):
    def __init__(self, inplanes: int, planes: int, down: bool = False) -> None:
        super().__init__()
        self.down = down
        self.conv1 = (
            nn.Conv2d(inplanes, planes, 3, stride=2, padding=1, bias=False) if down
            else nn.Conv2d(inplanes, planes, 3, padding='same', bias=False)
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, padding='same', bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        if self.down:
            self.downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes, 1, stride=2, bias=False),
                nn.BatchNorm2d(planes)
            )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.down:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

class BottleNeck(nn.Module):
    # unused, not confirm the correctness yet
    def __init__(self, inplanes: int, planes: int, outplanes: int, down: bool = False) -> None:
        super().__init__()
        self.down = down
        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = (
            nn.Conv2d(planes, planes, 3, stride=2, padding=1, bias=False) if down
            else nn.Conv2d(planes, planes, 3, padding='same', bias=False)
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, outplanes, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(outplanes)
        self.relu = nn.ReLU(inplace=True)
        if self.down:
            self.downsample = nn.Sequential(
                nn.Conv2d(inplanes, outplanes, 1, stride=2, bias=False),
                nn.BatchNorm2d(outplanes)
            )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.down:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

class CIFAR_ResNet(nn.Module):
    def __init__(self, n: int = 3, num_classes: int = 100, p: float = 0.2) -> None:
        super().__init__()
        self.n = n
        self.inplanes = 16
        self.planes = self.inplanes
        def consruct_layers(self, down: bool = False):
            layers = []
            for i in range(self.n):
                if i == 0 and down == True:
                    self.inplanes = self.planes
                    self.planes *= 2
                    layers.append(BasicBlock(self.inplanes, self.planes, down=True))
                else:
                    layers.append(BasicBlock(self.planes, self.planes))
            return nn.Sequential(*layers)
        self.stem = nn.Sequential(
            nn.Conv2d(3, self.inplanes, kernel_size=3, padding='same', bias=False),
            nn.BatchNorm2d(self.inplanes),
            nn.ReLU(inplace=True)
        )
        self.layer1 = consruct_layers(self)
        self.layer2 = consruct_layers(self, down=True)
        self.layer3 = consruct_layers(self, down=True)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(p=p, inplace=True),
            nn.Linear(self.planes, num_classes)
        )
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                init_range = 1.0 / math.sqrt(m.out_features)
                nn.init.uniform_(m.weight, -init_range, init_range)
                nn.init.zeros_(m.bias)
    def adjust_dropout(self, p: float = 0.2):
        self.classifier[2] = nn.Dropout(p=p, inplace=True)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.classifier(x)
        return x

## Lazy Model
### without model weight and bias initialization
### the result might be worse than original model
class LazyBasicBlock(nn.Module):
    def __init__(self, planes: int, down: bool = False) -> None:
        super().__init__()
        self.down = down
        self.conv1 = (
            nn.LazyConv2d(planes, 3, stride=2, padding=1, bias=False) if down
            else nn.LazyConv2d(planes, 3, padding='same', bias=False)
        )
        self.bn1 = nn.LazyBatchNorm2d()
        self.conv2 = nn.LazyConv2d(planes, 3, padding='same', bias=False)
        self.bn2 = nn.LazyBatchNorm2d()
        self.relu = nn.ReLU(inplace=True)
        if self.down:
            self.downsample = nn.Sequential(
                nn.LazyConv2d(planes, 1, stride=2, bias=False),
                nn.LazyBatchNorm2d()
            )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.down:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

class LazyBottleNeck(nn.Module):
    def __init__(self, planes: int, outplanes: int, down: bool = False) -> None:
        super().__init__()
        self.down = down
        self.conv1 = nn.LazyConv2d(planes, 1, bias=False)
        self.bn1 = nn.LazyBatchNorm2d()
        self.conv2 = (
            nn.LazyConv2d(planes, 3, stride=2, padding=1, bias=False) if down
            else nn.LazyConv2d(planes, 3, padding='same', bias=False)
        )
        self.bn2 = nn.LazyBatchNorm2d()
        self.conv3 = nn.LazyConv2d(outplanes, 1, bias=False)
        self.bn3 = nn.LazyBatchNorm2d()
        self.relu = nn.ReLU(inplace=True)
        if self.down:
            self.downsample = nn.Sequential(
                nn.LazyConv2d(outplanes, 1, stride=2, bias=False),
                nn.LazyBatchNorm2d()
            )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.down:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

class Lazy_CIFAR_ResNet(nn.Module):
    def __init__(self, n: int = 3, num_classes: int = 100, p: float = 0.2) -> None:
        super().__init__()
        self.n = n
        self.planes = 16
        def consruct_layers(self, down: bool = False):
            layers = []
            for i in range(self.n):
                if i == 0 and down == True:
                    self.planes *= 2
                    layers.append(LazyBasicBlock(self.planes, down=True))
                else:
                    layers.append(LazyBasicBlock(self.planes))
            return nn.Sequential(*layers)
        self.stem = nn.Sequential(
            nn.LazyConv2d(self.planes, kernel_size=3, padding='same', bias=False),
            nn.LazyBatchNorm2d(),
            nn.ReLU(inplace=True)
        )
        self.layer1 = consruct_layers(self)
        self.layer2 = consruct_layers(self, down=True)
        self.layer3 = consruct_layers(self, down=True)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(p=p, inplace=True),
            nn.LazyLinear(num_classes)
        )
    def adjust_dropout(self, p: float = 0.2):
        self.classifier[2] = nn.Dropout(p=p, inplace=True)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.classifier(x)
        return x

In [None]:
def train_step(model, dataset, dataloader, criterion, optimizer, scaler, AUTOCAST_FLAG=False):
    record_loss, record_acc = 0, 0
    model.train()
    for i, data in enumerate(dataloader):
        # load data
        inputs = data[0].to(device, non_blocking=True)
        labels = data[1].to(device, non_blocking=True)
        # compute
        optimizer.zero_grad()
        with torch.autocast(device.type, enabled=AUTOCAST_FLAG):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # record
        predict_labels = torch.max(outputs, dim=1).indices
        record_loss += loss.item()
        record_acc += torch.sum(labels==predict_labels).item()
    record_loss /= len(dataloader)
    record_acc /= len(dataset)
    return record_loss, record_acc

def eval_step(model, dataset, dataloader, criterion, AUTOCAST_FLAG=False):
    record_loss, record_acc = 0, 0
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            # load data
            inputs = data[0].to(device, non_blocking=True)
            labels = data[1].to(device, non_blocking=True)
            # compute
            with torch.autocast(device.type, enabled=AUTOCAST_FLAG):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            # record
            predict_labels = torch.max(outputs, dim=1).indices
            record_loss += loss.item()
            record_acc += torch.sum(labels==predict_labels).item()
    record_loss /= len(dataloader)
    record_acc /= len(dataset)
    return record_loss, record_acc

In [None]:
def progressive_learning_scheduler(
    epoch: int,
    stage_schedule: dict,
    model: nn.Module,
    dataset,
    dataloader,
    **kwargs
):
    if (epoch + 1) in stage_schedule['milestones']:
        stage_schedule['counter'] += 1
        if stage_schedule['counter'] == stage_schedule['len']:
            stage_schedule['counter'] = 0
        dataset, dataloader = load_data(
            size=stage_schedule['size'][stage_schedule['counter']],
            magnitude=stage_schedule['magnitude'][stage_schedule['counter']],
            **kwargs
        )
        model.adjust_dropout(p=stage_schedule['p'][stage_schedule['counter']])
    return (
        dataset,
        dataloader,
        stage_schedule['size'][stage_schedule['counter']],
        stage_schedule['magnitude'][stage_schedule['counter']],
        stage_schedule['p'][stage_schedule['counter']]
    )

def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

### Main

In [None]:
dataset, dataloader = load_data(
    root=root,
    batch_size=batch_size,
    eval_batch_size=eval_batch_size,
    num_workers=num_workers,
    size=size,
    num_ops=num_ops,
    magnitude=magnitude,
    pin_memory=pin_memory,
    drop_last=drop_last,
    persistent_workers=persistent_workers
)

In [None]:
# n = [3, 5, 7, 9, 18]
# layers = 6n + 2
# cifar_resnet = [20, 32, 44, 56, 110]
model = CIFAR_ResNet(n=n, num_classes=num_classes, p=p).to(device)

In [None]:
scaler = torch.cuda.amp.GradScaler(
    enabled=True if device.type=='cuda' and AUTOCAST_FLAG else False
)
# compile_mode: 'default', 'reduce-overhead', 'max-autotune'
model = torch.compile(model, mode='default', fullgraph=True, disable=not COMPILE_FLAG)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=gamma)

In [None]:
for epoch in range(epochs):
    # train
    result, time_cost = timed(
        lambda: train_step(
            model, dataset['train'], dataloader['train'],
            criterion, optimizer, scaler,
            AUTOCAST_FLAG
        )
    )
    train_loss, train_acc = result
    # eval
    test_loss, test_acc = eval_step(
        model, dataset['test'], dataloader['test'], criterion, AUTOCAST_FLAG)
    # print results
    print('----')
    print(f'epoch {epoch}')
    print(f'AUTOCAST: {AUTOCAST_FLAG}, COMPILE: {COMPILE_FLAG}')
    print(f'time_cost: {time_cost}')
    print(f'batch_size: {batch_size}')
    print(f'lr: {scheduler.get_last_lr()} / shape: {size} / magnitude: {magnitude} / dropout: {p}')
    print(f'train_loss: {train_loss}, train_acc: {train_acc}')
    print(f'test_loss: {test_loss}, test_acc: {test_acc}')
    print('----')
    # scheduler
    scheduler.step()
    #'''
    # adjust image size, dropout rate
    dataset, dataloader, size, magnitude, p = progressive_learning_scheduler(
        epoch, stage_schedule, model, dataset, dataloader,
        root=root,
        batch_size=batch_size,
        eval_batch_size=eval_batch_size,
        num_workers=num_workers,
        num_ops=num_ops,
        pin_memory=pin_memory,
        drop_last=drop_last,
        persistent_workers=persistent_workers
    )
    #'''