In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os

import pandas as pd
from IPython.display import HTML, display

In [4]:
MODELs_VERSIONS = {
    "LCNN": 0,
    "RawNet2": 0,
    "RawGAT": 0,
    "Wave2Vec2": 0,
    "WaveLM": 0,
    "LibriSeVoc": 0,
    "AudioClip": 0,
    "Wav2Clip": 0,
    "AASIST": 0,
    "SFATNet" : 1,
    "ASDG" : 1,
    "Ours/ResNet": 0,
}
models = MODELs_VERSIONS.keys()
versions = MODELs_VERSIONS.values()

In [5]:
def get_total_params(file_path):
    with open(file_path, 'r') as file:
        for line in file:
            # Look for the line that contains 'Total params'
            if 'Total params' in line:
                # Split the line and get the value
                parts = line.split()
                total_params = ' '.join(parts[0:2])  # The first element in the split list should be the number
                return total_params
    return None  # Return None if 'Total params' is not found

In [6]:
ROOT_PATH = "/home/ay/data/DATA/1-model_save/00-Deepfake/1-df-audio"
def read_test_result(model, task='ASV2021_inner', version=-1, metric_prefix="test", file_name="metrics"):
    save_path = f"{ROOT_PATH}/{model}/{task}/version_{version}"
    csv_path = os.path.join(save_path, f"{file_name}.csv")

    if not os.path.exists(csv_path):
        print("Warning!!!! cannot find: ", csv_path)
        return None

    data = pd.read_csv(csv_path)


    keys = ['train/device/samples_per_sec', 'validate/device/samples_per_sec', 'flops_train']
    values = []
    for key in keys:
        v = data[key].loc[data[key].last_valid_index()]
        values.append(v)
    
    res = {k:v for k, v in zip(keys, values)}
    res['model'] = model
    res['param'] = get_total_params(os.path.join(save_path, f"model.txt"))
    
    return res

In [7]:
res  = read_test_result('AASIST')
res.keys()

dict_keys(['train/device/samples_per_sec', 'validate/device/samples_per_sec', 'flops_train', 'model', 'param'])

In [8]:
data = pd.DataFrame(columns=['train/device/samples_per_sec', 'validate/device/samples_per_sec', 'flops_train', 'model', 'param'])
for model, version in zip(models, versions):
    res = read_test_result(model)
    if res is not None:
        data.loc[len(data)] = res



In [9]:
print(data[['model', 'param', 'flops_train','train/device/samples_per_sec', 'validate/device/samples_per_sec',]].set_index('model').to_latex())

\begin{tabular}{llrrr}
\toprule
 & param & flops_train & train/device/samples_per_sec & validate/device/samples_per_sec \\
model &  &  &  &  \\
\midrule
LCNN & 682 K & 0.255296 & 503.529886 & 1734.575040 \\
RawNet2 & 17.7 M & 1.199694 & 562.006842 & 1332.467762 \\
RawGAT & 440 K & 13.685228 & 89.333039 & 258.348062 \\
WaveLM & 94.4 M & 20.791399 & 257.359792 & 840.903328 \\
LibriSeVoc & 17.7 M & 1.199700 & 485.366909 & 1525.989388 \\
AudioClip & 134 M & 3.268036 & 557.176202 & 1942.107983 \\
Wav2Clip & 11.7 M & 1.346389 & 876.699565 & 2060.610695 \\
AASIST & 297 K & 7.100886 & 158.955672 & 461.367023 \\
SFATNet & 81.4 M & 16.303932 & 364.646279 & 923.170679 \\
ASDG & 1.1 M & 0.341719 & 1005.621651 & 1213.257123 \\
Ours/ResNet & 22.5 M & 3.211223 & 233.299355 & 1841.332878 \\
\bottomrule
\end{tabular}



In [10]:
import copy
import torch
import torchvision.models as models
from torchtnt.utils.flops import FlopTensorDispatchMode

module = models.resnet18()
module_input = torch.randn(1, 3, 224, 224)
with FlopTensorDispatchMode(module) as ftdm:
    # count forward flops
    res = module(module_input).mean()
    flops_forward = copy.deepcopy(ftdm.flop_counts)

    # reset count before counting backward flops
    ftdm.reset()
    res.backward()
    flops_backward = copy.deepcopy(ftdm.flop_counts)
print(flops_forward[''])

defaultdict(<class 'int'>, {'convolution.default': 1813561344, 'addmm.default': 512000})


In [15]:
from lightning.fabric.utilities.throughput import measure_flops
with torch.device("meta"):
    model = models.resnet18()
    x = torch.randn(1, 3, 224, 224)

model_fwd = lambda: model(x)
fwd_flops = measure_flops(model, model_fwd)

In [14]:
fwd_flops / 2

1814073344.0