diff --git a/basicsr/archs/focalir_arch.py b/basicsr/archs/focalir_arch.py index 1c2aa6886..c3892dd5b 100644 --- a/basicsr/archs/focalir_arch.py +++ b/basicsr/archs/focalir_arch.py @@ -17,12 +17,15 @@ import math + import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint -from basicsr.utils.registry import ARCH_REGISTRY + from basicsr.archs.arch_util import to_2tuple, trunc_normal_ +from basicsr.utils.registry import ARCH_REGISTRY +from thop import profile as hp def drop_path(x, drop_prob: float = 0., training: bool = False): @@ -1444,13 +1447,12 @@ def profile(model, inputs): if __name__ == '__main__': - img_hsize = 224 - img_wsize = 224 + img_hsize = 320 + img_wsize = 180 x = torch.rand(1, 3, img_hsize, img_wsize).cuda() - model = FocalIR(img_size=(img_hsize, img_wsize), upscale=2, embed_dim=96, depths=[6, 6, 6, 6], drop_path_rate=0.2, - focal_levels=[2, 2, 2, 2], expand_sizes=[3, 3, 3, 3], expand_layer="all", - num_heads=[4, 4, 4, 4], - focal_windows=[7, 7, 7, 7], window_size=4, use_shift=False).cuda() + model = FocalIR(img_size=(img_hsize, img_wsize), upscale=4, in_chans=3, embed_dim=60, depths=[6, 6, 6, 6], drop_path_rate=0.2, + focal_levels=[2, 2, 2, 2], expand_sizes=[3, 3, 3, 3], expand_layer="all",num_heads=[6, 6, 6, 6], + focal_windows=[7, 5, 3, 1], mlp_ratio=2, upsampler='pixelshuffle', window_size=4, resi_connection='1conv', use_shift=False).cuda() model.eval() @@ -1460,6 +1462,9 @@ def profile(model, inputs): #n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) #print(f"number of params: {n_parameters}") - y = model(x) - print(y.shape) + flops, params = hp(model, inputs=(x,)) + + print("FLOPs=", str(flops / 1e9) + '{}'.format("G")) + print("params=", str(params / 1e6) + '{}'.format("M")) + #profile(model, x) diff --git a/basicsr/archs/swinir_arch.py b/basicsr/archs/swinir_arch.py index f3e9e2c54..c688a600b 100644 --- a/basicsr/archs/swinir_arch.py +++ b/basicsr/archs/swinir_arch.py @@ -6,9 +6,10 @@ import torch import torch.nn as nn import torch.utils.checkpoint as checkpoint +from thop import profile as hp from basicsr.utils.registry import ARCH_REGISTRY -from .arch_util import to_2tuple, trunc_normal_ +from basicsr.archs.arch_util import to_2tuple, trunc_normal_ def drop_path(x, drop_prob: float = 0., training: bool = False): @@ -935,11 +936,13 @@ def flops(self): if __name__ == '__main__': upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size + window_size = 4 + # height = (1024 // upscale // window_size + 1) * window_size + # width = (720 // upscale // window_size + 1) * window_size + height = 320 + width = 180 model = SwinIR( - upscale=2, + upscale=4, img_size=(height, width), window_size=window_size, img_range=1., @@ -948,9 +951,13 @@ def flops(self): num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - + # print(model) + # print(height, width, model.flops() / 1e9) + model.eval() x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) + # x = model(x) + # print(x.shape) + flops, params = hp(model, inputs=(x,)) + + print("FLOPs=", str(flops / 1e9) + '{}'.format("G")) + print("params=", str(params / 1e6) + '{}'.format("M")) diff --git a/options/train/FocalIR/train_focalIR_SRx4_scratch.yml b/options/train/FocalIR/train_focalIR_SRx4_scratch.yml index ebd428b42..5d0c1944f 100644 --- a/options/train/FocalIR/train_focalIR_SRx4_scratch.yml +++ b/options/train/FocalIR/train_focalIR_SRx4_scratch.yml @@ -1,5 +1,5 @@ # general settings -name: train_FocalIR_SRx4_s48g96_DIV2K +name: train_test_FocalIR_SRx4_s48g96_DIV2K model_type: FocalIRModel scale: 4 num_gpu: 1