In [1]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.models import vit_b_16, ViT_B_16_Weights
import torchvision.transforms as transforms
from PIL import Image
from tqdm import tqdm
import pickle
from sklearn.preprocessing import MinMaxScaler
import cv2

# 创建保存特征的目录
os.makedirs('features/images', exist_ok=True)

# 读取元数据CSV
metadata_df = pd.read_csv('data/metadata.csv')

# 定义高光谱数据集
class HyperspectralDataset(Dataset):
    def __init__(self, metadata_df):
        self.metadata_df = metadata_df
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224, 224), antialias=True),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __len__(self):
        return len(self.metadata_df)
    
    def __getitem__(self, idx):
        row = self.metadata_df.iloc[idx]
        float_path = row['float_data路径']
        float_path = os.path.join('data', float_path)
        
        # 读取float文件
        float_data = np.fromfile(float_path, dtype=np.float32)
        
        # 跳过header offset (32768字节 = 8192个float32值)
        offset_values = 32768 // 4  # float32是4字节
        float_data = float_data[offset_values:]
        
        # 重建为(520, 696, 128)的三维图像
        expected_size = 520 * 696 * 128
        if len(float_data) > expected_size:
            float_data = float_data[:expected_size]
        
        hyperspectral_image = float_data.reshape(520, 696, 128)
        
        # Resize到224x224x128，保持波段数不变
        resized_image = np.zeros((224, 224, 128), dtype=np.float32)
        for band in range(128):
            # 使用OpenCV进行单通道resize
            band_image = hyperspectral_image[:, :, band]
            # 归一化到0-1范围，避免resize后像素值异常
            scaler = MinMaxScaler()
            band_image_normalized = scaler.fit_transform(band_image.reshape(-1, 1)).reshape(520, 696)
            resized_band = cv2.resize(band_image_normalized, (224, 224))
            resized_image[:, :, band] = resized_band
        
        return {
            'image': resized_image,
            'idx': idx,
            'path': float_path
        }

# 加载预训练的ViT模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

weights = ViT_B_16_Weights.DEFAULT
model = vit_b_16(weights=weights)
# 移除分类头，只保留特征提取部分
model = torch.nn.Sequential(*list(model.children())[:-1])
model = model.to(device)
model.eval()

# 创建数据加载器
dataset = HyperspectralDataset(metadata_df)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)


Using device: cuda


In [None]:
# 添加一个新列用于存储特征文件路径
if 'feature_path' not in metadata_df.columns:
    metadata_df['feature_path'] = None

# 获取ViT特征提取器
class ViTFeatureExtractor:
    def __init__(self, model):
        self.model = model
        # 定义用于特征提取的转换
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224, 224), antialias=True),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def extract_features(self, image, device):
        # 转换图像并将其移至设备
        x = self.transform(image).unsqueeze(0).to(device)
        # 使用ViT的中间层提取特征
        with torch.no_grad():
            # 经过patch embedding层
            x = self.model.conv_proj(x)
            # 重塑为序列
            x = x.flatten(2).transpose(1, 2)
            # 添加class token
            cls_token = self.model.class_token.expand(x.shape[0], -1, -1)
            x = torch.cat((cls_token, x), dim=1)
            # 经过position embedding
            x = x + self.model.encoder.pos_embedding
            # 经过encoder
            x = self.model.encoder.dropout(x)
            x = self.model.encoder.layers(x)
            x = self.model.encoder.ln(x)
            # 使用class token作为特征表示
            return x[:, 0].cpu().numpy()

# 初始化特征提取器
feature_extractor = ViTFeatureExtractor(model=vit_b_16(weights=ViT_B_16_Weights.DEFAULT).to(device))

root = 'data'

# 遍历所有样本并提取特征
for data in tqdm(dataloader, desc="提取特征"):
    idx = data['idx'].item()
    float_path = data['path'][0]
    hyperspectral_image = data['image'][0]  # shape: (224, 224, 128)
    
    # 为每个波段提取特征
    features = []
    for band_idx in range(128):
        band_image = hyperspectral_image[:, :, band_idx]
        
        # 将单通道图像转为3通道以符合ViT输入要求
        rgb_image = np.stack([band_image] * 3, axis=2)
        pil_image = Image.fromarray((rgb_image * 255).astype(np.uint8))
        
        # 提取特征
        feature = feature_extractor.extract_features(pil_image, device)
        features.append(feature)
    
    # 将特征转换为numpy数组
    features = np.array(features)  # shape: (128, 768) - 每个波段一个768维特征向量
    
    # 构建特征文件路径
    base_name = os.path.basename(float_path).split('.')[0]
    feature_path = f"features/images/{base_name}_features.pkl"
    full_feature_path = os.path.join(root, feature_path)
    
    # 确保目录存在
    os.makedirs(os.path.dirname(full_feature_path), exist_ok=True)
    
    # 保存特征
    with open(full_feature_path, 'wb') as f:
        pickle.dump(features, f)
    
    # 更新CSV
    metadata_df.at[idx, 'feature_path'] = feature_path

# 保存更新后的CSV
metadata_df.to_csv(os.path.join(root, 'metadata.csv'), index=False)
print("特征提取完成，更新后的元数据已保存到 data/metadata.csv")

提取特征: 100%|██████████| 182/182 [26:12<00:00,  8.64s/it]

特征提取完成，更新后的元数据已保存到 data/metadata_with_features.csv



