Skip to content

Commit

Permalink
calculate the FLOPS and parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
RebornForPower committed Nov 11, 2021
1 parent 10f55c8 commit 556909d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 20 deletions.
23 changes: 14 additions & 9 deletions basicsr/archs/focalir_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@


import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from basicsr.utils.registry import ARCH_REGISTRY

from basicsr.archs.arch_util import to_2tuple, trunc_normal_
from basicsr.utils.registry import ARCH_REGISTRY
from thop import profile as hp


def drop_path(x, drop_prob: float = 0., training: bool = False):
Expand Down Expand Up @@ -1444,13 +1447,12 @@ def profile(model, inputs):


if __name__ == '__main__':
img_hsize = 224
img_wsize = 224
img_hsize = 320
img_wsize = 180
x = torch.rand(1, 3, img_hsize, img_wsize).cuda()
model = FocalIR(img_size=(img_hsize, img_wsize), upscale=2, embed_dim=96, depths=[6, 6, 6, 6], drop_path_rate=0.2,
focal_levels=[2, 2, 2, 2], expand_sizes=[3, 3, 3, 3], expand_layer="all",
num_heads=[4, 4, 4, 4],
focal_windows=[7, 7, 7, 7], window_size=4, use_shift=False).cuda()
model = FocalIR(img_size=(img_hsize, img_wsize), upscale=4, in_chans=3, embed_dim=60, depths=[6, 6, 6, 6], drop_path_rate=0.2,
focal_levels=[2, 2, 2, 2], expand_sizes=[3, 3, 3, 3], expand_layer="all",num_heads=[6, 6, 6, 6],
focal_windows=[7, 5, 3, 1], mlp_ratio=2, upsampler='pixelshuffle', window_size=4, resi_connection='1conv', use_shift=False).cuda()

model.eval()

Expand All @@ -1460,6 +1462,9 @@ def profile(model, inputs):
#n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
#print(f"number of params: {n_parameters}")

y = model(x)
print(y.shape)
flops, params = hp(model, inputs=(x,))

print("FLOPs=", str(flops / 1e9) + '{}'.format("G"))
print("params=", str(params / 1e6) + '{}'.format("M"))

#profile(model, x)
27 changes: 17 additions & 10 deletions basicsr/archs/swinir_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from thop import profile as hp

from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import to_2tuple, trunc_normal_
from basicsr.archs.arch_util import to_2tuple, trunc_normal_


def drop_path(x, drop_prob: float = 0., training: bool = False):
Expand Down Expand Up @@ -935,11 +936,13 @@ def flops(self):

if __name__ == '__main__':
upscale = 4
window_size = 8
height = (1024 // upscale // window_size + 1) * window_size
width = (720 // upscale // window_size + 1) * window_size
window_size = 4
# height = (1024 // upscale // window_size + 1) * window_size
# width = (720 // upscale // window_size + 1) * window_size
height = 320
width = 180
model = SwinIR(
upscale=2,
upscale=4,
img_size=(height, width),
window_size=window_size,
img_range=1.,
Expand All @@ -948,9 +951,13 @@ def flops(self):
num_heads=[6, 6, 6, 6],
mlp_ratio=2,
upsampler='pixelshuffledirect')
print(model)
print(height, width, model.flops() / 1e9)

# print(model)
# print(height, width, model.flops() / 1e9)
model.eval()
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
# x = model(x)
# print(x.shape)
flops, params = hp(model, inputs=(x,))

print("FLOPs=", str(flops / 1e9) + '{}'.format("G"))
print("params=", str(params / 1e6) + '{}'.format("M"))
2 changes: 1 addition & 1 deletion options/train/FocalIR/train_focalIR_SRx4_scratch.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# general settings
name: train_FocalIR_SRx4_s48g96_DIV2K
name: train_test_FocalIR_SRx4_s48g96_DIV2K
model_type: FocalIRModel
scale: 4
num_gpu: 1
Expand Down

0 comments on commit 556909d

Please sign in to comment.