In [1]:
import json
import os
import re
import torch
import pickle
import numpy as np

import torch.distributed as dist
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms

from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None

In [2]:
# 原始标注文件
ann_path = "/root/autodl-tmp/mimic_cxr/mimic_annotation_promptmrg.json"
# 原始图像目录
images_dir = "/root/autodl-tmp/mimic_cxr/"

In [3]:
# 文本清理
def clean_report_mimic_cxr(report):
    report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
            .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace('  ', ' ') \
            .replace('  ', ' ').replace('  ', ' ').replace('  ', ' ').replace('  ', ' ').replace('  ', ' ') \
            .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
            .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
            .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
            .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
            .strip().lower().split('. ')
    sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').replace('\\', '').replace("'", '').strip().lower())
    tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
    report = ' . '.join(tokens) + ' .'
    return report

def my_pre_caption(caption, max_words=100):
    caption = clean_report_mimic_cxr(caption)
    #truncate caption
    caption_words = caption.split(' ')
    if len(caption_words)>max_words:
        caption = ' '.join(caption_words[:max_words])
    return caption

In [4]:
# 数字标签与对应的文本描述
SCORES = [
'[BLA]',
'[POS]',
'[NEG]',
'[UNC]'
]

class generation_train(Dataset):
    def __init__(self, transform, image_root, ann_root, max_words=100):
        self.annotation = json.load(open(os.path.join(ann_root),'r'))
        self.transform = transform
        self.image_root = image_root
        self.max_words = max_words
        self.all_ann = self.annotation['train']
        self.ann = []  # 子集标注样本
        
        for idx, ann in enumerate(self.all_ann):
            image_path = ann['image_path']
            full_path = os.path.join(self.image_root, image_path[0])
            if os.path.exists(full_path):
                # 图片存在，添加到ann中
                self.ann.append(ann)
            else:
                # 遇到第一个不存在的图片就停止
                break
                
        with open('/root/autodl-tmp/mimic_cxr/clip_text_features.json', 'r') as f:
            self.clip_features = np.array(json.load(f))
        
    def __len__(self):
        return len(self.ann)
    
    def __getitem__(self, index):    
        
        ann = self.ann[index]
        image_path = ann['image_path']
        image = Image.open(os.path.join(self.image_root, image_path[0])).convert('RGB')
        image = self.transform(image)
        
        cls_labels = ann['labels']
        prompt = [SCORES[l] for l in cls_labels]
        prompt = ' '.join(prompt)+' '
        caption = prompt + my_pre_caption(ann['report'], self.max_words)
        cls_labels = torch.from_numpy(np.array(cls_labels)).long()
        clip_indices = ann['clip_indices'][:21]
        clip_memory = self.clip_features[clip_indices]
        clip_memory = torch.from_numpy(clip_memory).float()

        return image, caption, cls_labels, clip_memory
    
class generation_eval(Dataset):
    def __init__(self, transform, image_root, ann_root, max_words=100, split='val'):
        self.annotation = json.load(open(os.path.join(ann_root), 'r'))
        self.transform = transform
        self.image_root = image_root
        self.max_words = max_words
        self.all_ann = self.annotation[split]
        self.ann = []  # 子集标注样本
        
        for idx, ann in enumerate(self.all_ann):
            image_path = ann['image_path']
            full_path = os.path.join(self.image_root, image_path[0])
            if os.path.exists(full_path):
                # 图片存在，添加到ann中
                self.ann.append(ann)
            else:
                # 遇到第一个不存在的图片就停止
                break
            
        
        with open('/root/autodl-tmp/mimic_cxr/clip_text_features.json', 'r') as f:
            self.clip_features = np.array(json.load(f))
        
    def __len__(self):
        return len(self.ann)
    
    def __getitem__(self, index):    
        
        ann = self.ann[index]
        image_path = ann['image_path']
        image = Image.open(os.path.join(self.image_root, image_path[0])).convert('RGB')
        image = self.transform(image)

        caption = my_pre_caption(ann['report'], self.max_words)
        cls_labels = ann['labels']
        cls_labels = torch.from_numpy(np.array(cls_labels))
        clip_indices = ann['clip_indices'][:21]
        clip_memory = self.clip_features[clip_indices]
        clip_memory = torch.from_numpy(clip_memory).float()

        return image, caption, cls_labels, clip_memory

In [5]:
# 归一化并生成训练/验证/测试集
def create_dataset(image_dir, ann_path):
    transform_train = transforms.Compose([
        transforms.Resize((336, 336)),
        transforms.RandomRotation(degrees=5),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))])
    transform_test = transforms.Compose([
        transforms.Resize((336, 336)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))])

    train_dataset = generation_train(transform_train, image_dir, ann_path)
    val_dataset = generation_eval(transform_test, image_dir, ann_path, split='val')
    test_dataset = generation_eval(transform_test, image_dir, ann_path, split='test')
    return train_dataset, val_dataset, test_dataset

In [6]:
train_dataset, val_dataset, test_dataset = create_dataset(images_dir, ann_path)

In [7]:
# 保存
def save_datasets(train_dataset, val_dataset, test_dataset, save_path="/root/autodl-tmp/mimic_cxr/datasets.pkl"):
    datasets = {
        'train': train_dataset,
        'val': val_dataset,
        'test': test_dataset,
        'info': {
            'train_size': len(train_dataset),
            'val_size': len(val_dataset),
            'test_size': len(test_dataset),
        }
    }
    
    with open(save_path, 'wb') as f:
        pickle.dump(datasets, f)
    
    print(f"数据集已保存到: {save_path}")

# 读取
def load_datasets(save_path="/root/autodl-tmp/mimic_cxr/datasets.pkl"):
    with open(save_path, 'rb') as f:
        datasets = pickle.load(f)
    
    train_dataset = datasets['train']
    val_dataset = datasets['val']
    test_dataset = datasets['test']
    
    print("数据集加载成功!")
    return train_dataset, val_dataset, test_dataset

In [8]:
save_datasets(train_dataset, val_dataset, test_dataset)

数据集已保存到: /root/autodl-tmp/mimic_cxr/datasets.pkl


In [12]:
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
print(f"验证集大小: {len(val_dataset)}")

训练集大小: 26488
测试集大小: 255
验证集大小: 185


In [14]:
# 获取100条样本并保存到文件
with open('samples.txt', 'w', encoding='utf-8') as f:
    for i in range(100):
        image, caption, cls_labels, clip_memory = train_dataset[i]
            
        f.write(f"样本 {i}:\n")
        f.write(f"  图像形状: {image.shape}\n")
        f.write(f"  文本: {caption}\n")
        f.write(f"  分类标签: {cls_labels}\n")
        f.write(f"  CLIP记忆形状: {clip_memory.shape}\n")
        f.write("-" * 50 + "\n")  # 分隔线
            
        if i % 10 == 0:
            print(f"已处理 {i} 条样本...")
print("样本已保存到 samples.txt")

已处理 0 条样本...
已处理 10 条样本...
已处理 20 条样本...
已处理 30 条样本...
已处理 40 条样本...
已处理 50 条样本...
已处理 60 条样本...
已处理 70 条样本...
已处理 80 条样本...
已处理 90 条样本...
样本已保存到 samples.txt


In [10]:
t1,t2,t3 = load_datasets()

数据集加载成功!


In [15]:
image, caption, cls_labels, clip_memory = t2[0]
print(f"样本 {0}:")
print(f"  图像形状: {image.shape}")
print(f"  文本: {caption}")
print(f"  分类标签: {cls_labels}")
print(f"  CLIP记忆形状: {clip_memory.shape}")

样本 0:
  图像形状: torch.Size([3, 336, 336])
  文本: no evidence of consolidation to suggest pneumonia is seen . there is some retrocardiac atelectasis . a small left pleural effusion may be present . no pneumothorax is seen . no pulmonary edema . a right granuloma is unchanged . the heart is mildly enlarged unchanged . there is tortuosity of the aorta .
  分类标签: tensor([0, 1, 0, 0, 2, 2, 2, 1, 2, 1, 0, 0, 0, 0])
  CLIP记忆形状: torch.Size([21, 512])
