In [7]:
import cv2
import numpy as np
import os
import math
from tqdm import tqdm
import pickle

## Define functions

计算一个图片的透视畸变的强度

In [8]:
# 计算一个图片的透视畸变的强度
def image_perspective_distortion(image_path):
    # 读取图像
    image = cv2.imread(image_path)
    
    # 将图像转换为灰度
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # 使用Canny边缘检测
    edges = cv2.Canny(gray, 50, 150, apertureSize=3)

    # 使用霍夫变换检测直线
    lines = cv2.HoughLines(edges, 1, np.pi / 180, threshold=100)
    
    if lines is None:
        return 0
    # 计算直线的斜率和截距
    slopes = []
    intercepts = []
    for line in lines:
        rho, theta = line[0]
        a = np.cos(theta)
        b = np.sin(theta)
        x0 = a * rho
        y0 = b * rho
        x1 = int(x0 + 1000 * (-b))
        y1 = int(y0 + 1000 * (a))
        x2 = int(x0 - 1000 * (-b))
        y2 = int(y0 - 1000 * (a))

        # 避免竖直线导致的除零错误
        if x1 != x2:
            slope = (y2 - y1) / (x2 - x1)
            intercept = y1 - slope * x1
            slopes.append(slope)
            intercepts.append(intercept)

    # 计算消失点的平均位置
    if slopes:
        avg_slope = np.mean(slopes)
        avg_intercept = np.mean(intercepts)
        vanishing_point_x = int(image.shape[1] / 2)
        vanishing_point_y = int(avg_slope * vanishing_point_x + avg_intercept)

        # 计算透视畸变的强度（消失点到图像中心的距离）
        distortion_strength = np.linalg.norm([vanishing_point_x - image.shape[1] / 2, vanishing_point_y - image.shape[0] / 2])

        return distortion_strength
    else:
        return 0

# 使用示例
image_path = 'data/mini-imagenet/images/dugong/n0207436700000408.jpg'
print("透视扭曲程度: "+str(image_perspective_distortion(image_path)))

透视扭曲程度: 71.0


计算数据集的valueList（并保存）， 加载给定路径的valueList

In [9]:
def get_valueList(path):
    path_class = os.listdir(path)
    valueList = []
    for i in tqdm(path_class):
        class_i = os.path.join(path, i)
        
        # 遍历类别下的每张图片
        distortion_strength = 0
        for image_name in os.listdir(class_i):
            image_path = os.path.join(class_i, image_name)
            distortion_strength += image_perspective_distortion(image_path)
            
        
        mean_distortion_strength = distortion_strength / len(os.listdir(class_i))    
        valueList.append(mean_distortion_strength)   
    
    # 创建一个文件夹叫valueList，如果它不存在的话
    valueList_folder = 'valueList_for_perpectiveDistortion'
    os.makedirs(valueList_folder, exist_ok=True)
    
    # 将结果保存
    start_index = path.find('data/') + len('data/')
    end_index = path.find('/', start_index)
    result_string = path[start_index:end_index]  
    valueList_file_path = os.path.join(valueList_folder, f'{result_string}_valueList.pkl')
    with open(valueList_file_path, 'wb') as f:
        pickle.dump(valueList, f)
    
    return valueList


def load_valueList(path):
    # 加载 valueList
    valueList_folder = 'valueList_for_perpectiveDistortion'
    start_index = path.find('data/') + len('data/')
    end_index = path.find('/', start_index)
    result_string = path[start_index:end_index]  
    valueList_file_path = os.path.join(valueList_folder, f'{result_string}_valueList.pkl')
    
    with open(valueList_file_path, 'rb') as f:
        loaded_valueList = pickle.load(f)    
    return loaded_valueList

给定两个数据集的valueList, 计算两个数据集之间的 perspective_distortion 关系 (由heuristic决定计算方式)

In [31]:
# 计算两个数据集之间的 perspective_distortion 关系
def dataset_perspective_distortion(source_valueList,target_valueList,alpha,heuristic):
     # 创建 权重矩阵
    weight_matrix = np.ones((len(source_valueList),len(target_valueList)))  
    weight_matrix = weight_matrix/np.sum(weight_matrix) 
    
    return heuristic(source_valueList,target_valueList,weight_matrix,alpha = alpha)
    
# 越大说明比源域的 透视畸变 更大
def fun1(list1,list2,weight_matrix,alpha=0.01):
    num = 0
    for i in range(len(list1)):
        for j in range(len(list2)):
            
            num += weight_matrix[i][j] * (list1[i]-list2[j])

    # return math.e ** (-1*alpha*tmp)    
    return num/(len(list1)*len(list2))

# 越大说明比源域的 透视畸变 更小
def fun2(list1,list2,weight_matrix,alpha=0.01):
    num = 0
    for i in range(len(list1)):
        for j in range(len(list2)):
            
            num += weight_matrix[i][j] * (list1[i]-list2[j])

    return math.e ** (-1*alpha*num)    

## 越大说明与源域的 透视畸变 越相似
def fun3(list1,list2,weight_matrix,alpha=0.01):
    num = 0
    for i in range(len(list1)):
        for j in range(len(list2)):
            
            num += weight_matrix[i][j] * abs(list1[i]-list2[j])

    return math.e ** (-1*alpha*num)   


## 不考虑 list1（即源域）， 越大说明透视畸变 越小
def fun4(list1,list2,weight_matrix,alpha=0.01):
    num = 0
    
    weight_matrix = np.ones((1,len(list2)))
    weight_matrix = weight_matrix/np.sum(weight_matrix)
    
    for j in range(len(list2)):
        
        num += weight_matrix[0][j] * list2[j]

    return math.e ** (-1*alpha*num)  

## Compute

In [11]:
# CUB
CUB_path = 'data/CUB_200_2011/images/'

# Stanford Cars
Cars_path = 'data/stanford_cars/images/'

# EuroSAT
EuroSAT_path = 'data/EuroSAT_RGB/'

# Plant disease
Plant_path = 'data/Plant_disease/images'

# mini-imagenet
mini_path = 'data/mini-imagenet/images/'

In [12]:
# mini_valueList = get_valueList(mini_path)

# CUB_valueList = get_valueList(CUB_path)
# Cars_valueList = get_valueList(Cars_path)
# EuroSAT_valueList = get_valueList(EuroSAT_path)
# Plant_valueList = get_valueList(Plant_path)

mini_valueList = load_valueList(mini_path)
CUB_valueList = load_valueList(CUB_path)
Cars_valueList = load_valueList(Cars_path)
EuroSAT_valueList = load_valueList(EuroSAT_path)
Plant_valueList = load_valueList(Plant_path)

100%|██████████| 100/100 [1:31:02<00:00, 54.63s/it]
100%|██████████| 200/200 [10:26<00:00,  3.13s/it]
100%|██████████| 196/196 [32:55<00:00, 10.08s/it]
100%|██████████| 10/10 [00:36<00:00,  3.64s/it]
100%|██████████| 8/8 [00:28<00:00,  3.52s/it]


In [33]:
print(dataset_perspective_distortion(mini_valueList,CUB_valueList,0.01,fun1))
print(dataset_perspective_distortion(mini_valueList,Cars_valueList,0.01,fun1))
print(dataset_perspective_distortion(mini_valueList,EuroSAT_valueList,0.01,fun1))
print(dataset_perspective_distortion(mini_valueList,Plant_valueList,0.01,fun1))

0.0013065639645103288
0.002145576998879036
0.09880474463064076
0.05879522719673957


In [34]:
print(dataset_perspective_distortion(mini_valueList,CUB_valueList,0.01,fun2))
print(dataset_perspective_distortion(mini_valueList,Cars_valueList,0.01,fun2))
print(dataset_perspective_distortion(mini_valueList,EuroSAT_valueList,0.01,fun2))
print(dataset_perspective_distortion(mini_valueList,Plant_valueList,0.01,fun2))

0.7700400188888097
0.6566966469015908
0.3723029232355211
0.6247761723834998


In [35]:
print(dataset_perspective_distortion(mini_valueList,CUB_valueList,0.01,fun3))
print(dataset_perspective_distortion(mini_valueList,Cars_valueList,0.01,fun3))
print(dataset_perspective_distortion(mini_valueList,EuroSAT_valueList,0.01,fun3))
print(dataset_perspective_distortion(mini_valueList,Plant_valueList,0.01,fun3))

0.5892485969376481
0.5990748203065243
0.3723029232355211
0.539139459396691


In [36]:
print(dataset_perspective_distortion(mini_valueList,CUB_valueList,0.01,fun4))
print(dataset_perspective_distortion(mini_valueList,Cars_valueList,0.01,fun4))
print(dataset_perspective_distortion(mini_valueList,EuroSAT_valueList,0.01,fun4))
print(dataset_perspective_distortion(mini_valueList,Plant_valueList,0.01,fun4))

0.483484581761772
0.5669322026065907
0.99999880000072
0.5958973676156077
