Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ViT inference #423

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 39 additions & 39 deletions libai/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,45 +115,45 @@ def __init__(
)

def forward(self, x):
if dist.same_sbp(self.weight.sbp, dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.split(0)])):
# If the last dim of weight sbp sign is S(0), then last dim of weight.t sbp
# sign is S(1), so the last dim of x sbp sign must be B.
if self.weight.sbp[-1] == flow.sbp.split(0):
x_sbp = x.sbp[:-1] + (flow.sbp.broadcast,)
x = x.to_global(sbp=x_sbp)

# x.grad sbp must be x.sbp, otherwise backward pass cannot be performed correctly.
x = x.to_global(grad_sbp=x.sbp)
x = flow.matmul(x, self.weight, transpose_b=True)

elif dist.same_sbp(
self.weight.sbp, dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.split(1)])
):
# If the last dim of weight sbp sign is S(1), then last dim of weight.t sbp
# sign is S(0), so the last dim of x sbp sign must be S(ndim-1).
if self.weight.sbp[-1] == flow.sbp.split(1):
x_sbp = x.sbp[:-1] + (flow.sbp.split(x.ndim - 1),)
x = x.to_global(sbp=x_sbp)
out_sbp = x.sbp[:-1] + (flow.sbp.broadcast,)
else:
out_sbp = x.sbp

x = flow.matmul(x, self.weight, transpose_b=True)
# Change x.sbp for followup forward pass.
# This line can be removed when sbp can be auto inferred.
x = x.to_global(sbp=out_sbp)
elif dist.same_sbp(
self.weight.sbp, dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])
):
# x.grad sbp must be x.sbp, otherwise backward pass cannot be performed correctly.
x = x.to_global(grad_sbp=x.sbp)
# NOTE(chengcheng): when input x is [S(0), B], there is no need to change sbp for x.
# x = x.to_global(sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.split(0)]))
x = flow.matmul(x, self.weight, transpose_b=True)
else:
# Not supported weight_sbp, deduce sbp and communicate with nccl automatically.
x = flow.matmul(x, self.weight, transpose_b=True)

# if dist.same_sbp(self.weight.sbp, dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.split(0)])):
# # If the last dim of weight sbp sign is S(0), then last dim of weight.t sbp
# # sign is S(1), so the last dim of x sbp sign must be B.
# if self.weight.sbp[-1] == flow.sbp.split(0):
# x_sbp = x.sbp[:-1] + (flow.sbp.broadcast,)
# x = x.to_global(sbp=x_sbp)
#
# # x.grad sbp must be x.sbp, otherwise backward pass cannot be performed correctly.
# x = x.to_global(grad_sbp=x.sbp)
# x = flow.matmul(x, self.weight, transpose_b=True)
#
# elif dist.same_sbp(
# self.weight.sbp, dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.split(1)])
# ):
# # If the last dim of weight sbp sign is S(1), then last dim of weight.t sbp
# # sign is S(0), so the last dim of x sbp sign must be S(ndim-1).
# if self.weight.sbp[-1] == flow.sbp.split(1):
# x_sbp = x.sbp[:-1] + (flow.sbp.split(x.ndim - 1),)
# x = x.to_global(sbp=x_sbp)
# out_sbp = x.sbp[:-1] + (flow.sbp.broadcast,)
# else:
# out_sbp = x.sbp
#
# x = flow.matmul(x, self.weight, transpose_b=True)
# # Change x.sbp for followup forward pass.
# # This line can be removed when sbp can be auto inferred.
# x = x.to_global(sbp=out_sbp)
# elif dist.same_sbp(
# self.weight.sbp, dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])
# ):
# # x.grad sbp must be x.sbp, otherwise backward pass cannot be performed correctly.
# x = x.to_global(grad_sbp=x.sbp)
# # NOTE(chengcheng): when input x is [S(0), B], there is no need to change sbp for x.
# # x = x.to_global(sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.split(0)]))
# x = flow.matmul(x, self.weight, transpose_b=True)
# else:
# # Not supported weight_sbp, deduce sbp and communicate with nccl automatically.
# x = flow.matmul(x, self.weight, transpose_b=True)
x = flow.matmul(x, self.weight, transpose_b=True)
if self.bias is not None:
if self.skip_bias_add:
return x, self.bias
Expand Down
2 changes: 1 addition & 1 deletion libai/layers/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def forward(
used for incremental decoding.
"""
# Change placement for pipeline parallelsim
hidden_states = hidden_states.to_global(placement=dist.get_layer_placement(self.layer_idx))
#### hidden_states = hidden_states.to_global(placement=dist.get_layer_placement(self.layer_idx))

# hidden_states shape: (batch_size, seq_length, hidden_size)
if attention_mask is not None:
Expand Down
9 changes: 5 additions & 4 deletions libai/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,14 @@ def forward_features(self, x):
cls_token = self.cls_token.expand(
x.shape[0], -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
cls_token = cls_token.to_global(sbp=x.sbp, placement=cls_token.placement)
# cls_token = cls_token.to_global(sbp=x.sbp, placement=cls_token.placement)

x = flow.cat((cls_token, x), dim=1)

# position embedding
pos_embed = self.pos_embed.expand(x.shape[0], -1, -1)
pos_embed = pos_embed.to_global(sbp=x.sbp, placement=pos_embed.placement)
x = self.pos_drop(x + pos_embed)
# pos_embed = self.pos_embed.expand(x.shape[0], -1, -1)
# pos_embed = pos_embed.to_global(sbp=x.sbp, placement=pos_embed.placement)
x = self.pos_drop(x + self.pos_embed)

# transformer block
x = self.blocks(x)
Expand Down
27 changes: 16 additions & 11 deletions projects/MAE/configs/mae_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
n_gpus = 8

# Graph training
graph.enabled = True
graph.enabled = False

# Refine model cfg for vit training on imagenet
model.num_classes = 1000
Expand All @@ -50,12 +50,12 @@
finetune.weight_style = (
"oneflow" # Set "oneflow" for loading oneflow weights, set "pytorch" for loading torch weights
)
finetune.path = "/path/to/pretrained_mae_weight"
finetune.path = "/work/libai/output/vit_base_82.658"


# Refine data path to imagenet
dataloader.train.dataset[0].root = "/path/to/imagenet"
dataloader.test[0].dataset.root = "/path/to/imagenet"
dataloader.train.dataset[0].root = "/imagenet"
dataloader.test[0].dataset.root = "/imagenet"

# Add Mixup Func
dataloader.train.mixup_func = LazyCall(Mixup)(
Expand All @@ -70,9 +70,9 @@


# Refine training settings for MAE finetune
train.train_micro_batch_size = 32
train.num_accumulation_steps = 4
train.test_micro_batch_size = 32
train.train_micro_batch_size = 128
train.num_accumulation_steps = 1
train.test_micro_batch_size = 64
effective_batch_size = train.train_micro_batch_size * train.num_accumulation_steps * n_gpus

train.train_epoch = 100
Expand All @@ -82,24 +82,24 @@
train.checkpointer.save_model_after_n_epoch = 1

# Set layer decay for MAE fine-tune
train.layer_decay = 0.65
train.layer_decay = 0.75

# AMP
train.amp.enabled = True
#train.amp.enabled = True


# Base learning in MAE is set to 1.5e-4
# The actually learning rate should be computed by linear scaling rule as follows:
# lr = base_lr * batch_size / 256
# In LiBai, you should refine the actually learning rate due to your on settings
# Here we use 8 GPUs, 128 batch_size per GPU for training, batch_size equals to 1024
base_lr = 5e-4
base_lr = 1e-3
actual_lr = base_lr * effective_batch_size / 256

# Refine optim settings
optim.params._target_ = param_groups_lrd
optim.params.weight_decay = 0.05
optim.params.layer_decay = 0.65
optim.params.layer_decay = 0.75
optim.lr = actual_lr

del optim.params.clip_grad_max_norm
Expand All @@ -120,6 +120,11 @@
min_lr=1e-6,
)

# checkpointing
#train.activation_checkpoint.enabled = True

# zero
#train.zero_optimization.enabled = True

# Distributed Settings
train.dist.pipeline_num_layers = model.depth
Expand Down
38 changes: 38 additions & 0 deletions projects/MAE/configs/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
""" Eval metrics and related

Hacked together by / Copyright 2020 Ross Wightman
"""
import oneflow as flow

class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()

def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0

def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
maxk = min(max(topk), output.size()[1])
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]

def reduce_tensor(tensor):
reduce_tensors = tensor.clone()
flow.comm.all_reduce(reduce_tensors)
reduce_tensors /= flow.env.get_world_size()
return reduce_tensors
97 changes: 97 additions & 0 deletions projects/MAE/configs/oneflow_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from libai.utils.checkpoint import Checkpointer
import oneflow as flow
from projects.MAE.modeling.vit import VisionTransformer
#from flowvision.models.vision_transformer import VisionTransformer
import os
from flowvision import datasets, transforms
from metrics import AverageMeter, accuracy,reduce_tensor
import PIL

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)


def build_dataset(data_path):
transform = build_transform()
root = os.path.join(data_path, 'val')
dataset = datasets.ImageFolder(root, transform=transform)
return dataset


def build_transform():
input_size = 224
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD

# eval transform
t = []
crop_pct = 224 / 256
size = int(input_size / crop_pct)
t.append(
transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(input_size))

t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)


def evaluate(model, data_loader):
model.eval()
acc1_meter = AverageMeter()
acc5_meter = AverageMeter()

for idx, (images, target) in enumerate(data_loader):

# compute output
images = images.cuda()
target = target.cuda()

output = model(images)["prediction_scores"]
# output = model(images)

# measure accuracy
acc1, acc5 = accuracy(output, target, topk=(1, 5))
acc1 = reduce_tensor(acc1)
acc5 = reduce_tensor(acc5)

acc1_meter.update(acc1.item(), target.size(0))
acc5_meter.update(acc5.item(), target.size(0))

print(f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t')
print(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')

model = VisionTransformer(
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
drop_path_rate=0.1,
global_pool=True,
)
model=model.cuda()
Checkpointer(model).load('/work/libai/output/vit_base_82.658')
model.to_local()
model = flow.nn.parallel.DistributedDataParallel(model, broadcast_buffers=False)

#state_dict = flow.load('/work/libai/vit_base_patch16_224')
#model.load_state_dict(state_dict,strict=False)

dataset_val = build_dataset('/imagenet')
#sampler_val = flow.utils.data.SequentialSampler(dataset_val)
sampler_val = flow.utils.data.distributed.DistributedSampler(dataset_val)
data_loader_val = flow.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=64,
num_workers=8,
drop_last=True
)
import time
start_time=time.time()
evaluate(model, data_loader_val)
print(f'eval time:{time.time()-start_time}')