From 8de5e156cd6892ab7be47a5cbc3c126824e65a5a Mon Sep 17 00:00:00 2001 From: Muhammad Anas Raza Date: Thu, 14 Sep 2023 23:39:54 -0400 Subject: [PATCH] initial release --- focalnet_keras_core/blocks.py | 4 +- focalnet_keras_core/builders.py | 77 +++++++++++++++++++++++++++++++++ focalnet_keras_core/focalnet.py | 70 ++++++++++++++++++++++++++++++ 3 files changed, 149 insertions(+), 2 deletions(-) create mode 100644 focalnet_keras_core/builders.py create mode 100644 focalnet_keras_core/focalnet.py diff --git a/focalnet_keras_core/blocks.py b/focalnet_keras_core/blocks.py index ca888f4..42d0748 100644 --- a/focalnet_keras_core/blocks.py +++ b/focalnet_keras_core/blocks.py @@ -1,5 +1,5 @@ import keras_core as keras -import keras.backend as K +import keras_core.backend as K from focalnet_keras_core.layers import * def Mlp(hidden_features=None, dropout_rate=0., act_layer=keras.activations.gelu, out_features=None, prefix=None): @@ -85,7 +85,7 @@ def _apply(x, H, W): x = keras.layers.Add()([x_alt, x]) else: x_alt = norm_layer(name=f"{name}.norm2")(x) - x_alt = Mlp(hidden_features=dim * mlp_ratio, dropout_rate=drop, prefix=name)(x_alt) + x_alt = Mlp(hidden_features=int(dim * mlp_ratio), dropout_rate=drop, prefix=name)(x_alt) x_alt = StochasticDepth(drop_path)(x_alt) x = keras.layers.Add()([x_alt, x]) x = keras.layers.Reshape((H * W, C))(x) diff --git a/focalnet_keras_core/builders.py b/focalnet_keras_core/builders.py new file mode 100644 index 0000000..caca3c8 --- /dev/null +++ b/focalnet_keras_core/builders.py @@ -0,0 +1,77 @@ +import keras_core as keras +from focalnet_keras_core.focalnet import FocalNet + +def Model(img_size=224, **kw) -> keras.Model: + + focalnet_model = FocalNet(img_size=img_size,**kw) + + inputs = keras.Input((img_size, img_size, 3)) + outputs = focalnet_model(inputs) + final_model = keras.Model(inputs, outputs ) + + return final_model + + + +def focalnet_tiny_srf(**kwargs): + model = Model(depths=[2, 2, 6, 2], embed_dim=96, **kwargs) + return model + +def focalnet_small_srf( **kwargs): + model = Model(depths=[2, 2, 18, 2], embed_dim=96, **kwargs) + return model + +def focalnet_base_srf(**kwargs): + model = Model(depths=[2, 2, 18, 2], embed_dim=128, **kwargs) + return model + +def focalnet_tiny_lrf(**kwargs): + model = Model(depths=[2, 2, 6, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs) + return model + +def focalnet_small_lrf(**kwargs): + model = Model(depths=[2, 2, 18, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs) + + return model + +def focalnet_base_lrf(**kwargs): + model = Model(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], **kwargs) + return model + +def focalnet_tiny_iso_16(**kwargs): + model = Model(depths=[12], patch_size=16, embed_dim=192, focal_levels=[3], focal_windows=[3], **kwargs) + return model + +def focalnet_small_iso_16(**kwargs): + model = Model(depths=[12], patch_size=16, embed_dim=384, focal_levels=[3], focal_windows=[3], **kwargs) + return model + +def focalnet_base_iso_16(**kwargs): + model = Model(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], focal_windows=[3], use_layerscale=True, use_postln=True, **kwargs) + return model + +# FocalNet large+ models +def focalnet_large_fl3(**kwargs): + model = Model(depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[3, 3, 3, 3], **kwargs) + return model + +def focalnet_large_fl4(**kwargs): + model = Model(depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[4, 4, 4, 4], **kwargs) + return model + +def focalnet_xlarge_fl3( **kwargs): + model = Model(depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[3, 3, 3, 3], **kwargs) + return model + + +def focalnet_xlarge_fl4( **kwargs): + model = Model(depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[4, 4, 4, 4], **kwargs) + return model + +def focalnet_huge_fl3( **kwargs): + model = Model(depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[3, 3, 3, 3], **kwargs) + return model + +def focalnet_huge_fl4( **kwargs): + model = Model(depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[4, 4, 4, 4], **kwargs) + return model \ No newline at end of file diff --git a/focalnet_keras_core/focalnet.py b/focalnet_keras_core/focalnet.py new file mode 100644 index 0000000..4594598 --- /dev/null +++ b/focalnet_keras_core/focalnet.py @@ -0,0 +1,70 @@ +import keras_core as keras +from focalnet_keras_core.layers import * +from focalnet_keras_core.blocks import * + +def FocalNet(img_size=224, + patch_size=4, + in_chans=3, + num_classes=1000, + embed_dim=128, + depths=[2, 2, 6, 2 + ], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.1, + norm_layer=keras.layers.LayerNormalization, + patch_norm=True, + use_checkpoint=False, + focal_levels=[2, 2, 3, 2], + focal_windows=[3, 2, 3, 2], + use_conv_embed=False, + use_layerscale=False, + layerscale_value=1e-4, + use_postln=False, + use_postln_in_modulation=False, + normalize_modulator=False): + num_layers = len(depths) + embed_dim = [embed_dim * (2 ** i) for i in range(num_layers)] + dpr = [ops.convert_to_numpy(x) for x in ops.linspace(0., drop_path_rate, sum(depths))] # stochastic depth decay rule + + + def _apply(x): + nonlocal num_classes + x, *patches_resolution = PatchEmbed( + img_size=(img_size, img_size), + patch_size=patch_size, + # in_chans=in_chans, + embed_dim=embed_dim[0], + use_conv_embed=use_conv_embed, + norm_layer=norm_layer if patch_norm else None, + is_stem=True)(x, img_size, img_size) + H, W = patches_resolution[0], patches_resolution[1] + x = keras.layers.Dropout(drop_rate)(x) + for i_layer in range(num_layers): + x, H, W = BasicLayer(dim=embed_dim[i_layer], + out_dim=embed_dim[i_layer+1] if (i_layer < num_layers - 1) else None, + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchEmbed if (i_layer < num_layers - 1) else None, + focal_level=focal_levels[i_layer], + focal_window=focal_windows[i_layer], + use_conv_embed=use_conv_embed, + use_layerscale=use_layerscale, + layerscale_value=layerscale_value, + use_postln=use_postln, + use_postln_in_modulation=use_postln_in_modulation, + normalize_modulator=normalize_modulator + )(x, H, W) + x = norm_layer(name='norm')(x) # B L C + x = keras.layers.GlobalAveragePooling1D()(x) #28,515,442 + x = keras.layers.Flatten()(x) + num_classes = num_classes if num_classes > 0 else None + x = keras.layers.Dense(num_classes, name='head')(x) + return x + + return _apply \ No newline at end of file