## FID

In [2]:
 
import torch
print('torch version:',torch.__version__)
print('cuda version:',torch.version.cuda)
if torch.cuda.is_available():
    print('cuda available')
else:
    print('cuda unavailable')
print('muti device count:', torch.cuda.device_count())

import os
print(os.path.abspath('.'))

torch version: 2.1.2+cu118
cuda version: 11.8
cuda available
muti device count: 4
/home/gao/haiwu/block-aigc


In [None]:
# import os
import torch
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
# 定义转换来加载图像数据
transform = transforms.Compose([
    transforms.ToTensor(),  # 将PIL图像或NumPy数组转换为FloatTensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])

# 加载CIFAR-10数据集
trainset = datasets.CIFAR10(root='./data/cifar', train=True, download=True, transform=transform)
print(trainset.data.shape)
# 创建temp文件夹，如果它不存在的话
os.makedirs('./temp/src', exist_ok=True)

# 为每个标签创建子文件夹并保存图片
labels = ["plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
for label_index in range(10):
    label_folder = os.path.join('./temp/src', labels[label_index])
    os.makedirs(label_folder, exist_ok=True)
    
    # 初始化计数器
    count = 0
    for i in range(len(trainset)):
        # 获取数据和标签
        data, label = trainset[i]
        if label == label_index and count < 10:
            # 保存图像
            img_name = os.path.join(label_folder, f"{count}.png")
            count += 1
            # 将数据保存为图像
            # print(data.shape) (3,32,32)
            xset = data.unsqueeze(0)
            grid = make_grid(xset, normalize=True, value_range=(-1, 1), nrow=1)
            save_image(grid, img_name)

print("图片保存完成")

In [None]:
!pip install pytorch-fid

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from pp import fid_score
# 准备真实数据分布和生成模型的图像数据
real_images_folder = './temp/src/bird/'
generated_images_folder = './temp/src/bird/'
# 加载预训练的Inception-v3模型
inception_model = torchvision.models.inception_v3(pretrained=True)
# 定义图像变换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将PIL图像或NumPy数组转换为FloatTensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])
# 计算FID距离值
fid_value = fid_score.calculate_fid_given_paths([real_images_folder, generated_images_folder],inception_model,device='cuda:0',dims=(1,3,32,32))
print('FID value:', fid_value)

## python -m pytorch_fid ./temp/src/bird ./temp/src/car --device cuda:3

## EMD

In [18]:
import torch
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from collections import defaultdict

# 加载mnist数据集
trainset = datasets.MNIST(root='./data/mnist', train=True, download=True, transform=transforms.ToTensor())
print(trainset.data.shape)

torch.Size([60000, 28, 28])


In [7]:
# 按照label划分数据集
def create_label_dict(dataset):
    label_dict = defaultdict(list)
    for image,label in dataset:
        label_dict[label].append(image)
    return label_dict
label_dict = create_label_dict(trainset)
for label in label_dict:
    print(f"Label {label} has {len(label_dict[label])} images")

Label 5 has 5421 images
Label 0 has 5923 images
Label 4 has 5842 images
Label 1 has 6742 images
Label 9 has 5949 images
Label 2 has 5958 images
Label 3 has 6131 images
Label 6 has 5918 images
Label 7 has 6265 images
Label 8 has 5851 images


In [20]:
def sample_data(label_dict, label, num_samples):
    import random
    # 从指定标签中随机抽取指定数量的样本
    data = random.sample(label_dict[label], num_samples)
    # 返回 num_sample, 28, 28的张量
    data = torch.stack(data)
    data = data.squeeze(1)
    return data
label = 0
num_samples = 10
sampled_data = sample_data(label_dict, label, num_samples)
print(f"Sampled {len(sampled_data)} images from label {label}, shape: {sampled_data.shape}")

Sampled 10 images from label 0, shape: torch.Size([10, 28, 28])


In [31]:
num_samples = 10
label1 = 1
label2 = 2
data1 = sample_data(label_dict, label1, num_samples)
labels1 = torch.full((num_samples,), label1)
data2 = sample_data(label_dict, label2, num_samples)
label2 = torch.full((num_samples,), label2)

build_images_a = torch.cat([data1,data2], dim=0) # 10个1 10个2
build_labels_a = torch.cat([labels1, label2], dim=0)
print(build_images_a.shape, build_labels_a.shape)

label3 = 3
data3 = sample_data(label_dict, label3, num_samples)
labels3 = torch.full((num_samples,), label3)
build_images_b = torch.cat([data1,data3], dim=0) # 10个3 10个1
build_labels_b = torch.cat([labels1,labels3], dim=0)
print(build_images_b.shape, build_labels_b.shape)

build_images_all = torch.cat([build_images_a, build_images_b], dim=0) # 20个1 10个2 10个3
build_labels_all = torch.cat([build_labels_a, build_labels_b], dim=0)
print(build_images_all.shape, build_labels_all.shape)

torch.Size([20, 28, 28]) torch.Size([20])
torch.Size([20, 28, 28]) torch.Size([20])
torch.Size([40, 28, 28]) torch.Size([40])


In [54]:
import numpy as np
def calculate_emd(labels1, labels2, labels_cnt=10):
    # 计算labels1的频率向量
    cnt1 = torch.zeros(labels_cnt)
    for label in labels1:
        cnt1[label] += 1
    cnt1 = cnt1 / len(labels1)
    # 计算labels2的频率向量
    cnt2 = torch.zeros(labels_cnt)
    for label in labels2:
        cnt2[label] += 1
    cnt2 = cnt2 / len(labels2)
    # 计算EMD
    emd = (cnt1-cnt2).abs().sum()
    return emd
emd = calculate_emd(build_labels_b, build_labels_a)
print(f"EMD between two datasets: {emd}")
emd = calculate_emd(build_labels_a, build_labels_all)
print(f"EMD between two datasets: {emd}")
emd = calculate_emd(build_labels_b, build_labels_all)
print(f"EMD between two datasets: {emd}")

EMD between two datasets: 1.0
EMD between two datasets: 0.5
EMD between two datasets: 0.5


## Split dataset with given avg EMD

In [55]:
import torch
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from collections import defaultdict

# 加载mnist数据集
trainset = datasets.MNIST(root='./data/mnist', train=True, download=True, transform=transforms.ToTensor())
print(trainset.data.shape)
def create_label_dict(dataset):
    label_dict = defaultdict(list)
    for image,label in dataset:
        label_dict[label].append(image)
    return label_dict
label_dict = create_label_dict(trainset)
for label in label_dict:
    print(f"Label {label} has {len(label_dict[label])} images")

torch.Size([60000, 28, 28])
Label 5 has 5421 images
Label 0 has 5923 images
Label 4 has 5842 images
Label 1 has 6742 images
Label 9 has 5949 images
Label 2 has 5958 images
Label 3 has 6131 images
Label 6 has 5918 images
Label 7 has 6265 images
Label 8 has 5851 images


In [64]:
def sample_data(label_dict, label, num_samples):
    import random
    # 从指定标签中随机抽取指定数量的样本
    data = random.sample(label_dict[label], num_samples)
    # 返回 num_sample, 28, 28的张量
    data = torch.stack(data)
    data = data.squeeze(1)
    return data
def calculate_emd(labels1, labels2, labels_cnt=10):
    # 计算labels1的频率向量
    cnt1 = torch.zeros(labels_cnt)
    for label in labels1:
        cnt1[label] += 1
    cnt1 = cnt1 / len(labels1)
    # 计算labels2的频率向量
    cnt2 = torch.zeros(labels_cnt)
    for label in labels2:
        cnt2[label] += 1
    cnt2 = cnt2 / len(labels2)
    print(f"cnt1 = {cnt1}", f"cnt2 = {cnt2}")
    # 计算EMD
    emd = (cnt1-cnt2).abs().sum()
    return emd
label_cnt = 10
total_samples = 1000
num_samples = total_samples // label_cnt
delta = 0.1 # avg EMD = 2*delta
more = delta * num_samples
images_a = []
labels_a = []
images_b = []
labels_b = []
images_all = []
labels_all = []
for label in range(label_cnt): 
    sampled_data = sample_data(label_dict, label, num_samples)
    # [0,5] more 
    if label < label_cnt // 2:
        images_a.append(sampled_data[:num_samples//2+int(more)])
        images_b.append(sampled_data[num_samples//2-int(more):])
        labels_a.append(torch.full((num_samples//2+int(more),), label))
        labels_b.append(torch.full((num_samples//2-int(more),), label))
    else :
        images_a.append(sampled_data[:num_samples//2-int(more)])
        images_b.append(sampled_data[num_samples//2+int(more):])
        labels_a.append(torch.full((num_samples//2-int(more),), label))
        labels_b.append(torch.full((num_samples//2+int(more),), label))
    images_all.append(sampled_data)
    labels_all.append(torch.full((num_samples,), label))
images_a = torch.cat(images_a, dim=0)
labels_a = torch.cat(labels_a, dim=0)
images_b = torch.cat(images_b, dim=0)
labels_b = torch.cat(labels_b, dim=0)
images_all = torch.cat(images_all, dim=0)
labels_all = torch.cat(labels_all, dim=0)
print(images_a.shape, labels_a.shape)
print(images_b.shape, labels_b.shape)
print(images_all.shape, labels_all.shape)
emd_a = calculate_emd(labels_a, labels_all)
emd_b = calculate_emd(labels_b, labels_all)
print("EMD between two a and all:", emd_a)
print("EMD between two b and all:", emd_b)
print("avg EMD:", (emd_a+emd_b)/2)

torch.Size([500, 28, 28]) torch.Size([500])
torch.Size([500, 28, 28]) torch.Size([500])
torch.Size([1000, 28, 28]) torch.Size([1000])
cnt1 = tensor([0.1200, 0.1200, 0.1200, 0.1200, 0.1200, 0.0800, 0.0800, 0.0800, 0.0800,
        0.0800]) cnt2 = tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
        0.1000])
cnt1 = tensor([0.0800, 0.0800, 0.0800, 0.0800, 0.0800, 0.1200, 0.1200, 0.1200, 0.1200,
        0.1200]) cnt2 = tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
        0.1000])
EMD between two a and all: tensor(0.2000)
EMD between two b and all: tensor(0.2000)
avg EMD: tensor(0.2000)
