In [7]:
import torch
import torchvision.models as models
import numpy as np
import cv2
import os
import re
from PIL import Image
import matplotlib.pyplot as plt

## 以下是适配MSTAR数据集使用的格式转化函数，有需要的可以自行下载MSTAR数据集进行使用

In [10]:
# -----------------------------
# 解析 MSTAR ASCII header
# -----------------------------


def parse_mstar_header(file):
    header = {}
    for line in file:
        line = line.decode('utf-8', errors='ignore').strip()
        if not line:
            continue
        if 'EndofPhoenixHeader' in line:
            break
        if '=' in line:
            key, value = line.split('=', 1)
            header[key.strip()] = value.strip()
    return header

# -----------------------------
# 读取 MSTAR 文件
# -----------------------------


def read_mstar_file(path, use_phase=use_phase):
    """
    返回原始尺寸数据，不裁剪
    """
    with open(path, 'rb') as f:
        header = parse_mstar_header(f)
        data = np.fromfile(f, dtype='>f4')

    h = int(header['NumberOfRows'])
    w = int(header['NumberOfColumns'])

    data = data.reshape(-1, h, w)         # 通道, 高, 宽
    data = data.transpose(1, 2, 0)        # 高, 宽, 通道
    data = data.astype(np.float32)

    if not use_phase:
        data = np.expand_dims(data[:, :, 0], axis=2)  # 只保留幅度

    return data, header

# -----------------------------
# SAR -> 高分辨率 RGB PNG
# -----------------------------


def sar_to_rgb_png(img, upscale=upscale, clip_percent=(1, 99)):
    """
    将 SAR 图像转换为高分辨率灰度 RGB PNG
    """
    if img.ndim == 3 and img.shape[2] == 1:
        img = img[:, :, 0]

    # 幅度平方 + log
    img_amp = np.abs(img)**2
    img_log = np.log1p(img_amp)

    # percentile 裁剪，增强对比
    vmin, vmax = np.percentile(img_log, clip_percent)
    img_clip = np.clip(img_log, vmin, vmax)

    # 高斯平滑，减少散斑噪声
    img_smooth = cv2.GaussianBlur(img_clip, (3, 3), 0)

    # 归一化 0~255
    img_uint8 = ((img_smooth - img_smooth.min()) / (img_smooth.max()-img_smooth.min()) * 255).astype(np.uint8)

    # 上采样
    H, W = img_uint8.shape
    img_large = cv2.resize(img_uint8, (W*upscale, H*upscale), interpolation=cv2.INTER_CUBIC)

    # 灰度复制为 RGB 三通道
    img_rgb = cv2.merge([img_large, img_large, img_large])

    return img_rgb

## 以下是实现将MSTAR数据集中的.000文件转换为PNG格式的代码

In [15]:
# -----------------------------
# 路径设置
# -----------------------------
input_root = '../data/eoc-2-vv'
output_png_root = '../data/MSTAR_eoc_2_vv'
use_phase = False
upscale = 1       # 输出放大倍数 (原始大小 * 4)

# output_npy_root = '../data/MSTAR_npy'
# os.makedirs(output_npy_root, exist_ok=True)
os.makedirs(output_png_root, exist_ok=True)

# -----------------------------
# 批量处理
# -----------------------------
for root, dirs, files in os.walk(input_root):
    for file in files:
        if re.search(r'\.\d+', file) and not file.endswith('.png'):
            src_path = os.path.join(root, file)
            rel_path = os.path.relpath(root, input_root)

            # npy 保存路径
            # npy_dir = os.path.join(output_npy_root, rel_path)
            # os.makedirs(npy_dir, exist_ok=True)
            # npy_name = file.replace('.', '_') + '.npy'
            # npy_path = os.path.join(npy_dir, npy_name)

            # png 保存路径
            png_dir = os.path.join(output_png_root, rel_path)
            os.makedirs(png_dir, exist_ok=True)
            png_name = file.replace('.', '_') + '.png'
            png_path = os.path.join(png_dir, png_name)

            try:
                # 读取原始尺寸
                img_data, header = read_mstar_file(src_path)

                # 保存 npy
                # np.save(npy_path, img_data)

                # 保存 RGB PNG
                img_rgb = sar_to_rgb_png(img_data, upscale=upscale)
                cv2.imwrite(png_path, img_rgb)

            except Exception as e:
                print(f"❌ 文件 {file} 处理失败: {e}")

# 以下是下载一些我们需要的预训练模型，有需要取消注释即可

In [None]:
# torch.hub.set_dir('/home/suxin/SCKansformer-main/model_pth')
# model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)
# model = models.efficientnet_v2_m(
#     weights=models.EfficientNet_V2_M_Weights.IMAGENET1K_V1
# )
# model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)

# model = models.resnet50(
#     weights=models.ResNet50_Weights.IMAGENET1K_V1
# )
# model = models.resnet101(
#     weights=models.ResNet101_Weights.IMAGENET1K_V1
# )