In [9]:
import pandas as pd

import torch
import torch.utils.benchmark as benchmark

from models import srgan, esrgan, mesrgan

In [2]:
device = torch.device("cuda:1")

In [3]:
# Define networks
generator_names = ['SRGAN', 'ESRGAN', 'MESRGAN_T2']
generator_models = [srgan, esrgan, mesrgan]
generator_trained_paths = ['trained_models/' + name + '_ALL_DATA_stage2_generator.trch' for name in generator_names]

In [4]:
batch_sizes = [1, 2, 4, 8, 16]
forward_ts = {name: {batch_size: [] for batch_size in batch_sizes} for name in generator_names}

def forward_test(generator, inp):
    with torch.no_grad():
        return generator(inp)

# run the test
package = list(zip(generator_names, generator_models, generator_trained_paths))
for batch_size in batch_sizes:
    inp = torch.randn(batch_size, 3, 128, 128)
    inp = inp.to(device)
    
    for name, model, pretrain_path in package:
        generator = model.Generator()
        generator.load_state_dict(torch.load(pretrain_path))
        generator.to(device)
        generator.eval()
        
        print(f'Batch size: {batch_size}, model: {name}')
        
        forward_ts[name][batch_size] = benchmark.Timer(
            stmt='forward_test(generator, inp)',
            setup='from __main__ import forward_test', 
            globals={'generator': generator, 'inp': inp}
        ).timeit(100).mean
        

Batch size: 1, model: SRGAN
Batch size: 1, model: ESRGAN
Batch size: 1, model: MESRGAN_T2
Batch size: 2, model: SRGAN
Batch size: 2, model: ESRGAN
Batch size: 2, model: MESRGAN_T2
Batch size: 4, model: SRGAN
Batch size: 4, model: ESRGAN
Batch size: 4, model: MESRGAN_T2
Batch size: 8, model: SRGAN
Batch size: 8, model: ESRGAN
Batch size: 8, model: MESRGAN_T2
Batch size: 16, model: SRGAN
Batch size: 16, model: ESRGAN
Batch size: 16, model: MESRGAN_T2


In [14]:
# Put into pandas DataFrame and format it properly
df = pd.DataFrame(forward_ts)
df = df.rename(columns={'MESRGAN_T2': 'MESRGAN'})
df = df.melt(var_name='Model', value_name='Time (s)', ignore_index=False)
df = df.reset_index()
df = df.rename(columns={'index': 'Batch size'})
df.to_pickle("benchmark.pkl")

In [33]:
print((1000*df.pivot(index="Batch size", columns="Model")).to_latex())

\begin{tabular}{lrrr}
\toprule
{} & \multicolumn{3}{l}{Time (s)} \\
Model &      ESRGAN &      MESRGAN &       SRGAN \\
Batch size &             &              &             \\
\midrule
1          &   33.444209 &   155.157632 &   26.442121 \\
2          &   62.940485 &   299.802770 &   50.596326 \\
4          &  126.464980 &   596.273267 &  134.314706 \\
8          &  266.134177 &  1204.886182 &  279.197597 \\
16         &  534.045201 &  2402.361290 &  561.801663 \\
\bottomrule
\end{tabular}

