In [1]:
import os, sys

os.chdir("../../..")
print(os.path.abspath("./"))

/data/projects/src/github.com/any35/MOS


In [8]:
from transformers import BertTokenizer, BertModel
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.tokenization_utils_base import BatchEncoding

tokenizer: BertTokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased", local_files_only=True)
bert: BertModel = BertModel.from_pretrained("bert-base-multilingual-cased", local_files_only=True)

token: BatchEncoding = tokenizer(
    "please explan EAT",
    padding="max_length",
    max_length=40,
    return_tensors="pt",
)
result: BaseModelOutputWithPoolingAndCrossAttentions = bert(**token)
print(result.last_hidden_state.shape, token.keys())

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([1, 40, 768]) dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])


In [6]:
from mos.models.sam.modeling_sam.embedding.cmri_cls2text_embedding import *

print(get_cls_text_embedding(2).shape)

torch.Size([1, 3, 768])


In [2]:
# make token dataset
# 数据集需要
# *.pt {
#   "image": torch.tensor, # [N, H, W]
#   "mae_pair": torch.tensor, # [N, 7] (bs, [src_image_index, token_index, target_image_index*5])
#   "mae_pair_simple": torch.tensor, # [N, 3] (bs, [src_image_index, token_index, target_image_index])
# }
# token.pt {
#   "token_list": torch.tensor, # [N, seq_len, hidden_size]
#   "token_selector": torch.tensor, # [N_l, 2] (bs, [offset, len])
# }
import torch, glob, json
from mos.models.sam.modeling_sam.embedding.cmri_cls2text_embedding import cls2index_key, make_cls_text_compose
from mos.models.sam.modeling_sam.embedding.text_embedding import text2tensor

os.makedirs(".cache/dataset/text-mae-sam-dataset/split", exist_ok=True)

exclude_files = set(["CMRI.pt"])

file_list = glob.glob(".cache/dataset/mmae-dataset/split-1.5mm-segment-only/*.pt")
file_list.sort()

token_index_dict = dict()
all_text_list = []
token_selector_list = []

total_image_count = 0
total_pair_count = 0

for file in file_list:
    file_name = file.split("/")[-1]
    if file_name in exclude_files:
        continue

    data = torch.load(file)

    train_pair = []  # (bs,[src_index, token_index, target_index*5])
    mae_pair_simple_list = []  # (bs,[src_index, token_index, target_index])

    image = data["image"]

    if "all_merged_pair" in data:
        all_merged_pair = data["all_merged_pair"].tolist()
    else:
        pairs = data["pair"].tolist()
        all_merged_pair = [
            [src_modality, src_index, target_modality, 30, 30, 30, 30, target_index, -1, -1, -1, -1]
            for src_modality, src_index, target_modality, target_index in pairs
        ]

    for pair in data["pair"].tolist():
        src_modality, src_index, target_modality, target_index = pair

        key = cls2index_key([target_modality, 30, 30, 30, 30])
        if key in token_index_dict:
            token_index = token_index_dict[key]
        else:
            token_index = len(token_selector_list)
            token_index_dict[key] = token_index
            text_list = make_cls_text_compose([target_modality, 30, 30, 30, 30])
            token_selector_list.append([len(all_text_list), len(text_list)])
            all_text_list.extend(text_list)

        mae_pair_simple_list.append([src_index, token_index, target_index])

    for merged_pair in all_merged_pair:
        src_modality, src_index, attr1, attr2, attr3, attr4, attr5, idx1, idx2, idx3, idx4, idx5 = merged_pair

        key = cls2index_key([attr1, attr2, attr3, attr4, attr5])
        if key in token_index_dict:
            token_index = token_index_dict[key]
        else:
            token_index = len(token_selector_list)
            token_index_dict[key] = token_index
            text_list = make_cls_text_compose([attr1, attr2, attr3, attr4, attr5])
            token_selector_list.append([len(all_text_list), len(text_list)])
            all_text_list.extend(text_list)
        train_pair.append([src_index, token_index, idx1, idx2, idx3, idx4, idx5])

    torch.save(
        {
            "image": image,
            "mae_pair": torch.tensor(train_pair),
            "mae_pair_simple": torch.tensor(mae_pair_simple_list),
        },
        f".cache/dataset/text-mae-sam-dataset/split/{file_name}",
    )
    print(f"processed {file_name}...\t {len(train_pair)} pairs")
    total_image_count += image.shape[0]
    total_pair_count += len(train_pair)

# todo, print max seq_len
print(f"total {total_image_count} train images, {total_pair_count} pairs!")

print(f"cls text count: {len(all_text_list)}")
# print("\n".join(all_text_list))

json.dump(all_text_list, open(".cache/dataset/text-mae-sam-dataset/cls_text.json", "w"), indent=2)

token_list = []
for txt in all_text_list:
    token_list.append(text2tensor(txt, 40).detach())
token_list = torch.cat(token_list, dim=0).float()

token_selector_list = torch.tensor(token_selector_list, dtype=torch.int32)

assert len(token_selector_list.shape) == 2
assert len(token_list.shape) == 3 and token_list.shape[1:] == (40, 768)

torch.save(
    {
        "token_list": token_list,
        "token_selector": token_selector_list,
    },
    ".cache/dataset/text-mae-sam-dataset/split/token.pt",
)

processed ACDC.pt...	 18006 pairs
processed CMRI-0.pt...	 149 pairs
processed CMRI-1.pt...	 136 pairs
processed CMRI-2.pt...	 154 pairs
processed CMRI-3.pt...	 150 pairs
processed CMRI-4.pt...	 152 pairs
processed CMRI-5.pt...	 162 pairs
processed CMRI-6.pt...	 145 pairs
processed CMRI-7.pt...	 162 pairs
processed CMRI-8.pt...	 146 pairs
processed CMRI-9.pt...	 145 pairs
processed CT_MR_2D_Dataset_DA.pt...	 48528 pairs
processed EMIDEC.pt...	 2124 pairs
processed HVSMR2016.pt...	 5269 pairs
processed LeftAtrialSegmentationChallenge2013.pt...	 2963 pairs
processed LeftAtrialSegmentationKaggle.pt...	 1351 pairs
processed MMWHS2017.pt...	 185103 pairs
processed MyoPS2020.pt...	 1968 pairs
processed VarDA.pt...	 6107 pairs
processed mnms.pt...	 35452 pairs
processed msd-seg.pt...	 1351 pairs
processed yorku_CardiacMRIDataset.pt...	 15033 pairs
total 113751 train images, 324756 pairs!
cls text count: 6297


In [3]:
# 组合text-mae-sam-dataset数据集

import torch, glob

exclude_files = set(["CMRI.pt", "token.pt"])
for i in range(10):
    exclude_files.add(f"CMRI-{i}.pt")


file_list = glob.glob(".cache/dataset/text-mae-sam-dataset/split/*.pt")
file_list.sort()

image_list = [
    torch.zeros(1, 214, 214, dtype=torch.uint8),  # zero padding
]
train_pair = []
mae_pair_simple_list = []

label = 1
for file in file_list:
    file_name = file.split("/")[-1]
    if file_name in exclude_files:
        continue

    data = torch.load(file)
    image, mea_pair, mea_pair_simple = data["image"], data["mae_pair"], data["mae_pair_simple"]
    #  (bs, [src_image_index, token_index, target_image_index*5])
    offset = mea_pair.clone()
    offset[offset >= 0] = label
    offset[offset == -1] = 1
    offset[:, 1] = 0
    mea_pair += offset
    train_pair.append(mea_pair)

    # (bs, [src_image_index, token_index, target_image_index])
    offset = mea_pair_simple.clone()
    offset[offset >= 0] = label
    offset[offset == -1] = 1
    offset[:, 1] = 0
    mea_pair_simple += offset
    mae_pair_simple_list.append(mea_pair_simple)

    image_list.append(image)
    label += image.shape[0]

image_list = torch.cat(image_list, dim=0).to(torch.uint8)
train_pair = torch.cat(train_pair, dim=0)
mae_pair_simple_list = torch.cat(mae_pair_simple_list, dim=0)

# load token
data = torch.load(".cache/dataset/text-mae-sam-dataset/split/token.pt")
token_list, token_selector_list = data["token_list"], data["token_selector"]
partition_index = label

# 划分数据集, 做交叉测试
for partition in range(10):
    train_image_list = [image_list]
    train_pair_list = [train_pair]
    train_pair_simple_list = [mae_pair_simple_list]

    train_cmri_only_list = []

    image_index = label

    valid_pair_list = []
    valid_pair_simple_list = []
    for i in range(10):
        file = f".cache/dataset/text-mae-sam-dataset/split/CMRI-{i}.pt"
        data = torch.load(file)
        image, mea_pair, mea_pair_simple = data["image"], data["mae_pair"], data["mae_pair_simple"]

        is_valid = i == partition

        #  (bs, [src_image_index, token_index, target_image_index*5])
        offset = mea_pair.clone()
        offset[offset >= 0] = image_index
        offset[:, 1] = 0
        offset[offset == -1] = 1
        mea_pair += offset
        if is_valid:
            valid_pair_list.append(mea_pair)
        else:
            for _ in range(10):  # cmri数据扩充10倍
                train_pair_list.append(mea_pair)

        # (bs, [src_image_index, token_index, target_image_index])
        offset = mea_pair_simple.clone()
        offset[offset >= 0] = image_index
        offset[:, 1] = 0
        offset[offset == -1] = 1
        mea_pair_simple += offset
        if is_valid:
            valid_pair_simple_list.append(mea_pair_simple)
        else:
            train_cmri_only_list.append(mea_pair_simple)
            for _ in range(10):  # cmri数据扩充10倍
                train_pair_simple_list.append(mea_pair_simple)

        train_image_list.append(image)
        image_index += image.shape[0]

    train_image_list = torch.cat(train_image_list, dim=0)
    train_pair_list = torch.cat(train_pair_list, dim=0)
    train_pair_simple_list = torch.cat(train_pair_simple_list, dim=0)
    train_cmri_only_list = torch.cat(train_cmri_only_list, dim=0)
    valid_pair_list = torch.cat(valid_pair_list, dim=0)
    valid_pair_simple_list = torch.cat(valid_pair_simple_list, dim=0)

    torch.save(
        {
            "image": train_image_list.to(torch.uint8),
            "token_list": token_list.float(),
            "token_selector_list": token_selector_list.int(),
            "train_mae_pair": train_pair_list.int(),
            "train_mae_pair_simple": train_pair_simple_list.int(),
            "train_cmri_pair_simple": train_cmri_only_list.int(),
            "valid_mae_pair": valid_pair_list.int(),
            "valid_mae_pair_simple": valid_pair_simple_list.int(),
        },
        f".cache/dataset/text-mae-sam-dataset/dataset-all-label-{partition}.pt",
    )
    print(f"processed dataset-all-label-{partition}.pt...")

processed dataset-0.pt...
processed dataset-1.pt...
processed dataset-2.pt...
processed dataset-3.pt...
processed dataset-4.pt...
processed dataset-5.pt...
processed dataset-6.pt...
processed dataset-7.pt...
processed dataset-8.pt...
processed dataset-9.pt...


In [2]:
# make token dataset with label filter (只保留和脂肪相关/附近的label, 并且加入unlable的图像)
# 数据集需要
# *.pt {
#   "image": torch.tensor, # [N, H, W]
#   "mae_pair": torch.tensor, # [N, 7] (bs, [src_image_index, token_index, target_image_index*5])
#   "mae_pair_simple": torch.tensor, # [N, 3] (bs, [src_image_index, token_index, target_image_index])
# }
# token.pt {
#   "token_list": torch.tensor, # [N, seq_len, hidden_size]
#   "token_selector": torch.tensor, # [N_l, 2] (bs, [offset, len])
# }
import torch, glob, json
from mos.models.sam.modeling_sam.embedding.cmri_cls2text_embedding import cls2index_key, make_cls_text_compose
from mos.models.sam.modeling_sam.embedding.text_embedding import text2tensor

export_dir = ".cache/dataset/text-mae-sam-dataset/split-filter-label"

os.makedirs(export_dir, exist_ok=True)

exclude_files = set(["CMRI.pt"])

file_list = glob.glob(".cache/dataset/mmae-dataset/split-1.5mm-segment-only/*.pt")
file_list.sort()

include_label = set([11, 12, 13, 20])  # LV,RV,MYO,EAT
include_image_type = set([0, 1, 2, 3, 4, 5, 7])  # 排除CT和超声影像


def is_any_included(label_list):
    for label in label_list:
        if label in include_label:
            return True
    return False


token_index_dict = dict()
all_text_list = []
token_selector_list = []

total_image_count = 0
total_pair_count = 0

for file in file_list:
    file_name = file.split("/")[-1]
    if file_name in exclude_files:
        continue

    data = torch.load(file)

    new_image_index = []
    new_image_mapping = dict()

    train_pair = []  # (bs,[src_index, token_index, target_index*5])
    mae_pair_simple_list = []  # (bs,[src_index, token_index, target_index])

    image = data["image"]

    if "all_merged_pair" in data:
        all_merged_pair = data["all_merged_pair"].tolist()
    else:
        pairs = data["pair"].tolist()
        all_merged_pair = [
            [src_modality, src_index, target_modality, 30, 30, 30, 30, target_index, -1, -1, -1, -1]
            for src_modality, src_index, target_modality, target_index in pairs
        ]

    for pair in data["pair"].tolist():
        src_modality, src_index, target_modality, target_index = pair

        if src_modality not in include_image_type:
            continue

        if target_modality not in include_label:
            continue

        if src_index not in new_image_mapping:
            new_image_mapping[src_index] = len(new_image_index)
            new_image_index.append(src_index)
        src_index = new_image_mapping[src_index]

        if target_index not in new_image_mapping:
            new_image_mapping[target_index] = len(new_image_index)
            new_image_index.append(target_index)
        target_index = new_image_mapping[target_index]

        key = cls2index_key([target_modality, 30, 30, 30, 30])
        if key in token_index_dict:
            token_index = token_index_dict[key]
        else:
            token_index = len(token_selector_list)
            token_index_dict[key] = token_index
            text_list = make_cls_text_compose([target_modality, 30, 30, 30, 30])
            token_selector_list.append([len(all_text_list), len(text_list)])
            all_text_list.extend(text_list)

        mae_pair_simple_list.append([src_index, token_index, target_index])

    for merged_pair in all_merged_pair:
        src_modality, src_index, attr1, attr2, attr3, attr4, attr5, idx1, idx2, idx3, idx4, idx5 = merged_pair

        if src_modality not in include_image_type:
            continue

        index_list = [idx1, idx2, idx3, idx4, idx5]
        target_modality_list = [attr1, attr2, attr3, attr4, attr5]
        if not is_any_included(target_modality_list):
            continue

        for i, (label, index) in enumerate(zip(target_modality_list, index_list)):
            if index < 0:
                continue
            if index not in new_image_mapping:
                new_image_mapping[index] = len(new_image_index)
                new_image_index.append(index)
            index_list[i] = new_image_mapping[index]

        key = cls2index_key([attr1, attr2, attr3, attr4, attr5])
        if key in token_index_dict:
            token_index = token_index_dict[key]
        else:
            token_index = len(token_selector_list)
            token_index_dict[key] = token_index
            text_list = make_cls_text_compose([attr1, attr2, attr3, attr4, attr5])
            token_selector_list.append([len(all_text_list), len(text_list)])
            all_text_list.extend(text_list)
        train_pair.append([src_index, token_index, *index_list])

    if len(new_image_index) == 0:
        continue
    image = image[new_image_index]

    torch.save(
        {
            "image": image,
            "mae_pair": torch.tensor(train_pair),
            "mae_pair_simple": torch.tensor(mae_pair_simple_list),
        },
        f"{export_dir}/{file_name}",
    )
    print(f"processed {file_name}...\t {len(train_pair)} pairs")
    total_image_count += image.shape[0]
    total_pair_count += len(train_pair)

# 加入CMRI图像
exclude_files = set(["CMRI-private.pt"])
file_list = glob.glob(".cache/dataset/mmae-dataset/split-1.5mm-image-only/CMRI-*.pt")
file_list.sort()
for file in file_list:
    file_name = file.split("/")[-1]
    if file_name in exclude_files:
        continue

    data = torch.load(file)

    new_image_index = []
    new_image_mapping = dict()

    image_pair_simple_list = []  # (bs,[src_index, token_index, target_index])

    image = data["image"]

    pairs = data["pair"].tolist()
    all_merged_pair = [
        [src_modality, src_index, target_modality, 30, 30, 30, 30, target_index, -1, -1, -1, -1]
        for src_modality, src_index, target_modality, target_index in pairs
    ]

    for pair in data["pair"].tolist():
        src_modality, src_index, target_modality, target_index = pair

        if src_modality not in include_image_type:
            continue

        if src_index not in new_image_mapping:
            new_image_mapping[src_index] = len(new_image_index)
            new_image_index.append(src_index)
        src_index = new_image_mapping[src_index]

        if target_index not in new_image_mapping:
            new_image_mapping[target_index] = len(new_image_index)
            new_image_index.append(target_index)
        target_index = new_image_mapping[target_index]

        key = cls2index_key([target_modality, 30, 30, 30, 30])
        if key in token_index_dict:
            token_index = token_index_dict[key]
        else:
            token_index = len(token_selector_list)
            token_index_dict[key] = token_index
            text_list = make_cls_text_compose([target_modality, 30, 30, 30, 30])
            token_selector_list.append([len(all_text_list), len(text_list)])
            all_text_list.extend(text_list)

        image_pair_simple_list.append([src_index, token_index, target_index])

    if len(new_image_index) == 0:
        continue

    image = image[new_image_index]

    file_name = file_name.replace(".pt", "-image.pt")

    torch.save(
        {
            "image": image,
            "image_pair_simple": torch.tensor(image_pair_simple_list),
        },
        f"{export_dir}/{file_name}",
    )
    print(f"processed {file_name}...\t image {len(image_pair_simple_list)} pairs")
    total_image_count += image.shape[0]
    total_pair_count += len(train_pair)

# todo, print max seq_len
print(f"total {total_image_count} train images, {total_pair_count} pairs!")

print(f"cls text count: {len(all_text_list)}")
# print("\n".join(all_text_list))

json.dump(all_text_list, open(".cache/dataset/text-mae-sam-dataset/cls_text-filter-label.json", "w"), indent=2)

token_list = []
for txt in all_text_list:
    token_list.append(text2tensor(txt, 40).detach())
token_list = torch.cat(token_list, dim=0).float()

token_selector_list = torch.tensor(token_selector_list, dtype=torch.int32)

assert len(token_selector_list.shape) == 2
assert len(token_list.shape) == 3 and token_list.shape[1:] == (40, 768)

torch.save(
    {
        "token_list": token_list,
        "token_selector": token_selector_list,
    },
    f"{export_dir}/token.pt",
)

processed ACDC.pt...	 18006 pairs
processed CMRI-0.pt...	 149 pairs
processed CMRI-1.pt...	 136 pairs
processed CMRI-2.pt...	 154 pairs
processed CMRI-3.pt...	 150 pairs
processed CMRI-4.pt...	 152 pairs
processed CMRI-5.pt...	 162 pairs
processed CMRI-6.pt...	 145 pairs
processed CMRI-7.pt...	 162 pairs
processed CMRI-8.pt...	 146 pairs
processed CMRI-9.pt...	 145 pairs
processed CT_MR_2D_Dataset_DA.pt...	 20594 pairs
processed EMIDEC.pt...	 2124 pairs
processed HVSMR2016.pt...	 3158 pairs
processed MMWHS2017.pt...	 52218 pairs
processed MyoPS2020.pt...	 1968 pairs
processed VarDA.pt...	 6107 pairs
processed mnms.pt...	 35452 pairs
processed yorku_CardiacMRIDataset.pt...	 15033 pairs
processed CMRI-0-image.pt...	 image 3090 pairs
processed CMRI-1-image.pt...	 image 2340 pairs
processed CMRI-2-image.pt...	 image 3000 pairs
processed CMRI-3-image.pt...	 image 3210 pairs
processed CMRI-4-image.pt...	 image 2970 pairs
processed CMRI-5-image.pt...	 image 2850 pairs
processed CMRI-6-image.p

In [2]:
# 组合text-mae-sam-dataset数据集(只保留和脂肪相关/附近的label)

import torch, glob

exclude_files = set(["CMRI.pt", "token.pt"])
for i in range(10):
    exclude_files.add(f"CMRI-{i}.pt")
    exclude_files.add(f"CMRI-{i}-image.pt")

src_dir = ".cache/dataset/text-mae-sam-dataset/split-filter-label"
export_dir = ".cache/dataset/text-mae-sam-dataset"

file_list = glob.glob(f"{src_dir}/*.pt")
file_list.sort()

image_list = [
    torch.zeros(1, 214, 214, dtype=torch.uint8),  # zero padding
]
train_pair = []
mae_pair_simple_list = []

label = 1
for file in file_list:
    file_name = file.split("/")[-1]
    if file_name in exclude_files:
        continue

    print(file)
    data = torch.load(file)
    image, mea_pair, mea_pair_simple = data["image"], data["mae_pair"], data["mae_pair_simple"]
    #  (bs, [src_image_index, token_index, target_image_index*5])
    offset = mea_pair.clone()
    offset[offset >= 0] = label
    offset[offset == -1] = 1
    offset[:, 1] = 0
    mea_pair += offset
    train_pair.append(mea_pair)

    # (bs, [src_image_index, token_index, target_image_index])
    offset = mea_pair_simple.clone()
    offset[offset >= 0] = label
    offset[offset == -1] = 1
    offset[:, 1] = 0
    mea_pair_simple += offset
    mae_pair_simple_list.append(mea_pair_simple)

    image_list.append(image)
    label += image.shape[0]

image_list = torch.cat(image_list, dim=0).to(torch.uint8)
train_pair = torch.cat(train_pair, dim=0)
mae_pair_simple_list = torch.cat(mae_pair_simple_list, dim=0)

# load token
data = torch.load(f"{src_dir}/token.pt")
token_list, token_selector_list = data["token_list"], data["token_selector"]
partition_index = label

# 划分数据集, 做交叉测试
for partition in range(10):
    train_image_list = [image_list]
    train_pair_list = [train_pair]
    train_pair_simple_list = [mae_pair_simple_list]

    train_cmri_only_list = []

    image_index = label

    valid_pair_list = []
    valid_pair_simple_list = []
    for i in range(10):
        file = f"{src_dir}/CMRI-{i}.pt"
        data = torch.load(file)
        image, mea_pair, mea_pair_simple = data["image"], data["mae_pair"], data["mae_pair_simple"]

        is_valid = i == partition

        #  (bs, [src_image_index, token_index, target_image_index*5])
        offset = mea_pair.clone()
        offset[offset >= 0] = image_index
        offset[:, 1] = 0
        offset[offset == -1] = 1
        mea_pair += offset
        if is_valid:
            valid_pair_list.append(mea_pair)
        else:
            for _ in range(10):  # cmri数据扩充10倍
                train_pair_list.append(mea_pair)

        # (bs, [src_image_index, token_index, target_image_index])
        offset = mea_pair_simple.clone()
        offset[offset >= 0] = image_index
        offset[:, 1] = 0
        offset[offset == -1] = 1
        mea_pair_simple += offset
        if is_valid:
            valid_pair_simple_list.append(mea_pair_simple)
        else:
            for _ in range(15):  # cmri数据扩充15倍
                train_cmri_only_list.append(mea_pair_simple)
            for _ in range(10):  # cmri数据扩充10倍
                train_pair_simple_list.append(mea_pair_simple)

        train_image_list.append(image)
        image_index += image.shape[0]

    for i in range(10):
        file = f"{src_dir}/CMRI-{i}-image.pt"
        data = torch.load(file)
        image, mea_pair_simple = data["image"], data["image_pair_simple"]

        is_valid = i == partition

        # (bs, [src_image_index, token_index, target_image_index])
        offset = mea_pair_simple.clone()
        offset[offset >= 0] = image_index
        offset[:, 1] = 0
        offset[offset == -1] = 1
        mea_pair_simple += offset
        if is_valid:
            # valid_pair_simple_list.append(mea_pair_simple)
            pass
        else:
            train_cmri_only_list.append(mea_pair_simple)
            train_pair_simple_list.append(mea_pair_simple)

        train_image_list.append(image)
        image_index += image.shape[0]

    train_image_list = torch.cat(train_image_list, dim=0)
    train_pair_list = torch.cat(train_pair_list, dim=0)
    train_pair_simple_list = torch.cat(train_pair_simple_list, dim=0)
    train_cmri_only_list = torch.cat(train_cmri_only_list, dim=0)
    valid_pair_list = torch.cat(valid_pair_list, dim=0)
    valid_pair_simple_list = torch.cat(valid_pair_simple_list, dim=0)

    torch.save(
        {
            "image": train_image_list.to(torch.uint8),
            "token_list": token_list.float(),
            "token_selector_list": token_selector_list.int(),
            "train_mae_pair": train_pair_list.int(),
            "train_mae_pair_simple": train_pair_simple_list.int(),
            "train_cmri_pair_simple": train_cmri_only_list.int(),
            "valid_mae_pair": valid_pair_list.int(),
            "valid_mae_pair_simple": valid_pair_simple_list.int(),
        },
        f"{export_dir}/dataset-filter-label-{partition}.pt",
    )
    print(f"processed dataset-filter-label-{partition}.pt...")

.cache/dataset/text-mae-sam-dataset/split-filter-label/ACDC.pt
.cache/dataset/text-mae-sam-dataset/split-filter-label/CT_MR_2D_Dataset_DA.pt
.cache/dataset/text-mae-sam-dataset/split-filter-label/EMIDEC.pt
.cache/dataset/text-mae-sam-dataset/split-filter-label/HVSMR2016.pt
.cache/dataset/text-mae-sam-dataset/split-filter-label/MMWHS2017.pt
.cache/dataset/text-mae-sam-dataset/split-filter-label/MyoPS2020.pt
.cache/dataset/text-mae-sam-dataset/split-filter-label/VarDA.pt
.cache/dataset/text-mae-sam-dataset/split-filter-label/mnms.pt
.cache/dataset/text-mae-sam-dataset/split-filter-label/yorku_CardiacMRIDataset.pt
processed dataset-filter-label-0.pt...
processed dataset-filter-label-1.pt...
processed dataset-filter-label-2.pt...
processed dataset-filter-label-3.pt...
processed dataset-filter-label-4.pt...
processed dataset-filter-label-5.pt...
processed dataset-filter-label-6.pt...
processed dataset-filter-label-7.pt...
processed dataset-filter-label-8.pt...
processed dataset-filter-label

In [3]:
# show dataset
import torch
import matplotlib.pyplot as plt
from ipywidgets import interact

file = ".cache/dataset/text-mae-sam-dataset/dataset-filter-label-0.pt"
data = torch.load(file)
(
    all_image,  # (bs, h, w)
    train_mae_pair,  # (bs, [src_image_index, token_index, target_image_index*5])
    valid_mae_pair,
    train_mae_pair_simple,  # (bs, [src_image_index, token_index, target_image_index])
    train_cmri_pair_simple,
    valid_mae_pair_simple,
) = (
    data["image"],
    data["train_mae_pair"],
    data["valid_mae_pair"],
    data["train_mae_pair_simple"],
    data["train_cmri_pair_simple"],
    data["valid_mae_pair_simple"],
)


def draw_image(pair):
    pair_count = pair.shape[0] - 1

    @interact
    def _draw_image(pair_index=(0, pair_count)):
        src_image_index, token_index, idx1, idx2, idx3, idx4, idx5 = pair[pair_index].tolist()

        image = all_image[[src_image_index, idx1, idx2, idx3, idx4, idx5], ::]

        image = image.permute(1, 0, 2).reshape(214, -1)

        plt.title([src_image_index, token_index, idx1, idx2, idx3, idx4, idx5])
        plt.imshow(image, cmap="gray")

    return _draw_image


def draw_image_simple(pair):
    pair_count = pair.shape[0] - 1

    @interact
    def _draw_image(pair_index=(0, pair_count)):
        src_image_index, token_index, target_index = pair[pair_index].tolist()

        image = all_image[[src_image_index, target_index], ::]

        image = image.permute(1, 0, 2).reshape(214, -1)

        plt.title([src_image_index, token_index, target_index])
        plt.imshow(image, cmap="gray")

    return _draw_image


draw_image(train_mae_pair), draw_image(valid_mae_pair), draw_image_simple(train_mae_pair_simple), draw_image_simple(
    train_cmri_pair_simple
), draw_image_simple(valid_mae_pair_simple),

interactive(children=(IntSlider(value=84089, description='pair_index', max=168179), Output()), _dom_classes=('…

interactive(children=(IntSlider(value=74, description='pair_index', max=148), Output()), _dom_classes=('widget…

interactive(children=(IntSlider(value=43415, description='pair_index', max=86831), Output()), _dom_classes=('w…

interactive(children=(IntSlider(value=13855, description='pair_index', max=27711), Output()), _dom_classes=('w…

interactive(children=(IntSlider(value=74, description='pair_index', max=148), Output()), _dom_classes=('widget…

(<function __main__.draw_image.<locals>._draw_image(pair_index=(0, 168179))>,
 <function __main__.draw_image.<locals>._draw_image(pair_index=(0, 148))>,
 <function __main__.draw_image_simple.<locals>._draw_image(pair_index=(0, 86831))>,
 <function __main__.draw_image_simple.<locals>._draw_image(pair_index=(0, 27711))>,
 <function __main__.draw_image_simple.<locals>._draw_image(pair_index=(0, 148))>)

In [16]:
file

'.cache/dataset/text-mae-sam-dataset/split-filter-label/LeftAtrialSegmentationChallenge2013.pt'

In [2]:
# 从sam-dataset迁移数据
import torch, json
from typing import Any
from mos.models.sam.modeling_sam.embedding.cmri_cls2text_embedding import cls2index_key, make_cls_text_compose
from mos.models.sam.modeling_sam.embedding.text_embedding import text2tensor


base_path = ".cache/dataset/sam-dataset"
device = "cpu"

valid_meta: list[dict[str, Any]] = json.load(open(f"{base_path}/valid-metas.json"))
valid_images = torch.load(f"{base_path}/valid-images.ot", map_location=device).squeeze(1) * 255  # (bs, h, w)
valid_segments = torch.load(f"{base_path}/valid-segments.ot", map_location=device) * 255  # (bs, h, w)

train_meta: list[dict[str, Any]] = json.load(open(f"{base_path}/train-metas.json"))
train_images: torch.Tensor = torch.load(f"{base_path}/train-images.ot", map_location=device).squeeze(1) * 255
train_segments = torch.load(f"{base_path}/train-segments.ot", map_location=device) * 255

# 1: EAT => 20
# 2: MYO => 13
# 3: RV => 12
train_image_types = [d["cls"] for d in train_meta]
train_image_types = torch.tensor(train_image_types, dtype=torch.int32)
train_image_types[train_image_types == 1] = 20
train_image_types[train_image_types == 2] = 13
train_image_types[train_image_types == 3] = 12

valid_image_types = [d["cls"] for d in valid_meta]
valid_image_types = torch.tensor(valid_image_types, dtype=torch.int32)
valid_image_types[valid_image_types == 1] = 20
valid_image_types[valid_image_types == 2] = 13
valid_image_types[valid_image_types == 3] = 12

image_list = []
train_pair = []
valid_pair = []

token_index_dict = dict()
all_text_list = []
token_selector_list = []

for i, (image, segment, target_modality) in enumerate(
    zip(train_images.split(1, 0), train_segments.split(1, 0), train_image_types.tolist())
):
    if segment.sum() <= 0:
        continue

    key = cls2index_key([target_modality, 30, 30, 30, 30])
    if key in token_index_dict:
        token_index = token_index_dict[key]
    else:
        token_index = len(token_selector_list)
        token_index_dict[key] = token_index
        text_list = make_cls_text_compose([target_modality, 30, 30, 30, 30])
        token_selector_list.append([len(all_text_list), len(text_list)])
        all_text_list.extend(text_list)

    src_index = len(image_list)
    target_index = src_index + 1
    image_list.append(image)
    image_list.append(segment)
    train_pair.append([src_index, token_index, target_index])

for i, (image, segment, target_modality) in enumerate(
    zip(valid_images.split(1, 0), valid_segments.split(1, 0), valid_image_types.tolist())
):
    if segment.sum() <= 0:
        continue

    key = cls2index_key([target_modality, 30, 30, 30, 30])
    if key in token_index_dict:
        token_index = token_index_dict[key]
    else:
        token_index = len(token_selector_list)
        token_index_dict[key] = token_index
        text_list = make_cls_text_compose([target_modality, 30, 30, 30, 30])
        token_selector_list.append([len(all_text_list), len(text_list)])
        all_text_list.extend(text_list)

    src_index = len(image_list)
    target_index = src_index + 1
    image_list.append(image)
    image_list.append(segment)
    valid_pair.append([src_index, token_index, target_index])

train_pair = torch.tensor(train_pair, dtype=torch.int32)
valid_pair = torch.tensor(valid_pair, dtype=torch.int32)
image_list = torch.cat(image_list, dim=0)

token_list = []
for txt in all_text_list:
    token_list.append(text2tensor(txt, 40).detach())
token_list = torch.cat(token_list, dim=0).float()
token_selector_list = torch.tensor(token_selector_list, dtype=torch.int32)
assert len(token_selector_list.shape) == 2
assert len(token_list.shape) == 3 and token_list.shape[1:] == (40, 768)

json.dump(all_text_list, open(".cache/dataset/text-mae-sam-dataset/dataset-reprod-0.json", "w"), indent=2)

data = {
    "image": image_list,
    "train_mae_pair": torch.empty(0),
    "train_mae_pair_simple": torch.empty(0),
    "train_cmri_pair_simple": train_pair,
    "valid_mae_pair": torch.empty(0),
    "valid_mae_pair_simple": valid_pair,
    "token_list": token_list,
    "token_selector_list": token_selector_list,
}
torch.save(
    data,
    ".cache/dataset/text-mae-sam-dataset/dataset-reprod-0.pt",
)

In [19]:
def plot_segment(imgae, segment, alpha=0.4):
    import matplotlib.pyplot as plt

    """Plot segment on image
    Args:
        imgae: [h, w]
        segment: [h, w]
    """
    plt.imshow(imgae, cmap="gray")
    plt.imshow(segment, cmap="jet", alpha=alpha)
    plt.show()


def ui_draw_segment_3d(image, segment, alpha=0.4):
    from ipywidgets import interact
    import matplotlib.pyplot as plt

    img_slices, _h, _w = image.shape

    @interact
    def _draw_segment_ui_3d(show_segment=True, slice_index=(0, img_slices - 1)):
        if show_segment is None:
            return
        if show_segment and segment is not None:
            plot_segment(image[slice_index, ::], segment[slice_index, ::], alpha)
        else:
            plt.imshow(image[slice_index, ::], cmap="gray")

    return _draw_segment_ui_3d


# 1: EAT
# 2: MYO
# 3: RV
ui_draw_segment_3d(train_images, train_segments * (train_image_types == 3)[:, None, None])

interactive(children=(Checkbox(value=True, description='show_segment'), IntSlider(value=5810, description='sli…

<function __main__.ui_draw_segment_3d.<locals>._draw_segment_ui_3d(show_segment=True, slice_index=(0, 11620))>

True