In [1]:
import os

dataset_dir = r'/root/autodl-tmp/dataset/anime_face_all'
train_dataset_name = dataset_dir.split('/')[-1]

model_dir = f'/root/autodl-tmp/trained_models/{train_dataset_name}'
generated_img_dir = f'/root/autodl-tmp/generated_img/{train_dataset_name}'
pure_train_img_dir = f'/root/autodl-tmp/pure_train_img/{train_dataset_name}'

generate_epoch = 2
generate_batch = 64
num_of_model = 100

model_files = os.listdir(model_dir)

model_files = model_files[0:num_of_model]
model_files

['1000']

# 保存训练图片

In [None]:
from datasets import load_from_disk
from tqdm import tqdm

dataset = load_from_disk(dataset_dir)
dataset = dataset['train']

os.makedirs(pure_train_img_dir, exist_ok=True)

for i in tqdm(range(len(dataset))):
    image = dataset[i]['image']
    image = image.resize((128, 128))
    image = image.convert('RGB')
    image_name = '{}.png'.format(i)
    image.save(os.path.join(pure_train_img_dir, image_name))

# 生成图片

In [2]:
from diffusers import DDPMPipeline
from datetime import datetime
import json

generate_time = {}
generate_img_dir_list = []

for model_file in model_files:

    # get inference step
    num_inference_steps = int(model_file)

    # create generate image dir
    current_generate_img_dir = os.path.join(generated_img_dir, model_file)
    generate_img_dir_list.append(current_generate_img_dir)
    os.makedirs(current_generate_img_dir, exist_ok=True)

    # load pipeline
    model_file = os.path.join(model_dir, model_file)
    pipeline = DDPMPipeline.from_pretrained(model_file)
    pipeline.to('cuda')

    # generate image
    generate_imgs = []
    start = datetime.now()
    for _ in range(generate_epoch):
        images = pipeline(generate_batch, num_inference_steps=num_inference_steps).images
        generate_imgs.extend(images)
    end = datetime.now()
    generate_time[num_inference_steps] = str(end - start)

    # save image
    for i in range(len(generate_imgs)):
        current_image = generate_imgs[i]
        current_image_file = '{}.png'.format(i)
        current_image.save(os.path.join(current_generate_img_dir, current_image_file))

print(json.dumps(generate_time))
print(generate_img_dir_list)

Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

{"1000": "0:08:25.447387"}
['/root/autodl-tmp/generated_img/anime_face_all/1000']


# 计算fid

```bash
python -m pytorch_fid /root/autodl-tmp/pure_train_img/oxford_flower /root/autodl-tmp/generated_img/oxford_flower/100
```

In [None]:
import subprocess

fid_list = {}
for current_dir in generate_img_dir_list:
    result = subprocess.run(['python', 
                            '-m', 
                            'pytorch_fid', 
                            pure_train_img_dir,
                            current_dir], stdout=subprocess.PIPE).stdout.decode('utf-8')
    fid_list[current_dir] = result
json.dumps(fid_list)

In [2]:
import json

# tmp = {}
# for key in anime_face_fid:
#     inference_step = key.split('/')[-1]
#     fid = anime_face_fid[key]
#     tmp[int(inference_step)] = round(float(fid[6:-2]), 2)
# tmp = dict(sorted(tmp.items()))
# json.dumps(tmp)

# oxford_flower
oxford_flower_time = {"100": "0:00:50.620381", "1000": "0:08:26.066021", "1100": "0:09:16.761789", "1200": "0:10:07.138548", "1300": "0:10:57.510178", "1400": "0:11:48.029013", "1500": "0:12:38.722301", "1600": "0:13:28.928998", "1700": "0:14:19.622817", "1800": "0:15:10.646142", "1900": "0:16:01.634605", "200": "0:01:41.190131", "2000": "0:16:52.135326", "300": "0:02:31.768917", "400": "0:03:22.217893", "500": "0:04:13.146280", "600": "0:05:03.682271", "700": "0:05:54.193482", "800": "0:06:44.669052", "900": "0:07:35.356787"}
oxford_flower_fid = {"/root/autodl-tmp/generated_img/oxford_flower/100": "FID:  369.8776792058484\\n", "/root/autodl-tmp/generated_img/oxford_flower/1000": "FID:  142.55629752929576\\n", "/root/autodl-tmp/generated_img/oxford_flower/1100": "FID:  146.50291275199405\\n", "/root/autodl-tmp/generated_img/oxford_flower/1200": "FID:  141.07067301339657\\n", "/root/autodl-tmp/generated_img/oxford_flower/1300": "FID:  127.78704059762066\\n", "/root/autodl-tmp/generated_img/oxford_flower/1400": "FID:  140.60170796390918\\n", "/root/autodl-tmp/generated_img/oxford_flower/1500": "FID:  132.09421200167884\\n", "/root/autodl-tmp/generated_img/oxford_flower/1600": "FID:  145.09557986584895\\n", "/root/autodl-tmp/generated_img/oxford_flower/1700": "FID:  139.22624071648505\\n", "/root/autodl-tmp/generated_img/oxford_flower/1800": "FID:  134.85808825226536\\n", "/root/autodl-tmp/generated_img/oxford_flower/1900": "FID:  139.89393895547383\\n", "/root/autodl-tmp/generated_img/oxford_flower/200": "FID:  208.7701907995583\\n", "/root/autodl-tmp/generated_img/oxford_flower/2000": "FID:  144.0710941787677\\n", "/root/autodl-tmp/generated_img/oxford_flower/300": "FID:  176.45799993799216\\n", "/root/autodl-tmp/generated_img/oxford_flower/400": "FID:  150.35638167998042\\n", "/root/autodl-tmp/generated_img/oxford_flower/500": "FID:  125.7740520129102\\n", "/root/autodl-tmp/generated_img/oxford_flower/600": "FID:  116.52463441937329\\n", "/root/autodl-tmp/generated_img/oxford_flower/700": "FID:  117.45831168501707\\n", "/root/autodl-tmp/generated_img/oxford_flower/800": "FID:  122.94940736170577\\n", "/root/autodl-tmp/generated_img/oxford_flower/900": "FID:  133.5546878883864\\n"}
oxford_flower_fid = {"100": 369.88, "200": 208.77, "300": 176.46, "400": 150.36, "500": 125.77, "600": 116.52, "700": 117.46, "800": 122.95, "900": 133.55, "1000": 142.56, "1100": 146.5, "1200": 141.07, "1300": 127.79, "1400": 140.6, "1500": 132.09, "1600": 145.1, "1700": 139.23, "1800": 134.86, "1900": 139.89, "2000": 144.07}

# smithsonian_butterfly
smithsonian_butterfly_time = {"100": "0:00:50.759654", "200": "0:01:41.229890", "300": "0:02:31.938889", "400": "0:03:22.673759", "500": "0:04:13.229171", "600": "0:05:03.903153", "700": "0:05:54.435045", "800": "0:06:45.308930", "900": "0:07:35.241919", "1000": "0:08:26.510326", "1100": "0:09:16.844281", "1200": "0:10:07.483987", "1300": "0:10:57.921574", "1400": "0:11:48.652164", "1500": "0:12:39.161810", "1600": "0:13:29.898042", "1700": "0:14:20.534768", "1800": "0:15:11.547640", "1900": "0:16:02.031062", "2000": "0:16:52.207837"}
smithsonian_butterfly_fid = {"/root/autodl-tmp/generated_img/smithsonian_butterfly/100": "FID:  464.2724543721416\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/200": "FID:  414.5440216196572\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/300": "FID:  436.98839668147986\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/400": "FID:  371.7241466508707\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/500": "FID:  349.5808687373876\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/600": "FID:  325.8778825706849\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/700": "FID:  316.47545036237176\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/800": "FID:  317.7606031161963\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/900": "FID:  314.48618132691195\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/1000": "FID:  313.2163031170472\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/1100": "FID:  319.56438650257286\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/1200": "FID:  316.08401619145786\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/1300": "FID:  314.73984079334906\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/1400": "FID:  314.3764072779564\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/1500": "FID:  314.95715263423733\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/1600": "FID:  315.8355223364883\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/1700": "FID:  308.777842887879\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/1800": "FID:  315.87201668660316\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/1900": "FID:  309.6636257749532\\n", "/root/autodl-tmp/generated_img/smithsonian_butterfly/2000": "FID:  321.5550042998717\\n"}
smithsonian_butterfly_fid = {"100": 464.27, "200": 414.54, "300": 436.99, "400": 371.72, "500": 349.58, "600": 325.88, "700": 316.48, "800": 317.76, "900": 314.49, "1000": 313.22, "1100": 319.56, "1200": 316.08, "1300": 314.74, "1400": 314.38, "1500": 314.96, "1600": 315.84, "1700": 308.78, "1800": 315.87, "1900": 309.66, "2000": 321.56}

# anime face
anime_face_time = {"100": "0:00:50.658891", "200": "0:01:41.059113", "300": "0:02:31.702856", "400": "0:03:22.221767", "500": "0:04:12.774080", "600": "0:05:03.319826", "700": "0:05:53.889197", "800": "0:06:44.373179", "900": "0:07:34.993387", "1000": "0:08:25.538796", "1100": "0:09:16.068998", "1200": "0:10:06.696925", "1300": "0:10:57.223643", "1400": "0:11:47.775845", "1500": "0:12:38.344527", "1600": "0:13:28.913110", "1700": "0:14:19.436052", "1800": "0:15:10.068166", "1900": "0:16:00.703198", "2000": "0:16:51.148078"}
anime_face_fid = {"/root/autodl-tmp/generated_img/anime_face/100": "FID:  313.4714258936905\\n", "/root/autodl-tmp/generated_img/anime_face/200": "FID:  270.7401079593031\\n", "/root/autodl-tmp/generated_img/anime_face/300": "FID:  229.20210225700657\\n", "/root/autodl-tmp/generated_img/anime_face/400": "FID:  119.71961035097232\\n", "/root/autodl-tmp/generated_img/anime_face/500": "FID:  100.95760680481968\\n", "/root/autodl-tmp/generated_img/anime_face/600": "FID:  95.54623592856649\\n", "/root/autodl-tmp/generated_img/anime_face/700": "FID:  96.3206619396052\\n", "/root/autodl-tmp/generated_img/anime_face/800": "FID:  96.50408440971233\\n", "/root/autodl-tmp/generated_img/anime_face/900": "FID:  109.63730507537554\\n", "/root/autodl-tmp/generated_img/anime_face/1000": "FID:  114.21639966805992\\n", "/root/autodl-tmp/generated_img/anime_face/1100": "FID:  114.28623965508461\\n", "/root/autodl-tmp/generated_img/anime_face/1200": "FID:  115.73106375887062\\n", "/root/autodl-tmp/generated_img/anime_face/1300": "FID:  107.03419744180826\\n", "/root/autodl-tmp/generated_img/anime_face/1400": "FID:  117.40251490755358\\n", "/root/autodl-tmp/generated_img/anime_face/1500": "FID:  111.05104751389939\\n", "/root/autodl-tmp/generated_img/anime_face/1600": "FID:  113.85158925703797\\n", "/root/autodl-tmp/generated_img/anime_face/1700": "FID:  117.59748429172805\\n", "/root/autodl-tmp/generated_img/anime_face/1800": "FID:  111.05559241833654\\n", "/root/autodl-tmp/generated_img/anime_face/1900": "FID:  118.79752567379126\\n", "/root/autodl-tmp/generated_img/anime_face/2000": "FID:  109.04190712543053\\n"}
anime_face_fid = {"100": 313.47, "200": 270.74, "300": 229.2, "400": 119.72, "500": 100.96, "600": 95.55, "700": 96.32, "800": 96.5, "900": 109.64, "1000": 114.22, "1100": 114.29, "1200": 115.73, "1300": 107.03, "1400": 117.4, "1500": 111.05, "1600": 113.85, "1700": 117.6, "1800": 111.06, "1900": 118.8, "2000": 109.04}

In [None]:
table_str = ''
for i in range(100, 2001, 100):
    i = str(i)
    table_str += i + \
        ' & ' + str(oxford_flower_fid[i]) \
              + ' & ' + str(smithsonian_butterfly_fid[i]) \
                + ' & ' + str(anime_face_fid[i])  + ' \\\\\n\\hline\n'
print(table_str)

In [4]:
from datetime import timedelta

def convert_time(s):
    hours, minutes, rest = s.split(":")
    seconds, _ = rest.split(".")
    result = timedelta(hours=int(hours), minutes=int(minutes), seconds=int(seconds)) * 10
    return str(result), result.seconds // 60

table_str = ''
for i in range(100, 2001, 100):
    i = str(i)
    table_str += i + \
        ' & ' + str(convert_time(oxford_flower_time[i])[0])  + ' \\\\\n\\hline\n'
print(table_str)

100 & 0:08:20 \\
\hline
200 & 0:16:50 \\
\hline
300 & 0:25:10 \\
\hline
400 & 0:33:40 \\
\hline
500 & 0:42:10 \\
\hline
600 & 0:50:30 \\
\hline
700 & 0:59:00 \\
\hline
800 & 1:07:20 \\
\hline
900 & 1:15:50 \\
\hline
1000 & 1:24:20 \\
\hline
1100 & 1:32:40 \\
\hline
1200 & 1:41:10 \\
\hline
1300 & 1:49:30 \\
\hline
1400 & 1:58:00 \\
\hline
1500 & 2:06:20 \\
\hline
1600 & 2:14:40 \\
\hline
1700 & 2:23:10 \\
\hline
1800 & 2:31:40 \\
\hline
1900 & 2:40:10 \\
\hline
2000 & 2:48:40 \\
\hline



In [None]:
line_chart_str = ''
tmp = []
for i in range(100, 2001, 100):
    i = str(i)
    tmp.append('(' + i + ',' + str(anime_face_fid[i]) + ')')
print(''.join(tmp))

In [None]:
from datetime import timedelta

def convert_time(s):
    hours, minutes, rest = s.split(":")
    seconds, _ = rest.split(".")
    result = timedelta(hours=int(hours), minutes=int(minutes), seconds=int(seconds)) * 10
    return str(result), result.seconds // 60

time_list = oxford_flower_time
fid_list = oxford_flower_fid
experiment_result = {}

for key in fid_list:
    fid = fid_list[key]
    fid = fid[6:-2]
    inference_step = key.split('/')[-1]
    experiment_result[int(inference_step)] = {
        'time': time_list[inference_step],
        'fid': fid
    }

experiment_result = dict(sorted(experiment_result.items()))
# print(experiment_result)

table_str = ''
figure_str = ''
fid_figure_str = ''

for inference_step in experiment_result:
    time, time_minutes = convert_time(experiment_result[inference_step]['time'])
    fid = float(experiment_result[inference_step]['fid'])
    fid = str(round(fid, 2))
    table_str += str(inference_step) + ' & ' + time + ' & ' + fid + ' \\\\\n\\hline\n'

    figure_str += '(' + str(inference_step) + ',' + str(time_minutes) + ')'
    fid_figure_str += '(' + str(inference_step) + ',' + str(fid) + ')'

# print(figure_str)
# print(table_str)
print(fid_figure_str)