In [66]:
import cv2
import numpy as np
import os
import math
from tqdm import tqdm
import matplotlib.pyplot as plt

## Define functions

计算一个图片的透视畸变的强度

In [67]:
# 计算一个图片的透视畸变的强度
def image_perspective_distortion(image_path):
    # 读取图像
    image = cv2.imread(image_path)
    
    # 将图像转换为灰度
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # 使用Canny边缘检测
    edges = cv2.Canny(gray, 50, 150, apertureSize=3)

    # 使用霍夫变换检测直线
    lines = cv2.HoughLines(edges, 1, np.pi / 180, threshold=100)
    
    if lines is None:
        return 0
    # 计算直线的斜率和截距
    slopes = []
    intercepts = []
    for line in lines:
        rho, theta = line[0]
        a = np.cos(theta)
        b = np.sin(theta)
        x0 = a * rho
        y0 = b * rho
        x1 = int(x0 + 1000 * (-b))
        y1 = int(y0 + 1000 * (a))
        x2 = int(x0 - 1000 * (-b))
        y2 = int(y0 - 1000 * (a))

        # 避免竖直线导致的除零错误
        if x1 != x2:
            slope = (y2 - y1) / (x2 - x1)
            intercept = y1 - slope * x1
            slopes.append(slope)
            intercepts.append(intercept)

    # 计算消失点的平均位置
    if slopes:
        avg_slope = np.mean(slopes)
        avg_intercept = np.mean(intercepts)
        vanishing_point_x = int(image.shape[1] / 2)
        vanishing_point_y = int(avg_slope * vanishing_point_x + avg_intercept)

        # 计算透视畸变的强度（消失点到图像中心的距离）
        distortion_strength = np.linalg.norm([vanishing_point_x - image.shape[1] / 2, vanishing_point_y - image.shape[0] / 2])

        return distortion_strength
    else:
        return 0

# 使用示例
image_path = 'data/mini-imagenet/images/dugong/n0207436700000409.jpg'
distortion_strength = image_perspective_distortion(image_path)

计算两个数据集之间的 perspective_distortion 关系

In [68]:
# 计算两个数据集之间的 perspective_distortion 关系
def dataset_perspective_distortion(source_path,target_path,alpha,heuristic):
    # 分别得到数据集的标签列表
    source_class = os.listdir(source_path)
    target_class = os.listdir(target_path)
        

    # # 遍历第 source 数据集的标签
    # source_valuelist = []
    # for i in tqdm(source_class):
    #     class_i = os.path.join(source_path, i)
        
    #     # 遍历类别下的每张图片
    #     distortion_strength = 0
    #     for image_name in os.listdir(class_i):
    #         image_path = os.path.join(class_i, image_name)
    #         distortion_strength += image_perspective_distortion(image_path)
            
        
    #     mean_distortion_strength = distortion_strength / len(os.listdir(class_i))    
    #     source_valuelist.append(mean_distortion_strength)   
            
    # 遍历第 target 数据集的标签
    target_valuelist = []
    for j in tqdm(target_class):
        class_j = os.path.join(target_path, j)
        
        # 遍历类别下的每张图片
        distortion_strength = 0
        for image_name in os.listdir(class_j):
            image_path = os.path.join(class_j, image_name)
            distortion_strength += image_perspective_distortion(image_path)
            
        
        mean_distortion_strength = distortion_strength / len(os.listdir(class_j))    
        target_valuelist.append(mean_distortion_strength)        
            
    
    # 创建 权重矩阵
    weight_matrix = np.ones((len(source_valuelist),len(target_valuelist)))  
    weight_matrix = weight_matrix/np.sum(weight_matrix)
        
    # 计算
    tmp = 0
    for i in range(len(source_valuelist)):
        for j in range(len(target_valuelist)):
            
            tmp += heuristic(source_valuelist,target_valuelist)

    return math.e ** (-1*alpha*tmp)    

In [69]:
def fun1(list1,list2):
    return 1

## Compute

In [70]:
# CUB
CUB_path = 'data/CUB_200_2011/images/'

# Stanford Cars
Cars_path = 'data/stanford_cars/images/'

# EuroSAT
EuroSAT_path = 'data/EuroSAT_RGB/'

# Plant disease
Plant_path = 'data/Plant_disease/images'

# mini-imagenet
mini_path = 'data/mini-imagenet/images/'

In [72]:
dataset_perspective_distortion(mini_path,CUB_path,alpha=0.01,heuristic=fun1)

 12%|█▏        | 24/200 [01:18<09:38,  3.29s/it]


KeyboardInterrupt: 