In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import sys
import os
from datasets import load_dataset  # ds = load_dataset("UCSC-VLAA/MedTrinity-25M", "25M_demo")

In [2]:
pq = pd.read_parquet('./train-00000-of-00010.parquet')
pq.head()

Unnamed: 0,image,id,caption
0,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,8031efe0-1b5c-11ef-8929-000066532cad,The image is a non-contrasted computed tomogra...
1,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,8031fb83-1b5c-11ef-a2c7-000066532cad,The image is a non-contrast computed tomograph...
2,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,8032083e-1b5c-11ef-bcf7-000066532cad,"The image is a CT scan of the brain, showing t..."
3,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,8031ea0d-1b5c-11ef-b7fd-000066532cad,The image is a non-contrasted computed tomogra...
4,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,8031f5b4-1b5c-11ef-9ae7-000066532cad,The image is a non-contrasted computed tomogra...


In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import io

class ImageCaptionDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_data = self.dataframe.iloc[idx]['image']['bytes']
        caption = self.dataframe.iloc[idx]['caption']
        image = Image.open(io.BytesIO(img_data))

        # Check if image_array contains valid data
        if np.array(image).sum() == 0:  # 检查像素总和是否为 0
            raise ValueError(f"Image array at index {idx} is all zeros.")
        
        image = self.transform(image)
        return image, caption

# Create the dataset
dataset = ImageCaptionDataset(pq)

# Create the dataloader
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Example of iterating through the dataloader
for images, captions in dataloader:
    print(images[0].shape, captions[0])
    break

torch.Size([3, 224, 224]) The image is a non-contrasted computed tomography (CT) scan of the brain, showing the cranial cavity with brain structures. The region of interest, located centrally at the top of the image, occupies approximately 0.6% of the area and appears to have a different density compared to the surrounding brain tissue, which may indicate an abnormality such as a hemorrhage or a mass effect. This region's relative position to other brain structures suggests it could be affecting or be affected by adjacent tissues, potentially indicating a pathological process that may have implications for the patient's neurological function.


In [4]:
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn
DEVICE = 'mps' if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'


class ImageFeatureExtract(nn.Module):
    def __init__(self):
        super().__init__()
        # resnet = models.resnet50(pretrained=True)  # load pretrained ResNet50 from torchvision
        resnet = resnet50(weights=ResNet50_Weights.DEFAULT)  # load pretrained ResNet50 from torchvision
        self.feature_extract = nn.Sequential(
            *list(
                resnet.children()  # iterator over the layers (modules) of the ResNet-50 architecture
                )[:-1]  # removes the last layer, the MLP 
            )  # chain the remaining ResNet-50 layers together

    def forward(self, x):
        # x: (batch_size, 2048, 1, 1) 
        # ResNet-50's final convolutional output is globally averaged to a single spatial position per channel
        x = self.feature_extract(x)
        return x.view(x.size(0), -1).to(DEVICE)  # flatten the tensor to (batch_size, 2048)


In [5]:
imgs, captions = dataloader.__iter__().__next__()
print(f'imgs.shape: {imgs.shape}')

img_ext = ImageFeatureExtract().to(DEVICE)
imgs = imgs.to(DEVICE)
img_features = img_ext(imgs)  # Extract features
print(f'img_features.shape: {img_features.shape}')

imgs.shape: torch.Size([16, 3, 224, 224])
img_features.shape: torch.Size([16, 2048])


In [6]:
# Load model directly
from transformers import AutoTokenizer, AutoModel

# Extract text features
# use BERT-base
class TextFeatureExtract(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.bert = AutoModel.from_pretrained("bert-base-uncased").to(DEVICE)
    
    def forward(self, text):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
        outputs = self.bert(**inputs)
        last_hidden_state = outputs.last_hidden_state
        return last_hidden_state

print(type(captions), len(captions))  # 确保是 List[str]

txt_ext = TextFeatureExtract().to(DEVICE)
txt_features = txt_ext(captions)
print(f'txt_features.shape: {txt_features.shape}')

<class 'tuple'> 16
txt_features.shape: torch.Size([16, 176, 768])


In [7]:
import torch.nn.functional as F

# Contrastive Learning for mapping text with image
class MappingNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.map = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        x = self.map(x)
        return F.normalize(x, p=2, dim=1)  # L2 normalization

    
# img_features.shape: torch.Size([16, 2048])
img_mapping = MappingNetwork(2048, 1024).to(DEVICE)
# txt_features.shape: torch.Size([16, 180, 768])
txt_mapping = MappingNetwork(768, 1024).to(DEVICE)

img_map = img_mapping(img_features)  # img_map.shape: torch.Size([16, 1024])
txt_map = txt_mapping(txt_features)  # txt_map.shape: torch.Size([16, 180, 1024])
fuse_embed = img_map + txt_map.mean(dim=1)  # fuse_embed.shape: torch.Size([16, 1024])
print(f'img_map.shape: {img_map.shape}, txt_map.shape: {txt_map.shape}, fuse_embed.shape: {fuse_embed.shape}')

img_map.shape: torch.Size([16, 1024]), txt_map.shape: torch.Size([16, 176, 1024]), fuse_embed.shape: torch.Size([16, 1024])


In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
import io
import numpy as np
from transformers import AutoTokenizer, AutoModel

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ImageCaptionDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_data = self.dataframe.iloc[idx]['image']['bytes']
        caption = self.dataframe.iloc[idx]['caption']
        image = Image.open(io.BytesIO(img_data))

        # 检查图像数据是否有效
        if np.array(image).sum() == 0:
            raise ValueError(f"Image array at index {idx} is all zeros.")
        
        image = self.transform(image)
        return image, caption

class ImageFeatureExtract(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.feature_extract = nn.Sequential(
            *list(resnet.children())[:-1]
        )

    def forward(self, x):
        x = self.feature_extract(x)
        return x.view(x.size(0), -1).to(DEVICE)

class TextFeatureExtract(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.bert = AutoModel.from_pretrained("bert-base-uncased").to(DEVICE)
    
    def forward(self, text):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
        outputs = self.bert(**inputs)
        last_hidden_state = outputs.last_hidden_state
        return last_hidden_state

class MappingNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.map = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        x = self.map(x)
        return F.normalize(x, p=2, dim=1)  # L2 归一化

class TwoPathEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # 定义卷积层
        self.cnn_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
            ) for in_channels, out_channels in zip([1024, 32, 64, 128, 256], [32, 64, 128, 256, 512])
        ])
        
        # 定义用于调整 pooled 通道数的 1x1 卷积层
        self.pool_conv_layers = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, kernel_size=1)
            for in_channels, out_channels in zip([1024, 32, 64, 128, 256], [32, 64, 128, 256, 512])
        ])

    def forward(self, x):
        pooled = x
        for layer, pool_conv in zip(self.cnn_layers, self.pool_conv_layers):
            cnn_out = layer(x)  # 输出形状 [batch, out_channels, h, w]
            # 使用自适应池化，使 pooled 的空间尺寸与 cnn_out 相同
            pooled = F.adaptive_max_pool2d(pooled, output_size=cnn_out.shape[2:])
            pooled = pool_conv(pooled)  # 调整通道数，输出形状 [batch, out_channels, h, w]
            x = cnn_out + pooled  # 相加，确保形状匹配
        return x

class Decoder(nn.Module):
    def __init__(self, encoded_dim=512, output_channels=1):
        super().__init__()
        self.fc = nn.Linear(1024, 7 * 7 * encoded_dim)  # 将融合向量映射到空间表示
        self.deconv_layers = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            ) for in_channels, out_channels in zip([512, 256, 128, 64, 32], [256, 128, 64, 32, 16])
        ])
        self.last_conv = nn.Conv2d(16, output_channels, kernel_size=1)

    def forward(self, z, encoded):
        z = self.fc(z).view(-1, 512, 7, 7)
        x = z + encoded
        for layer in self.deconv_layers:
            x = layer(x)
        return self.last_conv(x)

class TextImageFusionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.img_extract = ImageFeatureExtract().to(DEVICE)
        self.txt_extract = TextFeatureExtract().to(DEVICE)
        self.img_map = MappingNetwork(2048, 1024).to(DEVICE)
        self.txt_map = MappingNetwork(768, 1024).to(DEVICE)
        self.fusion_to_spatial = nn.Linear(1024, 1024 * 7 * 7).to(DEVICE)  # 新增线性层
        self.encoder = TwoPathEncoder().to(DEVICE)
        self.decoder = Decoder().to(DEVICE)

    def forward(self, image, text_ids):
        # Step 1: 提取特征
        img_fea = self.img_extract(image)  # [batch_size, 2048]
        txt_fea = self.txt_extract(text_ids)  # [batch_size, seq_length, 768]
        txt_fea = txt_fea.mean(dim=1)  # 池化: [batch_size, 768]

        # Step 2: 映射特征
        img_mapped = self.img_map(img_fea)  # [batch_size, 1024]
        txt_mapped = self.txt_map(txt_fea)  # [batch_size, 1024]

        # Step 3: 融合特征
        fused = img_mapped + txt_mapped  # [batch_size, 1024]

        # Step 4: 映射到空间维度
        fused_spatial = self.fusion_to_spatial(fused)  # [batch_size, 1024*7*7]
        fused_spatial = fused_spatial.view(fused.size(0), 1024, 7, 7)  # [batch_size, 1024, 7, 7]

        # Step 5: 编码和解码
        encoded = self.encoder(fused_spatial)  # [batch_size, 512, h, w]
        output = self.decoder(fused, encoded)  # [batch_size, output_channels, h, w]

        return output

# 创建数据集和数据加载器
# 假设 pq 是已经定义的 dataframe
dataset = ImageCaptionDataset(pq)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# 创建模型并移动到设备
model = TextImageFusionModel().to(DEVICE)

# 获取一个批次的数据
imgs, captions = next(iter(dataloader))
print(f'imgs.shape: {imgs.shape}')  # 应为 [16, 3, 224, 224]

# 传递给模型
output = model(imgs, captions)
print(f"Output shape: {output.shape}")  # 预期形状，例如 [16, 1, 224, 224]

imgs.shape: torch.Size([16, 3, 224, 224])
Output shape: torch.Size([16, 1, 224, 224])
