In [3]:
import json
import re
from collections import Counter
import torch
import numpy as np
import json
import re
from collections import Counter
import argparse

class Tokenizer(object):
    def __init__(self, args):
        self.ann_path = args.ann_path
        self.threshold = args.threshold
        self.dataset_name = args.dataset_name
        if self.dataset_name == 'iu_xray':
            self.clean_report = self.clean_report_iu_xray
        else:
            self.clean_report = self.clean_report_mimic_cxr
        self.ann = json.loads(open(self.ann_path, 'r').read())
        self.token2idx, self.idx2token = self.create_vocabulary()

    def create_vocabulary(self):
        total_tokens = []

        for example in self.ann['train']:
            tokens = self.clean_report(example['report']).split()
            for token in tokens:
                total_tokens.append(token)

        counter = Counter(total_tokens)
        vocab = [k for k, v in counter.items() if v >= self.threshold] + ['<unk>']
        vocab.sort()
        token2idx, idx2token = {}, {}
        for idx, token in enumerate(vocab):
            token2idx[token] = idx + 1
            idx2token[idx + 1] = token
        return token2idx, idx2token

    def clean_report_iu_xray(self, report):
        report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \
            .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
            .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
            .strip().lower().split('. ')
        sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').
                                        replace('\\', '').replace("'", '').strip().lower())
        tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
        report = ' . '.join(tokens) + ' .'
        return report

    def clean_report_mimic_cxr(self, report):
        report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
            .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace('  ', ' ') \
            .replace('  ', ' ').replace('  ', ' ').replace('  ', ' ').replace('  ', ' ').replace('  ', ' ') \
            .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
            .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
            .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
            .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
            .strip().lower().split('. ')
        sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '')
                                        .replace('\\', '').replace("'", '').strip().lower())
        tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
        report = ' . '.join(tokens) + ' .'
        return report

    def get_token_by_id(self, id):
        return self.idx2token[id]

    def get_id_by_token(self, token):
        if token not in self.token2idx:
            return self.token2idx['<unk>']
        return self.token2idx[token]

    def get_vocab_size(self):
        return len(self.token2idx)

    def __call__(self, report):
        tokens = self.clean_report(report).split()
        ids = []
        for token in tokens:
            ids.append(self.get_id_by_token(token))
        ids = [0] + ids + [0]
        return ids

    def decode(self, ids):
        txt = ''
        for i, idx in enumerate(ids):
            if idx > 0:
                if i >= 1:
                    txt += ' '
                txt += self.idx2token[idx]
            else:
                break
        return txt

    def decode_batch(self, ids_batch):
        out = []
        for ids in ids_batch:
            out.append(self.decode(ids))
        return out


In [36]:
def get_args():
    # 使用一个简单的参数对象来替代argparse的解析
    class Args:
        def __init__(self):
            self.ann_path = 'data/iu_xray/annotation_label_with_tokens.json'
            self.threshold = 3
            self.dataset_name = 'iu_xray'
            self.max_seq_length = 200
            self.seed = 9233
            # 添加其他必要的参数
    
    return Args()

# 设置随机种子
def set_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    np.random.seed(seed)

# 主函数
def init_tokenizer():
    # 获取参数
    args = get_args()
    
    # 设置随机种子
    set_seed(args.seed)
    
    # 创建tokenizer
    tokenizer = Tokenizer(args)
    return tokenizer

# 在Jupyter中运行此代码以获取tokenizer
tokenizer = init_tokenizer()

In [37]:
tokenizer

<__main__.Tokenizer at 0x7df91d6ae390>

In [38]:
# 输出词汇表大小
vocab_size = tokenizer.get_vocab_size()
print(f"Vocabulary size: {vocab_size}")

Vocabulary size: 760


In [39]:
import json
from collections import Counter

# 假设你已经初始化了tokenizer
# tokenizer = init_tokenizer()  # 如果还没有初始化，请先运行之前的代码

# 1. 读取处理后的文件
processed_file_path = "/home/ghan/R2Gen/data/iu_xray/annotation_label_with_tokens.json"

print("Loading processed data...")
with open(processed_file_path, 'r') as f:
    processed_data = json.load(f)

# 2. 收集所有的tokens并去重（确保全部小写）
all_tokens = set()
token_counter = Counter()
token_examples = {}  # 保存每个token出现的一个示例

print("Collecting tokens from processed data...")
for split_name, split_data in processed_data.items():
    print(f"Processing {split_name} split...")
    for sample in split_data:
        if 'tokens' in sample:
            # 确保所有token都转为小写
            lowercase_tokens = [token.lower() for token in sample['tokens']]
            
            # 记录每个token的出现次数
            for token in lowercase_tokens:
                token_counter[token] += 1
                
                # 如果还没有这个token的示例，保存当前样本作为示例
                if token not in token_examples and 'report' in sample:
                    token_examples[token] = sample['report']
            
            # 将所有小写token添加到集合中去重
            all_tokens.update(lowercase_tokens)

# 3. 检查是否都在tokenizer的词汇表中
missing_tokens = []
existing_tokens = []

print("Checking tokens against tokenizer vocabulary...")
for token in all_tokens:
    if token in tokenizer.token2idx:
        existing_tokens.append(token)
    else:
        missing_tokens.append(token)

# 4. 统计和输出结果
print("\n===== Token Analysis =====")
print(f"Total unique tokens from RadGraph (lowercase): {len(all_tokens)}")
print(f"Tokens present in tokenizer: {len(existing_tokens)} ({len(existing_tokens)/len(all_tokens)*100:.2f}%)")
print(f"Tokens missing from tokenizer: {len(missing_tokens)} ({len(missing_tokens)/len(all_tokens)*100:.2f}%)")

# 5. 展示一些最频繁的缺失token
if missing_tokens:
    print("\n===== Most Frequent Missing Tokens =====")
    most_frequent_missing = [(token, token_counter[token]) for token in missing_tokens]
    most_frequent_missing.sort(key=lambda x: x[1], reverse=True)
    
    # 显示前20个最频繁缺失的token和它们的示例
    for token, count in most_frequent_missing[:20]:
        example = token_examples.get(token, "N/A")
        print(f"'{token}' - appears {count} times")
        print(f"Example context: '{example[:100]}...'")  # 只显示前100个字符
        print("-"*50)

# 6. 展示tokenizer中有但RadGraph没提取的tokens
tokenizer_only_tokens = set(tokenizer.token2idx.keys()) - all_tokens
if len(tokenizer_only_tokens) > 0:
    print(f"\nTokens only in tokenizer (not in RadGraph): {len(tokenizer_only_tokens)}")
    print("Examples:", list(tokenizer_only_tokens)[:10])  # 显示10个示例

# 7. 保存分析结果以供进一步研究
output_analysis_path = "/home/ghan/R2Gen/data/iu_xray/token_analysis.json"
analysis_results = {
    "total_radgraph_tokens": len(all_tokens),
    "tokens_in_tokenizer": len(existing_tokens),
    "tokens_missing_from_tokenizer": len(missing_tokens),
    "missing_tokens_with_frequency": dict([(token, token_counter[token]) for token in missing_tokens]),
    "tokenizer_only_tokens": list(tokenizer_only_tokens)
}

print(f"\nSaving analysis results to {output_analysis_path}")
with open(output_analysis_path, 'w') as f:
    json.dump(analysis_results, f, indent=2)

print("Analysis complete!")

Loading processed data...
Collecting tokens from processed data...
Processing train split...
Processing val split...
Processing test split...
Checking tokens against tokenizer vocabulary...

===== Token Analysis =====
Total unique tokens from RadGraph (lowercase): 1141
Tokens present in tokenizer: 595 (52.15%)
Tokens missing from tokenizer: 546 (47.85%)

===== Most Frequent Missing Tokens =====
'-' - appears 174 times
Example context: 'The cardiomediastinal silhouette is normal in size and contour. There are a few XXXX opacities in th...'
--------------------------------------------------
'sided' - appears 58 times
Example context: 'The heart size is moderate to severely enlarged. There is prominence of the central pulmonary XXXX s...'
--------------------------------------------------
't' - appears 28 times
Example context: 'The aortic XXXX is mildly tortuous. The cardiomediastinal silhouette and pulmonary vasculature are w...'
--------------------------------------------------
'indet

In [40]:
import json
from tqdm import tqdm

# 假设你已经初始化了tokenizer
# tokenizer = init_tokenizer()  # 如果还没有初始化，请先运行之前的代码

def filter_tokens_by_tokenizer(input_path, output_path, tokenizer):
    print(f"Loading data from {input_path}...")
    with open(input_path, 'r') as f:
        data = json.load(f)
    
    # 统计信息
    total_tokens_before = 0
    total_tokens_after = 0
    removed_tokens_count = 0
    samples_with_no_tokens = 0
    
    print("Processing tokens...")
    # 处理每个分割
    for split_name, split_data in data.items():
        print(f"Processing {split_name} split with {len(split_data)} samples...")
        
        # 使用tqdm创建进度条
        for sample in tqdm(split_data, desc=f"{split_name}"):
            if 'tokens' in sample:
                original_tokens = sample['tokens']
                total_tokens_before += len(original_tokens)
                
                # 1. 转换为小写
                lowercase_tokens = [token.lower() for token in original_tokens]
                
                # 2. 只保留tokenizer中存在的tokens
                filtered_tokens = [token for token in lowercase_tokens if token in tokenizer.token2idx]
                
                # 更新统计信息
                total_tokens_after += len(filtered_tokens)
                removed_tokens_count += (len(original_tokens) - len(filtered_tokens))
                
                # 检查是否有样本的tokens被完全过滤掉了
                if len(filtered_tokens) == 0 and len(original_tokens) > 0:
                    samples_with_no_tokens += 1
                
                # 更新样本中的tokens
                sample['tokens'] = filtered_tokens
    
    # 输出统计信息
    print("\n===== Filtering Results =====")
    print(f"Total tokens before filtering: {total_tokens_before}")
    print(f"Total tokens after filtering: {total_tokens_after}")
    print(f"Removed tokens: {removed_tokens_count} ({removed_tokens_count/total_tokens_before*100:.2f}% of original)")
    print(f"Samples that lost all their tokens: {samples_with_no_tokens}")
    
    # 保存处理后的数据
    print(f"Saving filtered data to {output_path}...")
    with open(output_path, 'w') as f:
        json.dump(data, f)
    
    print("Done!")

# 使用函数处理文件
input_path = "/home/ghan/R2Gen/data/iu_xray/annotation_label_with_tokens.json"
output_path = "/home/ghan/R2Gen/data/iu_xray/annotation_label_with_filtered_tokens.json"

filter_tokens_by_tokenizer(input_path, output_path, tokenizer)

Loading data from /home/ghan/R2Gen/data/iu_xray/annotation_label_with_tokens.json...
Processing tokens...
Processing train split with 2069 samples...


 ... (more hidden) ...


Processing val split with 296 samples...


 ... (more hidden) ...


Processing test split with 590 samples...


 ... (more hidden) ...


===== Filtering Results =====
Total tokens before filtering: 51709
Total tokens after filtering: 50572
Removed tokens: 1137 (2.20% of original)
Samples that lost all their tokens: 0
Saving filtered data to /home/ghan/R2Gen/data/iu_xray/annotation_label_with_filtered_tokens.json...
Done!





In [41]:
def get_args():
    # 使用一个简单的参数对象来替代argparse的解析
    class Args:
        def __init__(self):
            self.ann_path = 'data/mimic_cxr/annotation_label_with_tokens.json'
            self.threshold = 10
            self.dataset_name = 'mimic'
            self.max_seq_length = 200
            self.seed = 9233
            # 添加其他必要的参数
    
    return Args()

# 设置随机种子
def set_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    np.random.seed(seed)

# 主函数
def init_tokenizer():
    # 获取参数
    args = get_args()
    
    # 设置随机种子
    set_seed(args.seed)
    
    # 创建tokenizer
    tokenizer = Tokenizer(args)
    return tokenizer

# 在Jupyter中运行此代码以获取tokenizer
tokenizer = init_tokenizer()
# 输出词汇表大小
vocab_size = tokenizer.get_vocab_size()
print(f"Vocabulary size: {vocab_size}")

Vocabulary size: 4335


In [42]:
# 输出词汇表大小
vocab_size = tokenizer.get_vocab_size()
print(f"Vocabulary size: {vocab_size}")

Vocabulary size: 4335


In [43]:
import json
from tqdm import tqdm

# 假设你已经初始化了tokenizer
# tokenizer = init_tokenizer()  # 如果还没有初始化，请先运行之前的代码

def filter_tokens_by_tokenizer(input_path, output_path, tokenizer):
    print(f"Loading data from {input_path}...")
    with open(input_path, 'r') as f:
        data = json.load(f)
    
    # 统计信息
    total_tokens_before = 0
    total_tokens_after = 0
    removed_tokens_count = 0
    samples_with_no_tokens = 0
    
    print("Processing tokens...")
    # 处理每个分割
    for split_name, split_data in data.items():
        print(f"Processing {split_name} split with {len(split_data)} samples...")
        
        # 使用tqdm创建进度条
        for sample in tqdm(split_data, desc=f"{split_name}"):
            if 'tokens' in sample:
                original_tokens = sample['tokens']
                total_tokens_before += len(original_tokens)
                
                # 1. 转换为小写
                lowercase_tokens = [token.lower() for token in original_tokens]
                
                # 2. 只保留tokenizer中存在的tokens
                filtered_tokens = [token for token in lowercase_tokens if token in tokenizer.token2idx]
                
                # 更新统计信息
                total_tokens_after += len(filtered_tokens)
                removed_tokens_count += (len(original_tokens) - len(filtered_tokens))
                
                # 检查是否有样本的tokens被完全过滤掉了
                if len(filtered_tokens) == 0 and len(original_tokens) > 0:
                    samples_with_no_tokens += 1
                
                # 更新样本中的tokens
                sample['tokens'] = filtered_tokens
    
    # 输出统计信息
    print("\n===== Filtering Results =====")
    print(f"Total tokens before filtering: {total_tokens_before}")
    print(f"Total tokens after filtering: {total_tokens_after}")
    print(f"Removed tokens: {removed_tokens_count} ({removed_tokens_count/total_tokens_before*100:.2f}% of original)")
    print(f"Samples that lost all their tokens: {samples_with_no_tokens}")
    
    # 保存处理后的数据
    print(f"Saving filtered data to {output_path}...")
    with open(output_path, 'w') as f:
        json.dump(data, f)
    
    print("Done!")

# 使用函数处理文件
input_path = "/home/ghan/R2Gen/data/mimic_cxr/annotation_label_with_tokens.json"
output_path = "/home/ghan/R2Gen/data/mimic_cxr/annotation_label_with_filtered_tokens.json"

filter_tokens_by_tokenizer(input_path, output_path, tokenizer)

Loading data from /home/ghan/R2Gen/data/mimic_cxr/annotation_label_with_tokens.json...
Processing tokens...
Processing train split with 270790 samples...


 ... (more hidden) ...


Processing val split with 2130 samples...


 ... (more hidden) ...


Processing test split with 3858 samples...


 ... (more hidden) ...



===== Filtering Results =====
Total tokens before filtering: 6305288
Total tokens after filtering: 6294404
Removed tokens: 10884 (0.17% of original)
Samples that lost all their tokens: 1
Saving filtered data to /home/ghan/R2Gen/data/mimic_cxr/annotation_label_with_filtered_tokens.json...
Done!
