diff --git a/classification/lib/models/cifar/ShuffleNetv1.py b/classification/lib/models/cifar/ShuffleNetv1.py deleted file mode 100644 index 8e5cd24..0000000 --- a/classification/lib/models/cifar/ShuffleNetv1.py +++ /dev/null @@ -1,138 +0,0 @@ -'''ShuffleNet in PyTorch. -See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. -''' -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class ShuffleBlock(nn.Module): - def __init__(self, groups): - super(ShuffleBlock, self).__init__() - self.groups = groups - - def forward(self, x): - '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' - N,C,H,W = x.size() - g = self.groups - return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W) - - -class Bottleneck(nn.Module): - def __init__(self, in_planes, out_planes, stride, groups, is_last=False): - super(Bottleneck, self).__init__() - self.is_last = is_last - self.stride = stride - - mid_planes = int(out_planes/4) - g = 1 if in_planes == 24 else groups - self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) - self.bn1 = nn.BatchNorm2d(mid_planes) - self.shuffle1 = ShuffleBlock(groups=g) - self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) - self.bn2 = nn.BatchNorm2d(mid_planes) - self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) - self.bn3 = nn.BatchNorm2d(out_planes) - - self.shortcut = nn.Sequential() - if stride == 2: - self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.shuffle1(out) - out = F.relu(self.bn2(self.conv2(out))) - out = self.bn3(self.conv3(out)) - res = self.shortcut(x) - preact = torch.cat([out, res], 1) if self.stride == 2 else out+res - out = F.relu(preact) - # out = F.relu(torch.cat([out, res], 1)) if self.stride == 2 else F.relu(out+res) - if self.is_last: - return out, preact - else: - return out - - -class ShuffleNet(nn.Module): - def __init__(self, cfg, num_classes=10): - super(ShuffleNet, self).__init__() - out_planes = cfg['out_planes'] - num_blocks = cfg['num_blocks'] - groups = cfg['groups'] - - self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(24) - self.in_planes = 24 - self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) - self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) - self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) - self.linear = nn.Linear(out_planes[2], num_classes) - - def _make_layer(self, out_planes, num_blocks, groups): - layers = [] - for i in range(num_blocks): - stride = 2 if i == 0 else 1 - cat_planes = self.in_planes if i == 0 else 0 - layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, - stride=stride, - groups=groups, - is_last=(i == num_blocks - 1))) - self.in_planes = out_planes - return nn.Sequential(*layers) - - def get_feat_modules(self): - feat_m = nn.ModuleList([]) - feat_m.append(self.conv1) - feat_m.append(self.bn1) - feat_m.append(self.layer1) - feat_m.append(self.layer2) - feat_m.append(self.layer3) - return feat_m - - def get_bn_before_relu(self): - raise NotImplementedError('ShuffleNet currently is not supported for "Overhaul" teacher') - - def forward(self, x, is_feat=False, preact=False): - out = F.relu(self.bn1(self.conv1(x))) - f0 = out - out, f1_pre = self.layer1(out) - f1 = out - out, f2_pre = self.layer2(out) - f2 = out - out, f3_pre = self.layer3(out) - f3 = out - out = F.avg_pool2d(out, 4) - out = out.view(out.size(0), -1) - f4 = out - out = self.linear(out) - - if is_feat: - if preact: - return [f0, f1_pre, f2_pre, f3_pre, f4], out - else: - return [f0, f1, f2, f3, f4], out - else: - return out - - -def ShuffleV1(**kwargs): - cfg = { - 'out_planes': [240, 480, 960], - 'num_blocks': [4, 8, 4], - 'groups': 3 - } - return ShuffleNet(cfg, **kwargs) - - -if __name__ == '__main__': - - x = torch.randn(2, 3, 32, 32) - net = ShuffleV1(num_classes=100) - import time - a = time.time() - feats, logit = net(x, is_feat=True, preact=True) - b = time.time() - print(b - a) - for f in feats: - print(f.shape, f.min().item()) - print(logit.shape) diff --git a/classification/lib/models/cifar/ShuffleNetv2.py b/classification/lib/models/cifar/ShuffleNetv2.py deleted file mode 100644 index bd0821b..0000000 --- a/classification/lib/models/cifar/ShuffleNetv2.py +++ /dev/null @@ -1,210 +0,0 @@ -'''ShuffleNetV2 in PyTorch. -See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. -''' -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class ShuffleBlock(nn.Module): - def __init__(self, groups=2): - super(ShuffleBlock, self).__init__() - self.groups = groups - - def forward(self, x): - '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' - N, C, H, W = x.size() - g = self.groups - return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) - - -class SplitBlock(nn.Module): - def __init__(self, ratio): - super(SplitBlock, self).__init__() - self.ratio = ratio - - def forward(self, x): - c = int(x.size(1) * self.ratio) - return x[:, :c, :, :], x[:, c:, :, :] - - -class BasicBlock(nn.Module): - def __init__(self, in_channels, split_ratio=0.5, is_last=False): - super(BasicBlock, self).__init__() - self.is_last = is_last - self.split = SplitBlock(split_ratio) - in_channels = int(in_channels * split_ratio) - self.conv1 = nn.Conv2d(in_channels, in_channels, - kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(in_channels) - self.conv2 = nn.Conv2d(in_channels, in_channels, - kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) - self.bn2 = nn.BatchNorm2d(in_channels) - self.conv3 = nn.Conv2d(in_channels, in_channels, - kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(in_channels) - self.shuffle = ShuffleBlock() - - def forward(self, x): - x1, x2 = self.split(x) - out = F.relu(self.bn1(self.conv1(x2))) - out = self.bn2(self.conv2(out)) - preact = self.bn3(self.conv3(out)) - out = F.relu(preact) - # out = F.relu(self.bn3(self.conv3(out))) - preact = torch.cat([x1, preact], 1) - out = torch.cat([x1, out], 1) - out = self.shuffle(out) - if self.is_last: - return out, preact - else: - return out - - -class DownBlock(nn.Module): - def __init__(self, in_channels, out_channels): - super(DownBlock, self).__init__() - mid_channels = out_channels // 2 - # left - self.conv1 = nn.Conv2d(in_channels, in_channels, - kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) - self.bn1 = nn.BatchNorm2d(in_channels) - self.conv2 = nn.Conv2d(in_channels, mid_channels, - kernel_size=1, bias=False) - self.bn2 = nn.BatchNorm2d(mid_channels) - # right - self.conv3 = nn.Conv2d(in_channels, mid_channels, - kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(mid_channels) - self.conv4 = nn.Conv2d(mid_channels, mid_channels, - kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False) - self.bn4 = nn.BatchNorm2d(mid_channels) - self.conv5 = nn.Conv2d(mid_channels, mid_channels, - kernel_size=1, bias=False) - self.bn5 = nn.BatchNorm2d(mid_channels) - - self.shuffle = ShuffleBlock() - - def forward(self, x): - # left - out1 = self.bn1(self.conv1(x)) - out1 = F.relu(self.bn2(self.conv2(out1))) - # right - out2 = F.relu(self.bn3(self.conv3(x))) - out2 = self.bn4(self.conv4(out2)) - out2 = F.relu(self.bn5(self.conv5(out2))) - # concat - out = torch.cat([out1, out2], 1) - out = self.shuffle(out) - return out - - -class ShuffleNetV2(nn.Module): - def __init__(self, net_size, num_classes=10): - super(ShuffleNetV2, self).__init__() - out_channels = configs[net_size]['out_channels'] - num_blocks = configs[net_size]['num_blocks'] - - # self.conv1 = nn.Conv2d(3, 24, kernel_size=3, - # stride=1, padding=1, bias=False) - self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(24) - self.in_channels = 24 - self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) - self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) - self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) - self.conv2 = nn.Conv2d(out_channels[2], out_channels[3], - kernel_size=1, stride=1, padding=0, bias=False) - self.bn2 = nn.BatchNorm2d(out_channels[3]) - self.linear = nn.Linear(out_channels[3], num_classes) - - def _make_layer(self, out_channels, num_blocks): - layers = [DownBlock(self.in_channels, out_channels)] - for i in range(num_blocks): - layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1))) - self.in_channels = out_channels - return nn.Sequential(*layers) - - def get_feat_modules(self): - feat_m = nn.ModuleList([]) - feat_m.append(self.conv1) - feat_m.append(self.bn1) - feat_m.append(self.layer1) - feat_m.append(self.layer2) - feat_m.append(self.layer3) - return feat_m - - def get_bn_before_relu(self): - raise NotImplementedError('ShuffleNetV2 currently is not supported for "Overhaul" teacher') - - def forward(self, x, is_feat=False, preact=False): - out = F.relu(self.bn1(self.conv1(x))) - # out = F.max_pool2d(out, 3, stride=2, padding=1) - f0 = out - out, f1_pre = self.layer1(out) - f1 = out - out, f2_pre = self.layer2(out) - f2 = out - out, f3_pre = self.layer3(out) - f3 = out - out = F.relu(self.bn2(self.conv2(out))) - out = F.avg_pool2d(out, 4) - out = out.view(out.size(0), -1) - f4 = out - out = self.linear(out) - if is_feat: - if preact: - return [f0, f1_pre, f2_pre, f3_pre, f4], out - else: - return [f0, f1, f2, f3, f4], out - else: - return out - - -configs = { - 0.2: { - 'out_channels': (40, 80, 160, 512), - 'num_blocks': (3, 3, 3) - }, - - 0.3: { - 'out_channels': (40, 80, 160, 512), - 'num_blocks': (3, 7, 3) - }, - - 0.5: { - 'out_channels': (48, 96, 192, 1024), - 'num_blocks': (3, 7, 3) - }, - - 1: { - 'out_channels': (116, 232, 464, 1024), - 'num_blocks': (3, 7, 3) - }, - 1.5: { - 'out_channels': (176, 352, 704, 1024), - 'num_blocks': (3, 7, 3) - }, - 2: { - 'out_channels': (224, 488, 976, 2048), - 'num_blocks': (3, 7, 3) - } -} - - -def ShuffleV2(**kwargs): - model = ShuffleNetV2(net_size=1, **kwargs) - return model - - -if __name__ == '__main__': - net = ShuffleV2(num_classes=100) - x = torch.randn(3, 3, 32, 32) - import time - a = time.time() - feats, logit = net(x, is_feat=True, preact=True) - b = time.time() - print(b - a) - for f in feats: - print(f.shape, f.min().item()) - print(logit.shape) diff --git a/classification/lib/models/cifar/__init__.py b/classification/lib/models/cifar/__init__.py deleted file mode 100644 index 1720555..0000000 --- a/classification/lib/models/cifar/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -from .resnet import resnet8, resnet14, resnet20, resnet32, resnet44, resnet56, resnet110, resnet8x4, resnet32x4 -from .resnetv2 import ResNet50 -from .wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2 -from .vgg import vgg19_bn, vgg16_bn, vgg13_bn, vgg11_bn, vgg8_bn -from .mobilenetv2 import mobile_half -from .ShuffleNetv1 import ShuffleV1 -from .ShuffleNetv2 import ShuffleV2 - -model_dict = { - 'resnet8': resnet8, - 'resnet14': resnet14, - 'resnet20': resnet20, - 'resnet32': resnet32, - 'resnet44': resnet44, - 'resnet56': resnet56, - 'resnet110': resnet110, - 'resnet8x4': resnet8x4, - 'resnet32x4': resnet32x4, - 'ResNet50': ResNet50, - 'wrn_16_1': wrn_16_1, - 'wrn_16_2': wrn_16_2, - 'wrn_40_1': wrn_40_1, - 'wrn_40_2': wrn_40_2, - 'vgg8': vgg8_bn, - 'vgg11': vgg11_bn, - 'vgg13': vgg13_bn, - 'vgg16': vgg16_bn, - 'vgg19': vgg19_bn, - 'MobileNetV2': mobile_half, - 'ShuffleV1': ShuffleV1, - 'ShuffleV2': ShuffleV2, -} diff --git a/classification/lib/models/cifar/classifier.py b/classification/lib/models/cifar/classifier.py deleted file mode 100644 index 167ddb6..0000000 --- a/classification/lib/models/cifar/classifier.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import print_function - -import torch.nn as nn - - -######################################### -# ===== Classifiers ===== # -######################################### - -class LinearClassifier(nn.Module): - - def __init__(self, dim_in, n_label=10): - super(LinearClassifier, self).__init__() - - self.net = nn.Linear(dim_in, n_label) - - def forward(self, x): - return self.net(x) - - -class NonLinearClassifier(nn.Module): - - def __init__(self, dim_in, n_label=10, p=0.1): - super(NonLinearClassifier, self).__init__() - - self.net = nn.Sequential( - nn.Linear(dim_in, 200), - nn.Dropout(p=p), - nn.BatchNorm1d(200), - nn.ReLU(inplace=True), - nn.Linear(200, n_label), - ) - - def forward(self, x): - return self.net(x) diff --git a/classification/lib/models/cifar/mobilenetv2.py b/classification/lib/models/cifar/mobilenetv2.py deleted file mode 100644 index 6bfe9fa..0000000 --- a/classification/lib/models/cifar/mobilenetv2.py +++ /dev/null @@ -1,202 +0,0 @@ -""" -MobileNetV2 implementation used in - -""" - -import torch -import torch.nn as nn -import math - -__all__ = ['mobilenetv2_T_w', 'mobile_half'] - -BN = None - - -def conv_bn(inp, oup, stride): - return nn.Sequential( - nn.Conv2d(inp, oup, 3, stride, 1, bias=False), - nn.BatchNorm2d(oup), - nn.ReLU(inplace=True) - ) - - -def conv_1x1_bn(inp, oup): - return nn.Sequential( - nn.Conv2d(inp, oup, 1, 1, 0, bias=False), - nn.BatchNorm2d(oup), - nn.ReLU(inplace=True) - ) - - -class InvertedResidual(nn.Module): - def __init__(self, inp, oup, stride, expand_ratio): - super(InvertedResidual, self).__init__() - self.blockname = None - - self.stride = stride - assert stride in [1, 2] - - self.use_res_connect = self.stride == 1 and inp == oup - - self.conv = nn.Sequential( - # pw - nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), - nn.BatchNorm2d(inp * expand_ratio), - nn.ReLU(inplace=True), - # dw - nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False), - nn.BatchNorm2d(inp * expand_ratio), - nn.ReLU(inplace=True), - # pw-linear - nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), - nn.BatchNorm2d(oup), - ) - self.names = ['0', '1', '2', '3', '4', '5', '6', '7'] - - def forward(self, x): - t = x - if self.use_res_connect: - return t + self.conv(x) - else: - return self.conv(x) - - -class MobileNetV2(nn.Module): - """mobilenetV2""" - def __init__(self, T, - feature_dim, - input_size=32, - width_mult=1., - remove_avg=False): - super(MobileNetV2, self).__init__() - self.remove_avg = remove_avg - - # setting of inverted residual blocks - self.interverted_residual_setting = [ - # t, c, n, s - [1, 16, 1, 1], - [T, 24, 2, 1], - [T, 32, 3, 2], - [T, 64, 4, 2], - [T, 96, 3, 1], - [T, 160, 3, 2], - [T, 320, 1, 1], - ] - - # building first layer - assert input_size % 32 == 0 - input_channel = int(32 * width_mult) - self.conv1 = conv_bn(3, input_channel, 2) - - # building inverted residual blocks - self.blocks = nn.ModuleList([]) - for t, c, n, s in self.interverted_residual_setting: - output_channel = int(c * width_mult) - layers = [] - strides = [s] + [1] * (n - 1) - for stride in strides: - layers.append( - InvertedResidual(input_channel, output_channel, stride, t) - ) - input_channel = output_channel - self.blocks.append(nn.Sequential(*layers)) - - self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 - self.conv2 = conv_1x1_bn(input_channel, self.last_channel) - - # building classifier - self.classifier = nn.Sequential( - # nn.Dropout(0.5), - nn.Linear(self.last_channel, feature_dim), - ) - - H = input_size // (32//2) - self.avgpool = nn.AvgPool2d(H, ceil_mode=True) - - self._initialize_weights() - print(T, width_mult) - - def get_bn_before_relu(self): - bn1 = self.blocks[1][-1].conv[-1] - bn2 = self.blocks[2][-1].conv[-1] - bn3 = self.blocks[4][-1].conv[-1] - bn4 = self.blocks[6][-1].conv[-1] - return [bn1, bn2, bn3, bn4] - - def get_feat_modules(self): - feat_m = nn.ModuleList([]) - feat_m.append(self.conv1) - feat_m.append(self.blocks) - return feat_m - - def forward(self, x, is_feat=False, preact=False): - - out = self.conv1(x) - f0 = out - - out = self.blocks[0](out) - out = self.blocks[1](out) - f1 = out - out = self.blocks[2](out) - f2 = out - out = self.blocks[3](out) - out = self.blocks[4](out) - f3 = out - out = self.blocks[5](out) - out = self.blocks[6](out) - f4 = out - - out = self.conv2(out) - - if not self.remove_avg: - out = self.avgpool(out) - out = out.view(out.size(0), -1) - f5 = out - out = self.classifier(out) - - if is_feat: - return [f0, f1, f2, f3, f4, f5], out - else: - return out - - def _initialize_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - n = m.weight.size(1) - m.weight.data.normal_(0, 0.01) - m.bias.data.zero_() - - -def mobilenetv2_T_w(T, W, feature_dim=100): - model = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W) - return model - - -def mobile_half(num_classes): - return mobilenetv2_T_w(6, 0.5, num_classes) - - -if __name__ == '__main__': - x = torch.randn(2, 3, 32, 32) - - net = mobile_half(100) - - feats, logit = net(x, is_feat=True, preact=True) - for f in feats: - print(f.shape, f.min().item()) - print(logit.shape) - - for m in net.get_bn_before_relu(): - if isinstance(m, nn.BatchNorm2d): - print('pass') - else: - print('warning') - diff --git a/classification/lib/models/cifar/resnet.py b/classification/lib/models/cifar/resnet.py deleted file mode 100644 index e5d9a27..0000000 --- a/classification/lib/models/cifar/resnet.py +++ /dev/null @@ -1,256 +0,0 @@ -from __future__ import absolute_import - -'''Resnet for cifar dataset. -Ported form -https://github.com/facebook/fb.resnet.torch -and -https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py -(c) YANG, Wei -''' -import torch.nn as nn -import torch.nn.functional as F -import math - - -__all__ = ['resnet'] - - -def conv3x3(in_planes, out_planes, stride=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): - super(BasicBlock, self).__init__() - self.is_last = is_last - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = nn.BatchNorm2d(planes) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = nn.BatchNorm2d(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - preact = out - out = F.relu(out) - if self.is_last: - return out, preact - else: - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): - super(Bottleneck, self).__init__() - self.is_last = is_last - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * 4) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - preact = out - out = F.relu(out) - if self.is_last: - return out, preact - else: - return out - - -class ResNet(nn.Module): - - def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10): - super(ResNet, self).__init__() - # Model type specifies number of layers for CIFAR-10 model - if block_name.lower() == 'basicblock': - assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' - n = (depth - 2) // 6 - block = BasicBlock - elif block_name.lower() == 'bottleneck': - assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' - n = (depth - 2) // 9 - block = Bottleneck - else: - raise ValueError('block_name shoule be Basicblock or Bottleneck') - - self.inplanes = num_filters[0] - self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1, - bias=False) - self.bn1 = nn.BatchNorm2d(num_filters[0]) - self.relu = nn.ReLU(inplace=True) - self.layer1 = self._make_layer(block, num_filters[1], n) - self.layer2 = self._make_layer(block, num_filters[2], n, stride=2) - self.layer3 = self._make_layer(block, num_filters[3], n, stride=2) - self.avgpool = nn.AvgPool2d(8) - self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - def _make_layer(self, block, planes, blocks, stride=1): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion), - ) - - layers = list([]) - layers.append(block(self.inplanes, planes, stride, downsample, is_last=(blocks == 1))) - self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(self.inplanes, planes, is_last=(i == blocks-1))) - - return nn.Sequential(*layers) - - def get_feat_modules(self): - feat_m = nn.ModuleList([]) - feat_m.append(self.conv1) - feat_m.append(self.bn1) - feat_m.append(self.relu) - feat_m.append(self.layer1) - feat_m.append(self.layer2) - feat_m.append(self.layer3) - return feat_m - - def get_bn_before_relu(self): - if isinstance(self.layer1[0], Bottleneck): - bn1 = self.layer1[-1].bn3 - bn2 = self.layer2[-1].bn3 - bn3 = self.layer3[-1].bn3 - elif isinstance(self.layer1[0], BasicBlock): - bn1 = self.layer1[-1].bn2 - bn2 = self.layer2[-1].bn2 - bn3 = self.layer3[-1].bn2 - else: - raise NotImplementedError('ResNet unknown block error !!!') - - return [bn1, bn2, bn3] - - def forward(self, x, is_feat=False, preact=False): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) # 32x32 - f0 = x - - x, f1_pre = self.layer1(x) # 32x32 - f1 = x - x, f2_pre = self.layer2(x) # 16x16 - f2 = x - x, f3_pre = self.layer3(x) # 8x8 - f3 = x - - x = self.avgpool(x) - x = x.view(x.size(0), -1) - f4 = x - x = self.fc(x) - - if is_feat: - if preact: - return [f0, f1_pre, f2_pre, f3_pre, f4], x - else: - return [f0, f1, f2, f3, f4], x - else: - return x - - -def resnet8(**kwargs): - return ResNet(8, [16, 16, 32, 64], 'basicblock', **kwargs) - - -def resnet14(**kwargs): - return ResNet(14, [16, 16, 32, 64], 'basicblock', **kwargs) - - -def resnet20(**kwargs): - return ResNet(20, [16, 16, 32, 64], 'basicblock', **kwargs) - - -def resnet32(**kwargs): - return ResNet(32, [16, 16, 32, 64], 'basicblock', **kwargs) - - -def resnet44(**kwargs): - return ResNet(44, [16, 16, 32, 64], 'basicblock', **kwargs) - - -def resnet56(**kwargs): - return ResNet(56, [16, 16, 32, 64], 'basicblock', **kwargs) - - -def resnet110(**kwargs): - return ResNet(110, [16, 16, 32, 64], 'basicblock', **kwargs) - - -def resnet8x4(**kwargs): - return ResNet(8, [32, 64, 128, 256], 'basicblock', **kwargs) - - -def resnet32x4(**kwargs): - return ResNet(32, [32, 64, 128, 256], 'basicblock', **kwargs) - - -if __name__ == '__main__': - import torch - - x = torch.randn(2, 3, 32, 32) - net = resnet8x4(num_classes=20) - feats, logit = net(x, is_feat=True, preact=True) - - for f in feats: - print(f.shape, f.min().item()) - print(logit.shape) - - for m in net.get_bn_before_relu(): - if isinstance(m, nn.BatchNorm2d): - print('pass') - else: - print('warning') diff --git a/classification/lib/models/cifar/resnetv2.py b/classification/lib/models/cifar/resnetv2.py deleted file mode 100644 index bc03eaf..0000000 --- a/classification/lib/models/cifar/resnetv2.py +++ /dev/null @@ -1,198 +0,0 @@ -'''ResNet in PyTorch. -For Pre-activation ResNet, see 'preact_resnet.py'. -Reference: -[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun - Deep Residual Learning for Image Recognition. arXiv:1512.03385 -''' -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, in_planes, planes, stride=1, is_last=False): - super(BasicBlock, self).__init__() - self.is_last = is_last - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion * planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion * planes) - ) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - out += self.shortcut(x) - preact = out - out = F.relu(out) - if self.is_last: - return out, preact - else: - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, in_planes, planes, stride=1, is_last=False): - super(Bottleneck, self).__init__() - self.is_last = is_last - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(self.expansion * planes) - - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion * planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion * planes) - ) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = F.relu(self.bn2(self.conv2(out))) - out = self.bn3(self.conv3(out)) - out += self.shortcut(x) - preact = out - out = F.relu(out) - if self.is_last: - return out, preact - else: - return out - - -class ResNet(nn.Module): - def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False): - super(ResNet, self).__init__() - self.in_planes = 64 - - self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(64) - self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) - self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) - self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) - self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.linear = nn.Linear(512 * block.expansion, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - # Zero-initialize the last BN in each residual branch, - # so that the residual branch starts with zeros, and each residual block behaves like an identity. - # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 - if zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) - elif isinstance(m, BasicBlock): - nn.init.constant_(m.bn2.weight, 0) - - def get_feat_modules(self): - feat_m = nn.ModuleList([]) - feat_m.append(self.conv1) - feat_m.append(self.bn1) - feat_m.append(self.layer1) - feat_m.append(self.layer2) - feat_m.append(self.layer3) - feat_m.append(self.layer4) - return feat_m - - def get_bn_before_relu(self): - if isinstance(self.layer1[0], Bottleneck): - bn1 = self.layer1[-1].bn3 - bn2 = self.layer2[-1].bn3 - bn3 = self.layer3[-1].bn3 - bn4 = self.layer4[-1].bn3 - elif isinstance(self.layer1[0], BasicBlock): - bn1 = self.layer1[-1].bn2 - bn2 = self.layer2[-1].bn2 - bn3 = self.layer3[-1].bn2 - bn4 = self.layer4[-1].bn2 - else: - raise NotImplementedError('ResNet unknown block error !!!') - - return [bn1, bn2, bn3, bn4] - - def _make_layer(self, block, planes, num_blocks, stride): - strides = [stride] + [1] * (num_blocks - 1) - layers = [] - for i in range(num_blocks): - stride = strides[i] - layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1)) - self.in_planes = planes * block.expansion - return nn.Sequential(*layers) - - def forward(self, x, is_feat=False, preact=False): - out = F.relu(self.bn1(self.conv1(x))) - f0 = out - out, f1_pre = self.layer1(out) - f1 = out - out, f2_pre = self.layer2(out) - f2 = out - out, f3_pre = self.layer3(out) - f3 = out - out, f4_pre = self.layer4(out) - f4 = out - out = self.avgpool(out) - out = out.view(out.size(0), -1) - f5 = out - out = self.linear(out) - if is_feat: - if preact: - return [[f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], out] - else: - return [f0, f1, f2, f3, f4, f5], out - else: - return out - - -def ResNet18(**kwargs): - return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) - - -def ResNet34(**kwargs): - return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) - - -def ResNet50(**kwargs): - return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) - - -def ResNet101(**kwargs): - return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) - - -def ResNet152(**kwargs): - return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) - - -if __name__ == '__main__': - net = ResNet18(num_classes=100) - x = torch.randn(2, 3, 32, 32) - feats, logit = net(x, is_feat=True, preact=True) - - for f in feats: - print(f.shape, f.min().item()) - print(logit.shape) - - for m in net.get_bn_before_relu(): - if isinstance(m, nn.BatchNorm2d): - print('pass') - else: - print('warning') diff --git a/classification/lib/models/cifar/util.py b/classification/lib/models/cifar/util.py deleted file mode 100644 index 90293f1..0000000 --- a/classification/lib/models/cifar/util.py +++ /dev/null @@ -1,290 +0,0 @@ -from __future__ import print_function - -import torch.nn as nn -import math - - -class Paraphraser(nn.Module): - """Paraphrasing Complex Network: Network Compression via Factor Transfer""" - def __init__(self, t_shape, k=0.5, use_bn=False): - super(Paraphraser, self).__init__() - in_channel = t_shape[1] - out_channel = int(t_shape[1] * k) - self.encoder = nn.Sequential( - nn.Conv2d(in_channel, in_channel, 3, 1, 1), - nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), - nn.LeakyReLU(0.1, inplace=True), - nn.Conv2d(in_channel, out_channel, 3, 1, 1), - nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), - nn.LeakyReLU(0.1, inplace=True), - nn.Conv2d(out_channel, out_channel, 3, 1, 1), - nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), - nn.LeakyReLU(0.1, inplace=True), - ) - self.decoder = nn.Sequential( - nn.ConvTranspose2d(out_channel, out_channel, 3, 1, 1), - nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), - nn.LeakyReLU(0.1, inplace=True), - nn.ConvTranspose2d(out_channel, in_channel, 3, 1, 1), - nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), - nn.LeakyReLU(0.1, inplace=True), - nn.ConvTranspose2d(in_channel, in_channel, 3, 1, 1), - nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), - nn.LeakyReLU(0.1, inplace=True), - ) - - def forward(self, f_s, is_factor=False): - factor = self.encoder(f_s) - if is_factor: - return factor - rec = self.decoder(factor) - return factor, rec - - -class Translator(nn.Module): - def __init__(self, s_shape, t_shape, k=0.5, use_bn=True): - super(Translator, self).__init__() - in_channel = s_shape[1] - out_channel = int(t_shape[1] * k) - self.encoder = nn.Sequential( - nn.Conv2d(in_channel, in_channel, 3, 1, 1), - nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), - nn.LeakyReLU(0.1, inplace=True), - nn.Conv2d(in_channel, out_channel, 3, 1, 1), - nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), - nn.LeakyReLU(0.1, inplace=True), - nn.Conv2d(out_channel, out_channel, 3, 1, 1), - nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), - nn.LeakyReLU(0.1, inplace=True), - ) - - def forward(self, f_s): - return self.encoder(f_s) - - -class Connector(nn.Module): - """Connect for Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons""" - def __init__(self, s_shapes, t_shapes): - super(Connector, self).__init__() - self.s_shapes = s_shapes - self.t_shapes = t_shapes - - self.connectors = nn.ModuleList(self._make_conenctors(s_shapes, t_shapes)) - - @staticmethod - def _make_conenctors(s_shapes, t_shapes): - assert len(s_shapes) == len(t_shapes), 'unequal length of feat list' - connectors = [] - for s, t in zip(s_shapes, t_shapes): - if s[1] == t[1] and s[2] == t[2]: - connectors.append(nn.Sequential()) - else: - connectors.append(ConvReg(s, t, use_relu=False)) - return connectors - - def forward(self, g_s): - out = [] - for i in range(len(g_s)): - out.append(self.connectors[i](g_s[i])) - - return out - - -class ConnectorV2(nn.Module): - """A Comprehensive Overhaul of Feature Distillation (ICCV 2019)""" - def __init__(self, s_shapes, t_shapes): - super(ConnectorV2, self).__init__() - self.s_shapes = s_shapes - self.t_shapes = t_shapes - - self.connectors = nn.ModuleList(self._make_conenctors(s_shapes, t_shapes)) - - def _make_conenctors(self, s_shapes, t_shapes): - assert len(s_shapes) == len(t_shapes), 'unequal length of feat list' - t_channels = [t[1] for t in t_shapes] - s_channels = [s[1] for s in s_shapes] - connectors = nn.ModuleList([self._build_feature_connector(t, s) - for t, s in zip(t_channels, s_channels)]) - return connectors - - @staticmethod - def _build_feature_connector(t_channel, s_channel): - C = [nn.Conv2d(s_channel, t_channel, kernel_size=1, stride=1, padding=0, bias=False), - nn.BatchNorm2d(t_channel)] - for m in C: - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - return nn.Sequential(*C) - - def forward(self, g_s): - out = [] - for i in range(len(g_s)): - out.append(self.connectors[i](g_s[i])) - - return out - - -class ConvReg(nn.Module): - """Convolutional regression for FitNet""" - def __init__(self, s_shape, t_shape, use_relu=True): - super(ConvReg, self).__init__() - self.use_relu = use_relu - s_N, s_C, s_H, s_W = s_shape - t_N, t_C, t_H, t_W = t_shape - if s_H == 2 * t_H: - self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1) - elif s_H * 2 == t_H: - self.conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1) - elif s_H >= t_H: - self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W)) - else: - raise NotImplemented('student size {}, teacher size {}'.format(s_H, t_H)) - self.bn = nn.BatchNorm2d(t_C) - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - x = self.conv(x) - if self.use_relu: - return self.relu(self.bn(x)) - else: - return self.bn(x) - - -class Regress(nn.Module): - """Simple Linear Regression for hints""" - def __init__(self, dim_in=1024, dim_out=1024): - super(Regress, self).__init__() - self.linear = nn.Linear(dim_in, dim_out) - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - x = x.view(x.shape[0], -1) - x = self.linear(x) - x = self.relu(x) - return x - - -class Embed(nn.Module): - """Embedding module""" - def __init__(self, dim_in=1024, dim_out=128): - super(Embed, self).__init__() - self.linear = nn.Linear(dim_in, dim_out) - self.l2norm = Normalize(2) - - def forward(self, x): - x = x.view(x.shape[0], -1) - x = self.linear(x) - x = self.l2norm(x) - return x - - -class LinearEmbed(nn.Module): - """Linear Embedding""" - def __init__(self, dim_in=1024, dim_out=128): - super(LinearEmbed, self).__init__() - self.linear = nn.Linear(dim_in, dim_out) - - def forward(self, x): - x = x.view(x.shape[0], -1) - x = self.linear(x) - return x - - -class MLPEmbed(nn.Module): - """non-linear embed by MLP""" - def __init__(self, dim_in=1024, dim_out=128): - super(MLPEmbed, self).__init__() - self.linear1 = nn.Linear(dim_in, 2 * dim_out) - self.relu = nn.ReLU(inplace=True) - self.linear2 = nn.Linear(2 * dim_out, dim_out) - self.l2norm = Normalize(2) - - def forward(self, x): - x = x.view(x.shape[0], -1) - x = self.relu(self.linear1(x)) - x = self.l2norm(self.linear2(x)) - return x - - -class Normalize(nn.Module): - """normalization layer""" - def __init__(self, power=2): - super(Normalize, self).__init__() - self.power = power - - def forward(self, x): - norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) - out = x.div(norm) - return out - - -class Flatten(nn.Module): - """flatten module""" - def __init__(self): - super(Flatten, self).__init__() - - def forward(self, feat): - return feat.view(feat.size(0), -1) - - -class PoolEmbed(nn.Module): - """pool and embed""" - def __init__(self, layer=0, dim_out=128, pool_type='avg'): - super().__init__() - if layer == 0: - pool_size = 8 - nChannels = 16 - elif layer == 1: - pool_size = 8 - nChannels = 16 - elif layer == 2: - pool_size = 6 - nChannels = 32 - elif layer == 3: - pool_size = 4 - nChannels = 64 - elif layer == 4: - pool_size = 1 - nChannels = 64 - else: - raise NotImplementedError('layer not supported: {}'.format(layer)) - - self.embed = nn.Sequential() - if layer <= 3: - if pool_type == 'max': - self.embed.add_module('MaxPool', nn.AdaptiveMaxPool2d((pool_size, pool_size))) - elif pool_type == 'avg': - self.embed.add_module('AvgPool', nn.AdaptiveAvgPool2d((pool_size, pool_size))) - - self.embed.add_module('Flatten', Flatten()) - self.embed.add_module('Linear', nn.Linear(nChannels*pool_size*pool_size, dim_out)) - self.embed.add_module('Normalize', Normalize(2)) - - def forward(self, x): - return self.embed(x) - - -if __name__ == '__main__': - import torch - - g_s = [ - torch.randn(2, 16, 16, 16), - torch.randn(2, 32, 8, 8), - torch.randn(2, 64, 4, 4), - ] - g_t = [ - torch.randn(2, 32, 16, 16), - torch.randn(2, 64, 8, 8), - torch.randn(2, 128, 4, 4), - ] - s_shapes = [s.shape for s in g_s] - t_shapes = [t.shape for t in g_t] - - net = ConnectorV2(s_shapes, t_shapes) - out = net(g_s) - for f in out: - print(f.shape) diff --git a/classification/lib/models/cifar/vgg.py b/classification/lib/models/cifar/vgg.py deleted file mode 100644 index b7bd5fe..0000000 --- a/classification/lib/models/cifar/vgg.py +++ /dev/null @@ -1,236 +0,0 @@ -'''VGG for CIFAR10. FC layers are removed. -(c) YANG, Wei -''' -import torch.nn as nn -import torch.nn.functional as F -import math - - -__all__ = [ - 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', - 'vgg19_bn', 'vgg19', -] - - -model_urls = { - 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', - 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', - 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', - 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', -} - - -class VGG(nn.Module): - - def __init__(self, cfg, batch_norm=False, num_classes=1000): - super(VGG, self).__init__() - self.block0 = self._make_layers(cfg[0], batch_norm, 3) - self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1]) - self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1]) - self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1]) - self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1]) - - self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) - self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) - self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) - self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) - self.pool4 = nn.AdaptiveAvgPool2d((1, 1)) - # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) - - self.classifier = nn.Linear(512, num_classes) - self._initialize_weights() - - def get_feat_modules(self): - feat_m = nn.ModuleList([]) - feat_m.append(self.block0) - feat_m.append(self.pool0) - feat_m.append(self.block1) - feat_m.append(self.pool1) - feat_m.append(self.block2) - feat_m.append(self.pool2) - feat_m.append(self.block3) - feat_m.append(self.pool3) - feat_m.append(self.block4) - feat_m.append(self.pool4) - return feat_m - - def get_bn_before_relu(self): - bn1 = self.block1[-1] - bn2 = self.block2[-1] - bn3 = self.block3[-1] - bn4 = self.block4[-1] - return [bn1, bn2, bn3, bn4] - - def forward(self, x, is_feat=False, preact=False): - h = x.shape[2] - x = F.relu(self.block0(x)) - f0 = x - x = self.pool0(x) - x = self.block1(x) - f1_pre = x - x = F.relu(x) - f1 = x - x = self.pool1(x) - x = self.block2(x) - f2_pre = x - x = F.relu(x) - f2 = x - x = self.pool2(x) - x = self.block3(x) - f3_pre = x - x = F.relu(x) - f3 = x - if h == 64: - x = self.pool3(x) - x = self.block4(x) - f4_pre = x - x = F.relu(x) - f4 = x - x = self.pool4(x) - x = x.view(x.size(0), -1) - f5 = x - x = self.classifier(x) - - if is_feat: - if preact: - return [f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], x - else: - return [f0, f1, f2, f3, f4, f5], x - else: - return x - - @staticmethod - def _make_layers(cfg, batch_norm=False, in_channels=3): - layers = [] - for v in cfg: - if v == 'M': - layers += [nn.MaxPool2d(kernel_size=2, stride=2)] - else: - conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) - if batch_norm: - layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] - else: - layers += [conv2d, nn.ReLU(inplace=True)] - in_channels = v - layers = layers[:-1] - return nn.Sequential(*layers) - - def _initialize_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - n = m.weight.size(1) - m.weight.data.normal_(0, 0.01) - m.bias.data.zero_() - - -cfg = { - 'A': [[64], [128], [256, 256], [512, 512], [512, 512]], - 'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]], - 'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]], - 'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]], - 'S': [[64], [128], [256], [512], [512]], -} - - -def vgg8(**kwargs): - """VGG 8-layer model (configuration "S") - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = VGG(cfg['S'], **kwargs) - return model - - -def vgg8_bn(**kwargs): - """VGG 8-layer model (configuration "S") - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = VGG(cfg['S'], batch_norm=True, **kwargs) - return model - - -def vgg11(**kwargs): - """VGG 11-layer model (configuration "A") - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = VGG(cfg['A'], **kwargs) - return model - - -def vgg11_bn(**kwargs): - """VGG 11-layer model (configuration "A") with batch normalization""" - model = VGG(cfg['A'], batch_norm=True, **kwargs) - return model - - -def vgg13(**kwargs): - """VGG 13-layer model (configuration "B") - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = VGG(cfg['B'], **kwargs) - return model - - -def vgg13_bn(**kwargs): - """VGG 13-layer model (configuration "B") with batch normalization""" - model = VGG(cfg['B'], batch_norm=True, **kwargs) - return model - - -def vgg16(**kwargs): - """VGG 16-layer model (configuration "D") - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = VGG(cfg['D'], **kwargs) - return model - - -def vgg16_bn(**kwargs): - """VGG 16-layer model (configuration "D") with batch normalization""" - model = VGG(cfg['D'], batch_norm=True, **kwargs) - return model - - -def vgg19(**kwargs): - """VGG 19-layer model (configuration "E") - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = VGG(cfg['E'], **kwargs) - return model - - -def vgg19_bn(**kwargs): - """VGG 19-layer model (configuration 'E') with batch normalization""" - model = VGG(cfg['E'], batch_norm=True, **kwargs) - return model - - -if __name__ == '__main__': - import torch - - x = torch.randn(2, 3, 32, 32) - net = vgg19_bn(num_classes=100) - feats, logit = net(x, is_feat=True, preact=True) - - for f in feats: - print(f.shape, f.min().item()) - print(logit.shape) - - for m in net.get_bn_before_relu(): - if isinstance(m, nn.BatchNorm2d): - print('pass') - else: - print('warning') diff --git a/classification/lib/models/cifar/wrn.py b/classification/lib/models/cifar/wrn.py deleted file mode 100644 index 72a7e10..0000000 --- a/classification/lib/models/cifar/wrn.py +++ /dev/null @@ -1,170 +0,0 @@ -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - -""" -Original Author: Wei Yang -""" - -__all__ = ['wrn'] - - -class BasicBlock(nn.Module): - def __init__(self, in_planes, out_planes, stride, dropRate=0.0): - super(BasicBlock, self).__init__() - self.bn1 = nn.BatchNorm2d(in_planes) - self.relu1 = nn.ReLU(inplace=True) - self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(out_planes) - self.relu2 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, - padding=1, bias=False) - self.droprate = dropRate - self.equalInOut = (in_planes == out_planes) - self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, - padding=0, bias=False) or None - - def forward(self, x): - if not self.equalInOut: - x = self.relu1(self.bn1(x)) - else: - out = self.relu1(self.bn1(x)) - out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) - if self.droprate > 0: - out = F.dropout(out, p=self.droprate, training=self.training) - out = self.conv2(out) - return torch.add(x if self.equalInOut else self.convShortcut(x), out) - - -class NetworkBlock(nn.Module): - def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): - super(NetworkBlock, self).__init__() - self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) - - def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): - layers = [] - for i in range(nb_layers): - layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) - return nn.Sequential(*layers) - - def forward(self, x): - return self.layer(x) - - -class WideResNet(nn.Module): - def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): - super(WideResNet, self).__init__() - nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] - assert (depth - 4) % 6 == 0, 'depth should be 6n+4' - n = (depth - 4) // 6 - block = BasicBlock - # 1st conv before any network block - self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, - padding=1, bias=False) - # 1st block - self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) - # 2nd block - self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) - # 3rd block - self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) - # global average pooling and classifier - self.bn1 = nn.BatchNorm2d(nChannels[3]) - self.relu = nn.ReLU(inplace=True) - self.fc = nn.Linear(nChannels[3], num_classes) - self.nChannels = nChannels[3] - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - m.bias.data.zero_() - - def get_feat_modules(self): - feat_m = nn.ModuleList([]) - feat_m.append(self.conv1) - feat_m.append(self.block1) - feat_m.append(self.block2) - feat_m.append(self.block3) - return feat_m - - def get_bn_before_relu(self): - bn1 = self.block2.layer[0].bn1 - bn2 = self.block3.layer[0].bn1 - bn3 = self.bn1 - - return [bn1, bn2, bn3] - - def forward(self, x, is_feat=False, preact=False): - out = self.conv1(x) - f0 = out - out = self.block1(out) - f1 = out - out = self.block2(out) - f2 = out - out = self.block3(out) - f3 = out - out = self.relu(self.bn1(out)) - out = F.avg_pool2d(out, 8) - out = out.view(-1, self.nChannels) - f4 = out - out = self.fc(out) - if is_feat: - if preact: - f1 = self.block2.layer[0].bn1(f1) - f2 = self.block3.layer[0].bn1(f2) - f3 = self.bn1(f3) - return [f0, f1, f2, f3, f4], out - else: - return out - - -def wrn(**kwargs): - """ - Constructs a Wide Residual Networks. - """ - model = WideResNet(**kwargs) - return model - - -def wrn_40_2(**kwargs): - model = WideResNet(depth=40, widen_factor=2, **kwargs) - return model - - -def wrn_40_1(**kwargs): - model = WideResNet(depth=40, widen_factor=1, **kwargs) - return model - - -def wrn_16_2(**kwargs): - model = WideResNet(depth=16, widen_factor=2, **kwargs) - return model - - -def wrn_16_1(**kwargs): - model = WideResNet(depth=16, widen_factor=1, **kwargs) - return model - - -if __name__ == '__main__': - import torch - - x = torch.randn(2, 3, 32, 32) - net = wrn_40_2(num_classes=100) - feats, logit = net(x, is_feat=True, preact=True) - - for f in feats: - print(f.shape, f.min().item()) - print(logit.shape) - - for m in net.get_bn_before_relu(): - if isinstance(m, nn.BatchNorm2d): - print('pass') - else: - print('warning') diff --git a/classification/lib/models/darts_model.py b/classification/lib/models/darts_model.py deleted file mode 100644 index 0432b33..0000000 --- a/classification/lib/models/darts_model.py +++ /dev/null @@ -1,96 +0,0 @@ -import torch -import torch.nn as nn -from .operations import DARTSCell, AuxiliaryHead - - -def gen_darts_model(net_cfg, dataset='imagenet', drop_rate=0., drop_path_rate=0., auxiliary_head=False, **kwargs): - if dataset.lower() == 'imagenet': - dataset = 'imagenet' - elif dataset.lower() in ['cifar', 'cifar10', 'cifar100']: - dataset = 'cifar' - model = DARTSModel(net_cfg, dataset, drop_rate, drop_path_rate, auxiliary_head=auxiliary_head) - return model - - -class DARTSModel(nn.Module): - def __init__(self, net_cfg, dataset='imagenet', drop_rate=0., drop_path_rate=0., auxiliary_head=False): - super(DARTSModel, self).__init__() - self.drop_rate = drop_rate - self.drop_path_rate = drop_path_rate - cell_normal = eval(net_cfg['genotype']['normal']) - cell_reduce = eval(net_cfg['genotype']['reduce']) - init_channels = net_cfg.get('init_channels', 48) - layers = net_cfg.get('layers', 14) - cell_multiplier = net_cfg.get('cell_multiplier', 4) - num_classes = net_cfg.get('num_classes', 1000) - - reduction_layers = [layers // 3, layers * 2 // 3] - C = init_channels - - if dataset == 'imagenet': - C_curr = C - self.stem0 = nn.Sequential( - nn.Conv2d(3, C_curr // 2, kernel_size=3, stride=2, padding=1, bias=False), - nn.BatchNorm2d(C_curr // 2), - nn.ReLU(inplace=True), - nn.Conv2d(C_curr // 2, C_curr, 3, stride=2, padding=1, bias=False), - nn.BatchNorm2d(C_curr), - ) - - self.stem1 = nn.Sequential( - nn.ReLU(inplace=True), - nn.Conv2d(C_curr, C_curr, 3, stride=2, padding=1, bias=False), - nn.BatchNorm2d(C_curr), - ) - elif dataset == 'cifar': - stem_multiplier = 3 - C_curr = C * stem_multiplier - self.stem0 = nn.Sequential( - nn.Conv2d(3, C_curr, 3, stride=1, padding=1, bias=False), - nn.BatchNorm2d(C_curr), - ) - self.stem1 = nn.Identity() - - C_prev_prev, C_prev, C_curr = C_curr, C_curr, C - - # cell blocks - self.nas_cells = nn.Sequential() - reduction_prev = dataset == 'imagenet' - for layer_idx in range(layers): - s = 1 - cell_arch = cell_normal - if layer_idx in reduction_layers: - s = 2 - C_curr *= 2 - cell_arch = cell_reduce - cell = DARTSCell(cell_arch, C_prev_prev, C_prev, C_curr, stride=s, reduction_prev=reduction_prev) - self.nas_cells.add_module('cell_{}'.format(layer_idx), cell) - reduction_prev = (s == 2) - C_prev_prev, C_prev = C_prev, C_curr * cell_multiplier - if auxiliary_head and layer_idx == 2 * layers // 3: - C_to_auxiliary = C_prev - - self.pool = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - - if auxiliary_head: - object.__setattr__(self, 'module_to_auxiliary', cell) - self.auxiliary_head = nn.Sequential( - nn.ReLU(inplace=False), - AuxiliaryHead(C_to_auxiliary, num_classes, avg_pool_stride=2 if dataset=='imagenet' else 3) - ) - - def get_classifier(self): - return self.classifier - - def forward(self, x): - s0 = self.stem0(x) - s1 = self.stem1(s0) - for cell in self.nas_cells: - s0, s1 = s1, cell(s0, s1, self.drop_path_rate) - x = self.pool(s1) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) - x = x.view(x.size(0), -1) - return self.classifier(x) - diff --git a/classification/lib/models/dyn_mamba_simple.py b/classification/lib/models/dyn_mamba_simple.py deleted file mode 100644 index 8bd5d81..0000000 --- a/classification/lib/models/dyn_mamba_simple.py +++ /dev/null @@ -1,832 +0,0 @@ -# Copyright (c) 2023, Tri Dao, Albert Gu. - -import math -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor - -from einops import rearrange, repeat -from timm.models.layers import trunc_normal_ - -try: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -except ImportError: - causal_conv1d_fn, causal_conv1d_update = None - -try: - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj -except ImportError: - selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None - -try: - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -except ImportError: - selective_state_update = None - -try: - from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn -except ImportError: - RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None - - - -def direct_tokens(x, w=7, w_first=False): - B, L, C = x.shape - H = W = int(L ** 0.5) - Hg = Wg = H // w - if w_first: - x = x.view(B, Hg, w, Wg, w, C).permute(0, 3, 1, 4, 2, 5).reshape(B, L, C) - else: - x = x.view(B, Hg, w, Wg, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, L, C) - return x - -def reverse_tokens(x, w=7, w_first=False): - B, L, C = x.shape - H = W = int(L ** 0.5) - Hg = Wg = H // w - if w_first: - x = x.view(B, Wg, Hg, w, w, C).permute(0, 2, 4, 1, 3, 5).reshape(B, L, C) - else: - x = x.view(B, Hg, Wg, w, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, L, C) - return x - - -class DynamicScan(nn.Module): - def __init__(self, dim, hidden_dim=96, window_size=2): - super().__init__() - self.window_size = window_size - self.num_tokens = window_size**2 - self.tokens = nn.Parameter(torch.zeros(1, 1, self.num_tokens, dim)) - - def forward(self, x): - B, L, D = x.shape - x = x.view(B, -1, self.num_tokens, D) - attn = self.tokens.expand(B, x.shape[1], -1, -1) @ x.transpose(-2, -1) # [B, -1, N, N] - # attn = F.gumbel_softmax(attn, hard=True) - attn = attn.softmax(-1) - new_x = (attn @ x).view(B, L, D) - return attn, new_x - - def reverse(self, x, attn): - B, L, D = x.shape - x = x.view(B, -1, self.num_tokens, D) - ori_x = attn.transpose(-2, -1) @ x - return ori_x.view(B, L, D) - - -class MultiScan(nn.Module): - - CHOICES = ('h', 'h_flip', 'v', 'v_flip', 'w2', 'w2_flip', 'w7', 'w7_flip') - - def __init__(self, dim): - super().__init__() - self.choices = MultiScan.CHOICES - self.norms = nn.ModuleList([nn.LayerNorm(dim, elementwise_affine=False) for _ in self.choices]) - self.weights = nn.Parameter(1e-3 * torch.randn(len(self.choices), 1, 1, 1)) - self._iter = 0 - - def forward(self, xs): - weights = self.weights.softmax(0) - xs = [norm(x) for norm, x in zip(self.norms, xs)] - xs = torch.stack(xs) * weights - x = xs.sum(0) - if self._iter % 200 == 0 and torch.distributed.get_rank() == 0: - print(weights.detach().view(-1).tolist()) - self._iter += 1 - return x - - def multi_scan(self, x): - """ - Input @x: shape [B, L, D] - """ - xs = [] - for direction in self.choices: - xs.append(self.scan(x, direction)) - return xs - - def multi_reverse(self, xs): - new_xs = [] - for x, direction in zip(xs, self.choices): - new_xs.append(self.reverse(x, direction)) - return new_xs - - def scan(self, x, direction='h'): - """ - Input @x: shape [B, L, D] - Return torch.Tensor: shape [B, L, D] - """ - B, L, D = x.shape - H = W = int(L ** 0.5) - if direction == 'h': - return x - elif direction == 'h_flip': - return x.flip([1]) - elif direction == 'v': - return x.view(B, H, W, D).transpose(1, 2).reshape(B, L, D) - elif direction == 'v_flip': - return x.view(B, H, W, D).transpose(1, 2).reshape(B, L, D).flip([1]) - elif direction == 'w2': - return direct_tokens(x, w=2, w_first=False) - elif direction == 'w2_flip': - return direct_tokens(x, w=2, w_first=False).flip([1]) - elif direction == 'w7': - return direct_tokens(x, w=7, w_first=False) - elif direction == 'w7_flip': - return direct_tokens(x, w=7, w_first=False).flip([1]) - else: - raise RuntimeError(f'Direction {direction} not found.') - - def reverse(self, x, direction='h'): - """ - Input @x: shape [B, L, D] - Return torch.Tensor: shape [B, L, D] - """ - B, L, D = x.shape - H = W = int(L ** 0.5) - if direction == 'h': - return x - elif direction == 'h_flip': - return x.flip([1]) - elif direction == 'v': - return x.view(B, W, H, D).transpose(1, 2).reshape(B, L, D) - elif direction == 'v_flip': - return x.flip([1]).view(B, W, H, D).transpose(1, 2).reshape(B, L, D) - elif direction == 'w2': - return reverse_tokens(x, w=2, w_first=False) - elif direction == 'w2_flip': - return reverse_tokens(x.flip([1]), w=2, w_first=False) - elif direction == 'w7': - return reverse_tokens(x, w=7, w_first=False) - elif direction == 'w7_flip': - return reverse_tokens(x.flip([1]), w=7, w_first=False) - else: - raise RuntimeError(f'Direction {direction} not found.') - - -class WindowScan(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor, window_size=2, w_first=False): - B, L, C = x.shape - H = W = int(L ** 0.5) - ctx.shape = (B, C, H, W) - ctx.window_size = window_size - ctx.w_first = w_first - - # H = W = int(L ** 0.5) - # Hg = Wg = H // w - # if w_first: - # x = x.view(B, Wg, Hg, w, w, C).permute(0, 2, 4, 1, 3, 5).reshape(B, L, C) - # else: - # x = x.view(B, Hg, Wg, w, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, L, C) - return direct_tokens(x, window_size, w_first) - - @staticmethod - def backward(ctx, grad: torch.Tensor): - return reverse_tokens(grad, ctx.window_size, ctx.w_first), None, None - # out: (b, k, d, l) - B, C, H, W = ctx.shape - L = H * W - ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) - y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) - return y.view(B, -1, H, W) - - - -class BiAttn(nn.Module): - def __init__(self, in_channels, act_ratio=0.125, act_fn=nn.GELU, gate_fn=nn.Sigmoid): - super().__init__() - reduce_channels = int(in_channels * act_ratio) - self.norm = nn.LayerNorm(in_channels) - self.global_reduce = nn.Linear(in_channels, reduce_channels) - # self.local_reduce = nn.Linear(in_channels, reduce_channels) - self.act_fn = act_fn() - self.channel_select = nn.Linear(reduce_channels, in_channels) - # self.spatial_select = nn.Linear(reduce_channels * 2, 1) - self.gate_fn = gate_fn() - - def forward(self, x): - ori_x = x - x = self.norm(x) - x_global = x.mean(1, keepdim=True) - x_global = self.act_fn(self.global_reduce(x_global)) - # x_local = self.act_fn(self.local_reduce(x)) - - c_attn = self.channel_select(x_global) - c_attn = self.gate_fn(c_attn) # [B, 1, C] - # s_attn = self.spatial_select(torch.cat([x_local, x_global.expand(-1, x.shape[1], -1)], dim=-1)) - # s_attn = self.gate_fn(s_attn) # [B, N, 1] - - attn = c_attn #* s_attn # [B, N, C] - return ori_x * attn - - -is_first = True -class DynMamba(nn.Module): - def __init__( - self, - d_model, - d_state=16, - d_conv=4, - expand=2, - dt_rank="auto", - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - conv_bias=True, - bias=False, - use_fast_path=True, # Fused kernel options - layer_idx=None, - device=None, - dtype=None, - bimamba_type="none" - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = int(self.expand * self.d_model) - self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank - self.use_fast_path = use_fast_path - self.layer_idx = layer_idx - self.bimamba_type = bimamba_type - - self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) - - # self.conv1d = nn.Conv1d( - # in_channels=self.d_inner, - # out_channels=self.d_inner, - # bias=conv_bias, - # kernel_size=d_conv, - # groups=self.d_inner, - # padding=d_conv - 1, - # **factory_kwargs, - # ) - - self.activation = "silu" - self.act = nn.SiLU() - - # self.x_proj = nn.Linear( - # self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - # ) - # self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - # # Initialize special dt projection to preserve variance at initialization - # dt_init_std = self.dt_rank**-0.5 * dt_scale - # if dt_init == "constant": - # nn.init.constant_(self.dt_proj.weight, dt_init_std) - # elif dt_init == "random": - # nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) - # else: - # raise NotImplementedError - - # # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - # dt = torch.exp( - # torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) - # + math.log(dt_min) - # ).clamp(min=dt_init_floor) - # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - # inv_dt = dt + torch.log(-torch.expm1(-dt)) - # with torch.no_grad(): - # self.dt_proj.bias.copy_(inv_dt) - # # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit - # self.dt_proj.bias._no_reinit = True - - - self.multi_scan = MultiScan(self.d_inner) - '''new for search''' - A = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - for i in range(len(self.multi_scan.choices)): - setattr(self, f'A_log_{i}', nn.Parameter(A_log)) - getattr(self, f'A_log_{i}')._no_weight_decay = True - - conv1d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - setattr(self, f'conv1d_{i}', conv1d) - - x_proj = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - setattr(self, f'x_proj_{i}', x_proj) - - dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - # Initialize special dt projection to preserve variance at initialization - dt_init_std = self.dt_rank**-0.5 * dt_scale - if dt_init == "constant": - nn.init.constant_(dt_proj.weight, dt_init_std) - elif dt_init == "random": - nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt = torch.exp( - torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ).clamp(min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - with torch.no_grad(): - dt_proj.bias.copy_(inv_dt) - # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit - dt_proj.bias._no_reinit = True - - setattr(self, f'dt_proj_{i}', dt_proj) - - D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - D._no_weight_decay = True - setattr(self, f'D_{i}', D) - - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - - self.attn = BiAttn(self.d_inner) - - return - - # S4D real initialization - A = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D._no_weight_decay = True - - # bidirectional - assert bimamba_type == "v2" - - A_b = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_b_log = torch.log(A_b) # Keep A_b_log in fp32 - self.A_b_log = nn.Parameter(A_b_log) - self.A_b_log._no_weight_decay = True - - self.conv1d_b = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_b = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_b._no_weight_decay = True - - - - '''c''' - A_c = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_c_log = torch.log(A_c) # Keep A_b_log in fp32 - self.A_c_log = nn.Parameter(A_c_log) - self.A_c_log._no_weight_decay = True - - self.conv1d_c = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_c = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_c = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_c = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_c._no_weight_decay = True - - - '''d''' - A_d = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_d_log = torch.log(A_d) # Keep A_b_log in fp32 - self.A_d_log = nn.Parameter(A_d_log) - self.A_d_log._no_weight_decay = True - - self.conv1d_d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_d = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_d = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_d = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_d._no_weight_decay = True - - - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - - self.attn = BiAttn(self.d_inner) - - # self.dyn_scan_a = DynamicScan(self.d_inner * 2) - # self.dyn_scan_b = DynamicScan(self.d_inner * 2) - - def forward(self, hidden_states, inference_params=None): - """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states - """ - batch, seqlen, dim = hidden_states.shape - - conv_state, ssm_state = None, None - if inference_params is not None: - conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: - # The states are updated inplace - out, _, _ = self.step(hidden_states, conv_state, ssm_state) - return out - - - xz = self.in_proj(hidden_states) - - # A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - # In the backward pass we write dx and dz next to each other to avoid torch.cat - if self.use_fast_path and inference_params is None: # Doesn't support outputting the states - xs = self.multi_scan.multi_scan(xz) - outs = [] - for i, xz in enumerate(xs): - xz = rearrange(xz, "b l d -> b d l") - A = -torch.exp(getattr(self, f'A_log_{i}').float()) - conv1d = getattr(self, f'conv1d_{i}') - x_proj = getattr(self, f'x_proj_{i}') - dt_proj = getattr(self, f'dt_proj_{i}') - D = getattr(self, f'D_{i}') - - out = mamba_inner_fn_no_out_proj( - xz, - conv1d.weight, - conv1d.bias, - x_proj.weight, - dt_proj.weight, - A, - None, # input-dependent B - None, # input-dependent C - D.float(), - delta_bias=dt_proj.bias.float(), - delta_softplus=True, - ) - outs.append(rearrange(out, "b d l -> b l d")) - - outs = self.multi_scan.multi_reverse(outs) - outs = [self.attn(out) for out in outs] - out = self.multi_scan(outs) - out = F.linear(out, self.out_proj.weight, self.out_proj.bias) - - return out - - if self.bimamba_type == "v2": - A_b = -torch.exp(self.A_b_log.float()) - A_c = -torch.exp(self.A_c_log.float()) - A_d = -torch.exp(self.A_d_log.float()) - - xz_w2 = direct_tokens(xz, 2) - # attn_a, xz_a = self.dyn_scan_a(xz_w2) - xz_a = rearrange(xz_w2, "b l d -> b d l") - # xz_b = rearrange(direct_tokens(xz, 2, True), "b l d -> b d l") - # attn_b, xz_b = self.dyn_scan_b(xz_w2) - # xz_b = rearrange(xz_b, "b l d -> b d l") - - xz_b = xz_a.flip([-1]) - xz_c = rearrange(xz, "b l d -> b d l") - xz_d = xz_c.flip([-1]) - # xz_d = rearrange(direct_tokens(xz, 14, True), "b l d -> b d l") - - - out = mamba_inner_fn_no_out_proj( - xz_a, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - # print(out.shape) - out_b = mamba_inner_fn_no_out_proj( - xz_b, - self.conv1d_b.weight, - self.conv1d_b.bias, - self.x_proj_b.weight, - self.dt_proj_b.weight, - A_b, - None, - None, - self.D_b.float(), - delta_bias=self.dt_proj_b.bias.float(), - delta_softplus=True, - ) - out_c = mamba_inner_fn_no_out_proj( - xz_c, - self.conv1d_c.weight, - self.conv1d_c.bias, - self.x_proj_c.weight, - self.dt_proj_c.weight, - A_c, - None, - None, - self.D_c.float(), - delta_bias=self.dt_proj_c.bias.float(), - delta_softplus=True, - ) - out_d = mamba_inner_fn_no_out_proj( - xz_d, - self.conv1d_d.weight, - self.conv1d_d.bias, - self.x_proj_d.weight, - self.dt_proj_d.weight, - A_d, - None, - None, - self.D_d.float(), - delta_bias=self.dt_proj_d.bias.float(), - delta_softplus=True, - ) - - # out = F.linear(rearrange(out + out_b, "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) - # out = rearrange(out, "b d l -> b l 1 d") * rearrange(probs_a, "b l k -> b l k 1") # [b l k d] - # out = out.transpose(1, 2).sum(2) # [b k d] - # out_b = rearrange(out_b, "b d l -> b l 1 d") * rearrange(probs_b, "b l k -> b l k 1") # [b l k d] - # out_b = out_b.transpose(1, 2).sum(2) # [b k d] - # out = probs_a.transpose(-2, -1) @ rearrange(out, "b d l -> b l d") # [b l d] - # out_b = probs_b.transpose(-2, -1) @ rearrange(out_b, "b d l -> b l d") # [b l d] - out = rearrange(out, "b d l -> b l d") # [b l d] - out_b = rearrange(out_b.flip([-1]), "b d l -> b l d") # [b l d] - # out = self.dyn_scan_a.reverse(out, attn_a) - out = reverse_tokens(out, 2) - # out_b = self.dyn_scan_b.reverse(out_b, attn_b) - out_b = reverse_tokens(out_b, 2) - out_c = rearrange(out_c, "b d l -> b l d") # [b l d] - out_d = rearrange(out_d.flip([-1]), "b d l -> b l d") # [b l d] - - out = self.attn(out) - out_b = self.attn(out_b) - out_c = self.attn(out_c) - out_d = self.attn(out_d) - - # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) - # out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) - out = F.linear(out + out_b + out_c + out_d, self.out_proj.weight, self.out_proj.bias) - # out = F.linear(out + out_b, self.out_proj.weight, self.out_proj.bias) - else: - out = mamba_inner_fn( - xz, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - self.out_proj.weight, - self.out_proj.bias, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - else: - x, z = xz.chunk(2, dim=1) - # Compute short convolution - if conv_state is not None: - conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W) - if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) - else: - assert self.activation in ["silu", "swish"] - x = causal_conv1d_fn( - x, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - # We're careful here about the layout, to avoid extra transposes. - # We want dt to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) - dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = self.dt_proj.weight @ dt.t() - dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) - B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - assert self.activation in ["silu", "swish"] - y = selective_scan_fn( - x, - dt, - A, - B, - C, - self.D.float(), - z=z, - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - return_last_state=ssm_state is not None, - ) - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(last_state) - y = rearrange(y, "b d l -> b l d") - out = self.out_proj(y) - return out - - def step(self, hidden_states, conv_state, ssm_state): - dtype = hidden_states.dtype - assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" - xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) - x, z = xz.chunk(2, dim=-1) # (B D) - - # Conv step - if causal_conv1d_update is None: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = x - x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv1d.bias is not None: - x = x + self.conv1d.bias - x = self.act(x).to(dtype=dtype) - else: - x = causal_conv1d_update( - x, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - x_db = self.x_proj(x) # (B dt_rank+2*d_state) - dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) - # Don't add dt_bias here - dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - # SSM step - if selective_state_update is None: - # Discretize A and B - dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) - dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) - dB = torch.einsum("bd,bn->bdn", dt, B) - ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) - y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) - y = y + self.D.to(dtype) * x - y = y * self.act(z) # (B D) - else: - y = selective_state_update( - ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True - ) - - out = self.out_proj(y) - return out.unsqueeze(1), conv_state, ssm_state - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - device = self.out_proj.weight.device - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype - ) - ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype - # ssm_dtype = torch.float32 - ssm_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype - ) - return conv_state, ssm_state - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - assert self.layer_idx is not None - if self.layer_idx not in inference_params.key_value_memory_dict: - batch_shape = (batch_size,) - conv_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_conv, - device=self.conv1d.weight.device, - dtype=self.conv1d.weight.dtype, - ) - ssm_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_state, - device=self.dt_proj.weight.device, - dtype=self.dt_proj.weight.dtype, - # dtype=torch.float32, - ) - inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) - else: - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - # TODO: What if batch size changes between generation, and we reuse the same states? - if initialize_states: - conv_state.zero_() - ssm_state.zero_() - return conv_state, ssm_state - - -class Block(nn.Module): - def __init__( - self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False - ): - """ - Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" - - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA/MLP -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Add -> LN -> Mixer, returning both - the hidden_states (output of the mixer) and the residual. - This is purely for performance reasons, as we can fuse add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - super().__init__() - self.residual_in_fp32 = residual_in_fp32 - self.fused_add_norm = fused_add_norm - self.mixer = mixer_cls(dim) - self.norm = norm_cls(dim) - if self.fused_add_norm: - assert RMSNorm is not None, "RMSNorm import fails" - assert isinstance( - self.norm, (nn.LayerNorm, RMSNorm) - ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" - - def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None - ): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Mixer(LN(residual)) - """ - if not self.fused_add_norm: - residual = (hidden_states + residual) if residual is not None else hidden_states - hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn - hidden_states, residual = fused_add_norm_fn( - hidden_states, - self.norm.weight, - self.norm.bias, - residual=residual, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - eps=self.norm.eps, - ) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) - return hidden_states, residual - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) \ No newline at end of file diff --git a/classification/lib/models/dyn_mamba_simple_bak0214.py b/classification/lib/models/dyn_mamba_simple_bak0214.py deleted file mode 100644 index 1c7b09c..0000000 --- a/classification/lib/models/dyn_mamba_simple_bak0214.py +++ /dev/null @@ -1,518 +0,0 @@ -# Copyright (c) 2023, Tri Dao, Albert Gu. - -import math -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor - -from einops import rearrange, repeat -from timm.models.layers import trunc_normal_ - -try: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -except ImportError: - causal_conv1d_fn, causal_conv1d_update = None - -try: - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj -except ImportError: - selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None - -try: - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -except ImportError: - selective_state_update = None - -try: - from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn -except ImportError: - RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None - -is_first = True -class DynMamba(nn.Module): - def __init__( - self, - d_model, - d_state=16, - d_conv=4, - expand=2, - dt_rank="auto", - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - conv_bias=True, - bias=False, - use_fast_path=True, # Fused kernel options - layer_idx=None, - device=None, - dtype=None, - bimamba_type="none" - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = int(self.expand * self.d_model) - self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank - self.use_fast_path = use_fast_path - self.layer_idx = layer_idx - self.bimamba_type = bimamba_type - - self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) - - self.conv1d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.activation = "silu" - self.act = nn.SiLU() - - self.x_proj = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - # Initialize special dt projection to preserve variance at initialization - dt_init_std = self.dt_rank**-0.5 * dt_scale - if dt_init == "constant": - nn.init.constant_(self.dt_proj.weight, dt_init_std) - elif dt_init == "random": - nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt = torch.exp( - torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ).clamp(min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - with torch.no_grad(): - self.dt_proj.bias.copy_(inv_dt) - # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit - self.dt_proj.bias._no_reinit = True - - # S4D real initialization - A = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D._no_weight_decay = True - - # bidirectional - assert bimamba_type == "v2" - - A_b = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_b_log = torch.log(A_b) # Keep A_b_log in fp32 - self.A_b_log = nn.Parameter(A_b_log) - self.A_b_log._no_weight_decay = True - - self.conv1d_b = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_b = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_b._no_weight_decay = True - - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - - # dynamic - d_dyn = self.d_model - self.dyn_proj = nn.Sequential( - nn.Linear(self.d_model, d_dyn, bias=True, **factory_kwargs), - nn.LayerNorm(d_dyn), - nn.GELU(), - nn.Linear(d_dyn, d_dyn, bias=True, **factory_kwargs), - ) - self.dyn_params = nn.Parameter(torch.randn(1, 196, d_dyn) * 0.01) - self.dyn_proj_b = nn.Sequential( - nn.Linear(self.d_model, d_dyn, bias=True, **factory_kwargs), - nn.LayerNorm(d_dyn), - nn.GELU(), - nn.Linear(d_dyn, d_dyn, bias=True, **factory_kwargs), - ) - self.dyn_params_b = nn.Parameter(torch.randn(1, 196, d_dyn) * 0.01) - trunc_normal_(self.dyn_params, std=0.02) - trunc_normal_(self.dyn_params_b, std=0.02) - self._iter = 0 - global is_first - if is_first: - self._first = True - is_first = False - else: - self._first = False - - def forward(self, hidden_states, inference_params=None): - """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states - """ - batch, seqlen, dim = hidden_states.shape - - conv_state, ssm_state = None, None - if inference_params is not None: - conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: - # The states are updated inplace - out, _, _ = self.step(hidden_states, conv_state, ssm_state) - return out - - dyn_xz = hidden_states - new_xz = hidden_states - dyn_xz_proj = self.dyn_proj(dyn_xz) - scale = self.dyn_params.shape[-1] ** -0.5 - probs_a = (self.dyn_params.expand(batch, -1, -1) * scale) @ dyn_xz_proj.transpose(-2, -1) # [b l l] - probs_a = probs_a.softmax(-1) - xz_a = probs_a @ new_xz # [b l d] - xz_a = rearrange( - self.in_proj.weight @ rearrange(xz_a, "b l d -> d (b l)"), - "d (b l) -> b d l", - l=seqlen, - ) - - dyn_xz_proj_b = self.dyn_proj_b(dyn_xz) - probs_b = (self.dyn_params_b.expand(batch, -1, -1) * scale) @ dyn_xz_proj_b.transpose(-2, -1) # [b l l] - probs_b = probs_b.softmax(-1) - xz_b = probs_b @ new_xz # [b l d] - xz_b = rearrange( - self.in_proj.weight @ rearrange(xz_b, "b l d -> d (b l)"), - "d (b l) -> b d l", - l=seqlen, - ) - - # We do matmul and transpose BLH -> HBL at the same time - # xz = rearrange( - # self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), - # "d (b l) -> b d l", - # l=seqlen, - # ) - if self.in_proj.bias is not None: - # xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") - xz_a = xz_a + rearrange(self.in_proj.bias.to(dtype=xz_a.dtype), "d -> d 1") - xz_b = xz_b + rearrange(self.in_proj.bias.to(dtype=xz_b.dtype), "d -> d 1") - - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - # In the backward pass we write dx and dz next to each other to avoid torch.cat - if self.use_fast_path and inference_params is None: # Doesn't support outputting the states - if self.bimamba_type == "v2": - A_b = -torch.exp(self.A_b_log.float()) - - """ - # dynamic - # print(xz.shape) - # new_xz = rearrange(xz, "b d l -> b l d") - dyn_xz = hidden_states - - weight_a = self.dyn_proj(dyn_xz) - xz_a = weight_a.transpose(-2, -1) * xz - - weight_b = self.dyn_proj_b(dyn_xz) - xz_b = weight_b.transpose(-2, -1) * xz - """ - - if self._first and self._iter % 200 == 0 and torch.distributed.get_rank() == 0: - print(probs_a[0, 5]) - print(probs_b[0, 5]) - self._iter += 1 - """ - dyn_xz_proj = self.dyn_proj(dyn_xz) - # dyn_xz_proj = F.linear(dyn_xz, self.dyn_proj.weight, self.dyn_proj.bias) - scale = self.dyn_params.shape[-1] ** -0.5 - probs_a = (self.dyn_params.expand(batch, -1, -1) * scale) @ dyn_xz_proj.transpose(-2, -1) # [b l l] - probs_a = probs_a.softmax(-1) - # probs_a = F.gumbel_softmax(self.dyn_pred(self.dyn_params).expand(batch, -1, -1), dim=-1, hard=True) - xz_a = probs_a @ new_xz # [b l d] - xz_a = rearrange(xz_a, "b l d -> b d l") - - dyn_xz_proj_b = self.dyn_proj_b(dyn_xz) - # dyn_xz_proj_b = F.linear(dyn_xz, self.dyn_proj_b.weight, self.dyn_proj_b.bias) - probs_b = (self.dyn_params_b.expand(batch, -1, -1) * scale) @ dyn_xz_proj_b.transpose(-2, -1) # [b l l] - probs_b = probs_b.softmax(-1) - # probs_b = F.gumbel_softmax(self.dyn_pred_b(self.dyn_params_b).expand(batch, -1, -1), dim=-1, hard=True) - xz_b = probs_b @ new_xz # [b l d] - xz_b = rearrange(xz_b, "b l d -> b d l") - # print(xz_a.shape) - """ - - out = mamba_inner_fn_no_out_proj( - xz_a, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - # print(out.shape) - out_b = mamba_inner_fn_no_out_proj( - xz_b, - self.conv1d_b.weight, - self.conv1d_b.bias, - self.x_proj_b.weight, - self.dt_proj_b.weight, - A_b, - None, - None, - self.D_b.float(), - delta_bias=self.dt_proj_b.bias.float(), - delta_softplus=True, - ) - - # out = F.linear(rearrange(out + out_b, "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) - # out = rearrange(out, "b d l -> b l 1 d") * rearrange(probs_a, "b l k -> b l k 1") # [b l k d] - # out = out.transpose(1, 2).sum(2) # [b k d] - # out_b = rearrange(out_b, "b d l -> b l 1 d") * rearrange(probs_b, "b l k -> b l k 1") # [b l k d] - # out_b = out_b.transpose(1, 2).sum(2) # [b k d] - out = probs_a.transpose(-2, -1) @ rearrange(out, "b d l -> b l d") # [b l d] - out_b = probs_b.transpose(-2, -1) @ rearrange(out_b, "b d l -> b l d") # [b l d] - - # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) - # out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) - out = F.linear(out + out_b, self.out_proj.weight, self.out_proj.bias) - else: - out = mamba_inner_fn( - xz, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - self.out_proj.weight, - self.out_proj.bias, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - else: - x, z = xz.chunk(2, dim=1) - # Compute short convolution - if conv_state is not None: - conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W) - if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) - else: - assert self.activation in ["silu", "swish"] - x = causal_conv1d_fn( - x, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - # We're careful here about the layout, to avoid extra transposes. - # We want dt to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) - dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = self.dt_proj.weight @ dt.t() - dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) - B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - assert self.activation in ["silu", "swish"] - y = selective_scan_fn( - x, - dt, - A, - B, - C, - self.D.float(), - z=z, - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - return_last_state=ssm_state is not None, - ) - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(last_state) - y = rearrange(y, "b d l -> b l d") - out = self.out_proj(y) - return out - - def step(self, hidden_states, conv_state, ssm_state): - dtype = hidden_states.dtype - assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" - xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) - x, z = xz.chunk(2, dim=-1) # (B D) - - # Conv step - if causal_conv1d_update is None: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = x - x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv1d.bias is not None: - x = x + self.conv1d.bias - x = self.act(x).to(dtype=dtype) - else: - x = causal_conv1d_update( - x, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - x_db = self.x_proj(x) # (B dt_rank+2*d_state) - dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) - # Don't add dt_bias here - dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - # SSM step - if selective_state_update is None: - # Discretize A and B - dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) - dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) - dB = torch.einsum("bd,bn->bdn", dt, B) - ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) - y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) - y = y + self.D.to(dtype) * x - y = y * self.act(z) # (B D) - else: - y = selective_state_update( - ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True - ) - - out = self.out_proj(y) - return out.unsqueeze(1), conv_state, ssm_state - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - device = self.out_proj.weight.device - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype - ) - ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype - # ssm_dtype = torch.float32 - ssm_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype - ) - return conv_state, ssm_state - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - assert self.layer_idx is not None - if self.layer_idx not in inference_params.key_value_memory_dict: - batch_shape = (batch_size,) - conv_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_conv, - device=self.conv1d.weight.device, - dtype=self.conv1d.weight.dtype, - ) - ssm_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_state, - device=self.dt_proj.weight.device, - dtype=self.dt_proj.weight.dtype, - # dtype=torch.float32, - ) - inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) - else: - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - # TODO: What if batch size changes between generation, and we reuse the same states? - if initialize_states: - conv_state.zero_() - ssm_state.zero_() - return conv_state, ssm_state - - -class Block(nn.Module): - def __init__( - self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False - ): - """ - Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" - - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA/MLP -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Add -> LN -> Mixer, returning both - the hidden_states (output of the mixer) and the residual. - This is purely for performance reasons, as we can fuse add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - super().__init__() - self.residual_in_fp32 = residual_in_fp32 - self.fused_add_norm = fused_add_norm - self.mixer = mixer_cls(dim) - self.norm = norm_cls(dim) - if self.fused_add_norm: - assert RMSNorm is not None, "RMSNorm import fails" - assert isinstance( - self.norm, (nn.LayerNorm, RMSNorm) - ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" - - def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None - ): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Mixer(LN(residual)) - """ - if not self.fused_add_norm: - residual = (hidden_states + residual) if residual is not None else hidden_states - hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn - hidden_states, residual = fused_add_norm_fn( - hidden_states, - self.norm.weight, - self.norm.bias, - residual=residual, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - eps=self.norm.eps, - ) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) - return hidden_states, residual - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) \ No newline at end of file diff --git a/classification/lib/models/dyn_mamba_simple_bak0215.py b/classification/lib/models/dyn_mamba_simple_bak0215.py deleted file mode 100644 index ede4ef2..0000000 --- a/classification/lib/models/dyn_mamba_simple_bak0215.py +++ /dev/null @@ -1,530 +0,0 @@ -# Copyright (c) 2023, Tri Dao, Albert Gu. - -import math -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor - -from einops import rearrange, repeat -from timm.models.layers import trunc_normal_ - -try: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -except ImportError: - causal_conv1d_fn, causal_conv1d_update = None - -try: - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj -except ImportError: - selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None - -try: - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -except ImportError: - selective_state_update = None - -try: - from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn -except ImportError: - RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None - - - -# def direct_tokens(x, w=7) - - - - -is_first = True -class DynMamba(nn.Module): - def __init__( - self, - d_model, - d_state=16, - d_conv=4, - expand=2, - dt_rank="auto", - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - conv_bias=True, - bias=False, - use_fast_path=True, # Fused kernel options - layer_idx=None, - device=None, - dtype=None, - bimamba_type="none" - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = int(self.expand * self.d_model) - self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank - self.use_fast_path = use_fast_path - self.layer_idx = layer_idx - self.bimamba_type = bimamba_type - - self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) - - self.conv1d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.activation = "silu" - self.act = nn.SiLU() - - self.x_proj = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - # Initialize special dt projection to preserve variance at initialization - dt_init_std = self.dt_rank**-0.5 * dt_scale - if dt_init == "constant": - nn.init.constant_(self.dt_proj.weight, dt_init_std) - elif dt_init == "random": - nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt = torch.exp( - torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ).clamp(min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - with torch.no_grad(): - self.dt_proj.bias.copy_(inv_dt) - # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit - self.dt_proj.bias._no_reinit = True - - # S4D real initialization - A = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D._no_weight_decay = True - - # bidirectional - assert bimamba_type == "v2" - - A_b = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_b_log = torch.log(A_b) # Keep A_b_log in fp32 - self.A_b_log = nn.Parameter(A_b_log) - self.A_b_log._no_weight_decay = True - - self.conv1d_b = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_b = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_b._no_weight_decay = True - - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - - # dynamic - d_dyn = self.d_model - self.dyn_proj = nn.Sequential( - nn.Linear(self.d_model * 2, d_dyn, bias=True, **factory_kwargs), - nn.LayerNorm(d_dyn), - nn.GELU(), - nn.Linear(d_dyn, d_dyn, bias=True, **factory_kwargs), - nn.Linear(d_dyn, 196) - ) - self.dyn_proj_b = nn.Sequential( - nn.Linear(self.d_model * 2, d_dyn, bias=True, **factory_kwargs), - nn.LayerNorm(d_dyn), - nn.GELU(), - nn.Linear(d_dyn, d_dyn, bias=True, **factory_kwargs), - nn.Linear(d_dyn, 196) - ) - self.ln_a = nn.LayerNorm(self.d_model) - self.ln_b = nn.LayerNorm(self.d_model) - self._iter = 0 - global is_first - if is_first: - self._first = True - is_first = False - else: - self._first = False - self.dyn_pos_embed = nn.Parameter(torch.zeros(1, 196, self.d_model * 2)) - trunc_normal_(self.dyn_pos_embed, std=0.02) - - def forward(self, hidden_states, inference_params=None): - """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states - """ - batch, seqlen, dim = hidden_states.shape - - conv_state, ssm_state = None, None - if inference_params is not None: - conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: - # The states are updated inplace - out, _, _ = self.step(hidden_states, conv_state, ssm_state) - return out - - global_x = hidden_states.mean(1, keepdim=True) - dyn_xz = torch.cat([hidden_states, global_x.expand(-1, seqlen, -1)], 2) - dyn_xz = dyn_xz + self.dyn_pos_embed - dyn_xz_proj = self.dyn_proj(dyn_xz).softmax(-1) - probs_a = dyn_xz_proj.transpose(-2, -1) - xz_a = probs_a @ hidden_states # [b l d] - xz_a = self.ln_a(xz_a) - # print(hidden_states.shape, probs_a.shape, xz_a.shape, self.in_proj.weight.shape) - xz_a = rearrange( - self.in_proj.weight @ rearrange(xz_a, "b l d -> d (b l)"), - "d (b l) -> b d l", - l=seqlen, - ) - - dyn_xz_proj_b = self.dyn_proj_b(dyn_xz).softmax(-1) - probs_b = dyn_xz_proj_b.transpose(-2, -1) - xz_b = probs_b @ hidden_states # [b l d] - xz_b = self.ln_b(xz_b) - xz_b = rearrange( - self.in_proj.weight @ rearrange(xz_b, "b l d -> d (b l)"), - "d (b l) -> b d l", - l=seqlen, - ) - - # We do matmul and transpose BLH -> HBL at the same time - # xz = rearrange( - # self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), - # "d (b l) -> b d l", - # l=seqlen, - # ) - if self.in_proj.bias is not None: - # xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") - xz_a = xz_a + rearrange(self.in_proj.bias.to(dtype=xz_a.dtype), "d -> d 1") - xz_b = xz_b + rearrange(self.in_proj.bias.to(dtype=xz_b.dtype), "d -> d 1") - - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - # In the backward pass we write dx and dz next to each other to avoid torch.cat - if self.use_fast_path and inference_params is None: # Doesn't support outputting the states - if self.bimamba_type == "v2": - A_b = -torch.exp(self.A_b_log.float()) - - """ - # dynamic - # print(xz.shape) - # new_xz = rearrange(xz, "b d l -> b l d") - dyn_xz = hidden_states - - weight_a = self.dyn_proj(dyn_xz) - xz_a = weight_a.transpose(-2, -1) * xz - - weight_b = self.dyn_proj_b(dyn_xz) - xz_b = weight_b.transpose(-2, -1) * xz - """ - - if self._first and self._iter % 200 == 0 and torch.distributed.get_rank() == 0: - print(probs_a[0, 5]) - print(probs_b[0, 5]) - self._iter += 1 - """ - dyn_xz_proj = self.dyn_proj(dyn_xz) - # dyn_xz_proj = F.linear(dyn_xz, self.dyn_proj.weight, self.dyn_proj.bias) - scale = self.dyn_params.shape[-1] ** -0.5 - probs_a = (self.dyn_params.expand(batch, -1, -1) * scale) @ dyn_xz_proj.transpose(-2, -1) # [b l l] - probs_a = probs_a.softmax(-1) - # probs_a = F.gumbel_softmax(self.dyn_pred(self.dyn_params).expand(batch, -1, -1), dim=-1, hard=True) - xz_a = probs_a @ new_xz # [b l d] - xz_a = rearrange(xz_a, "b l d -> b d l") - - dyn_xz_proj_b = self.dyn_proj_b(dyn_xz) - # dyn_xz_proj_b = F.linear(dyn_xz, self.dyn_proj_b.weight, self.dyn_proj_b.bias) - probs_b = (self.dyn_params_b.expand(batch, -1, -1) * scale) @ dyn_xz_proj_b.transpose(-2, -1) # [b l l] - probs_b = probs_b.softmax(-1) - # probs_b = F.gumbel_softmax(self.dyn_pred_b(self.dyn_params_b).expand(batch, -1, -1), dim=-1, hard=True) - xz_b = probs_b @ new_xz # [b l d] - xz_b = rearrange(xz_b, "b l d -> b d l") - # print(xz_a.shape) - """ - - out = mamba_inner_fn_no_out_proj( - xz_a, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - # print(out.shape) - out_b = mamba_inner_fn_no_out_proj( - xz_b, - self.conv1d_b.weight, - self.conv1d_b.bias, - self.x_proj_b.weight, - self.dt_proj_b.weight, - A_b, - None, - None, - self.D_b.float(), - delta_bias=self.dt_proj_b.bias.float(), - delta_softplus=True, - ) - - # out = F.linear(rearrange(out + out_b, "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) - # out = rearrange(out, "b d l -> b l 1 d") * rearrange(probs_a, "b l k -> b l k 1") # [b l k d] - # out = out.transpose(1, 2).sum(2) # [b k d] - # out_b = rearrange(out_b, "b d l -> b l 1 d") * rearrange(probs_b, "b l k -> b l k 1") # [b l k d] - # out_b = out_b.transpose(1, 2).sum(2) # [b k d] - # out = probs_a.transpose(-2, -1) @ rearrange(out, "b d l -> b l d") # [b l d] - # out_b = probs_b.transpose(-2, -1) @ rearrange(out_b, "b d l -> b l d") # [b l d] - out = rearrange(out, "b d l -> b l d") # [b l d] - out_b = rearrange(out_b, "b d l -> b l d") # [b l d] - - # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) - # out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) - out = F.linear(out + out_b, self.out_proj.weight, self.out_proj.bias) - else: - out = mamba_inner_fn( - xz, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - self.out_proj.weight, - self.out_proj.bias, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - else: - x, z = xz.chunk(2, dim=1) - # Compute short convolution - if conv_state is not None: - conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W) - if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) - else: - assert self.activation in ["silu", "swish"] - x = causal_conv1d_fn( - x, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - # We're careful here about the layout, to avoid extra transposes. - # We want dt to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) - dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = self.dt_proj.weight @ dt.t() - dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) - B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - assert self.activation in ["silu", "swish"] - y = selective_scan_fn( - x, - dt, - A, - B, - C, - self.D.float(), - z=z, - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - return_last_state=ssm_state is not None, - ) - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(last_state) - y = rearrange(y, "b d l -> b l d") - out = self.out_proj(y) - return out - - def step(self, hidden_states, conv_state, ssm_state): - dtype = hidden_states.dtype - assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" - xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) - x, z = xz.chunk(2, dim=-1) # (B D) - - # Conv step - if causal_conv1d_update is None: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = x - x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv1d.bias is not None: - x = x + self.conv1d.bias - x = self.act(x).to(dtype=dtype) - else: - x = causal_conv1d_update( - x, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - x_db = self.x_proj(x) # (B dt_rank+2*d_state) - dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) - # Don't add dt_bias here - dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - # SSM step - if selective_state_update is None: - # Discretize A and B - dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) - dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) - dB = torch.einsum("bd,bn->bdn", dt, B) - ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) - y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) - y = y + self.D.to(dtype) * x - y = y * self.act(z) # (B D) - else: - y = selective_state_update( - ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True - ) - - out = self.out_proj(y) - return out.unsqueeze(1), conv_state, ssm_state - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - device = self.out_proj.weight.device - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype - ) - ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype - # ssm_dtype = torch.float32 - ssm_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype - ) - return conv_state, ssm_state - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - assert self.layer_idx is not None - if self.layer_idx not in inference_params.key_value_memory_dict: - batch_shape = (batch_size,) - conv_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_conv, - device=self.conv1d.weight.device, - dtype=self.conv1d.weight.dtype, - ) - ssm_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_state, - device=self.dt_proj.weight.device, - dtype=self.dt_proj.weight.dtype, - # dtype=torch.float32, - ) - inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) - else: - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - # TODO: What if batch size changes between generation, and we reuse the same states? - if initialize_states: - conv_state.zero_() - ssm_state.zero_() - return conv_state, ssm_state - - -class Block(nn.Module): - def __init__( - self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False - ): - """ - Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" - - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA/MLP -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Add -> LN -> Mixer, returning both - the hidden_states (output of the mixer) and the residual. - This is purely for performance reasons, as we can fuse add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - super().__init__() - self.residual_in_fp32 = residual_in_fp32 - self.fused_add_norm = fused_add_norm - self.mixer = mixer_cls(dim) - self.norm = norm_cls(dim) - if self.fused_add_norm: - assert RMSNorm is not None, "RMSNorm import fails" - assert isinstance( - self.norm, (nn.LayerNorm, RMSNorm) - ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" - - def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None - ): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Mixer(LN(residual)) - """ - if not self.fused_add_norm: - residual = (hidden_states + residual) if residual is not None else hidden_states - hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn - hidden_states, residual = fused_add_norm_fn( - hidden_states, - self.norm.weight, - self.norm.bias, - residual=residual, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - eps=self.norm.eps, - ) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) - return hidden_states, residual - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) \ No newline at end of file diff --git a/classification/lib/models/dyn_mamba_simple_bak0216.py b/classification/lib/models/dyn_mamba_simple_bak0216.py deleted file mode 100644 index 67564c0..0000000 --- a/classification/lib/models/dyn_mamba_simple_bak0216.py +++ /dev/null @@ -1,576 +0,0 @@ -# Copyright (c) 2023, Tri Dao, Albert Gu. - -import math -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor - -from einops import rearrange, repeat -from timm.models.layers import trunc_normal_ - -try: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -except ImportError: - causal_conv1d_fn, causal_conv1d_update = None - -try: - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj -except ImportError: - selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None - -try: - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -except ImportError: - selective_state_update = None - -try: - from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn -except ImportError: - RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None - - - -def direct_tokens(x, w=7, pos_direction=True): - B, L, C = x.shape - H = W = int(L ** 0.5) - Hg = Wg = H // w - x = x.view(B, Hg, w, Wg, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, L, C) - return x - -def reverse_tokens(x, w=7): - B, L, C = x.shape - H = W = int(L ** 0.5) - Hg = Wg = H // w - x = x.view(B, Hg, Wg, w, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, L, C) - return x - - -class BiAttn(nn.Module): - def __init__(self, in_channels, act_ratio=0.125, act_fn=nn.GELU, gate_fn=nn.Sigmoid): - super().__init__() - reduce_channels = int(in_channels * act_ratio) - self.norm = nn.LayerNorm(in_channels) - self.global_reduce = nn.Linear(in_channels, reduce_channels) - # self.local_reduce = nn.Linear(in_channels, reduce_channels) - self.act_fn = act_fn() - self.channel_select = nn.Linear(reduce_channels, in_channels) - # self.spatial_select = nn.Linear(reduce_channels * 2, 1) - self.gate_fn = gate_fn() - - def forward(self, x): - ori_x = x - x = self.norm(x) - x_global = x.mean(1, keepdim=True) - x_global = self.act_fn(self.global_reduce(x_global)) - # x_local = self.act_fn(self.local_reduce(x)) - - c_attn = self.channel_select(x_global) - c_attn = self.gate_fn(c_attn) # [B, 1, C] - # s_attn = self.spatial_select(torch.cat([x_local, x_global.expand(-1, x.shape[1], -1)], dim=-1)) - # s_attn = self.gate_fn(s_attn) # [B, N, 1] - - attn = c_attn #* s_attn # [B, N, C] - return ori_x * attn - - -is_first = True -class DynMamba(nn.Module): - def __init__( - self, - d_model, - d_state=16, - d_conv=4, - expand=2, - dt_rank="auto", - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - conv_bias=True, - bias=False, - use_fast_path=True, # Fused kernel options - layer_idx=None, - device=None, - dtype=None, - bimamba_type="none" - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = int(self.expand * self.d_model) - self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank - self.use_fast_path = use_fast_path - self.layer_idx = layer_idx - self.bimamba_type = bimamba_type - - self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) - - self.conv1d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.activation = "silu" - self.act = nn.SiLU() - - self.x_proj = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - # Initialize special dt projection to preserve variance at initialization - dt_init_std = self.dt_rank**-0.5 * dt_scale - if dt_init == "constant": - nn.init.constant_(self.dt_proj.weight, dt_init_std) - elif dt_init == "random": - nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt = torch.exp( - torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ).clamp(min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - with torch.no_grad(): - self.dt_proj.bias.copy_(inv_dt) - # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit - self.dt_proj.bias._no_reinit = True - - # S4D real initialization - A = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D._no_weight_decay = True - - # bidirectional - assert bimamba_type == "v2" - - A_b = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_b_log = torch.log(A_b) # Keep A_b_log in fp32 - self.A_b_log = nn.Parameter(A_b_log) - self.A_b_log._no_weight_decay = True - - self.conv1d_b = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_b = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_b._no_weight_decay = True - - - - '''c''' - A_c = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_c_log = torch.log(A_c) # Keep A_b_log in fp32 - self.A_c_log = nn.Parameter(A_c_log) - self.A_c_log._no_weight_decay = True - - self.conv1d_c = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_c = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_c = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_c = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_c._no_weight_decay = True - - - '''d''' - A_d = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_d_log = torch.log(A_d) # Keep A_b_log in fp32 - self.A_d_log = nn.Parameter(A_d_log) - self.A_d_log._no_weight_decay = True - - self.conv1d_d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_d = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_d = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_d = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_d._no_weight_decay = True - - - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - - self.attn = BiAttn(self.d_inner) - - def forward(self, hidden_states, inference_params=None): - """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states - """ - batch, seqlen, dim = hidden_states.shape - - conv_state, ssm_state = None, None - if inference_params is not None: - conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: - # The states are updated inplace - out, _, _ = self.step(hidden_states, conv_state, ssm_state) - return out - - - xz = self.in_proj(hidden_states) - - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - # In the backward pass we write dx and dz next to each other to avoid torch.cat - if self.use_fast_path and inference_params is None: # Doesn't support outputting the states - if self.bimamba_type == "v2": - A_b = -torch.exp(self.A_b_log.float()) - A_c = -torch.exp(self.A_c_log.float()) - A_d = -torch.exp(self.A_d_log.float()) - - xz_a = rearrange(direct_tokens(xz, 2), "b l d -> b d l") - xz_b = xz_a.flip([-1]) - xz_c = rearrange(xz, "b l d -> b d l") - xz_d = xz_c.flip([-1]) - - out = mamba_inner_fn_no_out_proj( - xz_a, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - # print(out.shape) - out_b = mamba_inner_fn_no_out_proj( - xz_b, - self.conv1d_b.weight, - self.conv1d_b.bias, - self.x_proj_b.weight, - self.dt_proj_b.weight, - A_b, - None, - None, - self.D_b.float(), - delta_bias=self.dt_proj_b.bias.float(), - delta_softplus=True, - ) - out_c = mamba_inner_fn_no_out_proj( - xz_c, - self.conv1d_c.weight, - self.conv1d_c.bias, - self.x_proj_c.weight, - self.dt_proj_c.weight, - A_c, - None, - None, - self.D_c.float(), - delta_bias=self.dt_proj_c.bias.float(), - delta_softplus=True, - ) - out_d = mamba_inner_fn_no_out_proj( - xz_d, - self.conv1d_d.weight, - self.conv1d_d.bias, - self.x_proj_d.weight, - self.dt_proj_d.weight, - A_d, - None, - None, - self.D_d.float(), - delta_bias=self.dt_proj_d.bias.float(), - delta_softplus=True, - ) - - # out = F.linear(rearrange(out + out_b, "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) - # out = rearrange(out, "b d l -> b l 1 d") * rearrange(probs_a, "b l k -> b l k 1") # [b l k d] - # out = out.transpose(1, 2).sum(2) # [b k d] - # out_b = rearrange(out_b, "b d l -> b l 1 d") * rearrange(probs_b, "b l k -> b l k 1") # [b l k d] - # out_b = out_b.transpose(1, 2).sum(2) # [b k d] - # out = probs_a.transpose(-2, -1) @ rearrange(out, "b d l -> b l d") # [b l d] - # out_b = probs_b.transpose(-2, -1) @ rearrange(out_b, "b d l -> b l d") # [b l d] - out = rearrange(out, "b d l -> b l d") # [b l d] - out_b = rearrange(out_b.flip([-1]), "b d l -> b l d") # [b l d] - out = reverse_tokens(out, 2) - out_b = reverse_tokens(out_b, 2) - out_c = rearrange(out_c, "b d l -> b l d") # [b l d] - out_d = rearrange(out_d.flip([-1]), "b d l -> b l d") # [b l d] - - out = self.attn(out) - out_b = self.attn(out_b) - out_c = self.attn(out_c) - out_d = self.attn(out_d) - - # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) - # out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) - out = F.linear(out + out_b + out_c + out_d, self.out_proj.weight, self.out_proj.bias) - # out = F.linear(out + out_b, self.out_proj.weight, self.out_proj.bias) - else: - out = mamba_inner_fn( - xz, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - self.out_proj.weight, - self.out_proj.bias, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - else: - x, z = xz.chunk(2, dim=1) - # Compute short convolution - if conv_state is not None: - conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W) - if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) - else: - assert self.activation in ["silu", "swish"] - x = causal_conv1d_fn( - x, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - # We're careful here about the layout, to avoid extra transposes. - # We want dt to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) - dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = self.dt_proj.weight @ dt.t() - dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) - B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - assert self.activation in ["silu", "swish"] - y = selective_scan_fn( - x, - dt, - A, - B, - C, - self.D.float(), - z=z, - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - return_last_state=ssm_state is not None, - ) - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(last_state) - y = rearrange(y, "b d l -> b l d") - out = self.out_proj(y) - return out - - def step(self, hidden_states, conv_state, ssm_state): - dtype = hidden_states.dtype - assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" - xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) - x, z = xz.chunk(2, dim=-1) # (B D) - - # Conv step - if causal_conv1d_update is None: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = x - x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv1d.bias is not None: - x = x + self.conv1d.bias - x = self.act(x).to(dtype=dtype) - else: - x = causal_conv1d_update( - x, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - x_db = self.x_proj(x) # (B dt_rank+2*d_state) - dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) - # Don't add dt_bias here - dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - # SSM step - if selective_state_update is None: - # Discretize A and B - dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) - dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) - dB = torch.einsum("bd,bn->bdn", dt, B) - ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) - y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) - y = y + self.D.to(dtype) * x - y = y * self.act(z) # (B D) - else: - y = selective_state_update( - ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True - ) - - out = self.out_proj(y) - return out.unsqueeze(1), conv_state, ssm_state - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - device = self.out_proj.weight.device - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype - ) - ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype - # ssm_dtype = torch.float32 - ssm_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype - ) - return conv_state, ssm_state - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - assert self.layer_idx is not None - if self.layer_idx not in inference_params.key_value_memory_dict: - batch_shape = (batch_size,) - conv_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_conv, - device=self.conv1d.weight.device, - dtype=self.conv1d.weight.dtype, - ) - ssm_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_state, - device=self.dt_proj.weight.device, - dtype=self.dt_proj.weight.dtype, - # dtype=torch.float32, - ) - inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) - else: - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - # TODO: What if batch size changes between generation, and we reuse the same states? - if initialize_states: - conv_state.zero_() - ssm_state.zero_() - return conv_state, ssm_state - - -class Block(nn.Module): - def __init__( - self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False - ): - """ - Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" - - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA/MLP -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Add -> LN -> Mixer, returning both - the hidden_states (output of the mixer) and the residual. - This is purely for performance reasons, as we can fuse add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - super().__init__() - self.residual_in_fp32 = residual_in_fp32 - self.fused_add_norm = fused_add_norm - self.mixer = mixer_cls(dim) - self.norm = norm_cls(dim) - if self.fused_add_norm: - assert RMSNorm is not None, "RMSNorm import fails" - assert isinstance( - self.norm, (nn.LayerNorm, RMSNorm) - ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" - - def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None - ): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Mixer(LN(residual)) - """ - if not self.fused_add_norm: - residual = (hidden_states + residual) if residual is not None else hidden_states - hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn - hidden_states, residual = fused_add_norm_fn( - hidden_states, - self.norm.weight, - self.norm.bias, - residual=residual, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - eps=self.norm.eps, - ) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) - return hidden_states, residual - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) \ No newline at end of file diff --git a/classification/lib/models/dyn_mamba_simple_bak0217.py b/classification/lib/models/dyn_mamba_simple_bak0217.py deleted file mode 100644 index 3f9bc1f..0000000 --- a/classification/lib/models/dyn_mamba_simple_bak0217.py +++ /dev/null @@ -1,588 +0,0 @@ -# Copyright (c) 2023, Tri Dao, Albert Gu. - -import math -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor - -from einops import rearrange, repeat -from timm.models.layers import trunc_normal_ - -try: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -except ImportError: - causal_conv1d_fn, causal_conv1d_update = None - -try: - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj -except ImportError: - selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None - -try: - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -except ImportError: - selective_state_update = None - -try: - from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn -except ImportError: - RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None - - - -def direct_tokens(x, w=2, pos_direction=True): - B, L, C = x.shape - H = W = int(L ** 0.5) - Hg = Wg = H // w - xs = x.chunk(4, dim=2) - new_xs = [] - for idx, x in enumerate(xs): - if idx % 2 == 0: - x = x.view(B, Hg, w, Wg, w, -1).permute(0, 1, 3, 2, 4, 5).reshape(B, L, -1) - new_xs.append(x) - x = torch.cat(new_xs, 2) - return x - -def reverse_tokens(x, w=2): - B, L, C = x.shape - H = W = int(L ** 0.5) - Hg = Wg = H // w - xs = x.chunk(4, dim=2) - new_xs = [] - for idx, x in enumerate(xs): - if idx % 2 == 0: - x = x.view(B, Hg, Wg, w, w, -1).permute(0, 1, 3, 2, 4, 5).reshape(B, L, -1) - new_xs.append(x) - x = torch.cat(new_xs, 2) - return x - - -class BiAttn(nn.Module): - def __init__(self, in_channels, act_ratio=0.25, act_fn=nn.GELU, gate_fn=nn.Sigmoid): - super().__init__() - reduce_channels = int(in_channels * act_ratio) - self.norm = nn.LayerNorm(in_channels) - self.global_reduce = nn.Linear(in_channels, reduce_channels) - self.local_reduce = nn.Linear(in_channels, reduce_channels) - self.act_fn = act_fn() - self.channel_select = nn.Linear(reduce_channels, in_channels) - self.spatial_select = nn.Linear(reduce_channels * 2, 1) - self.gate_fn = gate_fn() - - def forward(self, x): - ori_x = x - x = self.norm(x) - x_global = x.mean(1, keepdim=True) - x_global = self.act_fn(self.global_reduce(x_global)) - x_local = self.act_fn(self.local_reduce(x)) - - c_attn = self.channel_select(x_global) - c_attn = self.gate_fn(c_attn) # [B, 1, C] - s_attn = self.spatial_select(torch.cat([x_local, x_global.expand(-1, x.shape[1], -1)], dim=-1)) - s_attn = self.gate_fn(s_attn) # [B, N, 1] - - attn = c_attn * s_attn # [B, N, C] - return ori_x * attn - - -is_first = True -class DynMamba(nn.Module): - def __init__( - self, - d_model, - d_state=16, - d_conv=4, - expand=2, - dt_rank="auto", - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - conv_bias=True, - bias=False, - use_fast_path=True, # Fused kernel options - layer_idx=None, - device=None, - dtype=None, - bimamba_type="none" - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = int(self.expand * self.d_model) - self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank - self.use_fast_path = use_fast_path - self.layer_idx = layer_idx - self.bimamba_type = bimamba_type - - self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) - - self.conv1d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.activation = "silu" - self.act = nn.SiLU() - - self.x_proj = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - # Initialize special dt projection to preserve variance at initialization - dt_init_std = self.dt_rank**-0.5 * dt_scale - if dt_init == "constant": - nn.init.constant_(self.dt_proj.weight, dt_init_std) - elif dt_init == "random": - nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt = torch.exp( - torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ).clamp(min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - with torch.no_grad(): - self.dt_proj.bias.copy_(inv_dt) - # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit - self.dt_proj.bias._no_reinit = True - - # S4D real initialization - A = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D._no_weight_decay = True - - # bidirectional - assert bimamba_type == "v2" - - A_b = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_b_log = torch.log(A_b) # Keep A_b_log in fp32 - self.A_b_log = nn.Parameter(A_b_log) - self.A_b_log._no_weight_decay = True - - self.conv1d_b = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_b = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_b._no_weight_decay = True - - - - '''c''' - A_c = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_c_log = torch.log(A_c) # Keep A_b_log in fp32 - self.A_c_log = nn.Parameter(A_c_log) - self.A_c_log._no_weight_decay = True - - self.conv1d_c = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_c = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_c = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_c = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_c._no_weight_decay = True - - - '''d''' - A_d = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_d_log = torch.log(A_d) # Keep A_b_log in fp32 - self.A_d_log = nn.Parameter(A_d_log) - self.A_d_log._no_weight_decay = True - - self.conv1d_d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_d = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_d = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_d = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_d._no_weight_decay = True - - - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - - self.attn = BiAttn(self.d_inner) - - def forward(self, hidden_states, inference_params=None): - """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states - """ - batch, seqlen, dim = hidden_states.shape - - conv_state, ssm_state = None, None - if inference_params is not None: - conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: - # The states are updated inplace - out, _, _ = self.step(hidden_states, conv_state, ssm_state) - return out - - - xz = self.in_proj(hidden_states) - - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - # In the backward pass we write dx and dz next to each other to avoid torch.cat - if self.use_fast_path and inference_params is None: # Doesn't support outputting the states - if self.bimamba_type == "v2": - A_b = -torch.exp(self.A_b_log.float()) - # A_c = -torch.exp(self.A_c_log.float()) - # A_d = -torch.exp(self.A_d_log.float()) - - xz_a = rearrange(direct_tokens(xz, 2), "b l d -> b d l") - xz_b = xz_a.flip([-1]) - # xz_c = rearrange(xz, "b l d -> b d l") - # xz_d = xz_c.flip([-1]) - - out = mamba_inner_fn_no_out_proj( - xz_a, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - # print(out.shape) - out_b = mamba_inner_fn_no_out_proj( - xz_b, - self.conv1d_b.weight, - self.conv1d_b.bias, - self.x_proj_b.weight, - self.dt_proj_b.weight, - A_b, - None, - None, - self.D_b.float(), - delta_bias=self.dt_proj_b.bias.float(), - delta_softplus=True, - ) - # out_c = mamba_inner_fn_no_out_proj( - # xz_c, - # self.conv1d_c.weight, - # self.conv1d_c.bias, - # self.x_proj_c.weight, - # self.dt_proj_c.weight, - # A_c, - # None, - # None, - # self.D_c.float(), - # delta_bias=self.dt_proj_c.bias.float(), - # delta_softplus=True, - # ) - # out_d = mamba_inner_fn_no_out_proj( - # xz_d, - # self.conv1d_d.weight, - # self.conv1d_d.bias, - # self.x_proj_d.weight, - # self.dt_proj_d.weight, - # A_d, - # None, - # None, - # self.D_d.float(), - # delta_bias=self.dt_proj_d.bias.float(), - # delta_softplus=True, - # ) - - # out = F.linear(rearrange(out + out_b, "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) - # out = rearrange(out, "b d l -> b l 1 d") * rearrange(probs_a, "b l k -> b l k 1") # [b l k d] - # out = out.transpose(1, 2).sum(2) # [b k d] - # out_b = rearrange(out_b, "b d l -> b l 1 d") * rearrange(probs_b, "b l k -> b l k 1") # [b l k d] - # out_b = out_b.transpose(1, 2).sum(2) # [b k d] - # out = probs_a.transpose(-2, -1) @ rearrange(out, "b d l -> b l d") # [b l d] - # out_b = probs_b.transpose(-2, -1) @ rearrange(out_b, "b d l -> b l d") # [b l d] - out = rearrange(out, "b d l -> b l d") # [b l d] - out_b = rearrange(out_b.flip([-1]), "b d l -> b l d") # [b l d] - out = reverse_tokens(out, 2) - out_b = reverse_tokens(out_b) - # out_c = rearrange(out_c, "b d l -> b l d") # [b l d] - # out_d = rearrange(out_d.flip([-1]), "b d l -> b l d") # [b l d] - - out = self.attn(out) - out_b = self.attn(out_b) - # out_c = self.attn(out_c) - # out_d = self.attn(out_d) - - # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) - # out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) - # out = F.linear(out + out_b + out_c + out_d, self.out_proj.weight, self.out_proj.bias) - out = F.linear(out + out_b, self.out_proj.weight, self.out_proj.bias) - else: - out = mamba_inner_fn( - xz, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - self.out_proj.weight, - self.out_proj.bias, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - else: - x, z = xz.chunk(2, dim=1) - # Compute short convolution - if conv_state is not None: - conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W) - if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) - else: - assert self.activation in ["silu", "swish"] - x = causal_conv1d_fn( - x, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - # We're careful here about the layout, to avoid extra transposes. - # We want dt to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) - dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = self.dt_proj.weight @ dt.t() - dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) - B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - assert self.activation in ["silu", "swish"] - y = selective_scan_fn( - x, - dt, - A, - B, - C, - self.D.float(), - z=z, - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - return_last_state=ssm_state is not None, - ) - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(last_state) - y = rearrange(y, "b d l -> b l d") - out = self.out_proj(y) - return out - - def step(self, hidden_states, conv_state, ssm_state): - dtype = hidden_states.dtype - assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" - xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) - x, z = xz.chunk(2, dim=-1) # (B D) - - # Conv step - if causal_conv1d_update is None: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = x - x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv1d.bias is not None: - x = x + self.conv1d.bias - x = self.act(x).to(dtype=dtype) - else: - x = causal_conv1d_update( - x, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - x_db = self.x_proj(x) # (B dt_rank+2*d_state) - dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) - # Don't add dt_bias here - dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - # SSM step - if selective_state_update is None: - # Discretize A and B - dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) - dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) - dB = torch.einsum("bd,bn->bdn", dt, B) - ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) - y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) - y = y + self.D.to(dtype) * x - y = y * self.act(z) # (B D) - else: - y = selective_state_update( - ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True - ) - - out = self.out_proj(y) - return out.unsqueeze(1), conv_state, ssm_state - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - device = self.out_proj.weight.device - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype - ) - ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype - # ssm_dtype = torch.float32 - ssm_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype - ) - return conv_state, ssm_state - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - assert self.layer_idx is not None - if self.layer_idx not in inference_params.key_value_memory_dict: - batch_shape = (batch_size,) - conv_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_conv, - device=self.conv1d.weight.device, - dtype=self.conv1d.weight.dtype, - ) - ssm_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_state, - device=self.dt_proj.weight.device, - dtype=self.dt_proj.weight.dtype, - # dtype=torch.float32, - ) - inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) - else: - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - # TODO: What if batch size changes between generation, and we reuse the same states? - if initialize_states: - conv_state.zero_() - ssm_state.zero_() - return conv_state, ssm_state - - -class Block(nn.Module): - def __init__( - self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False - ): - """ - Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" - - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA/MLP -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Add -> LN -> Mixer, returning both - the hidden_states (output of the mixer) and the residual. - This is purely for performance reasons, as we can fuse add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - super().__init__() - self.residual_in_fp32 = residual_in_fp32 - self.fused_add_norm = fused_add_norm - self.mixer = mixer_cls(dim) - self.norm = norm_cls(dim) - if self.fused_add_norm: - assert RMSNorm is not None, "RMSNorm import fails" - assert isinstance( - self.norm, (nn.LayerNorm, RMSNorm) - ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" - - def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None - ): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Mixer(LN(residual)) - """ - if not self.fused_add_norm: - residual = (hidden_states + residual) if residual is not None else hidden_states - hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn - hidden_states, residual = fused_add_norm_fn( - hidden_states, - self.norm.weight, - self.norm.bias, - residual=residual, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - eps=self.norm.eps, - ) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) - return hidden_states, residual - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) \ No newline at end of file diff --git a/classification/lib/models/dyn_mamba_simple_bak0218.py b/classification/lib/models/dyn_mamba_simple_bak0218.py deleted file mode 100644 index 0067245..0000000 --- a/classification/lib/models/dyn_mamba_simple_bak0218.py +++ /dev/null @@ -1,591 +0,0 @@ -# Copyright (c) 2023, Tri Dao, Albert Gu. - -import math -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor - -from einops import rearrange, repeat -from timm.models.layers import trunc_normal_ - -try: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -except ImportError: - causal_conv1d_fn, causal_conv1d_update = None - -try: - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj -except ImportError: - selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None - -try: - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -except ImportError: - selective_state_update = None - -try: - from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn -except ImportError: - RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None - - - -def direct_tokens(x, w=7, pos_direction=True): - B, L, C = x.shape - H = W = int(L ** 0.5) - Hg = Wg = H // w - x = x.view(B, Hg, w, Wg, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, L, C) - return x - -def reverse_tokens(x, w=7): - B, L, C = x.shape - H = W = int(L ** 0.5) - Hg = Wg = H // w - x = x.view(B, Hg, Wg, w, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, L, C) - return x - - -class BiAttn(nn.Module): - def __init__(self, in_channels, act_ratio=0.125, act_fn=nn.GELU, gate_fn=nn.Sigmoid): - super().__init__() - reduce_channels = int(in_channels * act_ratio) - self.norm = nn.LayerNorm(in_channels) - self.global_reduce = nn.Linear(in_channels, reduce_channels) - self.local_reduce = nn.Linear(in_channels, reduce_channels) - self.act_fn = act_fn() - self.channel_select = nn.Linear(reduce_channels, in_channels) - self.spatial_select = nn.Linear(reduce_channels * 2, 1) - self.gate_fn = gate_fn() - - def forward(self, x): - ori_x = x - x = self.norm(x) - x_global = x.mean(1, keepdim=True) - x_global = self.act_fn(self.global_reduce(x_global)) - x_local = self.act_fn(self.local_reduce(x)) - - c_attn = self.channel_select(x_global) - c_attn = self.gate_fn(c_attn) # [B, 1, C] - s_attn = self.spatial_select(torch.cat([x_local, x_global.expand(-1, x.shape[1], -1)], dim=-1)) - s_attn = self.gate_fn(s_attn) # [B, N, 1] - - attn = c_attn * s_attn # [B, N, C] - return ori_x * attn - - -is_first = True -class DynMamba(nn.Module): - def __init__( - self, - d_model, - d_state=16, - d_conv=4, - expand=2, - dt_rank="auto", - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - conv_bias=True, - bias=False, - use_fast_path=True, # Fused kernel options - layer_idx=None, - device=None, - dtype=None, - bimamba_type="none" - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = int(self.expand * self.d_model) - self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank - self.use_fast_path = use_fast_path - self.layer_idx = layer_idx - self.bimamba_type = bimamba_type - - self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) - - self.conv1d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.activation = "silu" - self.act = nn.SiLU() - - self.x_proj = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - # Initialize special dt projection to preserve variance at initialization - dt_init_std = self.dt_rank**-0.5 * dt_scale - if dt_init == "constant": - nn.init.constant_(self.dt_proj.weight, dt_init_std) - elif dt_init == "random": - nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt = torch.exp( - torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ).clamp(min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - with torch.no_grad(): - self.dt_proj.bias.copy_(inv_dt) - # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit - self.dt_proj.bias._no_reinit = True - - # S4D real initialization - A = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D._no_weight_decay = True - - # bidirectional - assert bimamba_type == "v2" - - A_b = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_b_log = torch.log(A_b) # Keep A_b_log in fp32 - self.A_b_log = nn.Parameter(A_b_log) - self.A_b_log._no_weight_decay = True - - self.conv1d_b = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_b = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_b._no_weight_decay = True - - - d_select = 192 - self.window_select = nn.Sequential( - nn.Linear(self.d_model, d_select), - nn.LayerNorm(d_select), - nn.GELU(), - nn.Linear(d_select, 2), - ) - - - # '''c''' - # A_c = repeat( - # torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - # "n -> d n", - # d=self.d_inner, - # ).contiguous() - # A_c_log = torch.log(A_c) # Keep A_b_log in fp32 - # self.A_c_log = nn.Parameter(A_c_log) - # self.A_c_log._no_weight_decay = True - - # self.conv1d_c = nn.Conv1d( - # in_channels=self.d_inner, - # out_channels=self.d_inner, - # bias=conv_bias, - # kernel_size=d_conv, - # groups=self.d_inner, - # padding=d_conv - 1, - # **factory_kwargs, - # ) - - # self.x_proj_c = nn.Linear( - # self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - # ) - # self.dt_proj_c = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - # self.D_c = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - # self.D_c._no_weight_decay = True - - - # '''d''' - # A_d = repeat( - # torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - # "n -> d n", - # d=self.d_inner, - # ).contiguous() - # A_d_log = torch.log(A_d) # Keep A_b_log in fp32 - # self.A_d_log = nn.Parameter(A_d_log) - # self.A_d_log._no_weight_decay = True - - # self.conv1d_d = nn.Conv1d( - # in_channels=self.d_inner, - # out_channels=self.d_inner, - # bias=conv_bias, - # kernel_size=d_conv, - # groups=self.d_inner, - # padding=d_conv - 1, - # **factory_kwargs, - # ) - - # self.x_proj_d = nn.Linear( - # self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - # ) - # self.dt_proj_d = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - # self.D_d = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - # self.D_d._no_weight_decay = True - - - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - - - def forward(self, hidden_states, inference_params=None): - """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states - """ - batch, seqlen, dim = hidden_states.shape - - conv_state, ssm_state = None, None - if inference_params is not None: - conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: - # The states are updated inplace - out, _, _ = self.step(hidden_states, conv_state, ssm_state) - return out - - - xz = self.in_proj(hidden_states) - - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - # In the backward pass we write dx and dz next to each other to avoid torch.cat - if self.use_fast_path and inference_params is None: # Doesn't support outputting the states - if self.bimamba_type == "v2": - A_b = -torch.exp(self.A_b_log.float()) - # A_c = -torch.exp(self.A_c_log.float()) - # A_d = -torch.exp(self.A_d_log.float()) - - - # xz_b = xz_a.flip([-1]) - # xz_d = xz_c.flip([-1]) - - H = W = int(seqlen ** 0.5) - select = hidden_states.view(batch, H // 2, 2, W, -1).mean([2, 3]) # [B, H // 2, D] - select = self.window_select(select) - # select = F.gumbel_softmax(select, hard=True).view(batch, H // 2, 1, 1, -1) # [B, H // 2, 1, 1, 2] - select = F.softmax(select).view(batch, H // 2, 1, 1, -1) # [B, H // 2, 1, 1, 2] - select = select.repeat(1, 1, 2, W, 1).view(batch, seqlen, -1) - xz_a = direct_tokens(xz, 2) * select[:, :, :1] + xz * select[:, :, 1:2] - - xz_a = rearrange(xz_a, "b l d -> b d l") - xz_b = xz_a.flip([-1]) - - out = mamba_inner_fn_no_out_proj( - xz_a, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - # print(out.shape) - out_b = mamba_inner_fn_no_out_proj( - xz_b, - self.conv1d_b.weight, - self.conv1d_b.bias, - self.x_proj_b.weight, - self.dt_proj_b.weight, - A_b, - None, - None, - self.D_b.float(), - delta_bias=self.dt_proj_b.bias.float(), - delta_softplus=True, - ) - # out_c = mamba_inner_fn_no_out_proj( - # xz_c, - # self.conv1d_c.weight, - # self.conv1d_c.bias, - # self.x_proj_c.weight, - # self.dt_proj_c.weight, - # A_c, - # None, - # None, - # self.D_c.float(), - # delta_bias=self.dt_proj_c.bias.float(), - # delta_softplus=True, - # ) - # out_d = mamba_inner_fn_no_out_proj( - # xz_d, - # self.conv1d_d.weight, - # self.conv1d_d.bias, - # self.x_proj_d.weight, - # self.dt_proj_d.weight, - # A_d, - # None, - # None, - # self.D_d.float(), - # delta_bias=self.dt_proj_d.bias.float(), - # delta_softplus=True, - # ) - - # out = F.linear(rearrange(out + out_b, "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) - # out = rearrange(out, "b d l -> b l 1 d") * rearrange(probs_a, "b l k -> b l k 1") # [b l k d] - # out = out.transpose(1, 2).sum(2) # [b k d] - # out_b = rearrange(out_b, "b d l -> b l 1 d") * rearrange(probs_b, "b l k -> b l k 1") # [b l k d] - # out_b = out_b.transpose(1, 2).sum(2) # [b k d] - # out = probs_a.transpose(-2, -1) @ rearrange(out, "b d l -> b l d") # [b l d] - # out_b = probs_b.transpose(-2, -1) @ rearrange(out_b, "b d l -> b l d") # [b l d] - out = rearrange(out, "b d l -> b l d") # [b l d] - out_b = rearrange(out_b.flip([-1]), "b d l -> b l d") # [b l d] - # out = reverse_tokens(out, 2) - # out_b = reverse_tokens(out_b) - - out = reverse_tokens(out, 2) * select[:, :, :1] + out * select[:, :, 1:2] - out_b = reverse_tokens(out_b) * select[:, :, :1] + out_b * select[:, :, 1:2] - # out_c = rearrange(out_c, "b d l -> b l d") # [b l d] - # out_d = rearrange(out_d.flip([-1]), "b d l -> b l d") # [b l d] - - # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) - # out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) - # out = F.linear(out + out_b + out_c + out_d, self.out_proj.weight, self.out_proj.bias) - out = F.linear(out + out_b, self.out_proj.weight, self.out_proj.bias) - else: - out = mamba_inner_fn( - xz, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - self.out_proj.weight, - self.out_proj.bias, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - else: - x, z = xz.chunk(2, dim=1) - # Compute short convolution - if conv_state is not None: - conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W) - if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) - else: - assert self.activation in ["silu", "swish"] - x = causal_conv1d_fn( - x, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - # We're careful here about the layout, to avoid extra transposes. - # We want dt to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) - dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = self.dt_proj.weight @ dt.t() - dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) - B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - assert self.activation in ["silu", "swish"] - y = selective_scan_fn( - x, - dt, - A, - B, - C, - self.D.float(), - z=z, - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - return_last_state=ssm_state is not None, - ) - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(last_state) - y = rearrange(y, "b d l -> b l d") - out = self.out_proj(y) - return out - - def step(self, hidden_states, conv_state, ssm_state): - dtype = hidden_states.dtype - assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" - xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) - x, z = xz.chunk(2, dim=-1) # (B D) - - # Conv step - if causal_conv1d_update is None: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = x - x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv1d.bias is not None: - x = x + self.conv1d.bias - x = self.act(x).to(dtype=dtype) - else: - x = causal_conv1d_update( - x, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - x_db = self.x_proj(x) # (B dt_rank+2*d_state) - dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) - # Don't add dt_bias here - dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - # SSM step - if selective_state_update is None: - # Discretize A and B - dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) - dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) - dB = torch.einsum("bd,bn->bdn", dt, B) - ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) - y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) - y = y + self.D.to(dtype) * x - y = y * self.act(z) # (B D) - else: - y = selective_state_update( - ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True - ) - - out = self.out_proj(y) - return out.unsqueeze(1), conv_state, ssm_state - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - device = self.out_proj.weight.device - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype - ) - ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype - # ssm_dtype = torch.float32 - ssm_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype - ) - return conv_state, ssm_state - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - assert self.layer_idx is not None - if self.layer_idx not in inference_params.key_value_memory_dict: - batch_shape = (batch_size,) - conv_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_conv, - device=self.conv1d.weight.device, - dtype=self.conv1d.weight.dtype, - ) - ssm_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_state, - device=self.dt_proj.weight.device, - dtype=self.dt_proj.weight.dtype, - # dtype=torch.float32, - ) - inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) - else: - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - # TODO: What if batch size changes between generation, and we reuse the same states? - if initialize_states: - conv_state.zero_() - ssm_state.zero_() - return conv_state, ssm_state - - -class Block(nn.Module): - def __init__( - self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False - ): - """ - Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" - - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA/MLP -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Add -> LN -> Mixer, returning both - the hidden_states (output of the mixer) and the residual. - This is purely for performance reasons, as we can fuse add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - super().__init__() - self.residual_in_fp32 = residual_in_fp32 - self.fused_add_norm = fused_add_norm - self.mixer = mixer_cls(dim) - self.norm = norm_cls(dim) - if self.fused_add_norm: - assert RMSNorm is not None, "RMSNorm import fails" - assert isinstance( - self.norm, (nn.LayerNorm, RMSNorm) - ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" - - def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None - ): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Mixer(LN(residual)) - """ - if not self.fused_add_norm: - residual = (hidden_states + residual) if residual is not None else hidden_states - hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn - hidden_states, residual = fused_add_norm_fn( - hidden_states, - self.norm.weight, - self.norm.bias, - residual=residual, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - eps=self.norm.eps, - ) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) - return hidden_states, residual - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) \ No newline at end of file diff --git a/classification/lib/models/dyn_mamba_simple_search.py b/classification/lib/models/dyn_mamba_simple_search.py deleted file mode 100644 index c8dc9cb..0000000 --- a/classification/lib/models/dyn_mamba_simple_search.py +++ /dev/null @@ -1,856 +0,0 @@ -# Copyright (c) 2023, Tri Dao, Albert Gu. - -import math -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor - -from einops import rearrange, repeat -from timm.models.layers import trunc_normal_ -import logging - - -try: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -except ImportError: - causal_conv1d_fn, causal_conv1d_update = None - -try: - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj -except ImportError: - selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None - -try: - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -except ImportError: - selective_state_update = None - -try: - from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn -except ImportError: - RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None - - - -def direct_tokens(x, w=7, H=14, W=14, w_first=False): - B, L, C = x.shape - x = x.view(B, H, W, C) - Hg, Wg = math.ceil(H / w), math.ceil(W / w) - if H % w != 0 or W % w != 0: - newH, newW = Hg * w, Wg * w - x = F.pad(x, (0, 0, 0, newW - W, 0, newH - H)) - if w_first: - x = x.view(B, Hg, w, Wg, w, C).permute(0, 3, 1, 4, 2, 5).reshape(B, -1, C) - else: - x = x.view(B, Hg, w, Wg, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, -1, C) - return x - -def reverse_tokens(x, w=7, H=14, W=14, w_first=False): - B, L, C = x.shape - Hg, Wg = math.ceil(H / w), math.ceil(W / w) - if H % w != 0 or W % w != 0: - if w_first: - x = x.view(B, Wg, Hg, w, w, C).permute(0, 2, 4, 1, 3, 5).reshape(B, Hg * w, Wg * w, C) - else: - x = x.view(B, Hg, Wg, w, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, Hg * w, Wg * w, C) - x = x[:, :H, :W].reshape(B, -1, C) - else: - if w_first: - x = x.view(B, Wg, Hg, w, w, C).permute(0, 2, 4, 1, 3, 5).reshape(B, L, C) - else: - x = x.view(B, Hg, Wg, w, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, L, C) - return x - - -class DynamicScan(nn.Module): - def __init__(self, dim, hidden_dim=96, window_size=2): - super().__init__() - self.window_size = window_size - self.num_tokens = window_size**2 - self.tokens = nn.Parameter(torch.zeros(1, 1, self.num_tokens, dim)) - - def forward(self, x): - B, L, D = x.shape - x = x.view(B, -1, self.num_tokens, D) - attn = self.tokens.expand(B, x.shape[1], -1, -1) @ x.transpose(-2, -1) # [B, -1, N, N] - # attn = F.gumbel_softmax(attn, hard=True) - attn = attn.softmax(-1) - new_x = (attn @ x).view(B, L, D) - return attn, new_x - - def reverse(self, x, attn): - B, L, D = x.shape - x = x.view(B, -1, self.num_tokens, D) - ori_x = attn.transpose(-2, -1) @ x - return ori_x.view(B, L, D) - - -class MultiScan(nn.Module): - - ALL_CHOICES = ('h', 'h_flip', 'v', 'v_flip', 'w2', 'w2_flip', 'w7', 'w7_flip') - - def __init__(self, dim, choices=None, token_size=(14, 14)): - super().__init__() - self.token_size = token_size - if choices is None: - self.choices = MultiScan.ALL_CHOICES - self.norms = nn.ModuleList([nn.LayerNorm(dim, elementwise_affine=False) for _ in self.choices]) - self.weights = nn.Parameter(1e-3 * torch.randn(len(self.choices), 1, 1, 1)) - self._iter = 0 - self.logger = logging.getLogger() - self.search = True - else: - self.choices = choices - self.search = False - - def forward(self, xs): - if self.search: - weights = self.weights.softmax(0) - xs = [norm(x) for norm, x in zip(self.norms, xs)] - xs = torch.stack(xs) * weights - x = xs.sum(0) - if self._iter % 200 == 0 and torch.distributed.get_rank() == 0: - self.logger.info(str(weights.detach().view(-1).tolist())) - self._iter += 1 - else: - x = torch.stack(xs).sum(0) - return x - - def multi_scan(self, x): - """ - Input @x: shape [B, L, D] - """ - xs = [] - for direction in self.choices: - xs.append(self.scan(x, direction)) - return xs - - def multi_reverse(self, xs): - new_xs = [] - for x, direction in zip(xs, self.choices): - new_xs.append(self.reverse(x, direction)) - return new_xs - - def scan(self, x, direction='h'): - """ - Input @x: shape [B, L, D] - Return torch.Tensor: shape [B, L, D] - """ - B, L, D = x.shape - H, W = self.token_size - if direction == 'h': - return x - elif direction == 'h_flip': - return x.flip([1]) - elif direction == 'v': - return x.view(B, H, W, D).transpose(1, 2).reshape(B, L, D) - elif direction == 'v_flip': - return x.view(B, H, W, D).transpose(1, 2).reshape(B, L, D).flip([1]) - elif direction == 'w2': - return direct_tokens(x, w=2, H=H, W=W, w_first=False) - elif direction == 'w2_flip': - return direct_tokens(x, w=2, H=H, W=W, w_first=False).flip([1]) - elif direction == 'w7': - return direct_tokens(x, w=7, H=H, W=W, w_first=False) - elif direction == 'w7_flip': - return direct_tokens(x, w=7, H=H, W=W, w_first=False).flip([1]) - else: - raise RuntimeError(f'Direction {direction} not found.') - - def reverse(self, x, direction='h'): - """ - Input @x: shape [B, L, D] - Return torch.Tensor: shape [B, L, D] - """ - B, L, D = x.shape - H, W = self.token_size - if direction == 'h': - return x - elif direction == 'h_flip': - return x.flip([1]) - elif direction == 'v': - return x.view(B, W, H, D).transpose(1, 2).reshape(B, L, D) - elif direction == 'v_flip': - return x.flip([1]).view(B, W, H, D).transpose(1, 2).reshape(B, L, D) - elif direction == 'w2': - return reverse_tokens(x, w=2, H=H, W=W, w_first=False) - elif direction == 'w2_flip': - return reverse_tokens(x.flip([1]), w=2, H=H, W=W, w_first=False) - elif direction == 'w7': - return reverse_tokens(x, w=7, H=H, W=W, w_first=False) - elif direction == 'w7_flip': - return reverse_tokens(x.flip([1]), w=7, H=H, W=W, w_first=False) - else: - raise RuntimeError(f'Direction {direction} not found.') - - -class WindowScan(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor, window_size=2, w_first=False): - B, L, C = x.shape - H = W = int(L ** 0.5) - ctx.shape = (B, C, H, W) - ctx.window_size = window_size - ctx.w_first = w_first - - # H = W = int(L ** 0.5) - # Hg = Wg = H // w - # if w_first: - # x = x.view(B, Wg, Hg, w, w, C).permute(0, 2, 4, 1, 3, 5).reshape(B, L, C) - # else: - # x = x.view(B, Hg, Wg, w, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, L, C) - return direct_tokens(x, window_size, w_first) - - @staticmethod - def backward(ctx, grad: torch.Tensor): - return reverse_tokens(grad, ctx.window_size, ctx.w_first), None, None - # out: (b, k, d, l) - B, C, H, W = ctx.shape - L = H * W - ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) - y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) - return y.view(B, -1, H, W) - - - -class BiAttn(nn.Module): - def __init__(self, in_channels, act_ratio=0.125, act_fn=nn.GELU, gate_fn=nn.Sigmoid): - super().__init__() - reduce_channels = int(in_channels * act_ratio) - self.norm = nn.LayerNorm(in_channels) - self.global_reduce = nn.Linear(in_channels, reduce_channels) - # self.local_reduce = nn.Linear(in_channels, reduce_channels) - self.act_fn = act_fn() - self.channel_select = nn.Linear(reduce_channels, in_channels) - # self.spatial_select = nn.Linear(reduce_channels * 2, 1) - self.gate_fn = gate_fn() - - def forward(self, x): - ori_x = x - x = self.norm(x) - x_global = x.mean(1, keepdim=True) - x_global = self.act_fn(self.global_reduce(x_global)) - # x_local = self.act_fn(self.local_reduce(x)) - - c_attn = self.channel_select(x_global) - c_attn = self.gate_fn(c_attn) # [B, 1, C] - # s_attn = self.spatial_select(torch.cat([x_local, x_global.expand(-1, x.shape[1], -1)], dim=-1)) - # s_attn = self.gate_fn(s_attn) # [B, N, 1] - - attn = c_attn #* s_attn # [B, N, C] - return ori_x * attn - - -is_first = True -class DynMamba(nn.Module): - def __init__( - self, - d_model, - d_state=16, - d_conv=4, - expand=2, - dt_rank="auto", - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - conv_bias=True, - bias=False, - use_fast_path=True, # Fused kernel options - layer_idx=None, - device=None, - dtype=None, - bimamba_type="none", - directions=None, - token_size=(14, 14), - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = int(self.expand * self.d_model) - self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank - self.use_fast_path = use_fast_path - self.layer_idx = layer_idx - self.bimamba_type = bimamba_type - self.token_size = token_size - - self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) - - # self.conv1d = nn.Conv1d( - # in_channels=self.d_inner, - # out_channels=self.d_inner, - # bias=conv_bias, - # kernel_size=d_conv, - # groups=self.d_inner, - # padding=d_conv - 1, - # **factory_kwargs, - # ) - - self.activation = "silu" - self.act = nn.SiLU() - - # self.x_proj = nn.Linear( - # self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - # ) - # self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - # # Initialize special dt projection to preserve variance at initialization - # dt_init_std = self.dt_rank**-0.5 * dt_scale - # if dt_init == "constant": - # nn.init.constant_(self.dt_proj.weight, dt_init_std) - # elif dt_init == "random": - # nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) - # else: - # raise NotImplementedError - - # # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - # dt = torch.exp( - # torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) - # + math.log(dt_min) - # ).clamp(min=dt_init_floor) - # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - # inv_dt = dt + torch.log(-torch.expm1(-dt)) - # with torch.no_grad(): - # self.dt_proj.bias.copy_(inv_dt) - # # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit - # self.dt_proj.bias._no_reinit = True - - - self.multi_scan = MultiScan(self.d_inner, choices=directions, token_size=token_size) - '''new for search''' - A = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - for i in range(len(self.multi_scan.choices)): - setattr(self, f'A_log_{i}', nn.Parameter(A_log)) - getattr(self, f'A_log_{i}')._no_weight_decay = True - - conv1d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - setattr(self, f'conv1d_{i}', conv1d) - - x_proj = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - setattr(self, f'x_proj_{i}', x_proj) - - dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - # Initialize special dt projection to preserve variance at initialization - dt_init_std = self.dt_rank**-0.5 * dt_scale - if dt_init == "constant": - nn.init.constant_(dt_proj.weight, dt_init_std) - elif dt_init == "random": - nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt = torch.exp( - torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ).clamp(min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - with torch.no_grad(): - dt_proj.bias.copy_(inv_dt) - # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit - dt_proj.bias._no_reinit = True - - setattr(self, f'dt_proj_{i}', dt_proj) - - D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - D._no_weight_decay = True - setattr(self, f'D_{i}', D) - - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - - self.attn = BiAttn(self.d_inner) - - return - - # S4D real initialization - A = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D._no_weight_decay = True - - # bidirectional - assert bimamba_type == "v2" - - A_b = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_b_log = torch.log(A_b) # Keep A_b_log in fp32 - self.A_b_log = nn.Parameter(A_b_log) - self.A_b_log._no_weight_decay = True - - self.conv1d_b = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_b = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_b._no_weight_decay = True - - - - '''c''' - A_c = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_c_log = torch.log(A_c) # Keep A_b_log in fp32 - self.A_c_log = nn.Parameter(A_c_log) - self.A_c_log._no_weight_decay = True - - self.conv1d_c = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_c = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_c = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_c = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_c._no_weight_decay = True - - - '''d''' - A_d = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_d_log = torch.log(A_d) # Keep A_b_log in fp32 - self.A_d_log = nn.Parameter(A_d_log) - self.A_d_log._no_weight_decay = True - - self.conv1d_d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_d = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_d = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_d = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_d._no_weight_decay = True - - - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - - self.attn = BiAttn(self.d_inner) - - # self.dyn_scan_a = DynamicScan(self.d_inner * 2) - # self.dyn_scan_b = DynamicScan(self.d_inner * 2) - - def forward(self, hidden_states, inference_params=None): - """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states - """ - batch, seqlen, dim = hidden_states.shape - - conv_state, ssm_state = None, None - if inference_params is not None: - conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: - # The states are updated inplace - out, _, _ = self.step(hidden_states, conv_state, ssm_state) - return out - - - xz = self.in_proj(hidden_states) - - # A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - # In the backward pass we write dx and dz next to each other to avoid torch.cat - if self.use_fast_path and inference_params is None: # Doesn't support outputting the states - xs = self.multi_scan.multi_scan(xz) - outs = [] - for i, xz in enumerate(xs): - xz = rearrange(xz, "b l d -> b d l") - A = -torch.exp(getattr(self, f'A_log_{i}').float()) - conv1d = getattr(self, f'conv1d_{i}') - x_proj = getattr(self, f'x_proj_{i}') - dt_proj = getattr(self, f'dt_proj_{i}') - D = getattr(self, f'D_{i}') - - out = mamba_inner_fn_no_out_proj( - xz, - conv1d.weight, - conv1d.bias, - x_proj.weight, - dt_proj.weight, - A, - None, # input-dependent B - None, # input-dependent C - D.float(), - delta_bias=dt_proj.bias.float(), - delta_softplus=True, - ) - outs.append(rearrange(out, "b d l -> b l d")) - - outs = self.multi_scan.multi_reverse(outs) - outs = [self.attn(out) for out in outs] - out = self.multi_scan(outs) - out = F.linear(out, self.out_proj.weight, self.out_proj.bias) - - return out - - if self.bimamba_type == "v2": - A_b = -torch.exp(self.A_b_log.float()) - A_c = -torch.exp(self.A_c_log.float()) - A_d = -torch.exp(self.A_d_log.float()) - - xz_w2 = direct_tokens(xz, 2) - # attn_a, xz_a = self.dyn_scan_a(xz_w2) - xz_a = rearrange(xz_w2, "b l d -> b d l") - # xz_b = rearrange(direct_tokens(xz, 2, True), "b l d -> b d l") - # attn_b, xz_b = self.dyn_scan_b(xz_w2) - # xz_b = rearrange(xz_b, "b l d -> b d l") - - xz_b = xz_a.flip([-1]) - xz_c = rearrange(xz, "b l d -> b d l") - xz_d = xz_c.flip([-1]) - # xz_d = rearrange(direct_tokens(xz, 14, True), "b l d -> b d l") - - - out = mamba_inner_fn_no_out_proj( - xz_a, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - # print(out.shape) - out_b = mamba_inner_fn_no_out_proj( - xz_b, - self.conv1d_b.weight, - self.conv1d_b.bias, - self.x_proj_b.weight, - self.dt_proj_b.weight, - A_b, - None, - None, - self.D_b.float(), - delta_bias=self.dt_proj_b.bias.float(), - delta_softplus=True, - ) - out_c = mamba_inner_fn_no_out_proj( - xz_c, - self.conv1d_c.weight, - self.conv1d_c.bias, - self.x_proj_c.weight, - self.dt_proj_c.weight, - A_c, - None, - None, - self.D_c.float(), - delta_bias=self.dt_proj_c.bias.float(), - delta_softplus=True, - ) - out_d = mamba_inner_fn_no_out_proj( - xz_d, - self.conv1d_d.weight, - self.conv1d_d.bias, - self.x_proj_d.weight, - self.dt_proj_d.weight, - A_d, - None, - None, - self.D_d.float(), - delta_bias=self.dt_proj_d.bias.float(), - delta_softplus=True, - ) - - # out = F.linear(rearrange(out + out_b, "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) - # out = rearrange(out, "b d l -> b l 1 d") * rearrange(probs_a, "b l k -> b l k 1") # [b l k d] - # out = out.transpose(1, 2).sum(2) # [b k d] - # out_b = rearrange(out_b, "b d l -> b l 1 d") * rearrange(probs_b, "b l k -> b l k 1") # [b l k d] - # out_b = out_b.transpose(1, 2).sum(2) # [b k d] - # out = probs_a.transpose(-2, -1) @ rearrange(out, "b d l -> b l d") # [b l d] - # out_b = probs_b.transpose(-2, -1) @ rearrange(out_b, "b d l -> b l d") # [b l d] - out = rearrange(out, "b d l -> b l d") # [b l d] - out_b = rearrange(out_b.flip([-1]), "b d l -> b l d") # [b l d] - # out = self.dyn_scan_a.reverse(out, attn_a) - out = reverse_tokens(out, 2) - # out_b = self.dyn_scan_b.reverse(out_b, attn_b) - out_b = reverse_tokens(out_b, 2) - out_c = rearrange(out_c, "b d l -> b l d") # [b l d] - out_d = rearrange(out_d.flip([-1]), "b d l -> b l d") # [b l d] - - out = self.attn(out) - out_b = self.attn(out_b) - out_c = self.attn(out_c) - out_d = self.attn(out_d) - - # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) - # out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) - out = F.linear(out + out_b + out_c + out_d, self.out_proj.weight, self.out_proj.bias) - # out = F.linear(out + out_b, self.out_proj.weight, self.out_proj.bias) - else: - out = mamba_inner_fn( - xz, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - self.out_proj.weight, - self.out_proj.bias, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - else: - x, z = xz.chunk(2, dim=1) - # Compute short convolution - if conv_state is not None: - conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W) - if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) - else: - assert self.activation in ["silu", "swish"] - x = causal_conv1d_fn( - x, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - # We're careful here about the layout, to avoid extra transposes. - # We want dt to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) - dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = self.dt_proj.weight @ dt.t() - dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) - B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - assert self.activation in ["silu", "swish"] - y = selective_scan_fn( - x, - dt, - A, - B, - C, - self.D.float(), - z=z, - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - return_last_state=ssm_state is not None, - ) - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(last_state) - y = rearrange(y, "b d l -> b l d") - out = self.out_proj(y) - return out - - def step(self, hidden_states, conv_state, ssm_state): - dtype = hidden_states.dtype - assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" - xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) - x, z = xz.chunk(2, dim=-1) # (B D) - - # Conv step - if causal_conv1d_update is None: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = x - x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv1d.bias is not None: - x = x + self.conv1d.bias - x = self.act(x).to(dtype=dtype) - else: - x = causal_conv1d_update( - x, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - x_db = self.x_proj(x) # (B dt_rank+2*d_state) - dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) - # Don't add dt_bias here - dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - # SSM step - if selective_state_update is None: - # Discretize A and B - dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) - dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) - dB = torch.einsum("bd,bn->bdn", dt, B) - ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) - y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) - y = y + self.D.to(dtype) * x - y = y * self.act(z) # (B D) - else: - y = selective_state_update( - ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True - ) - - out = self.out_proj(y) - return out.unsqueeze(1), conv_state, ssm_state - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - device = self.out_proj.weight.device - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype - ) - ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype - # ssm_dtype = torch.float32 - ssm_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype - ) - return conv_state, ssm_state - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - assert self.layer_idx is not None - if self.layer_idx not in inference_params.key_value_memory_dict: - batch_shape = (batch_size,) - conv_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_conv, - device=self.conv1d.weight.device, - dtype=self.conv1d.weight.dtype, - ) - ssm_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_state, - device=self.dt_proj.weight.device, - dtype=self.dt_proj.weight.dtype, - # dtype=torch.float32, - ) - inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) - else: - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - # TODO: What if batch size changes between generation, and we reuse the same states? - if initialize_states: - conv_state.zero_() - ssm_state.zero_() - return conv_state, ssm_state - - -class Block(nn.Module): - def __init__( - self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False - ): - """ - Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" - - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA/MLP -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Add -> LN -> Mixer, returning both - the hidden_states (output of the mixer) and the residual. - This is purely for performance reasons, as we can fuse add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - super().__init__() - self.residual_in_fp32 = residual_in_fp32 - self.fused_add_norm = fused_add_norm - self.mixer = mixer_cls(dim) - self.norm = norm_cls(dim) - if self.fused_add_norm: - assert RMSNorm is not None, "RMSNorm import fails" - assert isinstance( - self.norm, (nn.LayerNorm, RMSNorm) - ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" - - def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None - ): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Mixer(LN(residual)) - """ - if not self.fused_add_norm: - residual = (hidden_states + residual) if residual is not None else hidden_states - hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn - hidden_states, residual = fused_add_norm_fn( - hidden_states, - self.norm.weight, - self.norm.bias, - residual=residual, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - eps=self.norm.eps, - ) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) - return hidden_states, residual - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) \ No newline at end of file diff --git a/classification/lib/models/lightvit.py b/classification/lib/models/lightvit.py deleted file mode 100644 index 5bcd793..0000000 --- a/classification/lib/models/lightvit.py +++ /dev/null @@ -1,513 +0,0 @@ -import math -import torch -import torch.nn as nn -from functools import partial - -from timm.models.layers import DropPath, trunc_normal_, lecun_normal_ -from timm.models.registry import register_model - - -class ConvStem(nn.Module): - """ Image to Patch Embedding - """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): - super().__init__() - self.patch_size = patch_size - stem_dim = embed_dim // 2 - self.stem = nn.Sequential( - nn.Conv2d(in_chans, stem_dim, kernel_size=3, - stride=2, padding=1, bias=False), - nn.BatchNorm2d(stem_dim), - nn.GELU(), - nn.Conv2d(stem_dim, stem_dim, kernel_size=3, - groups=stem_dim, stride=1, padding=1, bias=False), - nn.BatchNorm2d(stem_dim), - nn.GELU(), - nn.Conv2d(stem_dim, stem_dim, kernel_size=3, - groups=stem_dim, stride=1, padding=1, bias=False), - nn.BatchNorm2d(stem_dim), - nn.GELU(), - nn.Conv2d(stem_dim, stem_dim, kernel_size=3, - groups=stem_dim, stride=2, padding=1, bias=False), - nn.BatchNorm2d(stem_dim), - nn.GELU(), - ) - self.proj = nn.Conv2d(stem_dim, embed_dim, - kernel_size=3, - stride=2, padding=1) - self.norm = nn.LayerNorm(embed_dim) - - def forward(self, x): - x = self.proj(self.stem(x)) - _, _, H, W = x.shape - x = x.flatten(2).transpose(1, 2) - x = self.norm(x) - return x, (H, W) - - -class BiAttn(nn.Module): - def __init__(self, in_channels, act_ratio=0.25, act_fn=nn.GELU, gate_fn=nn.Sigmoid): - super().__init__() - reduce_channels = int(in_channels * act_ratio) - self.norm = nn.LayerNorm(in_channels) - self.global_reduce = nn.Linear(in_channels, reduce_channels) - self.local_reduce = nn.Linear(in_channels, reduce_channels) - self.act_fn = act_fn() - self.channel_select = nn.Linear(reduce_channels, in_channels) - self.spatial_select = nn.Linear(reduce_channels * 2, 1) - self.gate_fn = gate_fn() - - def forward(self, x): - ori_x = x - x = self.norm(x) - x_global = x.mean(1, keepdim=True) - x_global = self.act_fn(self.global_reduce(x_global)) - x_local = self.act_fn(self.local_reduce(x)) - - c_attn = self.channel_select(x_global) - c_attn = self.gate_fn(c_attn) # [B, 1, C] - s_attn = self.spatial_select(torch.cat([x_local, x_global.expand(-1, x.shape[1], -1)], dim=-1)) - s_attn = self.gate_fn(s_attn) # [B, N, 1] - - attn = c_attn * s_attn # [B, N, C] - return ori_x * attn - - -class BiAttnMlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.attn = BiAttn(out_features) - self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity() - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.attn(x) - x = self.drop(x) - return x - - -def window_reverse( - windows: torch.Tensor, - original_size, - window_size=(7, 7) -) -> torch.Tensor: - """ Reverses the window partition. - Args: - windows (torch.Tensor): Window tensor of the shape [B * windows, window_size[0] * window_size[1], C]. - original_size (Tuple[int, int]): Original shape. - window_size (Tuple[int, int], optional): Window size which have been applied. Default (7, 7) - Returns: - output (torch.Tensor): Folded output tensor of the shape [B, original_size[0] * original_size[1], C]. - """ - # Get height and width - H, W = original_size - # Compute original batch size - B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) - # Fold grid tensor - output = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) - output = output.permute(0, 1, 3, 2, 4, 5).reshape(B, H * W, -1) - return output - - -def get_relative_position_index( - win_h: int, - win_w: int -) -> torch.Tensor: - """ Function to generate pair-wise relative position index for each token inside the window. - Taken from Timms Swin V1 implementation. - Args: - win_h (int): Window/Grid height. - win_w (int): Window/Grid width. - Returns: - relative_coords (torch.Tensor): Pair-wise relative position indexes [height * width, height * width]. - """ - coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) - coords_flatten = torch.flatten(coords, 1) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] - relative_coords = relative_coords.permute(1, 2, 0).contiguous() - relative_coords[:, :, 0] += win_h - 1 - relative_coords[:, :, 1] += win_w - 1 - relative_coords[:, :, 0] *= 2 * win_w - 1 - return relative_coords.sum(-1) - - -class LightViTAttention(nn.Module): - def __init__(self, dim, num_tokens=1, num_heads=8, window_size=7, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.num_tokens = num_tokens - self.window_size = window_size - self.attn_area = window_size * window_size - self.scale = qk_scale or head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.kv_global = nn.Linear(dim, dim * 2, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0 else nn.Identity() - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0 else nn.Identity() - - # Define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)) - - # Get pair-wise relative position index for each token inside the window - self.register_buffer("relative_position_index", get_relative_position_index(window_size, - window_size).view(-1)) - # Init relative positional bias - trunc_normal_(self.relative_position_bias_table, std=.02) - - def _get_relative_positional_bias( - self - ) -> torch.Tensor: - """ Returns the relative positional bias. - Returns: - relative_position_bias (torch.Tensor): Relative positional bias. - """ - relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index].view(self.attn_area, self.attn_area, -1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() - return relative_position_bias.unsqueeze(0) - - def forward_global_aggregation(self, q, k, v): - """ - q: global tokens - k: image tokens - v: image tokens - """ - B, _, N, _ = q.shape - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, -1) - return x - - def forward_local(self, q, k, v, H, W): - """ - q: image tokens - k: image tokens - v: image tokens - """ - B, num_heads, N, C = q.shape - ws = self.window_size - h_group, w_group = H // ws, W // ws - - # partition to windows - q = q.view(B, num_heads, h_group, ws, w_group, ws, -1).permute(0, 2, 4, 1, 3, 5, 6).contiguous() - q = q.view(-1, num_heads, ws*ws, C) - k = k.view(B, num_heads, h_group, ws, w_group, ws, -1).permute(0, 2, 4, 1, 3, 5, 6).contiguous() - k = k.view(-1, num_heads, ws*ws, C) - v = v.view(B, num_heads, h_group, ws, w_group, ws, -1).permute(0, 2, 4, 1, 3, 5, 6).contiguous() - v = v.view(-1, num_heads, ws*ws, v.shape[-1]) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - pos_bias = self._get_relative_positional_bias() - attn = (attn + pos_bias).softmax(dim=-1) - attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(v.shape[0], ws*ws, -1) - - # reverse - x = window_reverse(x, (H, W), (ws, ws)) - return x - - def forward_global_broadcast(self, q, k, v): - """ - q: image tokens - k: global tokens - v: global tokens - """ - B, num_heads, N, _ = q.shape - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, -1) - return x - - def forward(self, x, H, W): - B, N, C = x.shape - NT = self.num_tokens - # qkv - qkv = self.qkv(x) - q, k, v = qkv.view(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).unbind(0) - - # split img tokens & global tokens - q_img, k_img, v_img = q[:, :, NT:], k[:, :, NT:], v[:, :, NT:] - q_glb, _, _ = q[:, :, :NT], k[:, :, :NT], v[:, :, :NT] - - # local window attention - x_img = self.forward_local(q_img, k_img, v_img, H, W) - - # global aggregation - x_glb = self.forward_global_aggregation(q_glb, k_img, v_img) - - # global broadcast - k_glb, v_glb = self.kv_global(x_glb).view(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).unbind(0) - - x_img = x_img + self.forward_global_broadcast(q_img, k_glb, v_glb) - x = torch.cat([x_glb, x_img], dim=1) - x = self.proj(x) - return x - - -class Block(nn.Module): - - def __init__(self, dim, num_heads, num_tokens=1, window_size=7, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attention=LightViTAttention): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = attention(dim, num_heads=num_heads, num_tokens=num_tokens, window_size=window_size, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = BiAttnMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - def forward(self, x, H, W): - x = x + self.drop_path(self.attn(self.norm1(x), H, W)) - x = x + self.drop_path(self.mlp(self.norm2(x))) - return x - - -class ResidualMergePatch(nn.Module): - def __init__(self, dim, out_dim, num_tokens=1): - super().__init__() - self.num_tokens = num_tokens - self.norm = nn.LayerNorm(4 * dim) - self.reduction = nn.Linear(4 * dim, out_dim, bias=False) - self.norm2 = nn.LayerNorm(dim) - self.proj = nn.Linear(dim, out_dim, bias=False) - # use MaxPool3d to avoid permutations - self.maxp = nn.MaxPool3d((2, 2, 1), (2, 2, 1)) - self.res_proj = nn.Linear(dim, out_dim, bias=False) - - def forward(self, x, H, W): - global_token, x = x[:, :self.num_tokens].contiguous(), x[:, self.num_tokens:].contiguous() - B, L, C = x.shape - - x = x.view(B, H, W, C) - res = self.res_proj(self.maxp(x).view(B, -1, C)) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - x = x + res - global_token = self.proj(self.norm2(global_token)) - x = torch.cat([global_token, x], 1) - return x, (H // 2, W // 2) - - -class LightViT(nn.Module): - - def __init__(self, img_size=224, patch_size=8, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256], num_layers=[2, 6, 6], - num_heads=[2, 4, 8], mlp_ratios=[8, 4, 4], num_tokens=8, window_size=7, neck_dim=1280, qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=ConvStem, norm_layer=None, - act_layer=None, weight_init=''): - super().__init__() - self.num_classes = num_classes - self.embed_dims = embed_dims - self.num_tokens = num_tokens - self.mlp_ratios = mlp_ratios - self.patch_size = patch_size - self.num_layers = num_layers - self.window_size = window_size - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - act_layer = act_layer or nn.GELU - - self.patch_embed = embed_layer( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) - - self.global_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dims[0])) - - stages = [] - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(num_layers))] # stochastic depth decay rule - for stage, (embed_dim, num_layer, num_head, mlp_ratio) in enumerate(zip(embed_dims, num_layers, num_heads, mlp_ratios)): - blocks = [] - if stage > 0: - # downsample - blocks.append(ResidualMergePatch(embed_dims[stage-1], embed_dim, num_tokens=num_tokens)) - blocks += [ - Block( - dim=embed_dim, num_heads=num_head, num_tokens=num_tokens, window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, - attn_drop=attn_drop_rate, drop_path=dpr[sum(num_layers[:stage]) + i], norm_layer=norm_layer, act_layer=act_layer, attention=LightViTAttention) - for i in range(num_layer) - ] - blocks = nn.Sequential(*blocks) - stages.append(blocks) - self.stages = nn.Sequential(*stages) - - self.norm = norm_layer(embed_dim) - - self.neck = nn.Sequential( - nn.Linear(embed_dim, neck_dim), - nn.LayerNorm(neck_dim), - nn.GELU() - ) - - self.head = nn.Linear(neck_dim, num_classes) if num_classes > 0 else nn.Identity() - self.init_weights(weight_init) - - def init_weights(self, mode=''): - assert mode in ('jax', 'jax_nlhb', 'nlhb', '') - head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. - if mode.startswith('jax'): - # leave cls token as zeros to match jax impl - named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) - else: - trunc_normal_(self.global_token, std=.02) - self.apply(_init_vit_weights) - - def _init_weights(self, m): - # this fn left here for compat with downstream users - _init_vit_weights(m) - - @torch.jit.ignore - def no_weight_decay(self): - return {'global_token', '[g]relative_position_bias_table'} - - def forward_features(self, x): - x, (H, W) = self.patch_embed(x) - global_token = self.global_token.expand(x.shape[0], -1, -1) - x = torch.cat((global_token, x), dim=1) - for stage in self.stages: - for block in stage: - if isinstance(block, ResidualMergePatch): - x, (H, W) = block(x, H, W) - elif isinstance(block, Block): - x = block(x, H, W) - else: - x = block(x) - x = self.norm(x) - x = self.neck(x) - return x.mean(1) - - def forward(self, x): - x = self.forward_features(x) - x = self.head(x) - return x - - def flops(self, input_shape=(3, 224, 224)): - flops = 0 - ws = self.window_size - # stem - from lib.utils.measure import get_flops - flops += get_flops(self.patch_embed, input_shape) - H = input_shape[1] // self.patch_size - W = input_shape[2] // self.patch_size - N = self.num_tokens + H * W - # blocks - for stage in range(len(self.stages)): - embed_dim = self.embed_dims[stage] - if stage > 0: - # merge patch - # mp - reduction - flops += (H // 2) * (W // 2) * self.embed_dims[stage-1] * (4 * embed_dim) - # mp - residual - flops += (H // 2) * (W // 2) * self.embed_dims[stage-1] * embed_dim - # mp - cls proj - flops += self.num_tokens * self.embed_dims[stage-1] * embed_dim - H, W = H // 2, W // 2 - N = H * W + self.num_tokens - - for i in range(self.num_layers[stage]): - # attn - qkv (img & glb) - flops += N * embed_dim * embed_dim * 3 - # local window self-attn - flops += (H // ws) * (W // ws) * (ws * ws) * embed_dim * 2 - # global aggregation - flops += (H * W) * self.num_tokens * embed_dim * 2 - # global broadcast - flops += (H * W) * self.num_tokens * embed_dim * 2 - # attn - proj - flops += N * embed_dim * embed_dim - - # FFN - mlp - flops += (N * embed_dim * (embed_dim * self.mlp_ratios[stage])) * 2 - # FFN - biattn - attn_ratio = 0.25 - # c attn - flops += embed_dim * embed_dim * attn_ratio * 2 - # s attn - flops += N * embed_dim * embed_dim * attn_ratio + N * embed_dim * attn_ratio * 2 * 1 - # dot product - flops += N * embed_dim - # neck - neck_dim = self.neck[0].out_features - flops += N * embed_dim * neck_dim - # head - flops += neck_dim * 1000 - return flops - - -def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): - """ ViT weight initialization - * When called without n, head_bias, jax_impl args it will behave exactly the same - as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). - * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl - """ - if isinstance(module, nn.Linear): - if name.startswith('head'): - nn.init.zeros_(module.weight) - nn.init.constant_(module.bias, head_bias) - elif name.startswith('pre_logits'): - lecun_normal_(module.weight) - nn.init.zeros_(module.bias) - else: - if jax_impl: - nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - if 'mlp' in name: - nn.init.normal_(module.bias, std=1e-6) - else: - nn.init.zeros_(module.bias) - else: - trunc_normal_(module.weight, std=.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif jax_impl and isinstance(module, nn.Conv2d): - # NOTE conv was left to pytorch default in my original init - lecun_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): - nn.init.zeros_(module.bias) - nn.init.ones_(module.weight) - - -@register_model -def lightvit_tiny(pretrained=False, **kwargs): - model_kwargs = dict(patch_size=8, embed_dims=[64, 128, 256], num_layers=[2, 6, 6], - num_heads=[2, 4, 8, ], mlp_ratios=[8, 4, 4], num_tokens=8, **kwargs) - model = LightViT(**model_kwargs) - return model - - -@register_model -def lightvit_small(pretrained=False, **kwargs): - model_kwargs = dict(patch_size=8, embed_dims=[96, 192, 384], num_layers=[2, 6, 6], - num_heads=[3, 6, 12, ], mlp_ratios=[8, 4, 4], num_tokens=16, **kwargs) - model = LightViT(**model_kwargs) - return model - - -@register_model -def lightvit_base(pretrained=False, **kwargs): - model_kwargs = dict(patch_size=8, embed_dims=[128, 256, 512], num_layers=[3, 8, 6], - num_heads=[4, 8, 16, ], mlp_ratios=[8, 4, 4], num_tokens=24, **kwargs) - model = LightViT(**model_kwargs) - return model diff --git a/classification/lib/models/mdconv.py b/classification/lib/models/mdconv.py deleted file mode 100644 index 0b39e4c..0000000 --- a/classification/lib/models/mdconv.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -import numpy as np - - -def split_layer(total_channels, num_groups): - split = [int(np.ceil(total_channels / num_groups)) for _ in range(num_groups)] - split[num_groups - 1] += total_channels - sum(split) - return split - - -class DepthwiseConv2D(nn.Module): - def __init__(self, in_channels, kernal_size, stride, bias=False): - super(DepthwiseConv2D, self).__init__() - padding = (kernal_size - 1) // 2 - - self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=kernal_size, padding=padding, stride=stride, groups=in_channels, bias=bias) - - def forward(self, x): - out = self.depthwise_conv(x) - return out - - -class GroupConv2D(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=1, n_chunks=1, bias=False): - super(GroupConv2D, self).__init__() - self.n_chunks = n_chunks - self.split_in_channels = split_layer(in_channels, n_chunks) - split_out_channels = split_layer(out_channels, n_chunks) - - if n_chunks == 1: - self.group_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias) - else: - self.group_layers = nn.ModuleList() - for idx in range(n_chunks): - self.group_layers.append(nn.Conv2d(self.split_in_channels[idx], split_out_channels[idx], kernel_size=kernel_size, bias=bias)) - - def forward(self, x): - if self.n_chunks == 1: - return self.group_conv(x) - else: - split = torch.split(x, self.split_in_channels, dim=1) - out = torch.cat([layer(s) for layer, s in zip(self.group_layers, split)], dim=1) - return out - - -class MDConv(nn.Module): - def __init__(self, out_channels, n_chunks, stride=1, bias=False): - super(MDConv, self).__init__() - self.n_chunks = n_chunks - self.split_out_channels = split_layer(out_channels, n_chunks) - - self.layers = nn.ModuleList() - for idx in range(self.n_chunks): - kernel_size = 2 * idx + 3 - self.layers.append(DepthwiseConv2D(self.split_out_channels[idx], kernal_size=kernel_size, stride=stride, bias=bias)) - - def forward(self, x): - split = torch.split(x, self.split_out_channels, dim=1) - out = torch.cat([layer(s) for layer, s in zip(self.layers, split)], dim=1) - return out - - -# temp = torch.randn((16, 3, 32, 32)) -# group = GroupConv2D(3, 16, n_chunks=2) -# print(group(temp).size()) diff --git a/classification/lib/models/mobilenet_v1.py b/classification/lib/models/mobilenet_v1.py deleted file mode 100644 index e9a6bf2..0000000 --- a/classification/lib/models/mobilenet_v1.py +++ /dev/null @@ -1,73 +0,0 @@ -import math -import torch.nn as nn - - -def _initialize_weight_goog(m): - # weight init as per Tensorflow Official impl - # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # fan-out - m.weight.data.normal_(0, math.sqrt(2.0 / n)) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1.0) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - n = m.weight.size(0) # fan-out - init_range = 1.0 / math.sqrt(n) - m.weight.data.uniform_(-init_range, init_range) - m.bias.data.zero_() - - -class MobileNetV1(nn.Module): - def __init__(self, ch_in=3, num_classes=1000): - super(MobileNetV1, self).__init__() - - def conv_bn(inp, oup, stride): - return nn.Sequential( - nn.Conv2d(inp, oup, 3, stride, 1, bias=False), - nn.BatchNorm2d(oup), - nn.ReLU(inplace=True) - ) - - def conv_dw(inp, oup, stride): - return nn.Sequential( - # dw - nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), - nn.BatchNorm2d(inp), - nn.ReLU(inplace=True), - - # pw - nn.Conv2d(inp, oup, 1, 1, 0, bias=False), - nn.BatchNorm2d(oup), - nn.ReLU(inplace=True), - ) - - self.model = nn.Sequential( - conv_bn(ch_in, 32, 2), - conv_dw(32, 64, 1), - conv_dw(64, 128, 2), - conv_dw(128, 128, 1), - conv_dw(128, 256, 2), - conv_dw(256, 256, 1), - conv_dw(256, 512, 2), - conv_dw(512, 512, 1), - conv_dw(512, 512, 1), - conv_dw(512, 512, 1), - conv_dw(512, 512, 1), - conv_dw(512, 512, 1), - conv_dw(512, 1024, 2), - conv_dw(1024, 1024, 1), - nn.AdaptiveAvgPool2d(1) - ) - self.fc = nn.Linear(1024, num_classes) - - for m in self.modules(): - _initialize_weight_goog(m) - - def forward(self, x): - x = self.model(x) - x = x.view(-1, 1024) - x = self.fc(x) - return x diff --git a/classification/lib/models/nas_model.py b/classification/lib/models/nas_model.py deleted file mode 100644 index 14cdc78..0000000 --- a/classification/lib/models/nas_model.py +++ /dev/null @@ -1,130 +0,0 @@ -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -from .operations import OPS, AuxiliaryHead - - -def _initialize_weight_goog(m): - # weight init as per Tensorflow Official impl - # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # fan-out - m.weight.data.normal_(0, math.sqrt(2.0 / n)) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1.0) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - n = m.weight.size(0) # fan-out - init_range = 1.0 / math.sqrt(n) - m.weight.data.uniform_(-init_range, init_range) - m.bias.data.zero_() - - -def _initialize_weight_default(m): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1.0) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') - - -class NASModel(nn.Module): - def __init__(self, net_cfg, weight_init='goog', drop_rate=0.2, drop_path_rate=0.0, auxiliary_head=False, **kwargs): - super(NASModel, self).__init__() - self.drop_rate = drop_rate - self.drop_path_rate = drop_path_rate - if self.drop_path_rate != 0.: - raise NotImplementedError('Drop path is not implemented in NAS model.') - - backbone_cfg = net_cfg.pop('backbone') - self.features = nn.Sequential() - downsample_num = 0 - for layer in backbone_cfg: - if len(backbone_cfg[layer]) == 5: - stride, inp, oup, t, op = backbone_cfg[layer] - n = 1 - kwargs = {} - elif len(backbone_cfg[layer]) == 6 and isinstance(backbone_cfg[layer][-1], dict): - stride, inp, oup, t, op, kwargs = backbone_cfg[layer] - n = 1 - elif len(backbone_cfg[layer]) == 6: - n, stride, inp, oup, t, op = backbone_cfg[layer] - kwargs = {} - elif len(backbone_cfg[layer]) == 7: - n, stride, inp, oup, t, op, kwargs = backbone_cfg[layer] - else: - raise RuntimeError(f'Invalid layer configuration: {backbone_cfg[layer]}') - - for idx in range(n): - layer_ = layer + f'_{idx}' if n > 1 else layer - if isinstance(t, (list, tuple)) or isinstance(op, (list, tuple)): - # NAS supernet - if not isinstance(t, (list, tuple)): - t = [t] - if not isinstance(op, (list, tuple)): - op = [op] - from edgenn.models import ListChoice - blocks = [] - for t_ in t: - for op_ in op: - if op_ == 'id': - # add it later - continue - blocks.append(OPS[op_](inp, oup, t_, stride, kwargs)) - if 'id' in op: - blocks.append(OPS['id'](inp, oup, 1, stride, kwargs)) - self.features.add_module(layer_, ListChoice(blocks)) - else: - if t is None: - t = 1 - self.features.add_module(layer_, OPS[op](inp, oup, t, stride, kwargs)) - if stride == 2: - downsample_num += 1 - if auxiliary_head and downsample_num == 5: - # auxiliary head added after the 5-th downsampling layer - object.__setattr__(self, 'module_to_auxiliary', self.features[-1]) - C_to_auxiliary = oup - inp = oup - stride = 1 - - # build head - head_cfg = net_cfg.pop('head') - self.classifier = nn.Sequential() - for layer in head_cfg: - self.classifier.add_module(layer, nn.Linear(head_cfg[layer]['dim_in'], head_cfg[layer]['dim_out'])) - - if auxiliary_head: - self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, 1000) - - # init weight - for m in self.modules(): - if weight_init == 'goog': - _initialize_weight_goog(m) - else: - _initialize_weight_default(m) - - def get_classifier(self): - return self.classifier - - def forward(self, x): - x = self.features(x) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) - x = x.view(x.size(0), -1) - return self.classifier(x) - - -def gen_nas_model(net_cfg, drop_rate=0.2, drop_path_rate=0.0, auxiliary_head=False, **kwargs): - model = NASModel( - net_cfg, - drop_rate=drop_rate, - drop_path_rate=drop_path_rate, - auxiliary_head=auxiliary_head - ) - return model - diff --git a/classification/lib/models/operations.py b/classification/lib/models/operations.py deleted file mode 100644 index fde1e69..0000000 --- a/classification/lib/models/operations.py +++ /dev/null @@ -1,605 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from collections import OrderedDict -from .mdconv import MDConv - - -OPS = OrderedDict() -OPS['id'] = lambda inp, oup, t, stride, kwargs: Identity(in_channels=inp, out_channels=oup, kernel_size=1, stride=stride, **kwargs) - -'''MixConv''' -OPS['ir_mix_se'] = lambda inp, oup, t, stride, kwargs: InvertedResidualMixConv(in_channels=inp, out_channels=oup, dw_kernel_size=3, - stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=0.25, se_gate_fn=HSigmoid, **kwargs) -OPS['ir_mix_nse'] = lambda inp, oup, t, stride, kwargs: InvertedResidualMixConv(in_channels=inp, out_channels=oup, dw_kernel_size=3, - stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=None, **kwargs) - -'''MobileNet V2 Inverted Residual''' -OPS['ir_3x3_se'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=3, - stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=0.25, se_gate_fn=HSigmoid, **kwargs) -OPS['ir_5x5_se'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=5, - stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=0.25, se_gate_fn=HSigmoid, **kwargs) -OPS['ir_7x7_se'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=7, - stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=0.25, se_gate_fn=HSigmoid, **kwargs) -OPS['ir_3x3_nse'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=3, - stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=None, **kwargs) -OPS['ir_5x5_nse'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=5, - stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=None, **kwargs) -OPS['ir_7x7_nse'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=7, - stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=None, **kwargs) -OPS['ir_3x3'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=3, - stride=stride, act_fn=nn.ReLU, expand_ratio=t, se_ratio=None, **kwargs) -OPS['ir_5x5'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=5, - stride=stride, act_fn=nn.ReLU, expand_ratio=t, se_ratio=None, **kwargs) -OPS['ir_7x7'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=7, - stride=stride, act_fn=nn.ReLU, expand_ratio=t, se_ratio=None, **kwargs) - -# assign ops with given expand ratios -class OpWrapper: - def __init__(self, t, op_func): - self.t = t - self.op_func = op_func - def __call__(self, inp, oup, t, stride, kwargs): - return self.op_func(inp, oup, self.t, stride, kwargs) - -_t = [1, 3, 6] -new_ops = {} -for op in OPS: - if 'ir' in op and 't' not in op: - for given_t in _t: - newop = op + f'_t{given_t}' - func = OpWrapper(given_t, OPS[op]) - new_ops[newop] = func #lambda inp, oup, t, stride, kwargs: OPS[op](inp, oup, given_t, stride, kwargs) -for op in new_ops: - OPS[op] = new_ops[op] - -OPS['conv1x1'] = lambda inp, oup, t, stride, kwargs: ConvBnAct(in_channels=inp, out_channels=oup, kernel_size=1, stride=stride, **kwargs) -OPS['conv3x3'] = lambda inp, oup, t, stride, kwargs: ConvBnAct(in_channels=inp, out_channels=oup, kernel_size=3, stride=stride, **kwargs) -OPS['gavgp'] = lambda inp, oup, t, stride, kwargs: nn.AdaptiveAvgPool2d(1, **kwargs) -OPS['maxp'] = lambda inp, oup, t, stride, kwargs: nn.MaxPool2d(kernel_size=2, stride=stride, **kwargs) - -OPS['linear_relu'] = lambda inp, oup, t, stride, kwargs: LinearReLU(inp, oup) - -'''for NAS-Bench-Macro''' -OPS['ir_3x3_t3'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=3, - stride=stride, act_fn=nn.ReLU, expand_ratio=3, se_ratio=None, **kwargs) -OPS['ir_5x5_t6'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=5, - stride=stride, act_fn=nn.ReLU, expand_ratio=6, se_ratio=None, **kwargs) -OPS['ID'] = lambda inp, oup, t, stride, kwargs: Identity(in_channels=inp, out_channels=oup, kernel_size=1, stride=stride, **kwargs) - - -""" -========================== -basic operations & modules -========================== -""" - -class HSwish(nn.Module): - def __init__(self, inplace=True): - super(HSwish, self).__init__() - self.inplace = inplace - - def forward(self, x): - out = x * F.relu6(x + 3, inplace=self.inplace) / 6 - return out - - -class HSigmoid(nn.Module): - def __init__(self, inplace=True): - super(HSigmoid, self).__init__() - self.inplace = inplace - - def forward(self, x): - out = F.relu6(x + 3, inplace=self.inplace) / 6 - return out - - -class LinearReLU(nn.Module): - def __init__(self, inp, oup,): - super(LinearReLU, self).__init__() - self.fc = nn.Sequential( - nn.Linear(inp, oup, bias=True), - nn.ReLU(inplace=True)) - - def forward(self, x): - #if x.ndims != 2: - if len(x.shape) != 2: - x = x.view(x.shape[0], -1) - return self.fc(x) - - -def conv2d(in_channels, out_channels, kernel_size, stride=1, pad_type='SAME', **kwargs): - if pad_type == 'SAME' or pad_type == '': - if isinstance(kernel_size, (tuple, list)): - padding = [(kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2] - else: - padding = (kernel_size - 1) // 2 - elif pad_type == 'NONE': - padding = 0 - else: - raise NotImplementedError('Not supported padding type: {}.'.format(pad_type)) - return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, **kwargs) - - -class ConvBnAct(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, - stride=1, pad_type='SAME', act_fn=nn.ReLU, **attrs): - super(ConvBnAct, self).__init__() - for k, v in attrs.items(): - setattr(self, k, v) - self.conv = conv2d(in_channels, out_channels, kernel_size, stride=stride, pad_type=pad_type, bias=False) - self.bn1 = nn.BatchNorm2d(out_channels) - self.act = act_fn(inplace=True) - - def forward(self, x): - x = self.conv(x) - x = self.bn1(x) - x = self.act(x) - return x - - -class Identity(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, **kwargs): - super(Identity, self).__init__() - if in_channels != out_channels or stride != 1: - self.conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=False), - nn.BatchNorm2d(out_channels) - ) - else: - self.conv = None - - def forward(self, x): - if self.conv is not None: - return self.conv(x) - else: - return x - - -class SqueezeExcite(nn.Module): - def __init__(self, in_channels, reduce_channels, act_fn=nn.ReLU, gate_fn=nn.Sigmoid): - super(SqueezeExcite, self).__init__() - self.avgp = nn.AdaptiveAvgPool2d(1) - self.conv_reduce = nn.Conv2d(in_channels, reduce_channels, 1, bias=True) - self.act_fn = act_fn(inplace=True) - self.conv_expand = nn.Conv2d(reduce_channels, in_channels, 1, bias=True) - self.gate_fn = gate_fn() - - def forward(self, x): - x_se = self.avgp(x) - x_se = self.conv_reduce(x_se) - x_se = self.act_fn(x_se) - x_se = self.conv_expand(x_se) - x = x * self.gate_fn(x_se) - return x - - -def drop_path(x, drop_prob: float = 0., training: bool = False): - if drop_prob == 0. or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -""" -========================== -ShuffleNetV2-ops -========================== -""" -OPS['shuffle_3x3_se'] = lambda inp, oup, t, stride, kwargs: ShufflenetBlock(inp, oup, ksize=3, stride=stride, activation='HSwish', use_se=True) -OPS['shuffle_5x5_se'] = lambda inp, oup, t, stride, kwargs: ShufflenetBlock(inp, oup, ksize=5, stride=stride, activation='HSwish', use_se=True) -OPS['shuffle_7x7_se'] = lambda inp, oup, t, stride, kwargs: ShufflenetBlock(inp, oup, ksize=7, stride=stride, activation='HSwish', use_se=True) -OPS['shuffle_x_se'] = lambda inp, oup, t, stride, kwargs: ShufflenetBlock(inp, oup, ksize='x', stride=stride, activation='HSwish', use_se=True) - - -def channel_shuffle(x): - batchsize, num_channels, height, width = x.data.size() - assert (num_channels % 4 == 0) - x = x.reshape(batchsize * num_channels // 2, 2, height * width) - x = x.permute(1, 0, 2) - x = x.reshape(2, -1, num_channels // 2, height, width) - return x[0], x[1] - - -class ShufflenetBlock(nn.Module): - - def __init__(self, inp, oup, ksize, stride, activation='ReLU', use_se=False, **kwargs): - super(ShufflenetBlock, self).__init__() - self.stride = stride - assert stride in [1, 2] - assert ksize in [3, 5, 7, 'x'] - base_mid_channels = oup // 2 - - self.base_mid_channel = base_mid_channels - self.ksize = ksize - pad = ksize // 2 if ksize != 'x' else 3 // 2 - self.pad = pad - if stride == 1: - inp = inp // 2 - outputs = oup - inp - else: - outputs = oup // 2 - - self.inp = inp - - - if ksize != 'x': - branch_main = [ - # pw - nn.Conv2d(inp, base_mid_channels, 1, 1, 0, bias=False), - nn.BatchNorm2d(base_mid_channels), - nn.ReLU(inplace=True) if activation == 'ReLU' else HSwish(inplace=True), - # dw - nn.Conv2d(base_mid_channels, base_mid_channels, ksize, stride, pad, groups=base_mid_channels, bias=False), - nn.BatchNorm2d(base_mid_channels), - # pw-linear - nn.Conv2d(base_mid_channels, outputs, 1, 1, 0, bias=False), - nn.BatchNorm2d(outputs), - nn.ReLU(inplace=True) if activation == 'ReLU' else HSwish(inplace=True), - ] - else: - ksize = 3 - branch_main = [ - # dw - nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), - nn.BatchNorm2d(inp), - # pw - nn.Conv2d(inp, base_mid_channels, 1, 1, 0, bias=False), - nn.BatchNorm2d(base_mid_channels), - nn.ReLU(inplace=True) if activation == 'ReLU' else HSwish(inplace=True), - # dw - nn.Conv2d(base_mid_channels, base_mid_channels, 3, 1, 1, groups=base_mid_channels, bias=False), - nn.BatchNorm2d(base_mid_channels), - # pw - nn.Conv2d(base_mid_channels, base_mid_channels, 1, 1, 0, bias=False), - nn.BatchNorm2d(base_mid_channels), - nn.ReLU(inplace=True) if activation == 'ReLU' else HSwish(inplace=True), - # dw - nn.Conv2d(base_mid_channels, base_mid_channels, 3, 1, 1, groups=base_mid_channels, bias=False), - nn.BatchNorm2d(base_mid_channels), - # pw - nn.Conv2d(base_mid_channels, outputs, 1, 1, 0, bias=False), - nn.BatchNorm2d(outputs), - nn.ReLU(inplace=True) if activation == 'ReLU' else HSwish(inplace=True), - ] - if use_se: - assert activation != 'ReLU' - branch_main.append(SqueezeExcite(outputs, outputs // 4, act_fn=HSwish, gate_fn=HSigmoid)) - self.branch_main = nn.Sequential(*branch_main) - - if stride == 2: - branch_proj = [ - # dw - nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=False), - nn.BatchNorm2d(inp), - # pw-linear - nn.Conv2d(inp, outputs, 1, 1, 0, bias=False), - nn.BatchNorm2d(outputs), - nn.ReLU(inplace=True) if activation == 'ReLU' else HSwish(inplace=True), - ] - self.branch_proj = nn.Sequential(*branch_proj) - else: - self.branch_proj = None - - def forward(self, old_x): - if self.stride == 1: - x_proj, x = channel_shuffle(old_x) - return torch.cat((x_proj, self.branch_main(x)), 1) - elif self.stride == 2: - x_proj = old_x - x = old_x - return torch.cat((self.branch_proj(x_proj), self.branch_main(x)), 1) - - - -""" -========================== -DARTS-ops -========================== -""" -OPS['avg_pool_3x3'] = lambda inp, oup, t, stride, kwargs: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) -OPS['max_pool_3x3'] = lambda inp, oup, t, stride, kwargs: nn.MaxPool2d(3, stride=stride, padding=1) -OPS['skip_connect'] = lambda inp, oup, t, stride, kwargs: nn.Identity() if stride == 1 else FactorizedReduce(inp, oup) -OPS['sep_conv_3x3'] = lambda inp, oup, t, stride, kwargs: SepConv(inp, oup, 3, stride) -OPS['sep_conv_5x5'] = lambda inp, oup, t, stride, kwargs: SepConv(inp, oup, 5, stride) -OPS['dil_conv_3x3'] = lambda inp, oup, t, stride, kwargs: DilConv(inp, oup, 3, stride, padding=2) -OPS['dil_conv_5x5'] = lambda inp, oup, t, stride, kwargs: DilConv(inp, oup, 5, stride, padding=4) - - -class ReLUConvBN(nn.Module): - - def __init__(self, C_in, C_out, kernel_size, stride): - super(ReLUConvBN, self).__init__() - padding = (kernel_size - 1) // 2 - self.op = nn.Sequential( - nn.ReLU(inplace=False), - nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), - nn.BatchNorm2d(C_out) - ) - - def forward(self, x): - return self.op(x) - - -class FactorizedReduce(nn.Module): - - def __init__(self, C_in, C_out): - super(FactorizedReduce, self).__init__() - assert C_out % 2 == 0 - self.relu = nn.ReLU(inplace=False) - self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) - self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) - self.bn = nn.BatchNorm2d(C_out) - - def forward(self, x): - x = self.relu(x) - out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1) - out = self.bn(out) - return out - - -class SepConv(nn.Module): - - def __init__(self, C_in, C_out, kernel_size, stride): - super(SepConv, self).__init__() - padding = (kernel_size - 1) // 2 - self.op = nn.Sequential( - nn.ReLU(inplace=False), - nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), - nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), - nn.BatchNorm2d(C_in), - nn.ReLU(inplace=False), - nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), - nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), - nn.BatchNorm2d(C_out), - ) - - def forward(self, x): - return self.op(x) - - - -class DilConv(nn.Module): - - def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation=2): - super(DilConv, self).__init__() - self.op = nn.Sequential( - nn.ReLU(inplace=False), - nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), - nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), - nn.BatchNorm2d(C_out), - ) - - def forward(self, x): - return self.op(x) - - -""" -========================== -blocks -========================== -""" - -class InvertedResidualMixConv(nn.Module): - '''Inverted Residual block from MobileNet V2''' - def __init__(self, in_channels, out_channels, dw_kernel_size=3, - stride=1, pad_type='', act_fn=nn.ReLU, - expand_ratio=1.0, se_ratio=0., se_gate_fn=nn.Sigmoid, - drop_connect_rate=0.0, use_residual=True, use_3x3_dw_only=False, **attrs): - super(InvertedResidualMixConv, self).__init__() - mid_channels = int(in_channels * expand_ratio) - self.has_se = se_ratio is not None and se_ratio > 0. - self.has_residual = in_channels == out_channels and stride == 1 and use_residual - self.drop_connect_rate = drop_connect_rate - - for k, v in attrs.items(): - # for edgenn: NAS and pruning - setattr(self, k, v) - - # Point-wise convolution - if expand_ratio == 1: - self.conv_pw = nn.Sequential() - else: - self.conv_pw = nn.Sequential( - conv2d(in_channels, mid_channels, 1, 1, bias=False), - nn.BatchNorm2d(mid_channels), - act_fn(inplace=True) - ) - - use_3x3_dw_only = False - # Depth-wise convolution - if not use_3x3_dw_only: - self.conv_dw = nn.Sequential( - #conv2d(mid_channels, mid_channels, dw_kernel_size, stride, groups=mid_channels, bias=False), - MDConv(mid_channels, n_chunks=3, stride=stride, bias=False), - nn.BatchNorm2d(mid_channels), - act_fn(inplace=True) - ) - else: - conv_dw = [] - for i in range((dw_kernel_size - 3) // 2 + 1): - conv_dw.extend([ - conv2d(mid_channels, mid_channels, 3, stride if i == 0 else 1, groups=mid_channels, bias=False), - nn.BatchNorm2d(mid_channels), - ]) - conv_dw.append(act_fn(inplace=True)) - self.conv_dw = nn.Sequential(*conv_dw) - - # Squeeze-and-excitation - if self.has_se: - self.se = SqueezeExcite( - mid_channels, reduce_channels=max(1, int(mid_channels * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) - - # Point-wise convolution - self.conv_pw2 = nn.Sequential( - conv2d(mid_channels, out_channels, 1, 1, bias=False), - nn.BatchNorm2d(out_channels), - ) - - def forward(self, x): - residual = x - - x = self.conv_pw(x) - x = self.conv_dw(x) - - if self.has_se: - x = self.se(x) - - x = self.conv_pw2(x) - - if self.has_residual: - if self.drop_connect_rate > 0.: - x = drop_path(x, self.drop_connect_rate, self.training) - x += residual - - return x - - - - -class InvertedResidual(nn.Module): - '''Inverted Residual block from MobileNet V2''' - def __init__(self, in_channels, out_channels, dw_kernel_size=3, - stride=1, pad_type='', act_fn=nn.ReLU, - expand_ratio=1.0, se_ratio=0., se_gate_fn=nn.Sigmoid, - drop_connect_rate=0.0, use_residual=True, use_3x3_dw_only=False, **attrs): - super(InvertedResidual, self).__init__() - mid_channels = int(in_channels * expand_ratio) - self.has_se = se_ratio is not None and se_ratio > 0. - self.has_residual = in_channels == out_channels and stride == 1 and use_residual - self.drop_connect_rate = drop_connect_rate - - for k, v in attrs.items(): - # for edgenn: NAS and pruning - setattr(self, k, v) - - # Point-wise convolution - if expand_ratio == 1: - self.conv_pw = nn.Sequential() - else: - self.conv_pw = nn.Sequential( - conv2d(in_channels, mid_channels, 1, 1, bias=False), - nn.BatchNorm2d(mid_channels), - act_fn(inplace=True) - ) - - # Depth-wise convolution - if not use_3x3_dw_only: - self.conv_dw = nn.Sequential( - conv2d(mid_channels, mid_channels, dw_kernel_size, stride, groups=mid_channels, bias=False), - nn.BatchNorm2d(mid_channels), - act_fn(inplace=True) - ) - else: - conv_dw = [] - for i in range((dw_kernel_size - 3) // 2 + 1): - conv_dw.extend([ - conv2d(mid_channels, mid_channels, 3, stride if i == 0 else 1, groups=mid_channels, bias=False), - nn.BatchNorm2d(mid_channels), - ]) - conv_dw.append(act_fn(inplace=True)) - self.conv_dw = nn.Sequential(*conv_dw) - - # Squeeze-and-excitation - if self.has_se: - self.se = SqueezeExcite( - mid_channels, reduce_channels=max(1, int(mid_channels * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) - - # Point-wise convolution - self.conv_pw2 = nn.Sequential( - conv2d(mid_channels, out_channels, 1, 1, bias=False), - nn.BatchNorm2d(out_channels), - ) - - def forward(self, x): - residual = x - - x = self.conv_pw(x) - x = self.conv_dw(x) - - if self.has_se: - x = self.se(x) - - x = self.conv_pw2(x) - - if self.has_residual: - if self.drop_connect_rate > 0.: - x = drop_path(x, self.drop_connect_rate, self.training) - x += residual - - return x - - -class DARTSCell(nn.Module): - def __init__(self, cell_arch, c_prev_prev, c_prev, c, stride=1, reduction_prev=False, steps=4): - super().__init__() - self.cell_arch = cell_arch - self.steps = steps - self.preprocess0 = FactorizedReduce(c_prev_prev, c) if reduction_prev else \ - ReLUConvBN(c_prev_prev, c, 1, stride=1) - self.preprocess1 = ReLUConvBN(c_prev, c, 1, stride=1) - - if len(cell_arch[0]) != 0 and isinstance(cell_arch[0][0], str): - # DARTS-like genotype, convert it to topo-free type - cell_arch = [[cell_arch[idx*2], cell_arch[idx*2+1]] for idx in range(len(cell_arch) // 2)] - - self.ops = nn.ModuleList() - self.inputs = [] - for step in cell_arch: - step_ops = nn.ModuleList() - step_inputs = [] - for op_name, input_idx in step: - step_ops += [OPS[op_name](c, c, None, stride if input_idx < 2 else 1, {})] - step_inputs.append(input_idx) - self.ops += [step_ops] - self.inputs.append(step_inputs) - - def forward(self, s0, s1, drop_path_rate=0.): - s0 = self.preprocess0(s0) - s1 = self.preprocess1(s1) - states = [s0, s1] - - for step_idx, (step_inputs, step_ops) in enumerate(zip(self.inputs, self.ops)): - step_outs = [] - for input_idx, op in zip(step_inputs, step_ops): - out = op(states[input_idx]) - if drop_path_rate > 0. and not isinstance(op, (FactorizedReduce, nn.Identity)): - out = drop_path(out, drop_path_rate, self.training) - step_outs.append(out) - states.append(sum(step_outs)) - - return torch.cat(states[-4:], dim=1) - - -""" -========================= -Auxiliary Heads -========================= -""" -class AuxiliaryHead(nn.Module): - - def __init__(self, C, num_classes, avg_pool_stride=2): - """with avg_pol_stride=2, assuming input size 14x14""" - super(AuxiliaryHead, self).__init__() - self.features = nn.Sequential( - nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), - nn.Conv2d(C, 128, 1, bias=False), - nn.BatchNorm2d(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 768, 2, bias=False), - nn.BatchNorm2d(768), - nn.ReLU(inplace=True) - ) - self.classifier = nn.Linear(768, num_classes) - - def forward(self, x): - x = self.features(x) - x = self.classifier(x.view(x.size(0),-1)) - return x - - - diff --git a/classification/lib/models/operations_resnet.py b/classification/lib/models/operations_resnet.py deleted file mode 100644 index 32fd997..0000000 --- a/classification/lib/models/operations_resnet.py +++ /dev/null @@ -1,192 +0,0 @@ -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -from typing import Type, Any, Callable, Union, List, Optional -from torch import Tensor -from collections import OrderedDict -from .operations import OPS, conv2d, ConvBnAct, SqueezeExcite - - -'''ResNet''' -OPS['maxp_3x3'] = lambda inp, oup, t, stride, kwargs: nn.MaxPool2d(kernel_size=3, stride=stride, padding=1) -OPS['conv7x7'] = lambda inp, oup, t, stride, kwargs: ConvBnAct(inp, oup, kernel_size=7, stride=stride, **kwargs) -OPS['res_3x3'] = lambda inp, oup, t, stride, kwargs: Bottleneck(inplanes=inp, outplanes=oup, kernel_size=3, stride=stride, **kwargs) -OPS['res_5x5'] = lambda inp, oup, t, stride, kwargs: Bottleneck(inplanes=inp, outplanes=oup, kernel_size=5, stride=stride, **kwargs) -OPS['res_7x7'] = lambda inp, oup, t, stride, kwargs: Bottleneck(inplanes=inp, outplanes=oup, kernel_size=7, stride=stride, **kwargs) -OPS['res_3x3_se'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=3, stride=stride, use_se=True, expansion=4, **kwargs) -OPS['res_5x5_se'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=5, stride=stride, use_se=True, expansion=4, **kwargs) -OPS['res_7x7_se'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=7, stride=stride, use_se=True, expansion=4, **kwargs) -OPS['res_3x3_se_e'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=3, stride=stride, use_se=True, expansion=t, **kwargs) -OPS['res_5x5_se_e'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=5, stride=stride, use_se=True, expansion=t, **kwargs) -OPS['res_7x7_se_e'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=7, stride=stride, use_se=True, expansion=t, **kwargs) -OPS['resnext_3x3'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=3, stride=stride, **kwargs) -OPS['resnext_5x5'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=5, stride=stride, **kwargs) -OPS['resnext_7x7'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=7, stride=stride, **kwargs) -OPS['resnext_3x3_se'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=3, stride=stride, use_se=True, expansion=4, **kwargs) -OPS['resnext_5x5_se'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=5, stride=stride, use_se=True, expansion=4, **kwargs) -OPS['resnext_7x7_se'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=7, stride=stride, use_se=True, expansion=4, **kwargs) -OPS['resnext_3x3_se_e'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=3, stride=stride, use_se=True, expansion=t, **kwargs) -OPS['resnext_5x5_se_e'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=5, stride=stride, use_se=True, expansion=t, **kwargs) -OPS['resnext_7x7_se_e'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=7, stride=stride, use_se=True, expansion=t, **kwargs) - - -class Bottleneck(nn.Module): - # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) - # while original implementation places the stride at the first 1x1 convolution(self.conv1) - # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. - # This variant is also known as ResNet V1.5 and improves accuracy according to - # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. - - expansion: int = 4 - - def __init__( - self, - inplanes: int, - outplanes: int, - stride: int = 1, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None, - kernel_size: int = 3, - use_se: bool = False, - planes: int = None, - expansion = 4 - ) -> None: - super(Bottleneck, self).__init__() - self.expansion = expansion - if norm_layer is None: - norm_layer = nn.BatchNorm2d - - if stride != 1 or inplanes != outplanes: - self.downsample = nn.Sequential( - nn.Conv2d(inplanes, outplanes, stride=stride, kernel_size=1, bias=False), - norm_layer(outplanes), - ) - if planes is None: - planes = int(inplanes // self.expansion * 2) - else: - self.downsample = None - planes = int(inplanes // self.expansion) - - width = int(planes * (base_width / 64.)) * groups - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False) - self.bn1 = norm_layer(width) - self.conv2 = conv2d(width, width, kernel_size, stride, bias=False, groups=groups) - self.bn2 = norm_layer(width) - self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) - self.bn3 = norm_layer(outplanes) - if use_se: - self.se = SqueezeExcite(outplanes, reduce_channels=max(1, outplanes // 16)) - else: - self.se = None - self.relu = nn.ReLU(inplace=True) - self.stride = stride - - def forward(self, x: Tensor) -> Tensor: - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.se is not None: - out = self.se(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class ResNeXtBottleneck(nn.Module): - # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) - # while original implementation places the stride at the first 1x1 convolution(self.conv1) - # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. - # This variant is also known as ResNet V1.5 and improves accuracy according to - # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. - - def __init__( - self, - inplanes: int, - outplanes: int, - stride: int = 1, - groups: int = 32, - base_width: int = 4, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None, - kernel_size: int = 3, - use_se: bool = False, - planes: int = None, - expansion = 4, - ) -> None: - super(ResNeXtBottleneck, self).__init__() - self.expansion = expansion - - if stride != 1 or inplanes != outplanes: - self.downsample = nn.Sequential( - nn.Conv2d(inplanes, outplanes, stride=stride, kernel_size=1, bias=False), - nn.BatchNorm2d(outplanes), - ) - if planes is None: - planes = int(inplanes // self.expansion * 2 ) - else: - self.downsample = None - planes = int(inplanes // self.expansion) - - width = math.floor(planes * (base_width / 64)) * groups - self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, - stride=1) - self.bn1 = nn.BatchNorm2d(width) - self.conv2 = conv2d(width, width, kernel_size=kernel_size, stride=stride, - groups=groups, bias=False) - self.bn2 = nn.BatchNorm2d(width) - self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(outplanes) - if use_se: - self.se = SqueezeExcite(outplanes, reduce_channels=max(1, outplanes // 16)) - else: - self.se = None - self.relu = nn.ReLU(inplace=True) - self.stride = stride - - def forward(self, x: Tensor) -> Tensor: - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.se is not None: - out = self.se(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - - - diff --git a/classification/lib/models/resnet.py b/classification/lib/models/resnet.py deleted file mode 100644 index 40e9ec8..0000000 --- a/classification/lib/models/resnet.py +++ /dev/null @@ -1,392 +0,0 @@ -"""resnet implemented in torchvision: -https://pytorch.org/vision/stable/_modules/torchvision/models/resnet.html -""" -import torch -from torch import Tensor -import torch.nn as nn -from typing import Type, Any, Callable, Union, List, Optional - - -__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', - 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', - 'wide_resnet50_2', 'wide_resnet101_2'] - - -def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation) - - -def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: - """1x1 convolution""" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) - - -class BasicBlock(nn.Module): - expansion: int = 1 - - def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None - ) -> None: - super(BasicBlock, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') - if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = norm_layer(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x: Tensor) -> Tensor: - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) - # while original implementation places the stride at the first 1x1 convolution(self.conv1) - # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. - # This variant is also known as ResNet V1.5 and improves accuracy according to - # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. - - expansion: int = 4 - - def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None - ) -> None: - super(Bottleneck, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.)) * groups - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv1x1(inplanes, width) - self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, stride, groups, dilation) - self.bn2 = norm_layer(width) - self.conv3 = conv1x1(width, planes * self.expansion) - self.bn3 = norm_layer(planes * self.expansion) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x: Tensor) -> Tensor: - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class ResNet(nn.Module): - - def __init__( - self, - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - num_classes: int = 1000, - zero_init_residual: bool = False, - groups: int = 1, - width_per_group: int = 64, - replace_stride_with_dilation: Optional[List[bool]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None - ) -> None: - super(ResNet, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - self._norm_layer = norm_layer - - self.inplanes = 64 - self.dilation = 1 - if replace_stride_with_dilation is None: - # each element in the tuple indicates if we should replace - # the 2x2 stride with a dilated convolution instead - replace_stride_with_dilation = [False, False, False] - if len(replace_stride_with_dilation) != 3: - raise ValueError("replace_stride_with_dilation should be None " - "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) - self.groups = groups - self.base_width = width_per_group - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, - bias=False) - self.bn1 = norm_layer(self.inplanes) - self.relu = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, - dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, - dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, - dilate=replace_stride_with_dilation[2]) - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(512 * block.expansion, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - # Zero-initialize the last BN in each residual branch, - # so that the residual branch starts with zeros, and each residual block behaves like an identity. - # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 - if zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] - elif isinstance(m, BasicBlock): - nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] - - def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, - stride: int = 1, dilate: bool = False) -> nn.Sequential: - norm_layer = self._norm_layer - downsample = None - previous_dilation = self.dilation - if dilate: - self.dilation *= stride - stride = 1 - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - conv1x1(self.inplanes, planes * block.expansion, stride), - norm_layer(planes * block.expansion), - ) - - layers = [] - layers.append(block(self.inplanes, planes, stride, downsample, self.groups, - self.base_width, previous_dilation, norm_layer)) - self.inplanes = planes * block.expansion - for _ in range(1, blocks): - layers.append(block(self.inplanes, planes, groups=self.groups, - base_width=self.base_width, dilation=self.dilation, - norm_layer=norm_layer)) - - return nn.Sequential(*layers) - - def _forward_impl(self, x: Tensor) -> Tensor: - # See note [TorchScript super()] - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = self.avgpool(x) - x = torch.flatten(x, 1) - x = self.fc(x) - - return x - - def forward(self, x: Tensor) -> Tensor: - return self._forward_impl(x) - - -def _resnet( - arch: str, - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - pretrained: bool, - progress: bool, - **kwargs: Any -) -> ResNet: - model = ResNet(block, layers, **kwargs) - return model - - - -def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNet-18 model from - `"Deep Residual Learning for Image Recognition" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, - **kwargs) - - - - -def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNet-34 model from - `"Deep Residual Learning for Image Recognition" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, - **kwargs) - - - - -def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNet-50 model from - `"Deep Residual Learning for Image Recognition" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, - **kwargs) - - - - -def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNet-101 model from - `"Deep Residual Learning for Image Recognition" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, - **kwargs) - - - - -def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNet-152 model from - `"Deep Residual Learning for Image Recognition" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, - **kwargs) - - - - -def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNeXt-50 32x4d model from - `"Aggregated Residual Transformation for Deep Neural Networks" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['groups'] = 32 - kwargs['width_per_group'] = 4 - return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], - pretrained, progress, **kwargs) - - - - -def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNeXt-101 32x8d model from - `"Aggregated Residual Transformation for Deep Neural Networks" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['groups'] = 32 - kwargs['width_per_group'] = 8 - return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], - pretrained, progress, **kwargs) - - - - -def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""Wide ResNet-50-2 model from - `"Wide Residual Networks" `_. - - The model is the same as ResNet except for the bottleneck number of channels - which is twice larger in every block. The number of channels in outer 1x1 - convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 - channels, and in Wide ResNet-50-2 has 2048-1024-2048. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['width_per_group'] = 64 * 2 - return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], - pretrained, progress, **kwargs) - - - - -def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""Wide ResNet-101-2 model from - `"Wide Residual Networks" `_. - - The model is the same as ResNet except for the bottleneck number of channels - which is twice larger in every block. The number of channels in outer 1x1 - convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 - channels, and in Wide ResNet-50-2 has 2048-1024-2048. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['width_per_group'] = 64 * 2 - return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], - pretrained, progress, **kwargs) - diff --git a/classification/lib/models/rope.py b/classification/lib/models/rope.py deleted file mode 100644 index 488f1b9..0000000 --- a/classification/lib/models/rope.py +++ /dev/null @@ -1,145 +0,0 @@ -# -------------------------------------------------------- -# EVA-02: A Visual Representation for Neon Genesis -# Github source: https://github.com/baaivision/EVA/EVA02 -# Copyright (c) 2023 Beijing Academy of Artificial Intelligence (BAAI) -# Licensed under The MIT License [see LICENSE for details] -# By Yuxin Fang -# -# Based on https://github.com/lucidrains/rotary-embedding-torch -# --------------------------------------------------------' - -from math import pi - -import torch -from torch import nn - -from einops import rearrange, repeat - - - -def broadcat(tensors, dim = -1): - num_tensors = len(tensors) - shape_lens = set(list(map(lambda t: len(t.shape), tensors))) - assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' - shape_len = list(shape_lens)[0] - dim = (dim + shape_len) if dim < 0 else dim - dims = list(zip(*map(lambda t: list(t.shape), tensors))) - expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] - assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' - max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) - expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) - expanded_dims.insert(dim, (dim, dims[dim])) - expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) - tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) - return torch.cat(tensors, dim = dim) - - - -def rotate_half(x): - x = rearrange(x, '... (d r) -> ... d r', r = 2) - x1, x2 = x.unbind(dim = -1) - x = torch.stack((-x2, x1), dim = -1) - return rearrange(x, '... d r -> ... (d r)') - - - -class VisionRotaryEmbedding(nn.Module): - def __init__( - self, - dim, - pt_seq_len, - ft_seq_len=None, - custom_freqs = None, - freqs_for = 'lang', - theta = 10000, - max_freq = 10, - num_freqs = 1, - ): - super().__init__() - if custom_freqs: - freqs = custom_freqs - elif freqs_for == 'lang': - freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) - elif freqs_for == 'pixel': - freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi - elif freqs_for == 'constant': - freqs = torch.ones(num_freqs).float() - else: - raise ValueError(f'unknown modality {freqs_for}') - - if ft_seq_len is None: ft_seq_len = pt_seq_len - t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len - - freqs_h = torch.einsum('..., f -> ... f', t, freqs) - freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) - - freqs_w = torch.einsum('..., f -> ... f', t, freqs) - freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) - - freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) - - self.register_buffer("freqs_cos", freqs.cos()) - self.register_buffer("freqs_sin", freqs.sin()) - - print('======== shape of rope freq', self.freqs_cos.shape, '========') - - def forward(self, t, start_index = 0): - rot_dim = self.freqs_cos.shape[-1] - end_index = start_index + rot_dim - assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' - t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] - t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) - return torch.cat((t_left, t, t_right), dim = -1) - - - -class VisionRotaryEmbeddingFast(nn.Module): - def __init__( - self, - dim, - pt_seq_len=16, - ft_seq_len=None, - custom_freqs = None, - freqs_for = 'lang', - theta = 10000, - max_freq = 10, - num_freqs = 1, - ): - super().__init__() - if custom_freqs: - freqs = custom_freqs - elif freqs_for == 'lang': - freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) - elif freqs_for == 'pixel': - freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi - elif freqs_for == 'constant': - freqs = torch.ones(num_freqs).float() - else: - raise ValueError(f'unknown modality {freqs_for}') - - if ft_seq_len is None: ft_seq_len = pt_seq_len - t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len - - freqs = torch.einsum('..., f -> ... f', t, freqs) - freqs = repeat(freqs, '... n -> ... (n r)', r = 2) - freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) - - freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) - freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) - - self.register_buffer("freqs_cos", freqs_cos) - self.register_buffer("freqs_sin", freqs_sin) - - print('======== shape of rope freq', self.freqs_cos.shape, '========') - - def forward(self, t, freqs_cos=None, freqs_sin=None): - if freqs_cos is None: - freqs_cos = self.freqs_cos - if freqs_sin is None: - freqs_sin = self.freqs_sin - if t.shape[1] % 2 != 0: - t_spatial = t[:, 1:, :] - t_spatial = t_spatial * freqs_cos + rotate_half(t_spatial) * freqs_sin - return torch.cat((t[:, :1, :], t_spatial), dim=1) - else: - return t * freqs_cos + rotate_half(t) * freqs_sin \ No newline at end of file diff --git a/classification/lib/models/vim.py b/classification/lib/models/vim.py deleted file mode 100644 index e83e20a..0000000 --- a/classification/lib/models/vim.py +++ /dev/null @@ -1,799 +0,0 @@ -# Copyright (c) 2015-present, Facebook, Inc. -# All rights reserved. -import torch -import torch.nn as nn -import copy -from functools import partial -from torch import Tensor -from typing import Optional - -from timm.models.vision_transformer import VisionTransformer, _cfg -from timm.models.registry import register_model -from timm.models.layers import trunc_normal_ - -from timm.models.layers import DropPath, PatchEmbed -from timm.models.vision_transformer import _load_weights - -import math - -from collections import namedtuple - -from mamba_ssm.modules.mamba_simple import Mamba -from mamba_ssm.utils.generation import GenerationMixin -from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf -from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count - - -from .dyn_mamba_simple_search import DynMamba -# from .dyn_mamba_simple import DynMamba - -from .rope import * -import random - -try: - from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn -except ImportError: - RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None - - -__all__ = [ - 'vim_tiny_patch16_224', 'vim_small_patch16_224', 'vim_base_patch16_224', - 'vim_tiny_patch16_384', 'vim_small_patch16_384', 'vim_base_patch16_384', -] - -class Block(nn.Module): - def __init__( - self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False,drop_path=0., - ): - """ - Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" - - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA/MLP -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Add -> LN -> Mixer, returning both - the hidden_states (output of the mixer) and the residual. - This is purely for performance reasons, as we can fuse add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - super().__init__() - self.residual_in_fp32 = residual_in_fp32 - self.fused_add_norm = fused_add_norm - self.mixer = mixer_cls(dim) - self.norm = norm_cls(dim) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - if self.fused_add_norm: - assert RMSNorm is not None, "RMSNorm import fails" - assert isinstance( - self.norm, (nn.LayerNorm, RMSNorm) - ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" - - def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None - ): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Mixer(LN(residual)) - """ - if not self.fused_add_norm: - if residual is None: - residual = hidden_states - else: - residual = residual + self.drop_path(hidden_states) - - hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn - if residual is None: - hidden_states, residual = fused_add_norm_fn( - hidden_states, - self.norm.weight, - self.norm.bias, - residual=residual, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - eps=self.norm.eps, - ) - else: - hidden_states, residual = fused_add_norm_fn( - self.drop_path(hidden_states), - self.norm.weight, - self.norm.bias, - residual=residual, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - eps=self.norm.eps, - ) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) - return hidden_states, residual - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - - -def create_block( - d_model, - ssm_cfg=None, - norm_epsilon=1e-5, - drop_path=0., - rms_norm=False, - residual_in_fp32=False, - fused_add_norm=False, - layer_idx=None, - device=None, - dtype=None, - bimamba_type="none", - directions=None, - token_size=(14, 14), -): - if ssm_cfg is None: - ssm_cfg = {} - factory_kwargs = {"device": device, "dtype": dtype} - mixer_cls = partial(DynMamba, layer_idx=layer_idx, bimamba_type=bimamba_type, directions=directions, token_size=token_size, **ssm_cfg, **factory_kwargs) - norm_cls = partial( - nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs - ) - block = Block( - d_model, - mixer_cls, - norm_cls=norm_cls, - drop_path=drop_path, - fused_add_norm=fused_add_norm, - residual_in_fp32=residual_in_fp32, - ) - block.layer_idx = layer_idx - return block - - -# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 -def _init_weights( - module, - n_layer, - initializer_range=0.02, # Now only used for embedding layer. - rescale_prenorm_residual=True, - n_residuals_per_layer=1, # Change to 2 if we have MLP -): - if isinstance(module, nn.Linear): - if module.bias is not None: - if not getattr(module.bias, "_no_reinit", False): - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, std=initializer_range) - - if rescale_prenorm_residual: - # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: - # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale - # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. - # > -- GPT-2 :: https://openai.com/blog/better-language-models/ - # - # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - for name, p in module.named_parameters(): - if name in ["out_proj.weight", "fc2.weight"]: - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) - # We need to reinit p since this code could be called multiple times - # Having just p *= scale would repeatedly scale it down - nn.init.kaiming_uniform_(p, a=math.sqrt(5)) - with torch.no_grad(): - p /= math.sqrt(n_residuals_per_layer * n_layer) - - -def segm_init_weights(m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - if m.bias is not None: - nn.init.constant_(m.bias, 0) - if m.weight is not None: - nn.init.constant_(m.weight, 1.0) - - -def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): - """ - u: r(B D L) - delta: r(B D L) - A: r(D N) - B: r(B N L) - C: r(B N L) - D: r(D) - z: r(B D L) - delta_bias: r(D), fp32 - - ignores: - [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] - """ - assert not with_complex - # https://github.com/state-spaces/mamba/issues/110 - flops = 9 * B * L * D * N - if with_D: - flops += B * D * L - if with_Z: - flops += B * D * L - return flops - -def selective_scan_flop_jit(inputs, outputs): - print_jit_input_names(inputs) - B, D, L = inputs[0].type().sizes() - N = inputs[2].type().sizes()[1] - flops = flops_selective_scan_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False, with_Group=True) - return flops - - -class VisionMamba(nn.Module): - def __init__(self, - img_size=224, - patch_size=16, - depth=24, - embed_dim=192, - channels=3, - num_classes=1000, - ssm_cfg=None, - drop_rate=0., - drop_path_rate=0.1, - norm_epsilon: float = 1e-5, - rms_norm: bool = False, - initializer_cfg=None, - fused_add_norm=False, - residual_in_fp32=False, - device=None, - dtype=None, - ft_seq_len=None, - pt_hw_seq_len=14, - final_pool_type='none', - if_abs_pos_embed=False, - if_rope=False, - if_rope_residual=False, - bimamba_type="none", - if_cls_token=False, - directions=None, - **kwargs): - factory_kwargs = {"device": device, "dtype": dtype} - # add factory_kwargs into kwargs - kwargs.update(factory_kwargs) - super().__init__() - self.residual_in_fp32 = residual_in_fp32 - self.fused_add_norm = fused_add_norm - self.final_pool_type = final_pool_type - self.if_abs_pos_embed = if_abs_pos_embed - self.if_rope = if_rope - self.if_rope_residual = if_rope_residual - self.if_cls_token = if_cls_token - self.num_tokens = 1 if if_cls_token else 0 - self.patch_size = patch_size - - # pretrain parameters - self.num_classes = num_classes - self.d_model = self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=channels, embed_dim=embed_dim, strict_img_size=False, dynamic_img_pad=True) - num_patches = self.patch_embed.num_patches - self.token_size = self.patch_embed.grid_size - - if if_cls_token: - self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) - - if if_abs_pos_embed: - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, self.embed_dim)) - self.pos_drop = nn.Dropout(p=drop_rate) - - if if_rope: - half_head_dim = embed_dim // 2 - if isinstance(img_size, (tuple, list)): - hw_seq_len = img_size[0] // patch_size - else: - hw_seq_len = img_size // patch_size - self.rope = VisionRotaryEmbeddingFast( - dim=half_head_dim, - pt_seq_len=pt_hw_seq_len, - ft_seq_len=hw_seq_len - ) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - - - # TODO: release this comment - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - # import ipdb;ipdb.set_trace() - inter_dpr = [0.0] + dpr - self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() - # transformer blocks - if directions is None: - directions = [None] * depth - self.layers = nn.ModuleList( - [ - create_block( - embed_dim, - ssm_cfg=ssm_cfg, - norm_epsilon=norm_epsilon, - rms_norm=rms_norm, - residual_in_fp32=residual_in_fp32, - fused_add_norm=fused_add_norm, - layer_idx=i, - bimamba_type=bimamba_type, - drop_path=inter_dpr[i], - directions=directions[i], - token_size=self.token_size, - **factory_kwargs, - ) - for i in range(depth) - ] - ) - - # output head - self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( - embed_dim, eps=norm_epsilon, **factory_kwargs - ) - - self.pre_logits = nn.Identity() - - # original init - self.apply(segm_init_weights) - self.head.apply(segm_init_weights) - if if_abs_pos_embed: - trunc_normal_(self.pos_embed, std=.02) - - # mamba init - self.apply( - partial( - _init_weights, - n_layer=depth, - **(initializer_cfg if initializer_cfg is not None else {}), - ) - ) - - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return { - i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - for i, layer in enumerate(self.layers) - } - - @torch.jit.ignore - def no_weight_decay(self): - return {"pos_embed", "cls_token", "dist_token"} - - @torch.jit.ignore() - def load_pretrained(self, checkpoint_path, prefix=""): - _load_weights(self, checkpoint_path, prefix) - - def forward_features(self, x, inference_params=None, out_indices=None): - # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py - # with slight modifications to add the dist_token - B, _, H, W = x.shape - x = self.patch_embed(x) - if self.if_cls_token: - cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks - x = torch.cat((cls_token, x), dim=1) - - if self.if_abs_pos_embed: - H, W = math.ceil(H / self.patch_size), math.ceil(W / self.patch_size) - for layer in self.layers: - layer.mixer.multi_scan.token_size = (H, W) - if H != self.token_size[0] or W != self.token_size[1]: - # downstream tasks such as det and seg may have various input resolutions - pos_embed = self.resize_pos_embed(self.pos_embed, (H, W), self.token_size, 'bicubic') - if self.if_rope: - freqs_cos = self.resize_pos_embed(self.rope.freqs_cos.unsqueeze(0), (H, W), self.token_size, 'bicubic')[0] - freqs_sin = self.resize_pos_embed(self.rope.freqs_sin.unsqueeze(0), (H, W), self.token_size, 'bicubic')[0] - else: - pos_embed = self.pos_embed - freqs_cos = None - freqs_sin = None - x = x + pos_embed - x = self.pos_drop(x) - - outs = [] - - # mamba impl - residual = None - hidden_states = x - for layer_idx, layer in enumerate(self.layers): - # rope about - if self.if_rope: - hidden_states = self.rope(hidden_states, freqs_cos=freqs_cos, freqs_sin=freqs_sin) - if residual is not None and self.if_rope_residual: - residual = self.rope(residual, freqs_cos=freqs_cos, freqs_sin=freqs_sin) - - hidden_states, residual = layer( - hidden_states, residual, inference_params=inference_params - ) - - if out_indices is not None and layer_idx in out_indices: - outs.append(hidden_states) - - if out_indices is not None: - assert len(outs) == len(out_indices) - return outs, (H, W) - - if not self.fused_add_norm: - if residual is None: - residual = hidden_states - else: - residual = residual + self.drop_path(hidden_states) - hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) - else: - # Set prenorm=False here since we don't need the residual - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn - hidden_states = fused_add_norm_fn( - self.drop_path(hidden_states), - self.norm_f.weight, - self.norm_f.bias, - eps=self.norm_f.eps, - residual=residual, - prenorm=False, - residual_in_fp32=self.residual_in_fp32, - ) - - # return only cls token if it exists - if self.if_cls_token: - return hidden_states[:, 0, :] - - if self.final_pool_type == 'none': - return hidden_states[:, -1, :] - elif self.final_pool_type == 'mean': - return hidden_states.mean(dim=1) - elif self.final_pool_type == 'max': - return hidden_states.max(dim=1) - elif self.final_pool_type == 'all': - return hidden_states - else: - raise NotImplementedError - - def forward(self, x, return_features=False, inference_params=None): - x = self.forward_features(x, inference_params) - if return_features: - return x - x = self.head(x) - return x - - def flops(self, input_shape=(3, 224, 224)): - supported_ops={ - "aten::silu": None, # as relu is in _IGNORED_OPS - "aten::neg": None, # as relu is in _IGNORED_OPS - "aten::exp": None, # as relu is in _IGNORED_OPS - "aten::flip": None, # as permute is in _IGNORED_OPS - "prim::PythonOp.SelectiveScan": selective_scan_flop_jit, - "prim::PythonOp.SelectiveScanFn": selective_scan_flop_jit, - "prim::PythonOp.LayerNormFn": None - } - - model = copy.deepcopy(self) - model.cuda().eval() - - input = torch.randn((1, *input_shape), device=next(model.parameters()).device) - Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops) - print(unsupported) - - del model, input - return sum(Gflops.values()) * 1e9 - flops = 0 - from lib.utils.measure import get_flops - flops += get_flops(self.patch_embed, input_shape) - - L = self.patch_embed.num_patches - for layer in self.layers: - # 1 in_proj - flops += layer.mixer.in_proj.in_features * layer.mixer.in_proj.out_features * L - # 2 MambaInnerFnNoOutProj - # 2.1 causual conv1d - flops += (L + layer.mixer.d_conv - 1) * layer.mixer.d_inner * layer.mixer.d_conv - # 2.2 x_proj - flops += L * layer.mixer.x_proj_0.in_features * layer.mixer.x_proj_0.out_features - # 2.3 dt_proj - flops += L * layer.mixer.dt_proj_0.in_features * layer.mixer.dt_proj_0.out_features - # 2.4 selective scan - # select - # H = W = int(L ** 0.5) - # flops += layer.mixer.d_model * layer.mixer.d_model * H // 2 + layer.mixer.d_model * 2 * H // 2 - """ - u: r(B D L) - delta: r(B D L) - A: r(D N) - B: r(B N L) - C: r(B N L) - D: r(D) - z: r(B D L) - delta_bias: r(D), fp32 - """ - D = layer.mixer.d_inner - N = layer.mixer.d_state - for i in range(4): - flops += 9 * L * D * N + 2 * D * L - # # A - # flops += D * L * N - # # B - # flops += D * L * N * 2 - # # C - # flops += (D * N + D * N) * L - # # D - # flops += D * L - # # Z - # flops += D * L - # merge - attn = layer.mixer.attn - flops += attn.global_reduce.in_features * attn.global_reduce.out_features - # flops += attn.local_reduce.in_features * attn.local_reduce.out_features * L - flops += attn.channel_select.in_features * attn.channel_select.out_features - # flops += attn.spatial_select.in_features * attn.spatial_select.out_features * L - # 2.5 out_proj - flops += L * layer.mixer.out_proj.in_features * layer.mixer.out_proj.out_features - - # head - flops += self.embed_dim * 1000 - return flops - - model = copy.deepcopy(self) - model.cuda().eval() - - supported_ops={ - "aten::silu": None, # as relu is in _IGNORED_OPS - "aten::neg": None, # as relu is in _IGNORED_OPS - "aten::exp": None, # as relu is in _IGNORED_OPS - "aten::flip": None, # as permute is in _IGNORED_OPS - # "prim::PythonOp.CrossScan": None, - # "prim::PythonOp.CrossMerge": None, - "prim::PythonOp.SelectiveScan": None, - "prim::PythonOp.SelectiveScanFn": None, - } - - input = torch.randn((1, *input_shape), device=next(model.parameters()).device) - Gflops, unsupported = flop_count(model=self.cuda(), inputs=(input,), supported_ops=supported_ops) - del model, input - print(Gflops) - print(unsupported) - return sum(Gflops.values()) * 1e9 - -class Backbone_LocalVisionMamba(VisionMamba): - def __init__(self, out_indices=[4, 9, 14, 19], pretrained_ckpt=None, **kwargs): - super().__init__(**kwargs) - del self.head - del self.norm_f - - self.out_indices = out_indices - for i in range(len(out_indices)): - layer = nn.LayerNorm(self.embed_dim) - layer_name = f'outnorm_{i}' - self.add_module(layer_name, layer) - - self.load_pretrained(pretrained_ckpt) - - - def load_pretrained(self, ckpt): - print(f'Load backbone state dict from {ckpt}') - state_dict = torch.load(ckpt, map_location='cpu')['state_dict'] - if 'pos_embed' in state_dict: - pos_size = int(math.sqrt(state_dict['pos_embed'].shape[1])) - state_dict['pos_embed'] = self.resize_pos_embed( - state_dict['pos_embed'], - self.token_size, - (pos_size, pos_size), - 'bicubic' - ) - if 'rope.freqs_cos' in state_dict: - pos_size = int(math.sqrt(state_dict['rope.freqs_cos'].shape[0])) - state_dict['rope.freqs_cos'] = self.resize_pos_embed( - state_dict['rope.freqs_cos'].unsqueeze(0), - self.token_size, - (pos_size, pos_size), - 'bicubic' - )[0] - if 'rope.freqs_cos' in state_dict: - pos_size = int(math.sqrt(state_dict['rope.freqs_sin'].shape[0])) - state_dict['rope.freqs_sin'] = self.resize_pos_embed( - state_dict['rope.freqs_sin'].unsqueeze(0), - self.token_size, - (pos_size, pos_size), - 'bicubic' - )[0] - a, b = self.load_state_dict(state_dict, strict=False) - print(a, b) - - @staticmethod - def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): - from mmseg.models.utils import resize - """Resize pos_embed weights. - - Resize pos_embed using bicubic interpolate method. - Args: - pos_embed (torch.Tensor): Position embedding weights. - input_shpae (tuple): Tuple for (downsampled input image height, - downsampled input image width). - pos_shape (tuple): The resolution of downsampled origin training - image. - mode (str): Algorithm used for upsampling: - ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | - ``'trilinear'``. Default: ``'nearest'`` - Return: - torch.Tensor: The resized pos_embed of shape [B, L_new, C] - """ - assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' - pos_h, pos_w = pos_shape - pos_embed_weight = pos_embed - pos_embed_weight = pos_embed_weight.reshape( - 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) - pos_embed_weight = resize( - pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) - pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) - return pos_embed_weight - - def forward(self, x): - C = self.embed_dim - outs, (H, W) = self.forward_features(x, out_indices=self.out_indices) - outs = [getattr(self, f'outnorm_{i}')(o) for i, o in enumerate(outs)] - outs = [o.view(-1, H, W, C).permute(0, 3, 1, 2).contiguous() for o in outs] - return outs - - -@register_model -def vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual(pretrained=False, **kwargs): - model = VisionMamba( - patch_size=16, embed_dim=192, depth=20, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", **kwargs) - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="to.do", - map_location="cpu", check_hash=True - ) - model.load_state_dict(checkpoint["model"]) - return model - -@register_model -def local_vim_tiny_search(pretrained=False, **kwargs): - model = VisionMamba( - patch_size=16, embed_dim=128, depth=20, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", **kwargs) - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="to.do", - map_location="cpu", check_hash=True - ) - model.load_state_dict(checkpoint["model"]) - return model - - -@register_model -def local_vim_tiny_searched(pretrained=False, **kwargs): - directions = ( - ['h', 'v_flip', 'w7', 'w7_flip'], - ['h_flip', 'w2_flip', 'w7', 'w7_flip'], - ['h', 'h_flip', 'v', 'w7'], - ['h', 'h_flip', 'v', 'v_flip'], - ['h', 'h_flip', 'v', 'v_flip'], - ['h', 'h_flip', 'v', 'v_flip'], - ['h_flip', 'v', 'v_flip', 'w7'], - ['h_flip', 'v', 'v_flip', 'w2_flip'], - ['h', 'h_flip', 'v', 'v_flip'], - ['h', 'h_flip', 'v', 'v_flip'], - ['h', 'v', 'v_flip', 'w2'], - ['h', 'v', 'v_flip', 'w2_flip'], - ['h', 'h_flip', 'v_flip', 'w7'], - ['h_flip', 'v', 'v_flip', 'w2'], - ['h', 'h_flip', 'v', 'v_flip'], - ['h', 'v', 'w2', 'w2_flip'], - ['v', 'v_flip', 'w2', 'w7'], - ['h', 'h_flip', 'v', 'w2'], - ['h', 'h_flip', 'w2_flip', 'w7'], - ['v', 'v_flip', 'w2', 'w2_flip'], - ) - model = VisionMamba( - patch_size=16, embed_dim=192, depth=20, rms_norm=False, residual_in_fp32=True, fused_add_norm=False, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", directions=directions, **kwargs) - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="to.do", - map_location="cpu", check_hash=True - ) - model.load_state_dict(checkpoint["model"]) - return model - -@register_model -def local_vim_small_searched(pretrained=False, **kwargs): - directions = ( - ['h', 'v_flip', 'w7', 'w7_flip'], - ['h_flip', 'w2_flip', 'w7', 'w7_flip'], - ['h', 'h_flip', 'v', 'w7'], - ['h', 'h_flip', 'v', 'v_flip'], - ['h', 'h_flip', 'v', 'v_flip'], - ['h', 'h_flip', 'v', 'v_flip'], - ['h_flip', 'v', 'v_flip', 'w7'], - ['h_flip', 'v', 'v_flip', 'w2_flip'], - ['h', 'h_flip', 'v', 'v_flip'], - ['h', 'h_flip', 'v', 'v_flip'], - ['h', 'v', 'v_flip', 'w2'], - ['h', 'v', 'v_flip', 'w2_flip'], - ['h', 'h_flip', 'v_flip', 'w7'], - ['h_flip', 'v', 'v_flip', 'w2'], - ['h', 'h_flip', 'v', 'v_flip'], - ['h', 'v', 'w2', 'w2_flip'], - ['v', 'v_flip', 'w2', 'w7'], - ['h', 'h_flip', 'v', 'w2'], - ['h', 'h_flip', 'w2_flip', 'w7'], - ['v', 'v_flip', 'w2', 'w2_flip'], - ) - model = VisionMamba( - patch_size=16, embed_dim=384, depth=20, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", directions=directions, **kwargs) - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="to.do", - map_location="cpu", check_hash=True - ) - model.load_state_dict(checkpoint["model"]) - return model - - -@register_model -def vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual_with_cls_token(pretrained=False, **kwargs): - model = VisionMamba( - patch_size=16, embed_dim=192, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", if_cls_token=True, **kwargs) - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="to.do", - map_location="cpu", check_hash=True - ) - model.load_state_dict(checkpoint["model"]) - return model - - -@register_model -def vim_tiny_patch8_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual(pretrained=False, **kwargs): - model = VisionMamba( - patch_size=8, embed_dim=192, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", **kwargs) - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="to.do", - map_location="cpu", check_hash=True - ) - model.load_state_dict(checkpoint["model"]) - return model - - -@register_model -def vim_tiny_patch8_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual_with_cls_token(pretrained=False, **kwargs): - model = VisionMamba( - patch_size=8, embed_dim=192, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", if_cls_token=True, **kwargs) - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="to.do", - map_location="cpu", check_hash=True - ) - model.load_state_dict(checkpoint["model"]) - return model - - -@register_model -def vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual(pretrained=False, **kwargs): - model = VisionMamba( - patch_size=16, embed_dim=384, depth=20, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", **kwargs) - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="to.do", - map_location="cpu", check_hash=True - ) - model.load_state_dict(checkpoint["model"]) - return model - - -@register_model -def vim_base_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual(pretrained=False, **kwargs): - model = VisionMamba( - patch_size=16, embed_dim=768, depth=20, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", **kwargs) - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="to.do", - map_location="cpu", check_hash=True - ) - model.load_state_dict(checkpoint["model"]) - return model diff --git a/classification/lib/models/vmamba.py b/classification/lib/models/vmamba.py deleted file mode 100644 index a2fa0ba..0000000 --- a/classification/lib/models/vmamba.py +++ /dev/null @@ -1,1625 +0,0 @@ -import os -import time -import math -import copy -from functools import partial -from typing import Optional, Callable, Any -from collections import OrderedDict - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from einops import rearrange, repeat -from timm.models.layers import DropPath, trunc_normal_ -from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count -DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" - -# import mamba_ssm.selective_scan_fn (in which causal_conv1d is needed) -try: - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref -except: - pass - -# an alternative for mamba_ssm -try: - from selective_scan import selective_scan_fn as selective_scan_fn_v1 - from selective_scan import selective_scan_ref as selective_scan_ref_v1 -except: - pass - -# cross selective scan =============================== -if True: - import selective_scan_cuda_core as selective_scan_cuda - - class SelectiveScan(torch.autograd.Function): - @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) - def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1): - assert nrows in [1, 2, 3, 4], f"{nrows}" # 8+ is too slow to compile - assert u.shape[1] % (B.shape[1] * nrows) == 0, f"{nrows}, {u.shape}, {B.shape}" - ctx.delta_softplus = delta_softplus - ctx.nrows = nrows - - # all in float - if u.stride(-1) != 1: - u = u.contiguous() - if delta.stride(-1) != 1: - delta = delta.contiguous() - if D is not None: - D = D.contiguous() - if B.stride(-1) != 1: - B = B.contiguous() - if C.stride(-1) != 1: - C = C.contiguous() - if B.dim() == 3: - B = B.unsqueeze(dim=1) - ctx.squeeze_B = True - if C.dim() == 3: - C = C.unsqueeze(dim=1) - ctx.squeeze_C = True - - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) - - ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) - return out - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, dout, *args): - u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors - if dout.stride(-1) != 1: - dout = dout.contiguous() - du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( - u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 - # u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, ctx.nrows, - ) - dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB - dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC - return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None) - - class CrossScan(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor): - B, C, H, W = x.shape - ctx.shape = (B, C, H, W) - xs = x.new_empty((B, 4, C, H * W)) - xs[:, 0] = x.flatten(2, 3) - xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3) - xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) - return xs - - @staticmethod - def backward(ctx, ys: torch.Tensor): - # out: (b, k, d, l) - B, C, H, W = ctx.shape - L = H * W - ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) - y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) - return y.view(B, -1, H, W) - - class CrossMerge(torch.autograd.Function): - @staticmethod - def forward(ctx, ys: torch.Tensor): - B, K, D, H, W = ys.shape - ctx.shape = (H, W) - ys = ys.view(B, K, D, -1) - ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) - y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) - return y - - @staticmethod - def backward(ctx, x: torch.Tensor): - # B, D, L = x.shape - # out: (b, k, d, l) - H, W = ctx.shape - B, C, L = x.shape - xs = x.new_empty((B, 4, C, L)) - xs[:, 0] = x - xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3) - xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) - xs = xs.view(B, 4, C, H, W) - return xs, None, None - - def cross_selective_scan( - x: torch.Tensor=None, - x_proj_weight: torch.Tensor=None, - x_proj_bias: torch.Tensor=None, - dt_projs_weight: torch.Tensor=None, - dt_projs_bias: torch.Tensor=None, - A_logs: torch.Tensor=None, - Ds: torch.Tensor=None, - out_norm: torch.nn.Module=None, - softmax_version=False, - nrows = -1, - delta_softplus = True, - ): - B, D, H, W = x.shape - D, N = A_logs.shape - K, D, R = dt_projs_weight.shape - L = H * W - - if nrows < 1: - if D % 4 == 0: - nrows = 4 - elif D % 3 == 0: - nrows = 3 - elif D % 2 == 0: - nrows = 2 - else: - nrows = 1 - - xs = CrossScan.apply(x) - - x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight) - if x_proj_bias is not None: - x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1) - dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) - dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight) - - xs = xs.view(B, -1, L).to(torch.float) - dts = dts.contiguous().view(B, -1, L).to(torch.float) - As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state) - Bs = Bs.contiguous().to(torch.float) - Cs = Cs.contiguous().to(torch.float) - Ds = Ds.to(torch.float) # (K * c) - delta_bias = dt_projs_bias.view(-1).to(torch.float) - - # to enable fvcore.nn.jit_analysis: inputs[i].debugName - def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1): - return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) - - ys: torch.Tensor = selective_scan( - xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus, nrows, - ).view(B, K, -1, H, W) - - y = CrossMerge.apply(ys) - - if softmax_version: - y = y.softmax(y, dim=-1).to(x.dtype) - y = y.transpose(dim0=1, dim1=2).contiguous().view(B, H, W, -1) - else: - y = y.transpose(dim0=1, dim1=2).contiguous().view(B, H, W, -1) - y = out_norm(y).to(x.dtype) - - return y - - -# fvcore flops ======================================= - -def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): - """ - u: r(B D L) - delta: r(B D L) - A: r(D N) - B: r(B N L) - C: r(B N L) - D: r(D) - z: r(B D L) - delta_bias: r(D), fp32 - - ignores: - [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] - """ - assert not with_complex - # https://github.com/state-spaces/mamba/issues/110 - flops = 9 * B * L * D * N - if with_D: - flops += B * D * L - if with_Z: - flops += B * D * L - return flops - - -def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): - """ - u: r(B D L) - delta: r(B D L) - A: r(D N) - B: r(B N L) - C: r(B N L) - D: r(D) - z: r(B D L) - delta_bias: r(D), fp32 - - ignores: - [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] - """ - import numpy as np - - # fvcore.nn.jit_handles - def get_flops_einsum(input_shapes, equation): - np_arrs = [np.zeros(s) for s in input_shapes] - optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] - for line in optim.split("\n"): - if "optimized flop" in line.lower(): - # divided by 2 because we count MAC (multiply-add counted as one flop) - flop = float(np.floor(float(line.split(":")[-1]) / 2)) - return flop - - - assert not with_complex - - flops = 0 # below code flops = 0 - if False: - ... - """ - dtype_in = u.dtype - u = u.float() - delta = delta.float() - if delta_bias is not None: - delta = delta + delta_bias[..., None].float() - if delta_softplus: - delta = F.softplus(delta) - batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] - is_variable_B = B.dim() >= 3 - is_variable_C = C.dim() >= 3 - if A.is_complex(): - if is_variable_B: - B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) - if is_variable_C: - C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) - else: - B = B.float() - C = C.float() - x = A.new_zeros((batch, dim, dstate)) - ys = [] - """ - - flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") - if with_Group: - flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") - else: - flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") - if False: - ... - """ - deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) - if not is_variable_B: - deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) - else: - if B.dim() == 3: - deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) - else: - B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) - deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) - if is_variable_C and C.dim() == 4: - C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) - last_state = None - """ - - in_for_flops = B * D * N - if with_Group: - in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") - else: - in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") - flops += L * in_for_flops - if False: - ... - """ - for i in range(u.shape[2]): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] - if not is_variable_C: - y = torch.einsum('bdn,dn->bd', x, C) - else: - if C.dim() == 3: - y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) - else: - y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) - if i == u.shape[2] - 1: - last_state = x - if y.is_complex(): - y = y.real * 2 - ys.append(y) - y = torch.stack(ys, dim=2) # (batch dim L) - """ - - if with_D: - flops += B * D * L - if with_Z: - flops += B * D * L - if False: - ... - """ - out = y if D is None else y + u * rearrange(D, "d -> d 1") - if z is not None: - out = out * F.silu(z) - out = out.to(dtype=dtype_in) - """ - - return flops - - -def print_jit_input_names(inputs): - # tensor.11, dt.1, A.1, B.1, C.1, D.1, z.1, None - try: - print("input params: ", end=" ", flush=True) - for i in range(10): - print(inputs[i].debugName(), end=" ", flush=True) - except Exception as e: - pass - print("", flush=True) - - -def selective_scan_flop_jit(inputs, outputs): - print_jit_input_names(inputs) - - # xs, dts, As, Bs, Cs, Ds (skip), z (skip), dt_projs_bias (skip) - assert inputs[0].debugName().startswith("xs") # (B, D, L) - assert inputs[1].debugName().startswith("dts") # (B, D, L) - assert inputs[2].debugName().startswith("As") # (D, N) - assert inputs[3].debugName().startswith("Bs") # (D, N) - assert inputs[4].debugName().startswith("Cs") # (D, N) - with_Group = len(inputs[3].type().sizes()) == 4 - with_D = inputs[5].debugName().startswith("Ds") - if not with_D: - with_z = len(inputs) > 5 and inputs[5].debugName().startswith("z") - else: - with_z = len(inputs) > 6 and inputs[6].debugName().startswith("z") - B, D, L = inputs[0].type().sizes() - N = inputs[2].type().sizes()[1] - flops = flops_selective_scan_fn(B=B, L=L, D=D, N=N, with_D=with_D, with_Z=with_z, with_Group=with_Group) - # flops = flops_selective_scan_ref(B=B, L=L, D=D, N=N, with_D=with_D, with_Z=with_z, with_Group=with_Group) - return flops - -# ===================================================== - -class PatchMerging2D(nn.Module): - def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.reduction = nn.Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False) - self.norm = norm_layer(4 * dim) - - @staticmethod - def _patch_merging_pad(x: torch.Tensor): - H, W, _ = x.shape[-3:] - if (W % 2 != 0) or (H % 2 != 0): - x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) - x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C - x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C - x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C - x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C - return x - - def forward(self, x): - x = self._patch_merging_pad(x) - x = self.norm(x) - x = self.reduction(x) - - return x - - -DEV = False -class SS2D(nn.Module): - def __init__( - self, - # basic dims =========== - d_model=96, - d_state=16, - ssm_ratio=2, - dt_rank="auto", - # dwconv =============== - # d_conv=-1, # < 2 means no conv - d_conv=3, # < 2 means no conv - conv_bias=True, - # ====================== - dropout=0., - bias=False, - # dt init ============== - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - # ====================== - softmax_version=False, - # ====================== - **kwargs, - ): - if DEV: - d_conv = -1 - - factory_kwargs = {"device": None, "dtype": None} - super().__init__() - self.softmax_version = softmax_version - self.d_model = d_model - self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_state # 20240109 - self.d_conv = d_conv - self.expand = ssm_ratio - self.d_inner = int(self.expand * self.d_model) - self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank - - self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) - - # conv ======================================= - if self.d_conv > 1: - self.conv2d = nn.Conv2d( - in_channels=self.d_inner, - out_channels=self.d_inner, - groups=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - padding=(d_conv - 1) // 2, - **factory_kwargs, - ) - self.act = nn.SiLU() - - # x proj; dt proj ============================ - self.K = 4 if not (self.forward_core == self.forward_corev1_share_ssm) else 1 - self.x_proj = [ - nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs) - for _ in range(self.K) - ] - self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) - del self.x_proj - - self.dt_projs = [ - self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) - for _ in range(self.K) - ] - self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank) - self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner) - del self.dt_projs - - # A, D ======================================= - self.K2 = self.K if not (self.forward_core == self.forward_corev1_share_a) else 1 - self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=self.K2, merge=True) # (K * D, N) - self.Ds = self.D_init(self.d_inner, copies=self.K2, merge=True) # (K * D) - - # out proj ======================================= - if not self.softmax_version: - self.out_norm = nn.LayerNorm(self.d_inner) - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() - - @staticmethod - def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): - dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) - - # Initialize special dt projection to preserve variance at initialization - dt_init_std = dt_rank**-0.5 * dt_scale - if dt_init == "constant": - nn.init.constant_(dt_proj.weight, dt_init_std) - elif dt_init == "random": - nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt = torch.exp( - torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ).clamp(min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - with torch.no_grad(): - dt_proj.bias.copy_(inv_dt) - # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit - # dt_proj.bias._no_reinit = True - - return dt_proj - - @staticmethod - def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): - # S4D real initialization - A = repeat( - torch.arange(1, d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=d_inner, - ).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - if copies > 0: - A_log = repeat(A_log, "d n -> r d n", r=copies) - if merge: - A_log = A_log.flatten(0, 1) - A_log = nn.Parameter(A_log) - A_log._no_weight_decay = True - return A_log - - @staticmethod - def D_init(d_inner, copies=-1, device=None, merge=True): - # D "skip" parameter - D = torch.ones(d_inner, device=device) - if copies > 0: - D = repeat(D, "n1 -> r n1", r=copies) - if merge: - D = D.flatten(0, 1) - D = nn.Parameter(D) # Keep in fp32 - D._no_weight_decay = True - return D - - def forward_corev0(self, x: torch.Tensor): - selective_scan = selective_scan_fn - - B, C, H, W = x.shape - L = H * W - K = 4 - - x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) - xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) - - x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) - # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) - dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) - dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) - - xs = xs.float().view(B, -1, L) # (b, k * d, l) - dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) - Bs = Bs.float() # (b, k, d_state, l) - Cs = Cs.float() # (b, k, d_state, l) - - As = -torch.exp(self.A_logs.float()) # (k * d, d_state) - Ds = self.Ds.float() # (k * d) - dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) - - # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4 - # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1 - - out_y = selective_scan( - xs, dts, - As, Bs, Cs, Ds, z=None, - delta_bias=dt_projs_bias, - delta_softplus=True, - return_last_state=False, - ).view(B, K, -1, L) - # assert out_y.dtype == torch.float - - inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) - wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y - y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) - y = self.out_norm(y) - - return y - - def forward_corev0_seq(self, x: torch.Tensor): - selective_scan = selective_scan_fn - - B, C, H, W = x.shape - L = H * W - K = 4 - - x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) - xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) - - x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) - # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) - dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) - dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) - - xs = xs.float() # (b, k, d, l) - dts = dts.contiguous().float() # (b, k, d, l) - Bs = Bs.float() # (b, k, d_state, l) - Cs = Cs.float() # (b, k, d_state, l) - - As = -torch.exp(self.A_logs.float()).view(K, -1, self.d_state) # (k, d, d_state) - Ds = self.Ds.float().view(K, -1) # (k, d) - dt_projs_bias = self.dt_projs_bias.float().view(K, -1) # (k, d) - - # assert len(xs.shape) == 4 and len(dts.shape) == 4 and len(Bs.shape) == 4 and len(Cs.shape) == 4 - # assert len(As.shape) == 3 and len(Ds.shape) == 2 and len(dt_projs_bias.shape) == 2 - - out_y = [] - for i in range(4): - yi = selective_scan( - xs[:, i], dts[:, i], - As[i], Bs[:, i], Cs[:, i], Ds[i], - delta_bias=dt_projs_bias[i], - delta_softplus=True, - ).view(B, -1, L) - out_y.append(yi) - out_y = torch.stack(out_y, dim=1) - assert out_y.dtype == torch.float - - inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) - wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y - y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) - y = self.out_norm(y) - - return y - - def forward_corev1(self, x: torch.Tensor, float32=True): - # float32 should be true in training!!!! otherwise, the output of selective_scan would be inf... - selective_scan = selective_scan_fn_v1 - - B, C, H, W = x.shape - L = H * W - - xs = torch.stack([x.flatten(2, 3), x.transpose(dim0=2, dim1=3).contiguous().flatten(2, 3)], dim=1) - xs = torch.cat([xs, torch.flip(xs, dims=[-1])], dim=1) # (b, k, d, l) - - x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) - # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) - dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) - dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) - - xs = xs.view(B, -1, L) # (b, k * d, l) - dts = dts.contiguous().view(B, -1, L) # (b, k * d, l) - As = -torch.exp(self.A_logs.to(torch.float)) # (k * d, d_state) - Ds = self.Ds.to(torch.float) # (k * d) - dt_projs_bias = self.dt_projs_bias.to(torch.float).view(-1) # (k * d) - - if float32: - ys: torch.Tensor = selective_scan( - xs.to(torch.float), - dts.to(torch.float), - As, - Bs.to(torch.float), - Cs.to(torch.float), - Ds, - delta_bias=dt_projs_bias, - delta_softplus=True, - ).view(B, 4, -1, L) - ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) - y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) - else: - out_y: torch.Tensor = selective_scan( - xs, dts, - As, Bs, Cs, Ds, - delta_bias=dt_projs_bias, - delta_softplus=True, - ).view(B, 4, -1, L) - # assert out_y.dtype == torch.float16 - - inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) - wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - y = out_y[:, 0].float() + inv_y[:, 0].float() + wh_y.float() + invwh_y.float() - - if self.softmax_version: - y = torch.softmax(y, dim=-1).to(x.dtype) - y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) - else: - y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) - y = self.out_norm(y).to(x.dtype) - - # if torch.isinf(y).any() or torch.isnan(y).any(): - # for item in [y, xs, dts, As, Bs, Cs, Ds]: - # print(torch.isinf(item).any(), torch.isnan(item).any(), item.max(), item.min()) - # import time; time.sleep(10000) - - return y - - def forward_corev1_share_ssm(self, x: torch.Tensor): - selective_scan = selective_scan_fn_v1 - - B, C, H, W = x.shape - L = H * W - - def cross_scan_2d(x): - # (B, C, H, W) => (B, K, C, H * W) with K = len([HW, WH, FHW, FWH]) - x_hwwh = torch.stack([x.flatten(2, 3), x.transpose(dim0=2, dim1=3).contiguous().flatten(2, 3)], dim=1) - xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) - return xs - - x_dbl = torch.einsum("b d l, c d -> b c l", x.view(B, -1, L), self.x_proj_weight[0]) - # x_dbl = x_dbl + self.x_proj_bias.view(1, -1, 1) - dt, BC = torch.split(x_dbl, [self.dt_rank, 2 * self.d_state], dim=1) - dt = torch.einsum("b r l, d r -> b d l", dt, self.dt_projs_weight[0]) - x_dt_BC = torch.cat([x, dt.view(B, -1, H, W), BC.view(B, -1, H, W)], dim=1) # (b, -1, h, w) - - x_dt_BCs = cross_scan_2d(x_dt_BC) # (b, k, d, l) - xs, dts, Bs, Cs = torch.split(x_dt_BCs, [self.d_inner, self.d_inner, self.d_state, self.d_state], dim=2) - - xs = xs.contiguous().view(B, -1, L) # (b, k * d, l) - dts = dts.contiguous().view(B, -1, L) # (b, k * d, l) - As = -torch.exp(self.A_logs.float()).repeat(4, 1) # (k * d, d_state) - Ds = self.Ds.repeat(4) # (k * d) - dt_projs_bias = self.dt_projs_bias.view(-1).repeat(4) # (k * d) - - # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4 - # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1 - - out_y = selective_scan( - xs, dts, - As, Bs, Cs, Ds, - delta_bias=dt_projs_bias, - delta_softplus=True, - ).view(B, 4, -1, L) - # assert out_y.dtype == torch.float16 - - inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) - wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - y = out_y[:, 0].float() + inv_y[:, 0].float() + wh_y.float() + invwh_y.float() - - if self.softmax_version: - y = torch.softmax(y, dim=-1).to(x.dtype) - y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) - else: - y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) - y = self.out_norm(y).to(x.dtype) - - return y - - def forward_corev1_share_a(self, x: torch.Tensor): - selective_scan = selective_scan_fn_v1 - - B, C, H, W = x.shape - L = H * W - - def cross_scan_2d(x, dim=1): - # (B, C, H, W) => (B, K, C, H * W) with K = len([HW, WH, FHW, FWH]) - x_hwwh = torch.stack([x.flatten(2, 3), x.transpose(dim0=2, dim1=3).contiguous().flatten(2, 3)], dim=dim) - xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=dim) # (b, k, d, l) - return xs - - K = 4 - xs = cross_scan_2d(x, dim=1) # (b, d, k, l) - - x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) - # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) - dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) - dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) - dts = dts + self.dt_projs_bias.to(xs.dtype).view(1, K, -1, 1) - - xs = xs.transpose(dim0=1, dim1=2).contiguous().view(B, -1, K * L) - dts = dts.transpose(dim0=1, dim1=2).contiguous().view(B, -1, K * L) - As = -torch.exp(self.A_logs.float()) # (D, N) - Ds = self.Ds.view(-1) # (D) - Bs = Bs.transpose(dim0=1, dim1=2).contiguous().view(B, 1, -1, K * L) - Cs = Cs.transpose(dim0=1, dim1=2).contiguous().view(B, 1, -1, K * L) - - # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4 - # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1 - # print(self.Ds.dtype, self.A_logs.dtype, self.dt_projs_bias.dtype, flush=True) # fp16, fp16, fp16 - - out_y = selective_scan( - xs, dts, - As, Bs, Cs, Ds, - delta_bias=None, - delta_softplus=True, - ).view(B, -1, 4, L) - # assert out_y.dtype == torch.float16 - - inv_y = torch.flip(out_y[:, :, 2:4], dims=[-1]).view(B, -1, 2, L) - wh_y = torch.transpose(out_y[:, :, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - invwh_y = torch.transpose(inv_y[:, :, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - y = out_y[:, :, 0].float() + inv_y[:, :, 0].float() + wh_y.float() + invwh_y.float() - - if self.softmax_version: - y = torch.softmax(y, dim=-1).to(x.dtype) - y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) - else: - y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) - y = self.out_norm(y).to(x.dtype) - - return y - - def forward_corev2(self, x: torch.Tensor, nrows=-1): - return cross_selective_scan( - x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias, - self.A_logs, self.Ds, getattr(self, "out_norm", None), self.softmax_version, - nrows=nrows, - ) - - # forward_core = forward_core_share_ssm - # forward_core = forward_core_share_a - # forward_core = forward_corev1 - forward_core = forward_corev2 - # forward_core = forward_corev0 - - def forward(self, x: torch.Tensor, **kwargs): - xz = self.in_proj(x) - if self.d_conv > 1: - x, z = xz.chunk(2, dim=-1) # (b, h, w, d) - x = x.permute(0, 3, 1, 2).contiguous() - x = self.act(self.conv2d(x)) # (b, d, h, w) - y = self.forward_core(x) - if self.softmax_version: - y = y * z - else: - y = y * F.silu(z) - else: - if self.softmax_version: - x, z = xz.chunk(2, dim=-1) # (b, h, w, d) - x = F.silu(x) - else: - xz = F.silu(xz) - x, z = xz.chunk(2, dim=-1) # (b, h, w, d) - x = x.permute(0, 3, 1, 2).contiguous() - y = self.forward_core(x) - y = y * z - out = self.dropout(self.out_proj(y)) - return out - - -class Permute(nn.Module): - def __init__(self, *args): - super().__init__() - self.args = args - - def forward(self, x: torch.Tensor): - return x.permute(*self.args) - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - - Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear - self.fc1 = Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class VSSBlock(nn.Module): - def __init__( - self, - hidden_dim: int = 0, - drop_path: float = 0, - norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), - attn_drop_rate: float = 0, - d_state: int = 16, - dt_rank: Any = "auto", - ssm_ratio=2.0, - shared_ssm=False, - softmax_version=False, - use_checkpoint: bool = False, - mlp_ratio=4.0, - act_layer=nn.GELU, - drop: float = 0.0, - **kwargs, - ): - super().__init__() - self.use_checkpoint = use_checkpoint - self.norm = norm_layer(hidden_dim) - self.op = SS2D( - d_model=hidden_dim, - dropout=attn_drop_rate, - d_state=d_state, - ssm_ratio=ssm_ratio, - dt_rank=dt_rank, - shared_ssm=shared_ssm, - softmax_version=softmax_version, - **kwargs - ) - self.drop_path = DropPath(drop_path) - - self.mlp_branch = mlp_ratio > 0 - if self.mlp_branch: - self.norm2 = norm_layer(hidden_dim) - mlp_hidden_dim = int(hidden_dim * mlp_ratio) - self.mlp = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, channels_first=False) - - def _forward(self, input: torch.Tensor): - x = input + self.drop_path(self.op(self.norm(input))) - if self.mlp_branch: - x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN - return x - - def forward(self, input: torch.Tensor): - if self.use_checkpoint: - return checkpoint.checkpoint(self._forward, input) - else: - return self._forward(input) - - -class VSSM(nn.Module): - def __init__( - self, - patch_size=4, - in_chans=3, - num_classes=1000, - depths=[2, 2, 9, 2], - dims=[96, 192, 384, 768], - # ========================= - d_state=16, - dt_rank="auto", - ssm_ratio=2.0, - attn_drop_rate=0., - shared_ssm=False, - softmax_version=False, - # ========================= - drop_rate=0., - drop_path_rate=0.1, - mlp_ratio=4.0, - patch_norm=True, - norm_layer=nn.LayerNorm, - downsample_version: str = "v2", - use_checkpoint=False, - **kwargs, - ): - super().__init__() - self.num_classes = num_classes - self.num_layers = len(depths) - if isinstance(dims, int): - dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] - self.embed_dim = dims[0] - self.num_features = dims[-1] - self.dims = dims - - self.patch_embed = nn.Sequential( - nn.Conv2d(in_chans, self.embed_dim, kernel_size=patch_size, stride=patch_size, bias=True), - Permute(0, 2, 3, 1), - (norm_layer(self.embed_dim) if patch_norm else nn.Identity()), - ) - - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - - if downsample_version == "v2": - downsample = self._make_downsample( - self.dims[i_layer], - self.dims[i_layer + 1], - norm_layer=norm_layer, - ) if (i_layer < self.num_layers - 1) else nn.Identity() - else: - downsample = PatchMerging2D( - self.dims[i_layer], - self.dims[i_layer + 1], - norm_layer=norm_layer, - ) if (i_layer < self.num_layers - 1) else nn.Identity() - - self.layers.append(self._make_layer( - dim = self.dims[i_layer], - depth = depths[i_layer], - drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - use_checkpoint=use_checkpoint, - norm_layer=norm_layer, - downsample=downsample, - d_state=d_state, - dt_rank=dt_rank, - ssm_ratio=ssm_ratio, - attn_drop_rate=attn_drop_rate, - shared_ssm=shared_ssm, - softmax_version=softmax_version, - mlp_ratio=mlp_ratio, - drop_rate=drop_rate, - )) - - self.classifier = nn.Sequential(OrderedDict( - norm=norm_layer(self.num_features), # B,H,W,C - permute=Permute(0, 3, 1, 2), - avgpool=nn.AdaptiveAvgPool2d(1), - flatten=nn.Flatten(1), - head=nn.Linear(self.num_features, num_classes), - )) - - self.apply(self._init_weights) - - def _init_weights(self, m: nn.Module): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @staticmethod - def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm): - return nn.Sequential( - Permute(0, 3, 1, 2), - nn.Conv2d(dim, out_dim, kernel_size=2, stride=2), - Permute(0, 2, 3, 1), - norm_layer(out_dim), - ) - - @staticmethod - def _make_layer( - dim=96, - depth=2, - drop_path=[0.1, 0.1], - use_checkpoint=False, - norm_layer=nn.LayerNorm, - downsample=nn.Identity(), - # =========================== - d_state=16, - dt_rank="auto", - ssm_ratio=2.0, - attn_drop_rate=0.0, - shared_ssm=False, - softmax_version=False, - # =========================== - mlp_ratio=4.0, - drop_rate=0.0, - **kwargs, - ): - assert depth == len(drop_path) - blocks = [] - for d in range(depth): - blocks.append(VSSBlock( - hidden_dim=dim, - drop_path=drop_path[d], - norm_layer=norm_layer, - attn_drop_rate=attn_drop_rate, - d_state=d_state, - dt_rank=dt_rank, - ssm_ratio=ssm_ratio, - shared_ssm=shared_ssm, - softmax_version=softmax_version, - use_checkpoint=use_checkpoint, - mlp_ratio=mlp_ratio, - act_layer=nn.GELU, - drop=drop_rate, - **kwargs, - )) - - return nn.Sequential(OrderedDict( - blocks=nn.Sequential(*blocks,), - downsample=downsample, - )) - - def forward(self, x: torch.Tensor): - x = self.patch_embed(x) - for layer in self.layers: - x = layer(x) - x = self.classifier(x) - return x - - def flops(self, shape=(3, 224, 224)): - # shape = self.__input_shape__[1:] - supported_ops={ - "aten::silu": None, # as relu is in _IGNORED_OPS - "aten::neg": None, # as relu is in _IGNORED_OPS - "aten::exp": None, # as relu is in _IGNORED_OPS - "aten::flip": None, # as permute is in _IGNORED_OPS - "prim::PythonOp.CrossScan": None, - "prim::PythonOp.CrossMerge": None, - "prim::PythonOp.SelectiveScan": selective_scan_flop_jit, - "prim::PythonOp.SelectiveScanFn": selective_scan_flop_jit, - } - - model = copy.deepcopy(self) - model.cuda().eval() - - input = torch.randn((1, *shape), device=next(model.parameters()).device) - params = parameter_count(model)[""] - Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops) - - del model, input - return sum(Gflops.values()) * 1e9 - return f"params {params} GFLOPs {sum(Gflops.values())}" - - # used to load ckpt from previous training code - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - - def check_name(src, state_dict: dict = state_dict, strict=False): - if strict: - if prefix + src in list(state_dict.keys()): - return True - else: - key = prefix + src - for k in list(state_dict.keys()): - if k.startswith(key): - return True - return False - - def change_name(src, dst, state_dict: dict = state_dict, strict=False): - if strict: - if prefix + src in list(state_dict.keys()): - state_dict[prefix + dst] = state_dict[prefix + src] - state_dict.pop(prefix + src) - else: - key = prefix + src - for k in list(state_dict.keys()): - if k.startswith(key): - new_k = prefix + dst + k[len(key):] - state_dict[new_k] = state_dict[k] - state_dict.pop(k) - - change_name("patch_embed.proj", "patch_embed.0") - change_name("patch_embed.norm", "patch_embed.2") - for i in range(100): - for j in range(100): - change_name(f"layers.{i}.blocks.{j}.ln_1", f"layers.{i}.blocks.{j}.norm") - change_name(f"layers.{i}.blocks.{j}.self_attention", f"layers.{i}.blocks.{j}.op") - change_name("norm", "classifier.norm") - change_name("head", "classifier.head") - - return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - -# compatible with openmmlab -class Backbone_VSSM(VSSM): - def __init__(self, patch_size=4, in_chans=3, num_classes=1000, - depths=[2, 2, 9, 2], dims=[96, 192, 384, 768], - d_state=16, ssm_ratio=2.0, attn_drop_rate=0., - drop_rate=0., drop_path_rate=0.1, mlp_ratio=4.0, - patch_norm=True, norm_layer=nn.LayerNorm, - downsample_version: str = "v2", - use_checkpoint=False, - out_indices=(0, 1, 2, 3), pretrained=None, - **kwargs, - ): - super().__init__(patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, - depths=depths, dims=dims, - d_state=d_state, ssm_ratio=ssm_ratio, attn_drop_rate=attn_drop_rate, - drop_rate=drop_rate, drop_path_rate=drop_path_rate, mlp_ratio=mlp_ratio, - patch_norm=patch_norm, norm_layer=norm_layer, - downsample_version=downsample_version, - use_checkpoint=use_checkpoint, - **kwargs) - - self.out_indices = out_indices - for i in out_indices: - layer = norm_layer(self.dims[i]) - layer_name = f'outnorm{i}' - self.add_module(layer_name, layer) - - del self.classifier - self.load_pretrained(pretrained) - - def load_pretrained(self, ckpt=None, key="model"): - if ckpt is None: - return - - try: - _ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu")) - print(f"Successfully load ckpt {ckpt}") - incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False) - print(incompatibleKeys) - except Exception as e: - print(f"Failed loading checkpoint form {ckpt}: {e}") - - def forward(self, x): - def layer_forward(l, x): - x = l.blocks(x) - y = l.downsample(x) - return x, y - - x = self.patch_embed(x) - outs = [] - for i, layer in enumerate(self.layers): - o, x = layer_forward(layer, x) # (B, H, W, C) - if i in self.out_indices: - norm_layer = getattr(self, f'outnorm{i}') - out = norm_layer(o) - out = out.permute(0, 3, 1, 2).contiguous() - outs.append(out) - - if len(self.out_indices) == 0: - return x - - return outs - - -# ================================================== -def check_vssm_equals_vmambadp(): - try: - from _ignore.vmamba.vmamba_bak1 import VMamba2Dp - from _ignore.vmamba.vmamba_pub import VSSM - except: - print("original VSSM and VMamba2Dp not found.", flush=True) - return - - # test 1 True ================================= - torch.manual_seed(time.time()); torch.cuda.manual_seed(time.time()) - oldvss = VMamba2Dp(depths=[2,2,6,2]).half().cuda() - newvss = VSSM(depths=[2,2,6,2]).half().cuda() - newvss.load_state_dict(oldvss.state_dict()) - input = torch.randn((12, 3, 224, 224)).half().cuda() - torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y1 = oldvss.forward_backbone(input) - torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y2 = newvss.forward_backbone(input) - print((y1 -y2).abs().sum()) # tensor(0., device='cuda:0', grad_fn=) - - torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y1 = oldvss.forward(input) - torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y2 = newvss.forward(input) - print((y1 -y2).abs().sum()) # tensor(0., device='cuda:0', grad_fn=) - - # test 2 True ========================================== - torch.manual_seed(0); torch.cuda.manual_seed(0) - oldvss = VMamba2Dp(depths=[2,2,6,2]).cuda() - torch.manual_seed(0); torch.cuda.manual_seed(0) - newvss = VSSM(depths=[2,2,6,2]).cuda() - - miss_align = 0 - for k, v in oldvss.state_dict().items(): - same = (oldvss.state_dict()[k] == newvss.state_dict()[k]).all() - if not same: - print(k, same) - miss_align += 1 - print("init miss align", miss_align) # init miss align 0 - - -def check_vssm1_equals_vssm(ss2dfwd=SS2D.forward_corev0): - try: - from _ignore.vmamba.vmamba_pub import VSSM as VSSM0 - except: - print("original VSSM and VMamba2Dp not found.", flush=True) - return - orifwdcore = SS2D.forward_core - SS2D.forward_core = ss2dfwd - - class VSSM_(VSSM): - def __init__( - self, - patch_size=4, - in_chans=3, - num_classes=1000, - depths=[2, 2, 9, 2], - dims=[96, 192, 384, 768], - # ========================= - d_state=16, - dt_rank="auto", - ssm_ratio=2.0, - attn_drop_rate=0., - # ========================= - drop_rate=0., - drop_path_rate=0.1, - mlp_ratio=4.0, - patch_norm=True, - norm_layer=nn.LayerNorm, - downsample_version: str = "v2", - use_checkpoint=False, - **kwargs, - ): - nn.Module.__init__(self) - self.num_classes = num_classes - self.num_layers = len(depths) - if isinstance(dims, int): - dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] - self.embed_dim = dims[0] - self.num_features = dims[-1] - self.dims = dims - - self.patch_embed = nn.Sequential( - nn.Conv2d(in_chans, self.embed_dim, kernel_size=patch_size, stride=patch_size, bias=True), - Permute(0, 2, 3, 1), - (norm_layer(self.embed_dim) if patch_norm else nn.Identity()), - ) - - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - - # if downsample_version == "v2": - # downsample = self._make_downsample( - # self.dims[i_layer], - # self.dims[i_layer + 1], - # norm_layer=norm_layer, - # ) if (i_layer < self.num_layers - 1) else nn.Identity() - # else: - # downsample = PatchMerging2D( - # self.dims[i_layer], - # self.dims[i_layer + 1], - # norm_layer=norm_layer, - # ) if (i_layer < self.num_layers - 1) else nn.Identity() - - self.layers.append(self._make_layer( - dim = self.dims[i_layer], - depth = depths[i_layer], - drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - use_checkpoint=use_checkpoint, - norm_layer=norm_layer, - downsample=(i_layer < self.num_layers - 1), - d_state=d_state, - dt_rank=dt_rank, - ssm_ratio=ssm_ratio, - attn_drop_rate=attn_drop_rate, - mlp_ratio=mlp_ratio, - drop_rate=drop_rate, - )) - - self.classifier = nn.Sequential(OrderedDict( - norm=norm_layer(self.num_features), # B,H,W,C - permute=Permute(0, 3, 1, 2), - avgpool=nn.AdaptiveAvgPool2d(1), - flatten=nn.Flatten(1), - head=nn.Linear(self.num_features, num_classes), - )) - self.apply(self._init_weights) - - def _make_layer( - self, - dim=96, - depth=2, - drop_path=[0.1, 0.1], - use_checkpoint=False, - norm_layer=nn.LayerNorm, - downsample=nn.Identity(), - # =========================== - d_state=16, - dt_rank="auto", - ssm_ratio=2.0, - attn_drop_rate=0.0, - # =========================== - mlp_ratio=4.0, - drop_rate=0.0, - **kwargs, - ): - assert depth == len(drop_path) - blocks = [] - for d in range(depth): - blocks.append(VSSBlock( - hidden_dim=dim, - drop_path=drop_path[d], - norm_layer=norm_layer, - attn_drop_rate=attn_drop_rate, - d_state=d_state, - dt_rank=dt_rank, - ssm_ratio=ssm_ratio, - use_checkpoint=use_checkpoint, - mlp_ratio=mlp_ratio, - act_layer=nn.GELU, - drop=drop_rate, - **kwargs, - )) - # blocks[d].op = SS2D0(blocks[d].op.d_model) - - - if True: # is this really applied? Yes, but been overriden later in VSSM! - def _init_weights(module: nn.Module): - for name, p in module.named_parameters(): - if name in ["out_proj.weight"]: - p = p.clone().detach_() # fake init, just to keep the seed .... - nn.init.kaiming_uniform_(p, a=math.sqrt(5)) - layer = nn.Sequential(*copy.deepcopy(blocks)) - layer.apply(_init_weights) - - downsample = PatchMerging2D(dim, 2*dim, norm_layer=norm_layer) if downsample else nn.Identity() - - return nn.Sequential(OrderedDict( - blocks=nn.Sequential(*blocks,), - downsample=downsample, - )) - - def forward_backbone(self, x): - x = self.patch_embed(x) - for l in self.layers: - x = l(x) - return x - - def forward1(self, x: torch.Tensor): - x = self.patch_embed(x) - for layer in self.layers: - x = layer(x) - x = self.classifier.norm(x) - # here: whether has contiguous would differ - x = self.classifier.avgpool(x.permute(0, 3, 1, 2).contiguous()).flatten(1) - x = self.classifier.head(x) - return x - - VSSM1 = partial(VSSM_, downsample_version="v1", mlp_ratio=0.0, ssm_ratio=2.0, dt_rank="auto", d_state=16) - - # test 1 True ================================= - torch.manual_seed(time.time()); torch.cuda.manual_seed(time.time()) - oldvss = VSSM0(depths=[2,2,6,2]).half().cuda() - newvss = VSSM1(depths=[2,2,6,2]).half().cuda() - newvss.load_state_dict(oldvss.state_dict()) - input = torch.randn((12, 3, 224, 224)).half().cuda() - torch.manual_seed(0); torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y1 = oldvss.forward_backbone(input) - torch.manual_seed(0); torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y2 = newvss.forward_backbone(input) - print((y1 -y2).abs().sum()) # tensor(0., device='cuda:0', grad_fn=) - - torch.manual_seed(0); torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y1 = oldvss.forward(input) - torch.manual_seed(0); torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y2 = newvss.forward1(input) - print((y1 -y2).abs().sum()) # tensor(0., device='cuda:0', grad_fn=) - torch.manual_seed(0); torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y3 = newvss.forward(input) - print((y1 -y3).abs().sum()) # tensor(0.0008, device='cuda:0', grad_fn=) - - # test 2 True ========================================== - torch.manual_seed(0); torch.cuda.manual_seed(0) - oldvss = VSSM0(depths=[2,2,6,2]).cuda() - torch.manual_seed(0); torch.cuda.manual_seed(0) - newvss = VSSM1(depths=[2,2,6,2]).cuda() - - miss_align = 0 - oldvss2new = copy.deepcopy(newvss) - oldvss2new.load_state_dict(oldvss.state_dict()) - for k, v in oldvss2new.state_dict().items(): - same = (oldvss2new.state_dict()[k] == newvss.state_dict()[k]).all() - if not same: - print(k, same) - miss_align += 1 - print("init miss align", miss_align) # init miss align 0 - SS2D.forward_core = orifwdcore - - -def check_profile(): - vss = VSSM(depths=[1], dims=1024).half().cuda() - input = torch.randn((128, 3, 56, 56)).half().cuda() - torch.cuda.manual_seed(0) - - self = vss - blk = self.layers[0].blocks[0] - ln_1 = blk.ln_1 - self_attention = blk.self_attention - selfa = self_attention - drop_path = blk.drop_path - input = self.patch_embed(input).detach() - - def trace_handler(prof: torch.profiler.profile): - print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) - # print(prof.export_chrome_trace("./tracev1.json")) - - with torch.cuda.amp.autocast(): - # with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False, profile_memory=True, with_stack=True) as prof: - with torch.profiler.profile( - with_modules=True, - with_stack=True, - profile_memory=True, - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - - # In this example with wait=1, warmup=1, active=2, repeat=1, - # profiler will skip the first step/iteration, - # start warming up on the second, record - # the third and the forth iterations, - # after which the trace will become available - # and on_trace_ready (when set) is called; - # the cycle repeats starting with the next step - - schedule=torch.profiler.schedule( - wait=1, - warmup=1, - active=2, - repeat=1), - on_trace_ready=trace_handler - # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') - # used when outputting for tensorboard - ) as prof: - for iter in range(1000): - x = input - # with torch.autograd.profiler.record_function("patch_embed"): - # x = self.patch_embed(x) - - B, H, W, C = x.shape - ori = x - - with torch.autograd.profiler.record_function("VSSBlock.ln_1"): - x = ln_1(x) - - with torch.autograd.profiler.record_function("SS2D.inproj"): - xz = selfa.in_proj(x) - x, z = xz.chunk(2, dim=-1) # (b, h, w, d) - x = x.permute(0, 3, 1, 2).contiguous() - - with torch.autograd.profiler.record_function("SS2D.dwconv2d"): - x = selfa.act(selfa.conv2d(x)) # (b, d, h, w) - # x = self.act(x) # (b, d, h, w) - - with torch.autograd.profiler.record_function("SS2D.foreward_core"): - # y = selfa.forward_corev2(x) - # y = selfa.forward_corev3(x) - y = selfa.forward_corev1(x) - # y = selfa.forward_corev1(x) - - with torch.autograd.profiler.record_function("SS2D.transpose"): - y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) - y = selfa.out_norm(y) - y = y * F.silu(z) - - with torch.autograd.profiler.record_function("SS2D.out_proj"): - out = selfa.out_proj(y) - if selfa.dropout is not None: - out = selfa.dropout(out) - - with torch.autograd.profiler.record_function("SS2D.out"): - x = ori + drop_path(out) - - with torch.autograd.profiler.record_function("backward"): - x.sum().backward() - - prof.step() - - -def load22kto1k(): - if False: - # delete relative_position_index since we always re-init it - relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] - for k in relative_position_index_keys: - del state_dict[k] - - # delete relative_coords_table since we always re-init it - relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k] - for k in relative_position_index_keys: - del state_dict[k] - - # delete attn_mask since we always re-init it - attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] - for k in attn_mask_keys: - del state_dict[k] - - # bicubic interpolate relative_position_bias_table if not match - relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] - for k in relative_position_bias_table_keys: - relative_position_bias_table_pretrained = state_dict[k] - relative_position_bias_table_current = model.state_dict()[k] - L1, nH1 = relative_position_bias_table_pretrained.size() - L2, nH2 = relative_position_bias_table_current.size() - if nH1 != nH2: - logger.warning(f"Error in loading {k}, passing......") - else: - if L1 != L2: - # bicubic interpolate relative_position_bias_table if not match - S1 = int(L1 ** 0.5) - S2 = int(L2 ** 0.5) - relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( - relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), - mode='bicubic') - state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) - - # bicubic interpolate absolute_pos_embed if not match - absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k] - for k in absolute_pos_embed_keys: - # dpe - absolute_pos_embed_pretrained = state_dict[k] - absolute_pos_embed_current = model.state_dict()[k] - _, L1, C1 = absolute_pos_embed_pretrained.size() - _, L2, C2 = absolute_pos_embed_current.size() - if C1 != C1: - logger.warning(f"Error in loading {k}, passing......") - else: - if L1 != L2: - S1 = int(L1 ** 0.5) - S2 = int(L2 ** 0.5) - absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) - absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) - absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( - absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') - absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) - absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2) - state_dict[k] = absolute_pos_embed_pretrained_resized - - # check classifier, if not match, then re-init classifier to zero - head_bias_pretrained = state_dict['head.bias'] - Nc1 = head_bias_pretrained.shape[0] - Nc2 = model.head.bias.shape[0] - if (Nc1 != Nc2): - if Nc1 == 21841 and Nc2 == 1000: - logger.info("loading ImageNet-22K weight to ImageNet-1K ......") - map22kto1k_path = f'data/map22kto1k.txt' - with open(map22kto1k_path) as f: - map22kto1k = f.readlines() - map22kto1k = [int(id22k.strip()) for id22k in map22kto1k] - state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :] - state_dict['head.bias'] = state_dict['head.bias'][map22kto1k] - else: - torch.nn.init.constant_(model.head.bias, 0.) - torch.nn.init.constant_(model.head.weight, 0.) - del state_dict['head.weight'] - del state_dict['head.bias'] - logger.warning(f"Error in loading classifier head, re-init classifier head to 0") - - - -if __name__ == "__main__": - check_vssm_equals_vmambadp() - check_vssm1_equals_vssm(ss2dfwd=SS2D.forward_corev0) - check_vssm1_equals_vssm(ss2dfwd=SS2D.forward_corev0_seq) - check_vssm1_equals_vssm(ss2dfwd=SS2D.forward_core) - check_vssm1_equals_vssm(ss2dfwd=lambda *args, **kwargs: SS2D.forward_corev1(*args, **kwargs).float()) - - - diff --git a/classification/lib/models/vmamba_mobile.py b/classification/lib/models/vmamba_mobile.py deleted file mode 100644 index 50d0f74..0000000 --- a/classification/lib/models/vmamba_mobile.py +++ /dev/null @@ -1,1565 +0,0 @@ -import os -import time -import math -import copy -from functools import partial -from typing import Optional, Callable, Any -from collections import OrderedDict - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from einops import rearrange, repeat -from timm.models.layers import DropPath, trunc_normal_ -from timm.models.registry import register_model -from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count -DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" - - -try: - "sscore acts the same as mamba_ssm" - SSMODE = "sscore" - if torch.__version__ > '2.0.0': - from selective_scan_vmamba_pt202 import selective_scan_cuda_core - else: - from selective_scan_vmamba import selective_scan_cuda_core -except Exception as e: - print(e, flush=True) - "you should install mamba_ssm to use this" - SSMODE = "mamba_ssm" - import selective_scan_cuda - # from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref - - -# fvcore flops ======================================= - -def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): - """ - u: r(B D L) - delta: r(B D L) - A: r(D N) - B: r(B N L) - C: r(B N L) - D: r(D) - z: r(B D L) - delta_bias: r(D), fp32 - - ignores: - [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] - """ - assert not with_complex - # https://github.com/state-spaces/mamba/issues/110 - flops = 9 * B * L * D * N - if with_D: - flops += B * D * L - if with_Z: - flops += B * D * L - return flops - - -# this is only for selective_scan_ref... -def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): - """ - u: r(B D L) - delta: r(B D L) - A: r(D N) - B: r(B N L) - C: r(B N L) - D: r(D) - z: r(B D L) - delta_bias: r(D), fp32 - - ignores: - [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] - """ - import numpy as np - - # fvcore.nn.jit_handles - def get_flops_einsum(input_shapes, equation): - np_arrs = [np.zeros(s) for s in input_shapes] - optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] - for line in optim.split("\n"): - if "optimized flop" in line.lower(): - # divided by 2 because we count MAC (multiply-add counted as one flop) - flop = float(np.floor(float(line.split(":")[-1]) / 2)) - return flop - - - assert not with_complex - - flops = 0 # below code flops = 0 - - flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") - if with_Group: - flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") - else: - flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") - - in_for_flops = B * D * N - if with_Group: - in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") - else: - in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") - flops += L * in_for_flops - if with_D: - flops += B * D * L - if with_Z: - flops += B * D * L - return flops - - -def print_jit_input_names(inputs): - print("input params: ", end=" ", flush=True) - try: - for i in range(10): - print(inputs[i].debugName(), end=" ", flush=True) - except Exception as e: - pass - print("", flush=True) - - -# cross selective scan =============================== - -class SelectiveScan(torch.autograd.Function): - - @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) - def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1): - assert nrows in [1, 2, 3, 4], f"{nrows}" # 8+ is too slow to compile - assert u.shape[1] % (B.shape[1] * nrows) == 0, f"{nrows}, {u.shape}, {B.shape}" - ctx.delta_softplus = delta_softplus - ctx.nrows = nrows - # all in float - if u.stride(-1) != 1: - u = u.contiguous() - if delta.stride(-1) != 1: - delta = delta.contiguous() - if D is not None: - D = D.contiguous() - if B.stride(-1) != 1: - B = B.contiguous() - if C.stride(-1) != 1: - C = C.contiguous() - if B.dim() == 3: - B = B.unsqueeze(dim=1) - ctx.squeeze_B = True - if C.dim() == 3: - C = C.unsqueeze(dim=1) - ctx.squeeze_C = True - - if SSMODE == "mamba_ssm": - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) - else: - out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) - - ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) - return out - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, dout, *args): - u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors - if dout.stride(-1) != 1: - dout = dout.contiguous() - - if SSMODE == "mamba_ssm": - du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( - u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, - False # option to recompute out_z, not used here - ) - else: - du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd( - u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 - # u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, ctx.nrows, - ) - - dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB - dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC - return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None) - - -class CrossScan(torch.autograd.Function): - # [B, C, H, W] -> [B, 4, C, H * W] (original) - # [B, C, H, W] -> [B, 4, C, H/w * W/w] - @staticmethod - def forward(ctx, x: torch.Tensor, window_size=2, index=0): # [B, C, H, W] -> [B, 4, H/w * W/w] - B, C, H, W = x.shape - ctx.shape = (B, C, H, W) - ctx.window_size = window_size - ctx.index = index - - if (W % window_size != 0) or (H % window_size != 0): - x = F.pad(x, (0, window_size - H % window_size, 0, window_size - W % window_size)) - H, W = x.shape[2:] - H = H // window_size - W = W // window_size - - xs = x.new_empty((B, 4, C, H, W)) - # print(x) - - pos = [[0, 0], [1, 0], [0, 1], [1, 1]] - - if index == 1: - pos = pos[-1:] + pos[:3] - elif index == 2: - pos = pos[-2:] + pos[:2] - elif index == 3: - pos = pos[-3:] + pos[:1] - - ctx.pos = pos - - xs[:, 0] = x[:, :, pos[0][0]::window_size, pos[0][1]::window_size] - xs[:, 1] = x[:, :, pos[1][0]::window_size, pos[1][1]::window_size].transpose(dim0=2, dim1=3) - xs[:, 2] = x[:, :, pos[2][0]::window_size, pos[2][1]::window_size] - xs[:, 3] = x[:, :, pos[3][0]::window_size, pos[3][1]::window_size].transpose(dim0=2, dim1=3) - - xs = xs.view(B, 4, C, -1) - - return xs - - @staticmethod - def backward(ctx, grad_xs: torch.Tensor): # [B, 4, H/w * W/w] -> [B, C, H, W] - - B, C, H, W = ctx.shape - window_size = ctx.window_size - index = ctx.index - - newH, newW = math.ceil(H / window_size), math.ceil(W / window_size) - grad_x = grad_xs.new_empty((B, C, newH * window_size, newW * window_size)) - - - # H, W = H // window_size, W // window_size - grad_xs = grad_xs.view(B, 4, C, newH, newW) - - - # 原Vmamba扫描四次时,每个位点的grad被累加了四次,因此输出为每个位点乘4. - # 这里奇偶搭配每个位点的grad是独立 无累加 - xs = [grad_xs[:, 0], grad_xs[:, 1].transpose(dim0=2, dim1=3), grad_xs[:, 2], grad_xs[:, 3].transpose(dim0=2, dim1=3)] - pos = ctx.pos - - grad_x[:, :, pos[0][0]::window_size, pos[0][1]::window_size] = xs[0] - grad_x[:, :, pos[1][0]::window_size, pos[1][1]::window_size] = xs[1] - grad_x[:, :, pos[2][0]::window_size, pos[2][1]::window_size] = xs[2] - grad_x[:, :, pos[3][0]::window_size, pos[3][1]::window_size] = xs[3] - - if H != grad_x.shape[-2] or W != grad_x.shape[-1]: - grad_x = grad_x[:, :, :H, :W] - - return grad_x, None, None # 分别对应forward的传入参数 x, window_size, (ctx.window_size参与计算了) - - - -class CrossMerge(torch.autograd.Function): # [B, 4, C, H/w * W/w] -> [B, C, H*W] - @staticmethod - def forward(ctx, ys: torch.Tensor, ori_h, ori_w, window_size=2, index=0): - B, K, C, H, W = ys.shape # B, 4, C, H/w * W/w - ctx.shape = (H, W) - ctx.ori_h = ori_h - ctx.ori_w = ori_w - ctx.window_size = window_size - ctx.index = index - - - y = ys.new_empty((B, C, H * window_size, W * window_size)) - - pos = [[0, 0], [1, 0], [0, 1], [1, 1]] - - if index == 1: - pos = pos[-1:] + pos[:3] - elif index == 2: - pos = pos[-2:] + pos[:2] - elif index == 3: - pos = pos[-3:] + pos[:1] - ctx.pos = pos - - y0, y1, y2, y3 = ys[:, 0], ys[:, 1].transpose(dim0=2, dim1=3), ys[:, 2], ys[:, 3].transpose(dim0=2, dim1=3) - - y[:, :, pos[0][0]::window_size, pos[0][1]::window_size] = y0 - y[:, :, pos[1][0]::window_size, pos[1][1]::window_size] = y1 - y[:, :, pos[2][0]::window_size, pos[2][1]::window_size] = y2 - y[:, :, pos[3][0]::window_size, pos[3][1]::window_size] = y3 - - if ori_h != H or ori_w != W: - y = y[:, :, :ori_h, :ori_w].contiguous() - - y = y.view(B, C, -1) - return y - - @staticmethod - def backward(ctx, grad_x: torch.Tensor): # [B, C, H*W] -> [B, 4, C, H/w * W/w] - # B, C, L = x.shape - # out: (b, k, c, l) - H, W = ctx.shape - B, C, L = grad_x.shape - window_size = ctx.window_size - index = ctx.index - - grad_xs = grad_x.new_empty((B, 4, C, H, W)) - - grad_x = grad_x.view(B, C, ctx.ori_h, ctx.ori_w) - - if (ctx.ori_w % window_size != 0) or (ctx.ori_h % window_size != 0): - grad_x = F.pad(grad_x, (0, window_size - ctx.ori_h % window_size, 0, window_size - ctx.ori_w % window_size)) - - pos = ctx.pos - # print(x.shape, xs.shape, xs[:, 0].shape, x[:, :, ::window_size, ::window_size].shape) # torch.Size([16, 64, 56, 56]) torch.Size([16, 4, 64, 56, 56]) torch.Size([16, 64, 56, 56]) torch.Size([16, 64, 28, 28]) - - grad_xs[:, 0] = grad_x[:, :, pos[0][0]::window_size, pos[0][1]::window_size] - grad_xs[:, 1] = grad_x[:, :, pos[1][0]::window_size, pos[1][1]::window_size].transpose(dim0=2, dim1=3) - grad_xs[:, 2] = grad_x[:, :, pos[2][0]::window_size, pos[2][1]::window_size] - grad_xs[:, 3] = grad_x[:, :, pos[3][0]::window_size, pos[3][1]::window_size].transpose(dim0=2, dim1=3) - - return grad_xs, None, None, None, None - - -def cross_selective_scan( - x: torch.Tensor=None, - x_proj_weight: torch.Tensor=None, - x_proj_bias: torch.Tensor=None, - dt_projs_weight: torch.Tensor=None, - dt_projs_bias: torch.Tensor=None, - A_logs: torch.Tensor=None, - Ds: torch.Tensor=None, - out_norm: torch.nn.Module=None, - nrows = -1, - delta_softplus = True, - to_dtype=True, - window_size = 2, - index = 0, -): - # out_norm: whatever fits (B, L, C); LayerNorm; Sigmoid; Softmax(dim=1);... - - B, D, H, W = x.shape - D, N = A_logs.shape - K, D, R = dt_projs_weight.shape - L = H * W - - if nrows < 1: - if D % 4 == 0: - nrows = 4 - elif D % 3 == 0: - nrows = 3 - elif D % 2 == 0: - nrows = 2 - else: - nrows = 1 - # H * W - ori_h, ori_w = H, W - xs = CrossScan.apply(x, window_size, index) # [B, C, H*W] -> [B, 4, C, H//w * W//w] - # H//w * W//w - H = math.ceil(H / window_size) - W = math.ceil(W / window_size) - - L = H * W - - x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight) # l fixed - - if x_proj_bias is not None: - x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1) - dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) - dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight) - - xs = xs.view(B, -1, L).to(torch.float) - dts = dts.contiguous().view(B, -1, L).to(torch.float) - As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state) - Bs = Bs.contiguous().to(torch.float) - Cs = Cs.contiguous().to(torch.float) - Ds = Ds.to(torch.float) # (K * c) - delta_bias = dt_projs_bias.view(-1).to(torch.float) - - def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1): - return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) - - ys: torch.Tensor = selective_scan( - xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus, nrows, - ).view(B, K, -1, H, W) - - y: torch.Tensor = CrossMerge.apply(ys, ori_h, ori_w, window_size, index) # [B, 4, C, H//w * W//w] -> [B, C, H*W] - - H = ori_h - W = ori_w - L = H * W - - y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) - y = out_norm(y).view(B, H, W, -1) - - return (y.to(x.dtype) if to_dtype else y) - - -def selective_scan_flop_jit(inputs, outputs): - print_jit_input_names(inputs) - B, D, L = inputs[0].type().sizes() - N = inputs[2].type().sizes()[1] - flops = flops_selective_scan_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False, with_Group=True) - return flops - - -# ===================================================== - -class PatchMerging2D(nn.Module): - def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.reduction = nn.Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False) - self.norm = norm_layer(4 * dim) - - @staticmethod - def _patch_merging_pad(x: torch.Tensor): - H, W, _ = x.shape[-3:] - if (W % 2 != 0) or (H % 2 != 0): - x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) - x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C - x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C - x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C - x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C - return x - - def forward(self, x): - x = self._patch_merging_pad(x) - x = self.norm(x) - x = self.reduction(x) - - return x - - -scan_index = 0 - - -class SS2D(nn.Module): - def __init__( - self, - # basic dims =========== - d_model=96, - d_state=16, - ssm_ratio=2.0, - ssm_rank_ratio=2.0, - dt_rank="auto", - act_layer=nn.SiLU, - # dwconv =============== - d_conv=3, # < 2 means no conv - conv_bias=True, - # ====================== - dropout=0.0, - bias=False, - # dt init ============== - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - simple_init=False, - # ====================== - forward_type="v2", - # ====================== - window_size=2, - **kwargs, - ): - """ - ssm_rank_ratio would be used in the future... - """ - factory_kwargs = {"device": None, "dtype": None} - super().__init__() - d_expand = int(ssm_ratio * d_model) - d_inner = int(min(ssm_rank_ratio, ssm_ratio) * d_model) if ssm_rank_ratio > 0 else d_expand - self.dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank - self.d_state = math.ceil(d_model / 6) if d_state == "auto" else d_state # 20240109 - self.d_conv = d_conv - - self.window_size = window_size - - global scan_index - self.index = scan_index % 4 - scan_index += 1 - - # disable z act ====================================== - self.disable_z_act = forward_type[-len("nozact"):] == "nozact" - if self.disable_z_act: - forward_type = forward_type[:-len("nozact")] - - # softmax | sigmoid | norm =========================== - if forward_type[-len("softmax"):] == "softmax": - forward_type = forward_type[:-len("softmax")] - self.out_norm = nn.Softmax(dim=1) - elif forward_type[-len("sigmoid"):] == "sigmoid": - forward_type = forward_type[:-len("sigmoid")] - self.out_norm = nn.Sigmoid() - else: - self.out_norm = nn.LayerNorm(d_inner) - - # forward_type ======================================= - self.forward_core = dict( - v0=self.forward_corev0, - v0_seq=self.forward_corev0_seq, - v1=self.forward_corev2, - v2=self.forward_corev2, - share_ssm=self.forward_corev0_share_ssm, - share_a=self.forward_corev0_share_a, - ).get(forward_type, self.forward_corev2) - self.K = 4 if forward_type not in ["share_ssm"] else 1 - self.K2 = self.K if forward_type not in ["share_a"] else 1 - - # in proj ======================================= - self.in_proj = nn.Linear(d_model, d_expand * 2, bias=bias, **factory_kwargs) - self.act: nn.Module = act_layer() - - # conv ======================================= - if self.d_conv > 1: - self.conv2d = nn.Conv2d( - in_channels=d_expand, - out_channels=d_expand, - groups=d_expand, - bias=conv_bias, - kernel_size=d_conv, - padding=(d_conv - 1) // 2, - **factory_kwargs, - ) - - # rank ratio ===================================== - self.ssm_low_rank = False - if d_inner < d_expand: - self.ssm_low_rank = True - self.in_rank = nn.Conv2d(d_expand, d_inner, kernel_size=1, bias=False, **factory_kwargs) - self.out_rank = nn.Linear(d_inner, d_expand, bias=False, **factory_kwargs) - - # x proj ============================ - self.x_proj = [ - nn.Linear(d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs) - for _ in range(self.K) - ] - self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) - del self.x_proj - - # dt proj ============================ - self.dt_projs = [ - self.dt_init(self.dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) - for _ in range(self.K) - ] - self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank) - self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner) - del self.dt_projs - - # A, D ======================================= - self.A_logs = self.A_log_init(self.d_state, d_inner, copies=self.K2, merge=True) # (K * D, N) - self.Ds = self.D_init(d_inner, copies=self.K2, merge=True) # (K * D) - - # out proj ======================================= - self.out_proj = nn.Linear(d_expand, d_model, bias=bias, **factory_kwargs) - self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() - - if simple_init: - # simple init dt_projs, A_logs, Ds - self.Ds = nn.Parameter(torch.ones((self.K2 * d_inner))) - self.A_logs = nn.Parameter(torch.randn((self.K2 * d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 - self.dt_projs_weight = nn.Parameter(torch.randn((self.K, d_inner, self.dt_rank))) - self.dt_projs_bias = nn.Parameter(torch.randn((self.K, d_inner))) - - @staticmethod - def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): - dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) - - # Initialize special dt projection to preserve variance at initialization - dt_init_std = dt_rank**-0.5 * dt_scale - if dt_init == "constant": - nn.init.constant_(dt_proj.weight, dt_init_std) - elif dt_init == "random": - nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt = torch.exp( - torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ).clamp(min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - with torch.no_grad(): - dt_proj.bias.copy_(inv_dt) - # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit - # dt_proj.bias._no_reinit = True - - return dt_proj - - @staticmethod - def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): - # S4D real initialization - A = repeat( - torch.arange(1, d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=d_inner, - ).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - if copies > 0: - A_log = repeat(A_log, "d n -> r d n", r=copies) - if merge: - A_log = A_log.flatten(0, 1) - A_log = nn.Parameter(A_log) - A_log._no_weight_decay = True - return A_log - - @staticmethod - def D_init(d_inner, copies=-1, device=None, merge=True): - # D "skip" parameter - D = torch.ones(d_inner, device=device) - if copies > 0: - D = repeat(D, "n1 -> r n1", r=copies) - if merge: - D = D.flatten(0, 1) - D = nn.Parameter(D) # Keep in fp32 - D._no_weight_decay = True - return D - - # only used to run previous version - def forward_corev0(self, x: torch.Tensor, to_dtype=False, channel_first=False): - def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1): - return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) - - if not channel_first: - x = x.permute(0, 3, 1, 2).contiguous() - B, C, H, W = x.shape - L = H * W - K = 4 - - x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) - xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) - - x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) - # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) - dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) - dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) - - xs = xs.float().view(B, -1, L) # (b, k * d, l) - dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) - Bs = Bs.float() # (b, k, d_state, l) - Cs = Cs.float() # (b, k, d_state, l) - - As = -torch.exp(self.A_logs.float()) # (k * d, d_state) - Ds = self.Ds.float() # (k * d) - dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) - - # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4 - # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1 - - out_y = selective_scan( - xs, dts, - As, Bs, Cs, Ds, - delta_bias=dt_projs_bias, - delta_softplus=True, - ).view(B, K, -1, L) - # assert out_y.dtype == torch.float - - inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) - wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y - y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) - y = self.out_norm(y).view(B, H, W, -1) - - return (y.to(x.dtype) if to_dtype else y) - - # only has speed difference with v0 - def forward_corev0_seq(self, x: torch.Tensor, to_dtype=False, channel_first=False): - def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1): - return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) - - if not channel_first: - x = x.permute(0, 3, 1, 2).contiguous() - B, C, H, W = x.shape - L = H * W - K = 4 - - x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) - xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) - - x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) - # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) - dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) - dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) - - xs = xs.float() # (b, k, d, l) - dts = dts.contiguous().float() # (b, k, d, l) - Bs = Bs.float() # (b, k, d_state, l) - Cs = Cs.float() # (b, k, d_state, l) - - As = -torch.exp(self.A_logs.float()).view(K, -1, self.d_state) # (k, d, d_state) - Ds = self.Ds.float().view(K, -1) # (k, d) - dt_projs_bias = self.dt_projs_bias.float().view(K, -1) # (k, d) - - # assert len(xs.shape) == 4 and len(dts.shape) == 4 and len(Bs.shape) == 4 and len(Cs.shape) == 4 - # assert len(As.shape) == 3 and len(Ds.shape) == 2 and len(dt_projs_bias.shape) == 2 - - out_y = [] - for i in range(4): - yi = selective_scan( - xs[:, i], dts[:, i], - As[i], Bs[:, i], Cs[:, i], Ds[i], - delta_bias=dt_projs_bias[i], - delta_softplus=True, - ).view(B, -1, L) - out_y.append(yi) - out_y = torch.stack(out_y, dim=1) - assert out_y.dtype == torch.float - - inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) - wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y - y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) - y = self.out_norm(y).view(B, H, W, -1) - - return (y.to(x.dtype) if to_dtype else y) - - - def forward_corev0_share_ssm(self, x: torch.Tensor, channel_first=False): - """ - we may conduct this ablation later, but not with v0. - """ - ... - - def forward_corev0_share_a(self, x: torch.Tensor, channel_first=False): - """ - we may conduct this ablation later, but not with v0. - """ - ... - - def forward_corev2(self, x: torch.Tensor, nrows=-1, channel_first=False, window_size=2): - nrows = 1 - if not channel_first: - x = x.permute(0, 3, 1, 2).contiguous() - if self.ssm_low_rank: - x = self.in_rank(x) - x = cross_selective_scan( - x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias, - self.A_logs, self.Ds, getattr(self, "out_norm", None), - nrows=nrows, delta_softplus=True, window_size=window_size, index=self.index - ) - if self.ssm_low_rank: - x = self.out_rank(x) - return x - - def forward(self, x: torch.Tensor, **kwargs): - xz = self.in_proj(x) - if self.d_conv > 1: - x, z = xz.chunk(2, dim=-1) # (b, h, w, d) - if not self.disable_z_act: - z = self.act(z) - x = x.permute(0, 3, 1, 2).contiguous() - x = self.act(self.conv2d(x)) # (b, d, h, w) - else: - if self.disable_z_act: - x, z = xz.chunk(2, dim=-1) # (b, h, w, d) - x = self.act(x) - else: - xz = self.act(xz) - x, z = xz.chunk(2, dim=-1) # (b, h, w, d) - y = self.forward_core(x, channel_first=(self.d_conv > 1), window_size=self.window_size) - y = y * z - out = self.dropout(self.out_proj(y)) - return out - - -class Permute(nn.Module): - def __init__(self, *args): - super().__init__() - self.args = args - - def forward(self, x: torch.Tensor): - return x.permute(*self.args) - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - - Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear - self.fc1 = Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class VSSBlock(nn.Module): - def __init__( - self, - hidden_dim: int = 0, - drop_path: float = 0, - norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), - # ============================= - ssm_d_state: int = 16, - ssm_ratio=2.0, - ssm_rank_ratio=2.0, - ssm_dt_rank: Any = "auto", - ssm_act_layer=nn.SiLU, - ssm_conv: int = 3, - ssm_conv_bias=True, - ssm_drop_rate: float = 0, - ssm_simple_init=False, - forward_type="v2", - # ============================= - mlp_ratio=4.0, - mlp_act_layer=nn.GELU, - mlp_drop_rate: float = 0.0, - # ============================= - use_checkpoint: bool = False, - window_size=2, - **kwargs, - ): - super().__init__() - self.use_checkpoint = use_checkpoint - self.norm = norm_layer(hidden_dim) - self.op = SS2D( - d_model=hidden_dim, - d_state=ssm_d_state, - ssm_ratio=ssm_ratio, - ssm_rank_ratio=ssm_rank_ratio, - dt_rank=ssm_dt_rank, - act_layer=ssm_act_layer, - # ========================== - d_conv=ssm_conv, - conv_bias=ssm_conv_bias, - # ========================== - dropout=ssm_drop_rate, - # bias=False, - # ========================== - # dt_min=0.001, - # dt_max=0.1, - # dt_init="random", - # dt_scale="random", - # dt_init_floor=1e-4, - simple_init=ssm_simple_init, - # ========================== - forward_type=forward_type, - window_size=window_size, - ) - self.drop_path = DropPath(drop_path) - - self.mlp_branch = mlp_ratio > 0 - if self.mlp_branch: - self.norm2 = norm_layer(hidden_dim) - mlp_hidden_dim = int(hidden_dim * mlp_ratio) - self.mlp = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=False) - - def _forward(self, input: torch.Tensor): - x = input + self.drop_path(self.op(self.norm(input))) - if self.mlp_branch: - x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN - return x - - def forward(self, input: torch.Tensor): - if self.use_checkpoint: - return checkpoint.checkpoint(self._forward, input) - else: - return self._forward(input) - - -class VSSM(nn.Module): - def __init__( - self, - patch_size=4, - in_chans=3, - num_classes=1000, - depths=[2, 2, 9, 2], - dims=[96, 192, 384, 768], - # ========================= - ssm_d_state=16, - ssm_ratio=2.0, - ssm_rank_ratio=2.0, - ssm_dt_rank="auto", - ssm_act_layer="silu", - ssm_conv=3, - ssm_conv_bias=True, - ssm_drop_rate=0.0, - ssm_simple_init=False, - forward_type="v2", - # ========================= - mlp_ratio=4.0, - mlp_act_layer="gelu", - mlp_drop_rate=0.0, - # ========================= - drop_path_rate=0.1, - patch_norm=True, - norm_layer="LN", - downsample_version: str = "v2", # "v1", "v2", "v3" - patchembed_version: str = "v1", # "v1", "v2" - use_checkpoint=False, - **kwargs, - ): - super().__init__() - self.num_classes = num_classes - self.num_layers = len(depths) - if isinstance(dims, int): - dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] - self.num_features = dims[-1] - self.dims = dims - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - _NORMLAYERS = dict( - ln=nn.LayerNorm, - bn=nn.BatchNorm2d, - ) - - _ACTLAYERS = dict( - silu=nn.SiLU, - gelu=nn.GELU, - relu=nn.ReLU, - sigmoid=nn.Sigmoid, - ) - - if norm_layer.lower() in ["ln"]: - norm_layer: nn.Module = _NORMLAYERS[norm_layer.lower()] - - if ssm_act_layer.lower() in ["silu", "gelu", "relu"]: - ssm_act_layer: nn.Module = _ACTLAYERS[ssm_act_layer.lower()] - - if mlp_act_layer.lower() in ["silu", "gelu", "relu"]: - mlp_act_layer: nn.Module = _ACTLAYERS[mlp_act_layer.lower()] - - _make_patch_embed = dict( - v1=self._make_patch_embed, - v2=self._make_patch_embed_v2, - ).get(patchembed_version, None) - self.patch_embed = _make_patch_embed(in_chans, dims[0], patch_size, patch_norm, norm_layer) - - _make_downsample = dict( - v1=PatchMerging2D, - v2=self._make_downsample, - v3=self._make_downsample_v3, - none=(lambda *_, **_k: None), - ).get(downsample_version, None) - - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - downsample = _make_downsample( - self.dims[i_layer], - self.dims[i_layer + 1], - norm_layer=norm_layer, - ) if (i_layer < self.num_layers - 1) else nn.Identity() - - self.layers.append(self._make_layer( - dim = self.dims[i_layer], - drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - use_checkpoint=use_checkpoint, - norm_layer=norm_layer, - downsample=downsample, - # ================= - ssm_d_state=ssm_d_state, - ssm_ratio=ssm_ratio, - ssm_rank_ratio=ssm_rank_ratio, - ssm_dt_rank=ssm_dt_rank, - ssm_act_layer=ssm_act_layer, - ssm_conv=ssm_conv, - ssm_conv_bias=ssm_conv_bias, - ssm_drop_rate=ssm_drop_rate, - ssm_simple_init=ssm_simple_init, - forward_type=forward_type, - # ================= - mlp_ratio=mlp_ratio, - mlp_act_layer=mlp_act_layer, - mlp_drop_rate=mlp_drop_rate, - )) - - self.classifier = nn.Sequential(OrderedDict( - norm=norm_layer(self.num_features), # B,H,W,C - permute=Permute(0, 3, 1, 2), - avgpool=nn.AdaptiveAvgPool2d(1), - flatten=nn.Flatten(1), - head=nn.Linear(self.num_features, num_classes), - )) - - self.apply(self._init_weights) - - def _init_weights(self, m: nn.Module): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - # used in building optimizer - # @torch.jit.ignore - # def no_weight_decay(self): - # return {} - - # used in building optimizer - # @torch.jit.ignore - # def no_weight_decay_keywords(self): - # return {} - - @staticmethod - def _make_patch_embed(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm): - return nn.Sequential( - nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True), - Permute(0, 2, 3, 1), - (norm_layer(embed_dim) if patch_norm else nn.Identity()), - ) - - @staticmethod - def _make_patch_embed_v2(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm): - assert patch_size == 4 - return nn.Sequential( - nn.Conv2d(in_chans, embed_dim // 2, kernel_size=3, stride=2, padding=1), - (Permute(0, 2, 3, 1) if patch_norm else nn.Identity()), - (norm_layer(embed_dim // 2) if patch_norm else nn.Identity()), - (Permute(0, 3, 1, 2) if patch_norm else nn.Identity()), - nn.GELU(), - nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), - Permute(0, 2, 3, 1), - (norm_layer(embed_dim) if patch_norm else nn.Identity()), - ) - - @staticmethod - def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm): - return nn.Sequential( - Permute(0, 3, 1, 2), - nn.Conv2d(dim, out_dim, kernel_size=2, stride=2), - Permute(0, 2, 3, 1), - norm_layer(out_dim), - ) - - @staticmethod - def _make_downsample_v3(dim=96, out_dim=192, norm_layer=nn.LayerNorm): - return nn.Sequential( - Permute(0, 3, 1, 2), - nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1), - Permute(0, 2, 3, 1), - norm_layer(out_dim), - ) - - @staticmethod - def _make_layer( - dim=96, - drop_path=[0.1, 0.1], - use_checkpoint=False, - norm_layer=nn.LayerNorm, - downsample=nn.Identity(), - # =========================== - ssm_d_state=16, - ssm_ratio=2.0, - ssm_rank_ratio=2.0, - ssm_dt_rank="auto", - ssm_act_layer=nn.SiLU, - ssm_conv=3, - ssm_conv_bias=True, - ssm_drop_rate=0.0, - ssm_simple_init=False, - forward_type="v2", - # =========================== - mlp_ratio=4.0, - mlp_act_layer=nn.GELU, - mlp_drop_rate=0.0, - window_size=2, - **kwargs, - ): - depth = len(drop_path) - blocks = [] - for d in range(depth): - blocks.append(VSSBlock( - hidden_dim=dim, - drop_path=drop_path[d], - norm_layer=norm_layer, - ssm_d_state=ssm_d_state, - ssm_ratio=ssm_ratio, - ssm_rank_ratio=ssm_rank_ratio, - ssm_dt_rank=ssm_dt_rank, - ssm_act_layer=ssm_act_layer, - ssm_conv=ssm_conv, - ssm_conv_bias=ssm_conv_bias, - ssm_drop_rate=ssm_drop_rate, - ssm_simple_init=ssm_simple_init, - forward_type=forward_type, - mlp_ratio=mlp_ratio, - mlp_act_layer=mlp_act_layer, - mlp_drop_rate=mlp_drop_rate, - use_checkpoint=use_checkpoint, - window_size=window_size, - )) - - return nn.Sequential(OrderedDict( - blocks=nn.Sequential(*blocks,), - downsample=downsample, - )) - - def forward(self, x: torch.Tensor): - x = self.patch_embed(x) - for layer in self.layers: - x = layer(x) - x = self.classifier(x) - return x - - def flops(self, shape=(3, 224, 224)): - return 0 - # shape = self.__input_shape__[1:] - supported_ops={ - "aten::silu": None, # as relu is in _IGNORED_OPS - "aten::neg": None, # as relu is in _IGNORED_OPS - "aten::exp": None, # as relu is in _IGNORED_OPS - "aten::flip": None, # as permute is in _IGNORED_OPS - "prim::PythonOp.CrossScan": None, - "prim::PythonOp.CrossMerge": None, - "prim::PythonOp.SelectiveScan": selective_scan_flop_jit, - "prim::PythonOp.SelectiveScanFn": selective_scan_flop_jit, - } - - model = copy.deepcopy(self) - model.cuda().eval() - - input = torch.randn((1, *shape), device=next(model.parameters()).device) - params = parameter_count(model)[""] - Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops) - - del model, input - return sum(Gflops.values()) * 1e9 - return f"params {params} GFLOPs {sum(Gflops.values())}" - - # used to load ckpt from previous training code - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - - def check_name(src, state_dict: dict = state_dict, strict=False): - if strict: - if prefix + src in list(state_dict.keys()): - return True - else: - key = prefix + src - for k in list(state_dict.keys()): - if k.startswith(key): - return True - return False - - def change_name(src, dst, state_dict: dict = state_dict, strict=False): - if strict: - if prefix + src in list(state_dict.keys()): - state_dict[prefix + dst] = state_dict[prefix + src] - state_dict.pop(prefix + src) - else: - key = prefix + src - for k in list(state_dict.keys()): - if k.startswith(key): - new_k = prefix + dst + k[len(key):] - state_dict[new_k] = state_dict[k] - state_dict.pop(k) - - change_name("patch_embed.proj", "patch_embed.0") - change_name("patch_embed.norm", "patch_embed.2") - for i in range(100): - for j in range(100): - change_name(f"layers.{i}.blocks.{j}.ln_1", f"layers.{i}.blocks.{j}.norm") - change_name(f"layers.{i}.blocks.{j}.self_attention", f"layers.{i}.blocks.{j}.op") - change_name("norm", "classifier.norm") - change_name("head", "classifier.head") - - return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - -# compatible with openmmlab -class Backbone_VSSM(VSSM): - def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer=nn.LayerNorm, **kwargs): - kwargs.update(norm_layer=norm_layer) - super().__init__(**kwargs) - - self.out_indices = out_indices - for i in out_indices: - layer = norm_layer(self.dims[i]) - layer_name = f'outnorm{i}' - self.add_module(layer_name, layer) - - del self.classifier - self.load_pretrained(pretrained) - - def load_pretrained(self, ckpt=None, key="model"): - if ckpt is None: - return - - try: - _ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu")) - print(f"Successfully load ckpt {ckpt}") - incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False) - print(incompatibleKeys) - except Exception as e: - print(f"Failed loading checkpoint form {ckpt}: {e}") - - def forward(self, x): - def layer_forward(l, x): - x = l.blocks(x) - y = l.downsample(x) - return x, y - - x = self.patch_embed(x) - outs = [] - for i, layer in enumerate(self.layers): - o, x = layer_forward(layer, x) # (B, H, W, C) - if i in self.out_indices: - norm_layer = getattr(self, f'outnorm{i}') - out = norm_layer(o) - out = out.permute(0, 3, 1, 2).contiguous() - outs.append(out) - - if len(self.out_indices) == 0: - return x - - return outs - - -# ================================================== -def check_vssm_equals_vmambadp(): - try: - from _ignore.vmamba.vmamba_bak1 import VMamba2Dp - from _ignore.vmamba.vmamba_pub import VSSM - except: - print("original VSSM and VMamba2Dp not found.", flush=True) - return - - # test 1 True ================================= - torch.manual_seed(time.time()); torch.cuda.manual_seed(time.time()) - oldvss = VMamba2Dp(depths=[2,2,6,2]).half().cuda() - newvss = VSSM(depths=[2,2,6,2]).half().cuda() - newvss.load_state_dict(oldvss.state_dict()) - input = torch.randn((12, 3, 224, 224)).half().cuda() - torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y1 = oldvss.forward_backbone(input) - torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y2 = newvss.forward_backbone(input) - print((y1 -y2).abs().sum()) # tensor(0., device='cuda:0', grad_fn=) - - torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y1 = oldvss.forward(input) - torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y2 = newvss.forward(input) - print((y1 -y2).abs().sum()) # tensor(0., device='cuda:0', grad_fn=) - - # test 2 True ========================================== - torch.manual_seed(0); torch.cuda.manual_seed(0) - oldvss = VMamba2Dp(depths=[2,2,6,2]).cuda() - torch.manual_seed(0); torch.cuda.manual_seed(0) - newvss = VSSM(depths=[2,2,6,2]).cuda() - - miss_align = 0 - for k, v in oldvss.state_dict().items(): - same = (oldvss.state_dict()[k] == newvss.state_dict()[k]).all() - if not same: - print(k, same) - miss_align += 1 - print("init miss align", miss_align) # init miss align 0 - - -def check_vssm1_equals_vssm(forward_type="v0"): - try: - from _ignore.vmamba.vmamba_pub import VSSM as VSSM0 - except: - print("original VSSM and VMamba2Dp not found.", flush=True) - return - - class VSSM_(VSSM): - @staticmethod - def _make_layer(*args, **kwargs): - layer = VSSM._make_layer(*args, **kwargs) - dim = kwargs.get("dim", None) - norm_layer = kwargs.get("norm_layer", None) - downsample = kwargs.get("downsample", None) - blocks = layer.blocks - - if True: # is this really applied? Yes, but been overriden later in VSSM! - def _init_weights(module: nn.Module): - for name, p in module.named_parameters(): - if name in ["out_proj.weight"]: - p = p.clone().detach_() # fake init, just to keep the seed .... - nn.init.kaiming_uniform_(p, a=math.sqrt(5)) - blks = nn.Sequential(*copy.deepcopy(blocks)) - blks.apply(_init_weights) - - downsample = PatchMerging2D(dim, 2*dim, norm_layer=norm_layer) if downsample is None else nn.Identity() - - return nn.Sequential(OrderedDict( - blocks=nn.Sequential(*blocks,), - downsample=downsample, - )) - - def forward_backbone(self, x): - x = self.patch_embed(x) - for l in self.layers: - x = l(x) - return x - - def forward1(self, x: torch.Tensor): - x = self.patch_embed(x) - for layer in self.layers: - x = layer(x) - x = self.classifier.norm(x) - # here: whether has contiguous would differ - x = self.classifier.avgpool(x.permute(0, 3, 1, 2).contiguous()).flatten(1) - x = self.classifier.head(x) - return x - - # only has initial difference - VSSM1 = partial(VSSM, downsample_version="v1", patchembed_version="v1", mlp_ratio=0.0, ssm_ratio=2.0, ssm_rank_ratio=2.0, forward_type=forward_type) - VSSM.forward_backbone = VSSM_.forward_backbone - VSSM.forward1 = VSSM_.forward1 - # expected to be all the same - VSSM1 = partial(VSSM_, downsample_version="none", patchembed_version="v1", mlp_ratio=0.0, ssm_ratio=2.0, ssm_rank_ratio=2.0, forward_type=forward_type) - - # test 1 True ================================= - torch.manual_seed(time.time()); torch.cuda.manual_seed(time.time()) - oldvss = VSSM0(depths=[2,2,6,2]).half().cuda() - newvss = VSSM1(depths=[2,2,6,2]).half().cuda() - newvss.load_state_dict(oldvss.state_dict()) - input = torch.randn((12, 3, 224, 224)).half().cuda() - torch.manual_seed(0); torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y1 = oldvss.forward_backbone(input) - torch.manual_seed(0); torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y2 = newvss.forward_backbone(input) - print((y1 -y2).abs().sum()) # tensor(0., device='cuda:0', grad_fn=) - - torch.manual_seed(0); torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y1 = oldvss.forward(input) - torch.manual_seed(0); torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y2 = newvss.forward1(input) - print((y1 -y2).abs().sum()) # tensor(2.5988e-05, device='cuda:0', grad_fn=) - torch.manual_seed(0); torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y3 = newvss.forward(input) - print((y1 -y3).abs().sum()) # tensor(0., device='cuda:0', grad_fn=) - - # test 2 True ========================================== - torch.manual_seed(0); torch.cuda.manual_seed(0) - oldvss = VSSM0(depths=[2,2,6,2]).cuda() - torch.manual_seed(0); torch.cuda.manual_seed(0) - newvss = VSSM1(depths=[2,2,6,2]).cuda() - - miss_align = 0 - oldvss2new = copy.deepcopy(newvss) - oldvss2new.load_state_dict(oldvss.state_dict()) - for k, v in oldvss2new.state_dict().items(): - same = (oldvss2new.state_dict()[k] == newvss.state_dict()[k]).all() - if not same: - print(k, same) - miss_align += 1 - print("init miss align", miss_align) # init miss align 0 - - -def check_profile(): - vss = VSSM(depths=[1], dims=1024).half().cuda() - input = torch.randn((128, 3, 56, 56)).half().cuda() - torch.cuda.manual_seed(0) - - self = vss - blk = self.layers[0].blocks[0] - ln_1 = blk.ln_1 - self_attention = blk.self_attention - selfa = self_attention - drop_path = blk.drop_path - input = self.patch_embed(input).detach() - - def trace_handler(prof: torch.profiler.profile): - print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) - # print(prof.export_chrome_trace("./tracev1.json")) - - with torch.cuda.amp.autocast(): - # with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False, profile_memory=True, with_stack=True) as prof: - with torch.profiler.profile( - with_modules=True, - with_stack=True, - profile_memory=True, - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - - # In this example with wait=1, warmup=1, active=2, repeat=1, - # profiler will skip the first step/iteration, - # start warming up on the second, record - # the third and the forth iterations, - # after which the trace will become available - # and on_trace_ready (when set) is called; - # the cycle repeats starting with the next step - - schedule=torch.profiler.schedule( - wait=1, - warmup=1, - active=2, - repeat=1), - on_trace_ready=trace_handler - # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') - # used when outputting for tensorboard - ) as prof: - for iter in range(1000): - x = input - # with torch.autograd.profiler.record_function("patch_embed"): - # x = self.patch_embed(x) - - B, H, W, C = x.shape - ori = x - - with torch.autograd.profiler.record_function("VSSBlock.ln_1"): - x = ln_1(x) - - with torch.autograd.profiler.record_function("SS2D.inproj"): - xz = selfa.in_proj(x) - x, z = xz.chunk(2, dim=-1) # (b, h, w, d) - x = x.permute(0, 3, 1, 2).contiguous() - - with torch.autograd.profiler.record_function("SS2D.dwconv2d"): - x = selfa.act(selfa.conv2d(x)) # (b, d, h, w) - # x = self.act(x) # (b, d, h, w) - - with torch.autograd.profiler.record_function("SS2D.foreward_core"): - # y = selfa.forward_corev2(x) - # y = selfa.forward_corev3(x) - y = selfa.forward_corev1(x) - # y = selfa.forward_corev1(x) - - with torch.autograd.profiler.record_function("SS2D.transpose"): - y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) - y = selfa.out_norm(y) - y = y * F.silu(z) - - with torch.autograd.profiler.record_function("SS2D.out_proj"): - out = selfa.out_proj(y) - if selfa.dropout is not None: - out = selfa.dropout(out) - - with torch.autograd.profiler.record_function("SS2D.out"): - x = ori + drop_path(out) - - with torch.autograd.profiler.record_function("backward"): - x.sum().backward() - - prof.step() - - -class MobileVSSM(VSSM): - def __init__( - self, - patch_size=4, - in_chans=3, - num_classes=1000, - depths=[2, 2, 9, 2], - dims=[96, 192, 384, 768], - # ========================= - d_state=16, - dt_rank="auto", - ssm_ratio=2.0, - attn_drop_rate=0., - shared_ssm=False, - softmax_version=False, - # ========================= - drop_rate=0., - drop_path_rate=0.1, - mlp_ratio=4.0, - patch_norm=True, - norm_layer=nn.LayerNorm, - downsample_version: str = "v2", - use_checkpoint=False, - window_size=2, - **kwargs, - ): - super().__init__() - self.num_classes = num_classes - self.num_layers = len(depths) - if isinstance(dims, int): - dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] - self.embed_dim = dims[0] - self.num_features = dims[-1] - self.dims = dims - - self.patch_embed = nn.Sequential( - nn.Conv2d(in_chans, self.embed_dim, kernel_size=patch_size, stride=patch_size, bias=True), - Permute(0, 2, 3, 1), - (norm_layer(self.embed_dim) if patch_norm else nn.Identity()), - ) - - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - - if downsample_version == "v2": - downsample = self._make_downsample( - self.dims[i_layer], - self.dims[i_layer + 1], - norm_layer=norm_layer, - ) if (i_layer < self.num_layers - 1) else nn.Identity() - else: - downsample = PatchMerging2D( - self.dims[i_layer], - self.dims[i_layer + 1], - norm_layer=norm_layer, - ) if (i_layer < self.num_layers - 1) else nn.Identity() - - self.layers.append(self._make_layer( - dim = self.dims[i_layer], - depth = depths[i_layer], - drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - use_checkpoint=use_checkpoint, - norm_layer=norm_layer, - downsample=downsample, - d_state=d_state, - dt_rank=dt_rank, - ssm_ratio=ssm_ratio, - attn_drop_rate=attn_drop_rate, - shared_ssm=shared_ssm, - softmax_version=softmax_version, - mlp_ratio=mlp_ratio, - drop_rate=drop_rate, - window_size=window_size - )) - - self.classifier = nn.Sequential(OrderedDict( - norm=norm_layer(self.num_features), # B,H,W,C - permute=Permute(0, 3, 1, 2), - avgpool=nn.AdaptiveAvgPool2d(1), - flatten=nn.Flatten(1), - head=nn.Linear(self.num_features, num_classes), - )) - - self.apply(self._init_weights) - - # Optionally, you can modify other layers or add new ones as per your design - - def forward(self, x: torch.Tensor): - # Ensure to incorporate the modified downsampling step at the correct stage - # print(x.shape) #[1, 3, 224, 224] - x = self.patch_embed(x) #[1, 3, 224, 224] -> [1, 56, 56, 96] - # print("forward: (self.patch_embed(x))", x.shape) - # x = self.modified_downsampling(x) # Apply the modified downsampling - - # Continue with the forward pass as in the original VSSM or with modifications - for layer in self.layers: - x = layer(x) - # print("After x = layer(x), x.shape: ", x.shape) - x = self.classifier(x) - # print("After self.classifier(x): ", x.shape) - return x - - -@register_model -def mobile_vssm_tiny(*args, drop_path_rate=0.1, **kwargs): - return MobileVSSM(dims=[96, 192, 384, 768], depths=[2, 2, 9, 2], d_state=16, mlp_ratio=0, downsample_version='v1', drop_path_rate=drop_path_rate) - -@register_model -def mobile_vssm_micro(*args, drop_path_rate=0.1, **kwargs): - return MobileVSSM(dims=[64, 128, 256, 512], depths=[2, 2, 4, 2], d_state=16, mlp_ratio=0, downsample_version='v1', drop_path_rate=drop_path_rate) - -@register_model -def mobile_vssm_nano(*args, drop_path_rate=0.1, **kwargs): - return MobileVSSM(dims=[48, 96, 192, 384], depths=[2, 2, 4, 2], d_state=16, mlp_ratio=0, downsample_version='v1', drop_path_rate=drop_path_rate) \ No newline at end of file diff --git a/classification/lib/models/vmamba_mobile_fusion.py b/classification/lib/models/vmamba_mobile_fusion.py deleted file mode 100644 index ea424fc..0000000 --- a/classification/lib/models/vmamba_mobile_fusion.py +++ /dev/null @@ -1,1629 +0,0 @@ -import os -import time -import math -import copy -from functools import partial -from typing import Optional, Callable, Any -from collections import OrderedDict - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from einops import rearrange, repeat -from timm.models.layers import DropPath, trunc_normal_ -from timm.models.registry import register_model -from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count -DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" -from lib.models.operations import InvertedResidual - -try: - "sscore acts the same as mamba_ssm" - SSMODE = "sscore" - if torch.__version__ > '2.0.0': - from selective_scan_vmamba_pt202 import selective_scan_cuda_core - else: - from selective_scan_vmamba import selective_scan_cuda_core -except Exception as e: - print(e, flush=True) - "you should install mamba_ssm to use this" - SSMODE = "mamba_ssm" - import selective_scan_cuda - # from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref - - -# fvcore flops ======================================= - -def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): - """ - u: r(B D L) - delta: r(B D L) - A: r(D N) - B: r(B N L) - C: r(B N L) - D: r(D) - z: r(B D L) - delta_bias: r(D), fp32 - - ignores: - [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] - """ - assert not with_complex - # https://github.com/state-spaces/mamba/issues/110 - flops = 9 * B * L * D * N - if with_D: - flops += B * D * L - if with_Z: - flops += B * D * L - return flops - - -# this is only for selective_scan_ref... -def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): - """ - u: r(B D L) - delta: r(B D L) - A: r(D N) - B: r(B N L) - C: r(B N L) - D: r(D) - z: r(B D L) - delta_bias: r(D), fp32 - - ignores: - [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] - """ - import numpy as np - - # fvcore.nn.jit_handles - def get_flops_einsum(input_shapes, equation): - np_arrs = [np.zeros(s) for s in input_shapes] - optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] - for line in optim.split("\n"): - if "optimized flop" in line.lower(): - # divided by 2 because we count MAC (multiply-add counted as one flop) - flop = float(np.floor(float(line.split(":")[-1]) / 2)) - return flop - - - assert not with_complex - - flops = 0 # below code flops = 0 - - flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") - if with_Group: - flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") - else: - flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") - - in_for_flops = B * D * N - if with_Group: - in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") - else: - in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") - flops += L * in_for_flops - if with_D: - flops += B * D * L - if with_Z: - flops += B * D * L - return flops - - -def print_jit_input_names(inputs): - print("input params: ", end=" ", flush=True) - try: - for i in range(10): - print(inputs[i].debugName(), end=" ", flush=True) - except Exception as e: - pass - print("", flush=True) - - -# cross selective scan =============================== - -class SelectiveScan(torch.autograd.Function): - - @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) - def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1): - assert nrows in [1, 2, 3, 4], f"{nrows}" # 8+ is too slow to compile - assert u.shape[1] % (B.shape[1] * nrows) == 0, f"{nrows}, {u.shape}, {B.shape}" - ctx.delta_softplus = delta_softplus - ctx.nrows = nrows - # all in float - if u.stride(-1) != 1: - u = u.contiguous() - if delta.stride(-1) != 1: - delta = delta.contiguous() - if D is not None: - D = D.contiguous() - if B.stride(-1) != 1: - B = B.contiguous() - if C.stride(-1) != 1: - C = C.contiguous() - if B.dim() == 3: - B = B.unsqueeze(dim=1) - ctx.squeeze_B = True - if C.dim() == 3: - C = C.unsqueeze(dim=1) - ctx.squeeze_C = True - - if SSMODE == "mamba_ssm": - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) - else: - out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) - - ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) - return out - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, dout, *args): - u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors - if dout.stride(-1) != 1: - dout = dout.contiguous() - - if SSMODE == "mamba_ssm": - du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( - u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, - False # option to recompute out_z, not used here - ) - else: - du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd( - u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 - # u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, ctx.nrows, - ) - - dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB - dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC - return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None) - - -class CrossScan(torch.autograd.Function): - # [B, C, H, W] -> [B, 4, C, H * W] (original) - # [B, C, H, W] -> [B, 4, C, H/w * W/w] - @staticmethod - def forward(ctx, x: torch.Tensor, window_size=2): # [B, C, H, W] -> [B, 4, H/w * W/w] - B, C, org_h, org_w = x.shape - ctx.shape = (B, C, org_h, org_w) - ctx.window_size = window_size - - # if (W % window_size != 0) or (H % window_size != 0): - # x = F.pad(x, (0, window_size - H % window_size, 0, window_size - W % window_size)) - # print("x after padding: ", x.shape) - # H, W = x.shape[2:] - - if org_w % window_size != 0: - pad_w = window_size - org_w % window_size - x = F.pad(x, (0, pad_w, 0, 0)) - W = x.shape[3] - - if org_h % window_size != 0: - pad_h = window_size - org_h % window_size - x = F.pad(x, (0, 0, 0, pad_h)) - H = x.shape[2] - - H = H // window_size - W = W // window_size - - xs = x.new_empty((B, 4, C, H*W)) - - - xs[:, 0] = x[:, :, ::window_size, ::window_size].contiguous().view(B, C, -1) - xs[:, 1] = x.transpose(dim0=2, dim1=3)[:, :, ::window_size, 1::window_size].contiguous().view(B, C, -1) - xs[:, 2] = x[:, :, ::window_size, 1::window_size].contiguous().view(B, C, -1) - xs[:, 3] = x.transpose(dim0=2, dim1=3)[:, :, 1::window_size, 1::window_size].contiguous().view(B, C, -1) - - xs = xs.view(B, 4, C, -1) - return xs - - @staticmethod - def backward(ctx, grad_xs: torch.Tensor): # [B, 4, H/w * W/w] -> [B, C, H, W] - - B, C, org_h, org_w = ctx.shape - window_size = ctx.window_size - - newH, newW = math.ceil(org_h / window_size), math.ceil(org_w / window_size) - grad_x = grad_xs.new_empty((B, C, newH * window_size, newW * window_size)) - - - # H, W = H // window_size, W // window_size - grad_xs = grad_xs.view(B, 4, C, newH, newW) - - grad_x[:, :, ::window_size, ::window_size] = grad_xs[:, 0].reshape(B, C, newH, newW) - grad_x[:, :, 1::window_size, ::window_size] = grad_xs[:, 1].reshape(B, C, newW, newH).transpose(dim0=2, dim1=3) - grad_x[:, :, ::window_size, 1::window_size] = grad_xs[:, 2].reshape(B, C, newH, newW) - grad_x[:, :, 1::window_size, 1::window_size] = grad_xs[:, 3].reshape(B, C, newW, newH).transpose(dim0=2, dim1=3) - - if org_h != grad_x.shape[-2] or org_w != grad_x.shape[-1]: - grad_x = grad_x[:, :, :org_h, :org_w] - - return grad_x, None - -class CrossMerge(torch.autograd.Function): # [B, 4, C, H/w * W/w] -> [B, C, H*W] - @staticmethod - def forward(ctx, ys: torch.Tensor, ori_h: int, ori_w: int, window_size=2): - B, K, C, L = ys.shape - H, W = math.ceil(ori_h / window_size), math.ceil(ori_w / window_size) - ctx.shape = (H, W) - ctx.ori_h = ori_h - ctx.ori_w = ori_w - ctx.window_size = window_size - - # print("ori_h, ori_w H W window_size: ", ori_h, ori_w, H, W, window_size) - - new_h = H * window_size - new_w = W * window_size - - y = ys.new_empty((B, C, new_h, new_w)) - - # print("After new_empty y.shape", y.shape, ys.shape) - - y[:, :, ::window_size, ::window_size] = ys[:, 0].reshape(B, C, H, W) - y[:, :, 1::window_size, ::window_size] = ys[:, 1].reshape(B, C, W, H).transpose(dim0=2, dim1=3) - y[:, :, ::window_size, 1::window_size] = ys[:, 2].reshape(B, C, H, W) - y[:, :, 1::window_size, 1::window_size] = ys[:, 3].reshape(B, C, W, H).transpose(dim0=2, dim1=3) - - if ori_h != new_h or ori_w != new_w: - y = y[:, :, :ori_h, :ori_w].contiguous() - - y = y.view(B, C, -1) - return y - - @staticmethod - def backward(ctx, grad_x: torch.Tensor): # [B, C, H*W] -> [B, 4, C, H/w * W/w] - # out: (b, k, c, l) - H, W = ctx.shape - B, C, L = grad_x.shape - window_size = ctx.window_size - - grad_x = grad_x.view(B, C, ctx.ori_h, ctx.ori_w) - - if ctx.ori_w % window_size != 0: - pad_w = window_size - ctx.ori_w % window_size - grad_x = F.pad(grad_x, (0, pad_w, 0, 0)) - W = grad_x.shape[3] - - if ctx.ori_h % window_size != 0: - pad_h = window_size - ctx.ori_h % window_size - grad_x = F.pad(grad_x, (0, 0, 0, pad_h)) - H = grad_x.shape[2] - B, C, H, W = grad_x.shape - H = H // window_size - W = W // window_size - grad_xs = grad_x.new_empty((B, 4, C, H*W)) - # print("grad_x.shape, grad_xs.shape", grad_x.shape, grad_xs.shape) # grad_x.shape, grad_xs.shape torch.Size([2, 1536, 26, 34]) torch.Size([2, 4, 1536, 884]) - - grad_xs[:, 0] = grad_x[:, :, ::window_size, ::window_size].reshape(B, C, -1) # - grad_xs[:, 1] = grad_x.transpose(dim0=2, dim1=3)[:, :, ::window_size, 1::window_size].reshape(B, C, -1) - grad_xs[:, 2] = grad_x[:, :, ::window_size, 1::window_size].reshape(B, C, -1) - grad_xs[:, 3] = grad_x.transpose(dim0=2, dim1=3)[:, :, 1::window_size, 1::window_size].reshape(B, C, -1) - - return grad_xs, None, None, None # - - -def cross_selective_scan( - x: torch.Tensor=None, - x_proj_weight: torch.Tensor=None, - x_proj_bias: torch.Tensor=None, - dt_projs_weight: torch.Tensor=None, - dt_projs_bias: torch.Tensor=None, - A_logs: torch.Tensor=None, - Ds: torch.Tensor=None, - out_norm: torch.nn.Module=None, - nrows = -1, - delta_softplus = True, - to_dtype=True, - window_size = 2, -): - # out_norm: whatever fits (B, L, C); LayerNorm; Sigmoid; Softmax(dim=1);... - - B, D, H, W = x.shape - D, N = A_logs.shape - K, D, R = dt_projs_weight.shape - L = H * W - - if nrows < 1: - if D % 4 == 0: - nrows = 4 - elif D % 3 == 0: - nrows = 3 - elif D % 2 == 0: - nrows = 2 - else: - nrows = 1 - # H * W - ori_h, ori_w = H, W - - xs = CrossScan.apply(x, window_size) # [B, C, H*W] -> [B, 4, C, H//w * W//w] - # H//w * W//w - H = math.ceil(H / window_size) - W = math.ceil(W / window_size) - - L = H * W - - x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight) # l fixed - - if x_proj_bias is not None: - x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1) - dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) - dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight) - - xs = xs.view(B, -1, L).to(torch.float) - dts = dts.contiguous().view(B, -1, L).to(torch.float) - As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state) - Bs = Bs.contiguous().to(torch.float) - Cs = Cs.contiguous().to(torch.float) - Ds = Ds.to(torch.float) # (K * c) - delta_bias = dt_projs_bias.view(-1).to(torch.float) - - def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1): - return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) - - ys: torch.Tensor = selective_scan( - xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus, nrows, - ).view(B, K, -1, L) - # print(ys.shape, ori_h, ori_w, window_size, H, W) - ori_h, ori_w = int(ori_h), int(ori_w) - y = CrossMerge.apply(ys, ori_h, ori_w, window_size) # [B, 4, C, H//w * W//w] -> [B, C, H*W] - - H = ori_h - W = ori_w - L = H * W - - y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) - y = out_norm(y).view(B, H, W, -1) - - return (y.to(x.dtype) if to_dtype else y) - -def selective_scan_flop_jit(inputs, outputs): - print_jit_input_names(inputs) - B, D, L = inputs[0].type().sizes() - N = inputs[2].type().sizes()[1] - flops = flops_selective_scan_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False, with_Group=True) - return flops - - -# ===================================================== - -class PatchMerging2D(nn.Module): - def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.reduction = nn.Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False) - self.norm = norm_layer(4 * dim) - - @staticmethod - def _patch_merging_pad(x: torch.Tensor): - H, W, _ = x.shape[-3:] - if (W % 2 != 0) or (H % 2 != 0): - x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) - x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C - x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C - x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C - x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C - return x - - def forward(self, x): - x = self._patch_merging_pad(x) - x = self.norm(x) - x = self.reduction(x) - - return x - - -class SS2D(nn.Module): - def __init__( - self, - # basic dims =========== - d_model=96, - d_state=16, - ssm_ratio=2.0, - ssm_rank_ratio=2.0, - dt_rank="auto", - act_layer=nn.SiLU, - # dwconv =============== - d_conv=3, # < 2 means no conv - conv_bias=True, - # ====================== - dropout=0.0, - bias=False, - # dt init ============== - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - simple_init=False, - # ====================== - forward_type="v2", - # ====================== - window_size=2, - **kwargs, - ): - """ - ssm_rank_ratio would be used in the future... - """ - factory_kwargs = {"device": None, "dtype": None} - super().__init__() - d_expand = int(ssm_ratio * d_model) - d_inner = int(min(ssm_rank_ratio, ssm_ratio) * d_model) if ssm_rank_ratio > 0 else d_expand - self.dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank - self.d_state = math.ceil(d_model / 6) if d_state == "auto" else d_state # 20240109 - self.d_conv = d_conv - - self.window_size = window_size - - # disable z act ====================================== - self.disable_z_act = forward_type[-len("nozact"):] == "nozact" - if self.disable_z_act: - forward_type = forward_type[:-len("nozact")] - - # softmax | sigmoid | norm =========================== - if forward_type[-len("softmax"):] == "softmax": - forward_type = forward_type[:-len("softmax")] - self.out_norm = nn.Softmax(dim=1) - elif forward_type[-len("sigmoid"):] == "sigmoid": - forward_type = forward_type[:-len("sigmoid")] - self.out_norm = nn.Sigmoid() - else: - self.out_norm = nn.LayerNorm(d_inner) - - # forward_type ======================================= - self.forward_core = dict( - v0=self.forward_corev0, - v0_seq=self.forward_corev0_seq, - v1=self.forward_corev2, - v2=self.forward_corev2, - share_ssm=self.forward_corev0_share_ssm, - share_a=self.forward_corev0_share_a, - ).get(forward_type, self.forward_corev2) - self.K = 4 if forward_type not in ["share_ssm"] else 1 - self.K2 = self.K if forward_type not in ["share_a"] else 1 - - # in proj ======================================= - self.in_proj = nn.Linear(d_model, d_expand * 2, bias=bias, **factory_kwargs) - self.act: nn.Module = act_layer() - - # conv ======================================= - if self.d_conv > 1: - self.conv2d = nn.Conv2d( - in_channels=d_expand, - out_channels=d_expand, - groups=d_expand, - bias=conv_bias, - kernel_size=d_conv, - padding=(d_conv - 1) // 2, - **factory_kwargs, - ) - - # rank ratio ===================================== - self.ssm_low_rank = False - if d_inner < d_expand: - self.ssm_low_rank = True - self.in_rank = nn.Conv2d(d_expand, d_inner, kernel_size=1, bias=False, **factory_kwargs) - self.out_rank = nn.Linear(d_inner, d_expand, bias=False, **factory_kwargs) - - # x proj ============================ - self.x_proj = [ - nn.Linear(d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs) - for _ in range(self.K) - ] - self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) - del self.x_proj - - # dt proj ============================ - self.dt_projs = [ - self.dt_init(self.dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) - for _ in range(self.K) - ] - self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank) - self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner) - del self.dt_projs - - # A, D ======================================= - self.A_logs = self.A_log_init(self.d_state, d_inner, copies=self.K2, merge=True) # (K * D, N) - self.Ds = self.D_init(d_inner, copies=self.K2, merge=True) # (K * D) - - # out proj ======================================= - self.out_proj = nn.Linear(d_expand, d_model, bias=bias, **factory_kwargs) - self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() - - if simple_init: - # simple init dt_projs, A_logs, Ds - self.Ds = nn.Parameter(torch.ones((self.K2 * d_inner))) - self.A_logs = nn.Parameter(torch.randn((self.K2 * d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 - self.dt_projs_weight = nn.Parameter(torch.randn((self.K, d_inner, self.dt_rank))) - self.dt_projs_bias = nn.Parameter(torch.randn((self.K, d_inner))) - - @staticmethod - def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): - dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) - - # Initialize special dt projection to preserve variance at initialization - dt_init_std = dt_rank**-0.5 * dt_scale - if dt_init == "constant": - nn.init.constant_(dt_proj.weight, dt_init_std) - elif dt_init == "random": - nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt = torch.exp( - torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ).clamp(min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - with torch.no_grad(): - dt_proj.bias.copy_(inv_dt) - # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit - # dt_proj.bias._no_reinit = True - - return dt_proj - - @staticmethod - def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): - # S4D real initialization - A = repeat( - torch.arange(1, d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=d_inner, - ).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - if copies > 0: - A_log = repeat(A_log, "d n -> r d n", r=copies) - if merge: - A_log = A_log.flatten(0, 1) - A_log = nn.Parameter(A_log) - A_log._no_weight_decay = True - return A_log - - @staticmethod - def D_init(d_inner, copies=-1, device=None, merge=True): - # D "skip" parameter - D = torch.ones(d_inner, device=device) - if copies > 0: - D = repeat(D, "n1 -> r n1", r=copies) - if merge: - D = D.flatten(0, 1) - D = nn.Parameter(D) # Keep in fp32 - D._no_weight_decay = True - return D - - # only used to run previous version - def forward_corev0(self, x: torch.Tensor, to_dtype=False, channel_first=False): - def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1): - return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) - - if not channel_first: - x = x.permute(0, 3, 1, 2).contiguous() - B, C, H, W = x.shape - L = H * W - K = 4 - - x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) - xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) - - x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) - # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) - dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) - dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) - - xs = xs.float().view(B, -1, L) # (b, k * d, l) - dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) - Bs = Bs.float() # (b, k, d_state, l) - Cs = Cs.float() # (b, k, d_state, l) - - As = -torch.exp(self.A_logs.float()) # (k * d, d_state) - Ds = self.Ds.float() # (k * d) - dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) - - # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4 - # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1 - - out_y = selective_scan( - xs, dts, - As, Bs, Cs, Ds, - delta_bias=dt_projs_bias, - delta_softplus=True, - ).view(B, K, -1, L) - # assert out_y.dtype == torch.float - - inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) - wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y - y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) - y = self.out_norm(y).view(B, H, W, -1) - - return (y.to(x.dtype) if to_dtype else y) - - # only has speed difference with v0 - def forward_corev0_seq(self, x: torch.Tensor, to_dtype=False, channel_first=False): - def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1): - return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) - - if not channel_first: - x = x.permute(0, 3, 1, 2).contiguous() - B, C, H, W = x.shape - L = H * W - K = 4 - - x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) - xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) - - x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) - # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) - dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) - dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) - - xs = xs.float() # (b, k, d, l) - dts = dts.contiguous().float() # (b, k, d, l) - Bs = Bs.float() # (b, k, d_state, l) - Cs = Cs.float() # (b, k, d_state, l) - - As = -torch.exp(self.A_logs.float()).view(K, -1, self.d_state) # (k, d, d_state) - Ds = self.Ds.float().view(K, -1) # (k, d) - dt_projs_bias = self.dt_projs_bias.float().view(K, -1) # (k, d) - - # assert len(xs.shape) == 4 and len(dts.shape) == 4 and len(Bs.shape) == 4 and len(Cs.shape) == 4 - # assert len(As.shape) == 3 and len(Ds.shape) == 2 and len(dt_projs_bias.shape) == 2 - - out_y = [] - for i in range(4): - yi = selective_scan( - xs[:, i], dts[:, i], - As[i], Bs[:, i], Cs[:, i], Ds[i], - delta_bias=dt_projs_bias[i], - delta_softplus=True, - ).view(B, -1, L) - out_y.append(yi) - out_y = torch.stack(out_y, dim=1) - assert out_y.dtype == torch.float - - inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) - wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) - y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y - y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) - y = self.out_norm(y).view(B, H, W, -1) - - return (y.to(x.dtype) if to_dtype else y) - - - def forward_corev0_share_ssm(self, x: torch.Tensor, channel_first=False): - """ - we may conduct this ablation later, but not with v0. - """ - ... - - def forward_corev0_share_a(self, x: torch.Tensor, channel_first=False): - """ - we may conduct this ablation later, but not with v0. - """ - ... - - def forward_corev2(self, x: torch.Tensor, nrows=-1, channel_first=False, window_size=2): - nrows = 1 - if not channel_first: - x = x.permute(0, 3, 1, 2).contiguous() - if self.ssm_low_rank: - x = self.in_rank(x) - x = cross_selective_scan( - x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias, - self.A_logs, self.Ds, getattr(self, "out_norm", None), - nrows=nrows, delta_softplus=True, window_size=window_size - ) - if self.ssm_low_rank: - x = self.out_rank(x) - return x - - def forward(self, x: torch.Tensor, **kwargs): - xz = self.in_proj(x) - if self.d_conv > 1: - x, z = xz.chunk(2, dim=-1) # (b, h, w, d) - if not self.disable_z_act: - z = self.act(z) - x = x.permute(0, 3, 1, 2).contiguous() - x = self.act(self.conv2d(x)) # (b, d, h, w) - else: - if self.disable_z_act: - x, z = xz.chunk(2, dim=-1) # (b, h, w, d) - x = self.act(x) - else: - xz = self.act(xz) - x, z = xz.chunk(2, dim=-1) # (b, h, w, d) - y = self.forward_core(x, channel_first=(self.d_conv > 1), window_size=self.window_size) - y = y * z - out = self.dropout(self.out_proj(y)) - return out - - -class Permute(nn.Module): - def __init__(self, *args): - super().__init__() - self.args = args - - def forward(self, x: torch.Tensor): - return x.permute(*self.args) - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - - Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear - self.fc1 = Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class SqueezeExcite(nn.Module): - def __init__(self, in_channels, reduce_channels, act_fn=nn.GELU, gate_fn=nn.Sigmoid): - super(SqueezeExcite, self).__init__() - self.avgp = nn.AdaptiveAvgPool2d(1) - self.conv_reduce = nn.Conv2d(in_channels, reduce_channels, 1, bias=True) - self.act_fn = act_fn() - self.conv_expand = nn.Conv2d(reduce_channels, in_channels, 1, bias=True) - self.gate_fn = gate_fn() - - def forward(self, x): - x_se = self.avgp(x) - x_se = self.conv_reduce(x_se) - x_se = self.act_fn(x_se) - x_se = self.conv_expand(x_se) - x = x * self.gate_fn(x_se) - return x - -class BiAttn(nn.Module): - def __init__(self, in_channels, act_ratio=0.125, act_fn=nn.GELU, gate_fn=nn.Sigmoid): - super().__init__() - reduce_channels = int(in_channels * act_ratio) - self.norm = nn.LayerNorm(in_channels) - self.global_reduce = nn.Linear(in_channels, reduce_channels) - # self.local_reduce = nn.Linear(in_channels, reduce_channels) - self.act_fn = act_fn() - self.channel_select = nn.Linear(reduce_channels, in_channels) - # self.spatial_select = nn.Linear(reduce_channels * 2, 1) - self.gate_fn = gate_fn() - - def forward(self, x): - ori_x = x - x = self.norm(x) - x_global = x.mean([1, 2], keepdim=True) - x_global = self.act_fn(self.global_reduce(x_global)) - # x_local = self.act_fn(self.local_reduce(x)) - - c_attn = self.channel_select(x_global) - c_attn = self.gate_fn(c_attn) # [B, 1, C] - # s_attn = self.spatial_select(torch.cat([x_local, x_global.expand(-1, x.shape[1], -1)], dim=-1)) - # s_attn = self.gate_fn(s_attn) # [B, N, 1] - - attn = c_attn #* s_attn # [B, N, C] - out = ori_x * attn - return out - - -class VSSBlock(nn.Module): - def __init__( - self, - hidden_dim: int = 0, - drop_path: float = 0, - norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), - # ============================= - ssm_d_state: int = 16, - ssm_ratio=2.0, - ssm_rank_ratio=2.0, - ssm_dt_rank: Any = "auto", - ssm_act_layer=nn.SiLU, - ssm_conv: int = 3, - ssm_conv_bias=True, - ssm_drop_rate: float = 0, - ssm_simple_init=False, - forward_type="v2", - # ============================= - mlp_ratio=4.0, - mlp_act_layer=nn.GELU, - mlp_drop_rate: float = 0.0, - # ============================= - use_checkpoint: bool = False, - window_size=2, - **kwargs, - ): - super().__init__() - self.use_checkpoint = use_checkpoint - self.norm = norm_layer(hidden_dim) - self.op = SS2D( - d_model=hidden_dim, - d_state=ssm_d_state, - ssm_ratio=ssm_ratio, - ssm_rank_ratio=ssm_rank_ratio, - dt_rank=ssm_dt_rank, - act_layer=ssm_act_layer, - # ========================== - d_conv=ssm_conv, - conv_bias=ssm_conv_bias, - # ========================== - dropout=ssm_drop_rate, - # bias=False, - # ========================== - # dt_min=0.001, - # dt_max=0.1, - # dt_init="random", - # dt_scale="random", - # dt_init_floor=1e-4, - simple_init=ssm_simple_init, - # ========================== - forward_type=forward_type, - window_size=window_size, - ) - self.conv_branch = nn.Sequential( - nn.Conv2d(hidden_dim, hidden_dim, 3, stride=1, padding=1, groups=hidden_dim), - nn.BatchNorm2d(hidden_dim), - nn.GELU(), - nn.Conv2d(hidden_dim, hidden_dim, 1) - ) - self.se = BiAttn(hidden_dim) #SqueezeExcite(hidden_dim, hidden_dim // 8) - self.drop_path = DropPath(drop_path) - - self.mlp_branch = mlp_ratio > 0 - if self.mlp_branch: - self.norm2 = norm_layer(hidden_dim) - mlp_hidden_dim = int(hidden_dim * mlp_ratio) - self.mlp = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=False) - - def _forward(self, input: torch.Tensor): - x = self.norm(input) - x_ssm = self.op(x) - x_conv = self.conv_branch(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) - x = self.se(x_ssm) + self.se(x_conv) - x = input + self.drop_path(x) - if self.mlp_branch: - x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN - return x - - def forward(self, input: torch.Tensor): - if self.use_checkpoint: - return checkpoint.checkpoint(self._forward, input) - else: - return self._forward(input) - - -class VSSM(nn.Module): - def __init__( - self, - patch_size=4, - in_chans=3, - num_classes=1000, - depths=[2, 2, 9, 2], - dims=[96, 192, 384, 768], - # ========================= - ssm_d_state=16, - ssm_ratio=2.0, - ssm_rank_ratio=2.0, - ssm_dt_rank="auto", - ssm_act_layer="silu", - ssm_conv=3, - ssm_conv_bias=True, - ssm_drop_rate=0.0, - ssm_simple_init=False, - forward_type="v2", - # ========================= - mlp_ratio=4.0, - mlp_act_layer="gelu", - mlp_drop_rate=0.0, - # ========================= - drop_path_rate=0.1, - patch_norm=True, - norm_layer="LN", - downsample_version: str = "v2", # "v1", "v2", "v3" - patchembed_version: str = "v1", # "v1", "v2" - use_checkpoint=False, - **kwargs, - ): - super().__init__() - self.num_classes = num_classes - self.num_layers = len(depths) - if isinstance(dims, int): - dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] - self.num_features = dims[-1] - self.dims = dims - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - _NORMLAYERS = dict( - ln=nn.LayerNorm, - bn=nn.BatchNorm2d, - ) - - _ACTLAYERS = dict( - silu=nn.SiLU, - gelu=nn.GELU, - relu=nn.ReLU, - sigmoid=nn.Sigmoid, - ) - - if norm_layer.lower() in ["ln"]: - norm_layer: nn.Module = _NORMLAYERS[norm_layer.lower()] - - if ssm_act_layer.lower() in ["silu", "gelu", "relu"]: - ssm_act_layer: nn.Module = _ACTLAYERS[ssm_act_layer.lower()] - - if mlp_act_layer.lower() in ["silu", "gelu", "relu"]: - mlp_act_layer: nn.Module = _ACTLAYERS[mlp_act_layer.lower()] - - _make_patch_embed = dict( - v1=self._make_patch_embed, - v2=self._make_patch_embed_v2, - ).get(patchembed_version, None) - self.patch_embed = _make_patch_embed(in_chans, dims[0], patch_size, patch_norm, norm_layer) - - _make_downsample = dict( - v1=PatchMerging2D, - v2=self._make_downsample, - v3=self._make_downsample_v3, - none=(lambda *_, **_k: None), - ).get(downsample_version, None) - - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - downsample = _make_downsample( - self.dims[i_layer], - self.dims[i_layer + 1], - norm_layer=norm_layer, - ) if (i_layer < self.num_layers - 1) else nn.Identity() - - self.layers.append(self._make_layer( - dim = self.dims[i_layer], - drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - use_checkpoint=use_checkpoint, - norm_layer=norm_layer, - downsample=downsample, - # ================= - ssm_d_state=ssm_d_state, - ssm_ratio=ssm_ratio, - ssm_rank_ratio=ssm_rank_ratio, - ssm_dt_rank=ssm_dt_rank, - ssm_act_layer=ssm_act_layer, - ssm_conv=ssm_conv, - ssm_conv_bias=ssm_conv_bias, - ssm_drop_rate=ssm_drop_rate, - ssm_simple_init=ssm_simple_init, - forward_type=forward_type, - # ================= - mlp_ratio=mlp_ratio, - mlp_act_layer=mlp_act_layer, - mlp_drop_rate=mlp_drop_rate, - )) - - self.classifier = nn.Sequential(OrderedDict( - norm=norm_layer(self.num_features), # B,H,W,C - permute=Permute(0, 3, 1, 2), - avgpool=nn.AdaptiveAvgPool2d(1), - flatten=nn.Flatten(1), - head=nn.Linear(self.num_features, num_classes), - )) - - self.apply(self._init_weights) - - def _init_weights(self, m: nn.Module): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - # used in building optimizer - # @torch.jit.ignore - # def no_weight_decay(self): - # return {} - - # used in building optimizer - # @torch.jit.ignore - # def no_weight_decay_keywords(self): - # return {} - - @staticmethod - def _make_patch_embed(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm): - return nn.Sequential( - nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True), - Permute(0, 2, 3, 1), - (norm_layer(embed_dim) if patch_norm else nn.Identity()), - ) - - @staticmethod - def _make_patch_embed_v2(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm): - assert patch_size == 4 - return nn.Sequential( - nn.Conv2d(in_chans, embed_dim // 2, kernel_size=3, stride=2, padding=1), - (Permute(0, 2, 3, 1) if patch_norm else nn.Identity()), - (norm_layer(embed_dim // 2) if patch_norm else nn.Identity()), - (Permute(0, 3, 1, 2) if patch_norm else nn.Identity()), - nn.GELU(), - nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), - Permute(0, 2, 3, 1), - (norm_layer(embed_dim) if patch_norm else nn.Identity()), - ) - - @staticmethod - def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm): - return nn.Sequential( - Permute(0, 3, 1, 2), - nn.Conv2d(dim, out_dim, kernel_size=2, stride=2), - Permute(0, 2, 3, 1), - norm_layer(out_dim), - ) - - @staticmethod - def _make_downsample_v3(dim=96, out_dim=192, norm_layer=nn.LayerNorm): - return nn.Sequential( - Permute(0, 3, 1, 2), - nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1), - Permute(0, 2, 3, 1), - norm_layer(out_dim), - ) - - @staticmethod - def _make_layer( - dim=96, - drop_path=[0.1, 0.1], - use_checkpoint=False, - norm_layer=nn.LayerNorm, - downsample=nn.Identity(), - # =========================== - ssm_d_state=16, - ssm_ratio=2.0, - ssm_rank_ratio=2.0, - ssm_dt_rank="auto", - ssm_act_layer=nn.SiLU, - ssm_conv=3, - ssm_conv_bias=True, - ssm_drop_rate=0.0, - ssm_simple_init=False, - forward_type="v2", - # =========================== - mlp_ratio=4.0, - mlp_act_layer=nn.GELU, - mlp_drop_rate=0.0, - window_size=2, - **kwargs, - ): - depth = len(drop_path) - blocks = [] - for d in range(depth): - blocks.append(VSSBlock( - hidden_dim=dim, - drop_path=drop_path[d], - norm_layer=norm_layer, - ssm_d_state=ssm_d_state, - ssm_ratio=ssm_ratio, - ssm_rank_ratio=ssm_rank_ratio, - ssm_dt_rank=ssm_dt_rank, - ssm_act_layer=ssm_act_layer, - ssm_conv=ssm_conv, - ssm_conv_bias=ssm_conv_bias, - ssm_drop_rate=ssm_drop_rate, - ssm_simple_init=ssm_simple_init, - forward_type=forward_type, - mlp_ratio=mlp_ratio, - mlp_act_layer=mlp_act_layer, - mlp_drop_rate=mlp_drop_rate, - use_checkpoint=use_checkpoint, - window_size=window_size, - )) - - return nn.Sequential(OrderedDict( - blocks=nn.Sequential(*blocks,), - downsample=downsample, - )) - - def forward(self, x: torch.Tensor): - x = self.patch_embed(x) - for layer in self.layers: - x = layer(x) - x = self.classifier(x) - return x - - def flops(self, shape=(3, 224, 224)): - supported_ops={ - "aten::silu": None, # as relu is in _IGNORED_OPS - "aten::neg": None, # as relu is in _IGNORED_OPS - "aten::exp": None, # as relu is in _IGNORED_OPS - "aten::flip": None, # as permute is in _IGNORED_OPS - "prim::PythonOp.CrossScan": None, - "prim::PythonOp.CrossMerge": None, - "prim::PythonOp.SelectiveScan": selective_scan_flop_jit, - "prim::PythonOp.SelectiveScanFn": selective_scan_flop_jit, - } - - model = copy.deepcopy(self) - model.cuda().eval() - - input = torch.randn((1, *shape), device=next(model.parameters()).device) - params = parameter_count(model)[""] - Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops) - - del model, input - return sum(Gflops.values()) * 1e9 - return f"params {params} GFLOPs {sum(Gflops.values())}" - - # used to load ckpt from previous training code - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - - def check_name(src, state_dict: dict = state_dict, strict=False): - if strict: - if prefix + src in list(state_dict.keys()): - return True - else: - key = prefix + src - for k in list(state_dict.keys()): - if k.startswith(key): - return True - return False - - def change_name(src, dst, state_dict: dict = state_dict, strict=False): - if strict: - if prefix + src in list(state_dict.keys()): - state_dict[prefix + dst] = state_dict[prefix + src] - state_dict.pop(prefix + src) - else: - key = prefix + src - for k in list(state_dict.keys()): - if k.startswith(key): - new_k = prefix + dst + k[len(key):] - state_dict[new_k] = state_dict[k] - state_dict.pop(k) - - change_name("patch_embed.proj", "patch_embed.0") - change_name("patch_embed.norm", "patch_embed.2") - for i in range(100): - for j in range(100): - change_name(f"layers.{i}.blocks.{j}.ln_1", f"layers.{i}.blocks.{j}.norm") - change_name(f"layers.{i}.blocks.{j}.self_attention", f"layers.{i}.blocks.{j}.op") - change_name("norm", "classifier.norm") - change_name("head", "classifier.head") - - return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - - - - -# ================================================== -def check_vssm_equals_vmambadp(): - try: - from _ignore.vmamba.vmamba_bak1 import VMamba2Dp - from _ignore.vmamba.vmamba_pub import VSSM - except: - print("original VSSM and VMamba2Dp not found.", flush=True) - return - - # test 1 True ================================= - torch.manual_seed(time.time()); torch.cuda.manual_seed(time.time()) - oldvss = VMamba2Dp(depths=[2,2,6,2]).half().cuda() - newvss = VSSM(depths=[2,2,6,2]).half().cuda() - newvss.load_state_dict(oldvss.state_dict()) - input = torch.randn((12, 3, 224, 224)).half().cuda() - torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y1 = oldvss.forward_backbone(input) - torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y2 = newvss.forward_backbone(input) - print((y1 -y2).abs().sum()) # tensor(0., device='cuda:0', grad_fn=) - - torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y1 = oldvss.forward(input) - torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y2 = newvss.forward(input) - print((y1 -y2).abs().sum()) # tensor(0., device='cuda:0', grad_fn=) - - # test 2 True ========================================== - torch.manual_seed(0); torch.cuda.manual_seed(0) - oldvss = VMamba2Dp(depths=[2,2,6,2]).cuda() - torch.manual_seed(0); torch.cuda.manual_seed(0) - newvss = VSSM(depths=[2,2,6,2]).cuda() - - miss_align = 0 - for k, v in oldvss.state_dict().items(): - same = (oldvss.state_dict()[k] == newvss.state_dict()[k]).all() - if not same: - print(k, same) - miss_align += 1 - print("init miss align", miss_align) # init miss align 0 - - -def check_vssm1_equals_vssm(forward_type="v0"): - try: - from _ignore.vmamba.vmamba_pub import VSSM as VSSM0 - except: - print("original VSSM and VMamba2Dp not found.", flush=True) - return - - class VSSM_(VSSM): - @staticmethod - def _make_layer(*args, **kwargs): - layer = VSSM._make_layer(*args, **kwargs) - dim = kwargs.get("dim", None) - norm_layer = kwargs.get("norm_layer", None) - downsample = kwargs.get("downsample", None) - blocks = layer.blocks - - if True: # is this really applied? Yes, but been overriden later in VSSM! - def _init_weights(module: nn.Module): - for name, p in module.named_parameters(): - if name in ["out_proj.weight"]: - p = p.clone().detach_() # fake init, just to keep the seed .... - nn.init.kaiming_uniform_(p, a=math.sqrt(5)) - blks = nn.Sequential(*copy.deepcopy(blocks)) - blks.apply(_init_weights) - - downsample = PatchMerging2D(dim, 2*dim, norm_layer=norm_layer) if downsample is None else nn.Identity() - - return nn.Sequential(OrderedDict( - blocks=nn.Sequential(*blocks,), - downsample=downsample, - )) - - def forward_backbone(self, x): - x = self.patch_embed(x) - for l in self.layers: - x = l(x) - return x - - def forward1(self, x: torch.Tensor): - x = self.patch_embed(x) - for layer in self.layers: - x = layer(x) - x = self.classifier.norm(x) - # here: whether has contiguous would differ - x = self.classifier.avgpool(x.permute(0, 3, 1, 2).contiguous()).flatten(1) - x = self.classifier.head(x) - return x - - # only has initial difference - VSSM1 = partial(VSSM, downsample_version="v1", patchembed_version="v1", mlp_ratio=0.0, ssm_ratio=2.0, ssm_rank_ratio=2.0, forward_type=forward_type) - VSSM.forward_backbone = VSSM_.forward_backbone - VSSM.forward1 = VSSM_.forward1 - # expected to be all the same - VSSM1 = partial(VSSM_, downsample_version="none", patchembed_version="v1", mlp_ratio=0.0, ssm_ratio=2.0, ssm_rank_ratio=2.0, forward_type=forward_type) - - # test 1 True ================================= - torch.manual_seed(time.time()); torch.cuda.manual_seed(time.time()) - oldvss = VSSM0(depths=[2,2,6,2]).half().cuda() - newvss = VSSM1(depths=[2,2,6,2]).half().cuda() - newvss.load_state_dict(oldvss.state_dict()) - input = torch.randn((12, 3, 224, 224)).half().cuda() - torch.manual_seed(0); torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y1 = oldvss.forward_backbone(input) - torch.manual_seed(0); torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y2 = newvss.forward_backbone(input) - print((y1 -y2).abs().sum()) # tensor(0., device='cuda:0', grad_fn=) - - torch.manual_seed(0); torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y1 = oldvss.forward(input) - torch.manual_seed(0); torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y2 = newvss.forward1(input) - print((y1 -y2).abs().sum()) # tensor(2.5988e-05, device='cuda:0', grad_fn=) - torch.manual_seed(0); torch.cuda.manual_seed(0) - with torch.cuda.amp.autocast(): - y3 = newvss.forward(input) - print((y1 -y3).abs().sum()) # tensor(0., device='cuda:0', grad_fn=) - - # test 2 True ========================================== - torch.manual_seed(0); torch.cuda.manual_seed(0) - oldvss = VSSM0(depths=[2,2,6,2]).cuda() - torch.manual_seed(0); torch.cuda.manual_seed(0) - newvss = VSSM1(depths=[2,2,6,2]).cuda() - - miss_align = 0 - oldvss2new = copy.deepcopy(newvss) - oldvss2new.load_state_dict(oldvss.state_dict()) - for k, v in oldvss2new.state_dict().items(): - same = (oldvss2new.state_dict()[k] == newvss.state_dict()[k]).all() - if not same: - print(k, same) - miss_align += 1 - print("init miss align", miss_align) # init miss align 0 - - -def check_profile(): - vss = VSSM(depths=[1], dims=1024).half().cuda() - input = torch.randn((128, 3, 56, 56)).half().cuda() - torch.cuda.manual_seed(0) - - self = vss - blk = self.layers[0].blocks[0] - ln_1 = blk.ln_1 - self_attention = blk.self_attention - selfa = self_attention - drop_path = blk.drop_path - input = self.patch_embed(input).detach() - - def trace_handler(prof: torch.profiler.profile): - print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) - # print(prof.export_chrome_trace("./tracev1.json")) - - with torch.cuda.amp.autocast(): - # with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False, profile_memory=True, with_stack=True) as prof: - with torch.profiler.profile( - with_modules=True, - with_stack=True, - profile_memory=True, - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - - # In this example with wait=1, warmup=1, active=2, repeat=1, - # profiler will skip the first step/iteration, - # start warming up on the second, record - # the third and the forth iterations, - # after which the trace will become available - # and on_trace_ready (when set) is called; - # the cycle repeats starting with the next step - - schedule=torch.profiler.schedule( - wait=1, - warmup=1, - active=2, - repeat=1), - on_trace_ready=trace_handler - # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') - # used when outputting for tensorboard - ) as prof: - for iter in range(1000): - x = input - # with torch.autograd.profiler.record_function("patch_embed"): - # x = self.patch_embed(x) - - B, H, W, C = x.shape - ori = x - - with torch.autograd.profiler.record_function("VSSBlock.ln_1"): - x = ln_1(x) - - with torch.autograd.profiler.record_function("SS2D.inproj"): - xz = selfa.in_proj(x) - x, z = xz.chunk(2, dim=-1) # (b, h, w, d) - x = x.permute(0, 3, 1, 2).contiguous() - - with torch.autograd.profiler.record_function("SS2D.dwconv2d"): - x = selfa.act(selfa.conv2d(x)) # (b, d, h, w) - # x = self.act(x) # (b, d, h, w) - - with torch.autograd.profiler.record_function("SS2D.foreward_core"): - # y = selfa.forward_corev2(x) - # y = selfa.forward_corev3(x) - y = selfa.forward_corev1(x) - # y = selfa.forward_corev1(x) - - with torch.autograd.profiler.record_function("SS2D.transpose"): - y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) - y = selfa.out_norm(y) - y = y * F.silu(z) - - with torch.autograd.profiler.record_function("SS2D.out_proj"): - out = selfa.out_proj(y) - if selfa.dropout is not None: - out = selfa.dropout(out) - - with torch.autograd.profiler.record_function("SS2D.out"): - x = ori + drop_path(out) - - with torch.autograd.profiler.record_function("backward"): - x.sum().backward() - - prof.step() - - -class MobileVSSM(VSSM): - def __init__( - self, - patch_size=4, - in_chans=3, - num_classes=1000, - depths=[2, 2, 9, 2], - dims=[96, 192, 384, 768], - # ========================= - d_state=16, - dt_rank="auto", - ssm_ratio=2.0, - attn_drop_rate=0., - shared_ssm=False, - softmax_version=False, - # ========================= - drop_rate=0., - drop_path_rate=0.1, - mlp_ratio=4.0, - patch_norm=True, - norm_layer=nn.LayerNorm, - downsample_version: str = "v2", - use_checkpoint=False, - window_size=2, - **kwargs, - ): - super().__init__() - self.num_classes = num_classes - self.num_layers = len(depths) - if isinstance(dims, int): - dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] - self.embed_dim = dims[0] - self.num_features = dims[-1] - self.dims = dims - - self.patch_embed = nn.Sequential( - nn.Conv2d(in_chans, self.embed_dim, kernel_size=patch_size, stride=patch_size, bias=True), - Permute(0, 2, 3, 1), - (norm_layer(self.embed_dim) if patch_norm else nn.Identity()), - ) - - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - - if downsample_version == "v2": - downsample = self._make_downsample( - self.dims[i_layer], - self.dims[i_layer + 1], - norm_layer=norm_layer, - ) if (i_layer < self.num_layers - 1) else nn.Identity() - else: - downsample = PatchMerging2D( - self.dims[i_layer], - self.dims[i_layer + 1], - norm_layer=norm_layer, - ) if (i_layer < self.num_layers - 1) else nn.Identity() - - if i_layer < 2: - self.layers.append(self._make_layer( - dim = self.dims[i_layer], - depth = depths[i_layer], - drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - use_checkpoint=use_checkpoint, - norm_layer=norm_layer, - downsample=downsample, - d_state=d_state, - dt_rank=dt_rank, - ssm_ratio=ssm_ratio, - attn_drop_rate=attn_drop_rate, - shared_ssm=shared_ssm, - softmax_version=softmax_version, - mlp_ratio=mlp_ratio, - drop_rate=drop_rate, - window_size=window_size - )) - else: - self.layers.append(nn.Sequential( - Permute(0, 3, 1, 2), - *[ - InvertedResidual(self.dims[i_layer], self.dims[i_layer], expand_ratio=4, se_ratio=0.125, drop_connect_rate=dpr[sum(depths[:i_layer]) + i]) - for i in range(depths[i_layer]) - ], - Permute(0, 2, 3, 1), - downsample)) - - self.classifier = nn.Sequential(OrderedDict( - norm=norm_layer(self.num_features), # B,H,W,C - permute=Permute(0, 3, 1, 2), - avgpool=nn.AdaptiveAvgPool2d(1), - flatten=nn.Flatten(1), - head=nn.Linear(self.num_features, num_classes), - )) - - self.apply(self._init_weights) - - # Optionally, you can modify other layers or add new ones as per your design - - def forward(self, x: torch.Tensor): - # Ensure to incorporate the modified downsampling step at the correct stage - # print(x.shape) #[1, 3, 224, 224] - x = self.patch_embed(x) #[1, 3, 224, 224] -> [1, 56, 56, 96] - # print("forward: (self.patch_embed(x))", x.shape) - # x = self.modified_downsampling(x) # Apply the modified downsampling - - # Continue with the forward pass as in the original VSSM or with modifications - for idx, layer in enumerate(self.layers): - x = layer(x) - # if idx < 2: - # x = layer(x) - # elif idx == 2: - # x = layer(x.permute(0, 3, 1, 2)) - # elif idx == 3: - # x = layer(x) - # x = x.permute(0, 2, 3, 1) - - # print("After x = layer(x), x.shape: ", x.shape) - x = self.classifier(x) - # print("After self.classifier(x): ", x.shape) - return x - - -# compatible with openmmlab -class Backbone_MobileVSSM(MobileVSSM): - def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer=nn.LayerNorm, **kwargs): - kwargs.update(norm_layer=norm_layer) - super().__init__(**kwargs) - - self.out_indices = out_indices - for i in out_indices: - layer = norm_layer(self.dims[i]) - layer_name = f'outnorm{i}' - self.add_module(layer_name, layer) - - del self.classifier - if pretrained is not None: - self.load_pretrained(pretrained) - - def load_pretrained(self, ckpt=None, key="state_dict"): - if ckpt is None: - return - - try: - _ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu")) - state_dict = _ckpt[key] - print(f"Successfully load ckpt {ckpt}") - incompatibleKeys = self.load_state_dict(state_dict, strict=False) - print(incompatibleKeys) - except Exception as e: - print(f"Failed loading checkpoint form {ckpt}: {e}") - - def forward(self, x): - def layer_forward(l, x): - if hasattr(l, 'blocks'): - x = l.blocks(x) - y = l.downsample(x) - else: - x = l[:-1](x) - y = l[-1](x) - return x, y - - x = self.patch_embed(x) - outs = [] - for i, layer in enumerate(self.layers): - o, x = layer_forward(layer, x) # (B, H, W, C) - if i in self.out_indices: - norm_layer = getattr(self, f'outnorm{i}') - out = norm_layer(o) - out = out.permute(0, 3, 1, 2).contiguous() - outs.append(out) - - if len(self.out_indices) == 0: - return x - - return outs - - -@register_model -def mobile_vssm_fusion_tiny(*args, drop_path_rate=0.1, **kwargs): - return MobileVSSM(dims=[96, 192, 384, 768], depths=[2, 2, 9, 2], d_state=16, mlp_ratio=0, downsample_version='v1', drop_path_rate=drop_path_rate) - -@register_model -def mobile_vssm_fusion_micro(*args, drop_path_rate=0.1, **kwargs): - return MobileVSSM(dims=[64, 128, 256, 512], depths=[2, 2, 4, 2], d_state=16, mlp_ratio=0, downsample_version='v1', drop_path_rate=drop_path_rate) - -@register_model -def mobile_vssm_fusion_nano(*args, drop_path_rate=0.1, **kwargs): - return MobileVSSM(dims=[48, 96, 192, 384], depths=[2, 2, 4, 2], d_state=16, mlp_ratio=0, downsample_version='v1', drop_path_rate=drop_path_rate) \ No newline at end of file