generated from fkodom/python-repo-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
149 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import keras_core as keras | ||
from focalnet_keras_core.focalnet import FocalNet | ||
|
||
def Model(img_size=224, **kw) -> keras.Model: | ||
|
||
focalnet_model = FocalNet(img_size=img_size,**kw) | ||
|
||
inputs = keras.Input((img_size, img_size, 3)) | ||
outputs = focalnet_model(inputs) | ||
final_model = keras.Model(inputs, outputs ) | ||
|
||
return final_model | ||
|
||
|
||
|
||
def focalnet_tiny_srf(**kwargs): | ||
model = Model(depths=[2, 2, 6, 2], embed_dim=96, **kwargs) | ||
return model | ||
|
||
def focalnet_small_srf( **kwargs): | ||
model = Model(depths=[2, 2, 18, 2], embed_dim=96, **kwargs) | ||
return model | ||
|
||
def focalnet_base_srf(**kwargs): | ||
model = Model(depths=[2, 2, 18, 2], embed_dim=128, **kwargs) | ||
return model | ||
|
||
def focalnet_tiny_lrf(**kwargs): | ||
model = Model(depths=[2, 2, 6, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs) | ||
return model | ||
|
||
def focalnet_small_lrf(**kwargs): | ||
model = Model(depths=[2, 2, 18, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs) | ||
|
||
return model | ||
|
||
def focalnet_base_lrf(**kwargs): | ||
model = Model(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], **kwargs) | ||
return model | ||
|
||
def focalnet_tiny_iso_16(**kwargs): | ||
model = Model(depths=[12], patch_size=16, embed_dim=192, focal_levels=[3], focal_windows=[3], **kwargs) | ||
return model | ||
|
||
def focalnet_small_iso_16(**kwargs): | ||
model = Model(depths=[12], patch_size=16, embed_dim=384, focal_levels=[3], focal_windows=[3], **kwargs) | ||
return model | ||
|
||
def focalnet_base_iso_16(**kwargs): | ||
model = Model(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], focal_windows=[3], use_layerscale=True, use_postln=True, **kwargs) | ||
return model | ||
|
||
# FocalNet large+ models | ||
def focalnet_large_fl3(**kwargs): | ||
model = Model(depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[3, 3, 3, 3], **kwargs) | ||
return model | ||
|
||
def focalnet_large_fl4(**kwargs): | ||
model = Model(depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[4, 4, 4, 4], **kwargs) | ||
return model | ||
|
||
def focalnet_xlarge_fl3( **kwargs): | ||
model = Model(depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[3, 3, 3, 3], **kwargs) | ||
return model | ||
|
||
|
||
def focalnet_xlarge_fl4( **kwargs): | ||
model = Model(depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[4, 4, 4, 4], **kwargs) | ||
return model | ||
|
||
def focalnet_huge_fl3( **kwargs): | ||
model = Model(depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[3, 3, 3, 3], **kwargs) | ||
return model | ||
|
||
def focalnet_huge_fl4( **kwargs): | ||
model = Model(depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[4, 4, 4, 4], **kwargs) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import keras_core as keras | ||
from focalnet_keras_core.layers import * | ||
from focalnet_keras_core.blocks import * | ||
|
||
def FocalNet(img_size=224, | ||
patch_size=4, | ||
in_chans=3, | ||
num_classes=1000, | ||
embed_dim=128, | ||
depths=[2, 2, 6, 2 | ||
], | ||
mlp_ratio=4., | ||
drop_rate=0., | ||
drop_path_rate=0.1, | ||
norm_layer=keras.layers.LayerNormalization, | ||
patch_norm=True, | ||
use_checkpoint=False, | ||
focal_levels=[2, 2, 3, 2], | ||
focal_windows=[3, 2, 3, 2], | ||
use_conv_embed=False, | ||
use_layerscale=False, | ||
layerscale_value=1e-4, | ||
use_postln=False, | ||
use_postln_in_modulation=False, | ||
normalize_modulator=False): | ||
num_layers = len(depths) | ||
embed_dim = [embed_dim * (2 ** i) for i in range(num_layers)] | ||
dpr = [ops.convert_to_numpy(x) for x in ops.linspace(0., drop_path_rate, sum(depths))] # stochastic depth decay rule | ||
|
||
|
||
def _apply(x): | ||
nonlocal num_classes | ||
x, *patches_resolution = PatchEmbed( | ||
img_size=(img_size, img_size), | ||
patch_size=patch_size, | ||
# in_chans=in_chans, | ||
embed_dim=embed_dim[0], | ||
use_conv_embed=use_conv_embed, | ||
norm_layer=norm_layer if patch_norm else None, | ||
is_stem=True)(x, img_size, img_size) | ||
H, W = patches_resolution[0], patches_resolution[1] | ||
x = keras.layers.Dropout(drop_rate)(x) | ||
for i_layer in range(num_layers): | ||
x, H, W = BasicLayer(dim=embed_dim[i_layer], | ||
out_dim=embed_dim[i_layer+1] if (i_layer < num_layers - 1) else None, | ||
input_resolution=(patches_resolution[0] // (2 ** i_layer), | ||
patches_resolution[1] // (2 ** i_layer)), | ||
depth=depths[i_layer], | ||
mlp_ratio=mlp_ratio, | ||
drop=drop_rate, | ||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], | ||
norm_layer=norm_layer, | ||
downsample=PatchEmbed if (i_layer < num_layers - 1) else None, | ||
focal_level=focal_levels[i_layer], | ||
focal_window=focal_windows[i_layer], | ||
use_conv_embed=use_conv_embed, | ||
use_layerscale=use_layerscale, | ||
layerscale_value=layerscale_value, | ||
use_postln=use_postln, | ||
use_postln_in_modulation=use_postln_in_modulation, | ||
normalize_modulator=normalize_modulator | ||
)(x, H, W) | ||
x = norm_layer(name='norm')(x) # B L C | ||
x = keras.layers.GlobalAveragePooling1D()(x) #28,515,442 | ||
x = keras.layers.Flatten()(x) | ||
num_classes = num_classes if num_classes > 0 else None | ||
x = keras.layers.Dense(num_classes, name='head')(x) | ||
return x | ||
|
||
return _apply |