In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import torch
import timm
from PIL import Image
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import cv2
import os
import random
from matplotlib.gridspec import GridSpec
from concurrent.futures import ThreadPoolExecutor
import pickle

plt.rcParams["font.family"] = ["Arial", "sans-serif"]

数据加载

In [None]:
def get_device():
    """获取可用的计算设备"""
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def batch_extract_cnn_features(image_paths, model, transform, device=None, batch_size=16):
    """批量提取图像的CNN特征"""
    device = device or get_device()
    all_features = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(image_paths), batch_size), desc="Feature Extraction"):
            batch = []
            for path in image_paths[i:i+batch_size]:
                try:
                    batch.append(transform(Image.open(path).convert("RGB")))
                except:
                    batch.append(None)
            
            valid_indices = [j for j, img in enumerate(batch) if img is not None]
            if not valid_indices:
                continue
                
            valid_batch = torch.stack([batch[j] for j in valid_indices]).to(device)
            features = model(valid_batch)
            
            batch_features = []
            current_idx = 0
            for j in range(len(batch)):
                batch_features.append([f[current_idx].cpu() for f in features] if j in valid_indices else None)
                current_idx += 1 if j in valid_indices else 0
            
            all_features.extend(batch_features)
    
    return all_features

def cosine_cnn_similarity(feat1, feat2, dim=1):
    """计算两个特征之间的余弦相似度"""
    if feat1 is None or feat2 is None:
        return None
        
    if len(feat1.shape) > 2:
        feat1 = feat1.flatten(1)
    if len(feat2.shape) > 2:
        feat2 = feat2.flatten(1)
    
    feat1 = F.normalize(feat1, p=2, dim=dim)
    feat2 = F.normalize(feat2, p=2, dim=dim)
    
    return torch.sum(feat1 * feat2, dim=dim).mean().item()

数据处理

In [None]:
def segment_human(image_path, sam_model, device=None, image_size=512, output_folder="segmented_images", stability_type=None):
    """使用SAM模型分割人体并保存到指定文件夹，保留原文件名"""
    device = device or get_device()
    if sam_model is None:
        return image_path
    
    try:
        image = cv2.imread(image_path)
        h, w = image.shape[:2]
        if max(h, w) > image_size:
            scale = image_size / max(h, w)
            image = cv2.resize(image, (int(w * scale), int(h * scale)))
            
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        predictor = SamPredictor(sam_model)
        predictor.set_image(image)
        
        h, w = image.shape[:2]
        masks, scores, _ = predictor.predict(
            point_coords=np.array([[w//2, h//2]]),
            point_labels=np.array([1]),
            multimask_output=True,
        )
        
        best_mask = masks[np.argmax(scores)]
        masked_image = image.copy()
        masked_image[~best_mask] = 0
        
        save_dir = os.path.join(output_folder, stability_type) if stability_type else output_folder
        os.makedirs(save_dir, exist_ok=True)
        
        save_path = os.path.join(save_dir, os.path.basename(image_path))
        Image.fromarray(masked_image).save(save_path)
        return save_path
    except:
        return image_path

模型初始化

In [None]:
def extract_all_segments(pairs, sam_model, stability_type="stable", device=None, image_size=512, 
                        batch_size=4, cache_file=None, output_folder="segmented_images"):
    """提取所有图像的人体分割结果（支持批量处理、缓存和自定义保存路径）"""
    device = device or get_device()
    
    if cache_file and os.path.exists(cache_file):
        try:
            with open(cache_file, 'rb') as f:
                return pickle.load(f)
        except:
            pass
    
    all_paths = [path for pair in pairs for path in [pair[0]['image_path'], pair[1]['image_path']]]
    
    with ThreadPoolExecutor(max_workers=batch_size) as executor:
        futures = [executor.submit(segment_human, path, sam_model, device, image_size, output_folder, stability_type) 
                  for path in all_paths]
        seg_image_paths = [future.result() for future in tqdm(futures, desc=f"{stability_type.capitalize()} Segmentation")]
    
    if cache_file:
        try:
            os.makedirs(os.path.dirname(cache_file), exist_ok=True)
            with open(cache_file, 'wb') as f:
                pickle.dump(seg_image_paths, f)
        except:
            pass
    
    return seg_image_paths

图像分割

In [None]:
def sample_and_visualize(stable_seg_paths, non_stable_seg_paths, num_samples=16):
    """从两类图像中随机抽样并可视化"""
    if not stable_seg_paths or not non_stable_seg_paths:
        return
    
    stable_samples = random.sample(stable_seg_paths, min(num_samples, len(stable_seg_paths)))
    non_stable_samples = random.sample(non_stable_seg_paths, min(num_samples, len(non_stable_seg_paths)))
    
    fig = plt.figure(figsize=(16, 10))
    gs = GridSpec(4, 8, figure=fig)
    
    for i, path in enumerate(stable_samples):
        ax = fig.add_subplot(gs[i//8, i%8])
        try:
            ax.imshow(Image.open(path))
            ax.set_title('Stable Pairs')
        except:
            ax.text(0.5, 0.5, 'Failed to load', ha='center', va='center')
        ax.axis('off')
    
    for i, path in enumerate(non_stable_samples):
        ax = fig.add_subplot(gs[2 + (i//8), i%8])
        try:
            ax.imshow(Image.open(path))
            ax.set_title('Non-Stable Pairs')
        except:
            ax.text(0.5, 0.5, 'Failed to load', ha='center', va='center')
        ax.axis('off')
    
    plt.suptitle('Randomly Sampled Segmented Images', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig('sam_segmented_samples.png', dpi=300, bbox_inches='tight')
    plt.show()

特征提取

In [None]:
class FeatureExtractor:
    """多特征提取器，支持GLCM、ViT和ConvNeXt"""
    
    def __init__(self, method='convnext', model_name='convnext_base', device=device):
        """
        初始化特征提取器
        
        参数:
            method: 特征提取方法，可选 'glcm', 'vit', 'convnext'
            model_name: 深度学习模型名称
            device: 运行设备
        """
        self.method = method.lower()
        self.device = device
        self.model = None
        self.transform = None
        
        print(f"\n===== 初始化 {self.method.upper()} 特征提取器 =====")
        
        if self.method == 'glcm':
            print("使用GLCM (灰度共生矩阵) 特征提取")
            
        elif self.method in ['vit', 'convnext']:
            print(f"使用 {self.method.upper()} 特征提取")

            if self.method == 'vit':
                self.model = timm.create_model(
                    'vit_base_patch16_224', 
                    pretrained=True,
                    num_classes=0 
                )
            elif self.method == 'convnext':
                self.model = timm.create_model(
                    'convnext_base', 
                    pretrained=True,
                    num_classes=0
                )
            
            self.model.to(self.device)
            self.model.eval()

            self.transform = timm.data.create_transform(
                **timm.data.resolve_data_config(self.model.pretrained_cfg)
            )
            
            print(f"已加载预训练 {self.method.upper()} 模型")
            
        else:
            raise ValueError(f"不支持的特征提取方法: {method}。支持的方法有 'glcm', 'vit', 'convnext'")
    
    def extract_features(self, image_path):
        """提取图像特征"""
        try:
            if self.method == 'glcm':
                return self._extract_glcm_features(image_path)
            else:
                return self._extract_deep_features(image_path)
        except Exception as e:
            print(f"特征提取失败: {image_path}, 错误: {e}")
            return None
    
    def _extract_glcm_features(self, image_path):
        """提取GLCM特征"""
        img = Image.open(image_path).convert('L') 
        img = img.resize((224, 224))
        img_array = np.array(img)

        distances = [1, 2, 3]
        angles = [0, np.pi/4, np.pi/2, 3*np.pi/4]
        glcm = skfeature.greycomatrix(img_array, distances, angles, levels=256, symmetric=True, normed=True)

        contrast = skfeature.greycoprops(glcm, 'contrast').mean()
        dissimilarity = skfeature.greycoprops(glcm, 'dissimilarity').mean()
        homogeneity = skfeature.greycoprops(glcm, 'homogeneity').mean()
        energy = skfeature.greycoprops(glcm, 'energy').mean()
        correlation = skfeature.greycoprops(glcm, 'correlation').mean()
        asm = skfeature.greycoprops(glcm, 'ASM').mean()
        
        return np.array([contrast, dissimilarity, homogeneity, energy, correlation, asm])
    
    def _extract_deep_features(self, image_path):
        """提取深度学习特征"""
        img = Image.open(image_path).convert('RGB')
        img_tensor = self.transform(img).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            features = self.model(img_tensor)

            if self.method == 'vit':
                features = features[:, 0, :] 
            elif self.method == 'convnext':

                features = F.adaptive_avg_pool2d(features, (1, 1)).flatten(1)
        
        return features.cpu().numpy().flatten()

相似度计算

In [None]:
def compute_similarities(pairs, feature_method='convnext', model_name='convnext_base'):
    """计算图像对的相似度"""
    print(f"\n===== 模块6: 使用 {feature_method.upper()} 计算相似度 =====")
    
    if not pairs:
        print("无图像对，跳过相似度计算")
        return [], []
    
    extractor = FeatureExtractor(method=feature_method, model_name=model_name, device=device)

    similarities = []
    valid_pairs = []
    
    print(f"开始计算 {len(pairs)} 对图像的相似度...")
    
    for i, (wood_path, clinic_path) in enumerate(pairs):
        print(f"({i+1}/{len(pairs)}) 处理: {os.path.basename(wood_path)} 和 {os.path.basename(clinic_path)}")
        
        wood_features = extractor.extract_features(wood_path)
        clinic_features = extractor.extract_features(clinic_path)
        
        if wood_features is not None and clinic_features is not None:
            wood_tensor = torch.tensor(wood_features, dtype=torch.float32).unsqueeze(0)
            clinic_tensor = torch.tensor(clinic_features, dtype=torch.float32).unsqueeze(0)
            
            similarity = F.cosine_similarity(wood_tensor, clinic_tensor).item()
            similarities.append(similarity)
            valid_pairs.append((wood_path, clinic_path))
            
            print(f"  相似度: {similarity:.4f}")
        else:
            print(f"  无法提取特征，跳过此对")
    
    print(f"完成相似度计算: 有效对 {len(similarities)}/{len(pairs)}")
    
    if similarities:
        avg_similarity = np.mean(similarities)
        print(f"平均相似度: {avg_similarity:.4f}")
    
    return similarities, valid_pairs

结果可视化

In [None]:
def compute_pair_cnn_similarities(pairs, model_name='convnext_tiny', device=None, batch_size=16, image_size=224, cache_file=None):
    """计算多对图像的CNN特征相似度"""
    device = device or get_device()
    
    if cache_file and os.path.exists(cache_file):
        try:
            with open(cache_file, 'rb') as f:
                data = pickle.load(f)
                return data['similarities'], data['failed_pairs']
        except:
            pass
    
    wood_paths = [wood['image_path'] for wood, _ in pairs]
    clinic_paths = [clinic['image_path'] for _, clinic in pairs]
    all_paths = wood_paths + clinic_paths
    
    model = timm.create_model(model_name, pretrained=True, features_only=True, out_indices=[0, 1, 2, 3])
    model.to(device).eval()
    
    data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
    data_cfg['input_size'] = (3, image_size, image_size)
    transform = timm.data.create_transform(**data_cfg)
    
    all_features = batch_extract_cnn_features(all_paths, model, transform, device, batch_size)
    wood_features = all_features[:len(wood_paths)]
    clinic_features = all_features[len(wood_paths):]
    
    similarities = []
    failed_pairs = []
    
    for i, (wood_feat, clinic_feat) in enumerate(tqdm(zip(wood_features, clinic_features), total=len(pairs), desc="Similarity")):
        if wood_feat is None or clinic_feat is None:
            failed_pairs.append(i)
            similarities.append(None)
            continue
            
        layer_similarities = [cosine_cnn_similarity(w_feat, c_feat) for w_feat, c_feat in zip(wood_feat, clinic_feat) if w_feat is not None]
        similarities.append(np.mean(layer_similarities) if layer_similarities else None)
    
    valid_similarities = [s for s in similarities if s is not None]
    
    if cache_file:
        try:
            os.makedirs(os.path.dirname(cache_file), exist_ok=True)
            with open(cache_file, 'wb') as f:
                pickle.dump({'similarities': similarities, 'failed_pairs': failed_pairs}, f)
        except:
            pass
    
    return valid_similarities, failed_pairs

In [None]:

config = {
    'batch_size': 4,
    'image_size': 192,
    'cnn_model': 'convnext_tiny',
    'sam_image_size': 384,
    'use_cache': True,
    'cache_dir': 'cache',
    'seg_output_folder': 'segmented_images'
}

device = get_device()
if config['use_cache']:
    os.makedirs(config['cache_dir'], exist_ok=True)
os.makedirs(config['seg_output_folder'], exist_ok=True)

In [None]:

df = pd.read_csv('../../datasets/data.csv')

stable_df = df[df['stability'] == 'stable']
non_stable_df = df[df['stability'] == 'non-stable']

stable_pairs = []
for pair_id in stable_df['pair_id'].unique():
    pair_images = stable_df[stable_df['pair_id'] == pair_id]
    wood = pair_images[pair_images['image_type'] == 'wood'].iloc[0]
    clinic = pair_images[pair_images['image_type'] == 'clinic'].iloc[0]
    stable_pairs.append((wood, clinic))

non_stable_pairs = []
for pair_id in non_stable_df['pair_id'].unique():
    pair_images = non_stable_df[non_stable_df['pair_id'] == pair_id]
    wood = pair_images[pair_images['image_type'] == 'wood'].iloc[0]
    clinic = pair_images[pair_images['image_type'] == 'clinic'].iloc[0]
    non_stable_pairs.append((wood, clinic))

print(f"Found {len(stable_pairs)} stable pairs and {len(non_stable_pairs)} non-stable pairs")

In [None]:

from segment_anything import sam_model_registry, SamPredictor

try:
    sam_checkpoint = '../../outputs/checkpoints/proposed/sam_vit_b_01ec64.pth'

    if not os.path.exists(sam_checkpoint):
        sam_checkpoint = '../../outputs/checkpoints/proposed/sam_vit_h_4b8939(1).pth'
    
    model_type = "vit_b" if "vit_b" in sam_checkpoint else "vit_l" if "vit_l" in sam_checkpoint else "vit_h"
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
except:
    print("SAM model not found, using original images instead of segmentation")
    sam = None

In [None]:

non_stable_seg_cache = os.path.join(config['cache_dir'], 'non_stable_seg_paths.pkl') if config['use_cache'] else None

stable_seg_paths = extract_all_segments(
    stable_pairs, sam, "stable", device, 
    image_size=config['sam_image_size'], 
    batch_size=config['batch_size'],
    cache_file=stable_seg_cache,
    output_folder=config['seg_output_folder']
)

non_stable_seg_paths = extract_all_segments(
    non_stable_pairs, sam, "non-stable", device, 
    image_size=config['sam_image_size'], 
    batch_size=config['batch_size'],
    cache_file=non_stable_seg_cache,
    output_folder=config['seg_output_folder']
)


for i, (wood, clinic) in enumerate(stable_pairs):
    stable_pairs[i] = (
        {'image_path': stable_seg_paths[i*2], 'pair_id': wood['pair_id'], 'image_type': wood['image_type']},
        {'image_path': stable_seg_paths[i*2+1], 'pair_id': clinic['pair_id'], 'image_type': clinic['image_type']}
    )

for i, (wood, clinic) in enumerate(non_stable_pairs):
    non_stable_pairs[i] = (
        {'image_path': non_stable_seg_paths[i*2], 'pair_id': wood['pair_id'], 'image_type': wood['image_type']},
        {'image_path': non_stable_seg_paths[i*2+1], 'pair_id': clinic['pair_id'], 'image_type': clinic['image_type']}
    )

In [None]:

sample_and_visualize(stable_seg_paths, non_stable_seg_paths)

In [None]:

stable_sim_cache = os.path.join(config['cache_dir'], f'stable_similarities_{config["cnn_model"]}_{config["image_size"]}.pkl')
non_stable_sim_cache = os.path.join(config['cache_dir'], f'non_stable_similarities_{config["cnn_model"]}_{config["image_size"]}.pkl')

stable_similarities, _ = compute_pair_cnn_similarities(
    stable_pairs, 
    model_name=config['cnn_model'],
    device=device,
    batch_size=config['batch_size'],
    image_size=config['image_size'],
    cache_file=stable_sim_cache
)

non_stable_similarities, _ = compute_pair_cnn_similarities(
    non_stable_pairs, 
    model_name=config['cnn_model'],
    device=device,
    batch_size=config['batch_size'],
    image_size=config['image_size'],
    cache_file=non_stable_sim_cache
)

if stable_similarities and non_stable_similarities:
    plt.figure(figsize=(10, 6))
    plt.boxplot([stable_similarities, non_stable_similarities], 
               labels=['Stable Pairs', 'Non-Stable Pairs'], 
               patch_artist=True,
               showmeans=True)
    
    plt.title('Feature Similarities Comparison')
    plt.ylabel('Cosine Similarity')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.plot([1, 2], [np.mean(stable_similarities), np.mean(non_stable_similarities)], 'ro-', label='Mean')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('similarity_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()