In [1]:
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from PIL import Image

In [2]:
from network.mynetwork_uu import Unet as Uunet
from network.mynetwork import Unet
from network.mynetwork_cmp import Unet as Unet_cmp
from network.styler import Unet as Unet_styler
from loss.loss import CLIPLoss


In [3]:
from thop import profile
from torchsummary import summary


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"

input_pic = torch.ones(1, 3, 512, 512).to(device)
input1 = torch.ones(1, 3, 244, 244).to(device).long()
input2 = torch.ones(1, 77).to(device).long()

In [5]:
print("clip")
model = CLIPLoss(device).to(device)
model_clip = model.model
clip_flops, clip_params = profile(model_clip, inputs=(input1, input2), verbose=False)
print('FLOPs = ' + str(clip_flops/(1000**3)) + 'G')
print('Params = ' + str(clip_params/(1000**2)) + 'M')

clip
FLOPs = 4.884529152G
Params = 84.225024M


In [6]:
print("vgg19")
model_vgg = torchvision.models.vgg19(pretrained=True).features.to(device)
vgg_flops, vgg_params = profile(model_vgg, inputs=(input_pic,), verbose=False)
print('FLOPs = ' + str(vgg_flops/(1000**3)) + 'G')
print('Params = ' + str(vgg_params/(1000**2)) + 'M')

vgg19
FLOPs = 101.9215872G
Params = 20.024384M


In [7]:
print("mine")
model_mine = Uunet(device).to(device)
mine_flops, mine_params = profile(model_mine, inputs=(input_pic,), verbose=False)
print('FLOPs = ' + str((mine_flops+clip_flops)/(1000**3)) + 'G')
print('Params = ' + str((mine_params+clip_params)/(1000**2)) + 'M')



mine
FLOPs = 56.077058048G
Params = 86.960163M


In [8]:
print("mine_u")
model_mine_u = Unet(device).to(device)
mine_u_flops, mine_u_params = profile(model_mine_u, inputs=(input_pic,), verbose=False)
print('FLOPs = ' + str((mine_u_flops+clip_flops)/(1000**3)) + 'G')
print('Params = ' + str((mine_u_params+clip_params)/(1000**2)) + 'M')

mine_u
FLOPs = 58.46990848G
Params = 86.858755M


In [9]:
# print("cmp")
# model = Unet_cmp(device).to(device)
# cmp_flops, cmp_params = profile(model, inputs=(input_pic,), verbose=False)
# print('FLOPs = ' + str((cmp_flops+clip_flops+vgg_flops)/(1000**3)) + 'G')
# print('Params = ' + str((cmp_params+clip_params+vgg_params)/(1000**2)) + 'M')

In [10]:
print("styler")
model = Unet_styler().to(device)
sty_flops, sty_params = profile(model, inputs=(input_pic,), verbose=False)
print('FLOPs = ' + str((sty_flops+clip_flops+vgg_flops)/(1000**3)) + 'G')
print('Params = ' + str((sty_params+clip_params+vgg_params)/(1000**2)) + 'M')



styler
FLOPs = 123.059044352G
Params = 104.864675M


In [11]:
print("vgg")
summary(model_vgg,input_size=input_pic.shape[1:])

vgg
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 512, 512]           1,792
              ReLU-2         [-1, 64, 512, 512]               0
            Conv2d-3         [-1, 64, 512, 512]          36,928
              ReLU-4         [-1, 64, 512, 512]               0
         MaxPool2d-5         [-1, 64, 256, 256]               0
            Conv2d-6        [-1, 128, 256, 256]          73,856
              ReLU-7        [-1, 128, 256, 256]               0
            Conv2d-8        [-1, 128, 256, 256]         147,584
              ReLU-9        [-1, 128, 256, 256]               0
        MaxPool2d-10        [-1, 128, 128, 128]               0
           Conv2d-11        [-1, 256, 128, 128]         295,168
             ReLU-12        [-1, 256, 128, 128]               0
           Conv2d-13        [-1, 256, 128, 128]         590,080
             ReLU-14        [-1, 25

In [12]:
print("mine")
summary(model_mine,input_size=input_pic.shape[1:])

mine
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ReflectionPad2d-1          [-1, 3, 514, 514]               0
            Conv2d-2         [-1, 32, 512, 512]             896
    InstanceNorm2d-3         [-1, 32, 512, 512]               0
              ReLU-4         [-1, 32, 512, 512]               0
            Conv2d-5         [-1, 32, 512, 512]           9,248
    InstanceNorm2d-6         [-1, 32, 512, 512]               0
              ReLU-7         [-1, 32, 512, 512]               0
          ResBlock-8         [-1, 32, 512, 512]               0
            Conv2d-9         [-1, 64, 512, 512]          18,496
   InstanceNorm2d-10         [-1, 64, 512, 512]               0
             ReLU-11         [-1, 64, 512, 512]               0
        MaxPool2d-12         [-1, 64, 256, 256]               0
   InstanceNorm2d-13         [-1, 64, 256, 256]               0
             ReLU-14         [-1, 

In [13]:
print("mine_u")
summary(model_mine_u,input_size=input_pic.shape[1:])


mine_u
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ReflectionPad2d-1          [-1, 3, 514, 514]               0
            Conv2d-2         [-1, 32, 512, 512]             896
    InstanceNorm2d-3         [-1, 32, 512, 512]               0
              ReLU-4         [-1, 32, 512, 512]               0
            Conv2d-5         [-1, 32, 512, 512]           9,248
    InstanceNorm2d-6         [-1, 32, 512, 512]               0
              ReLU-7         [-1, 32, 512, 512]               0
          ResBlock-8         [-1, 32, 512, 512]               0
   ReflectionPad2d-9         [-1, 32, 514, 514]               0
           Conv2d-10         [-1, 64, 512, 512]          18,496
   InstanceNorm2d-11         [-1, 64, 512, 512]               0
             ReLU-12         [-1, 64, 512, 512]               0
           Conv2d-13         [-1, 64, 512, 512]           2,112
        MaxPool2d-14         [-1

In [14]:
print("styler")
summary(model,input_size=input_pic.shape[1:])



styler
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 16, 512, 512]              64
            Conv2d-2         [-1, 16, 512, 512]           2,320
    InstanceNorm2d-3         [-1, 16, 512, 512]               0
              ReLU-4         [-1, 16, 512, 512]               0
            Conv2d-5         [-1, 16, 512, 512]           2,320
    InstanceNorm2d-6         [-1, 16, 512, 512]               0
              ReLU-7         [-1, 16, 512, 512]               0
            Conv2d-8         [-1, 16, 512, 512]             272
     EncodingBlock-9         [-1, 16, 512, 512]               0
           Conv2d-10         [-1, 16, 512, 512]           2,320
   InstanceNorm2d-11         [-1, 16, 512, 512]               0
             ReLU-12         [-1, 16, 512, 512]               0
           Conv2d-13         [-1, 16, 512, 512]           2,320
   InstanceNorm2d-14         [-1