### Text & Image 매핑 -> tokenzier & embedding
- tokenzier : 이미지의 경우 patch로 나눠서 각 숫자 벡터로 바꿈
              자연어의 경우 단어 하나하나를 임베딩해서 문장을 토큰 시퀀스로 만드는 것
              = Input Slice encoding
              + Learned positional encoding(슬라이스 순서)
- embedding : 위 token한 값을 각 이미지나 단어에 매핑됨

- Transformer : token + embedding 후 2단계


In [3]:
import os, glob
import pandas as pd
import numpy as np
import torch
import nibabel as nib
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPModel, CLIPTokenizer, CLIPProcessor
from PIL import Image
import random

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Load CLIP model & processor
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device).eval()
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


### 이미 PD, PDX 분류한 폴더에 접근해서 2D 데이터 + text + label 해서 입력함

In [5]:
# 1. 텍스트 프롬프트 정의
text_prompts = {
    "PD": [
        "Brain MRI with Parkinson's disease",
        "T1-weighted axial brain MRI showing Parkinsonian features",
        "MRI scan of the brain affected by Parkinson's disease",
        "Neuroimaging of a patient diagnosed with Parkinson's",
        "Axial brain MRI with signs of neurodegeneration"
    ],
    "PDX": [
        "Brain MRI of a healthy subject",
        "T1-weighted axial brain MRI without abnormalities",
        "Normal brain scan with no pathological findings",
        "Control subject brain MRI image",
        "Axial brain MRI with typical anatomy and no disease"
    ]
}

label_map = {"PD": 0, "PDX": 1}

# 2. 이미지-텍스트 쌍 수집 함수
def collect_image_text_pairs(root_dir):
    imagetext_dataset = []

    label_dirs = {
        "PD": 0,
        "PDX": 1
    }

    for label_key, label_value in label_dirs.items():
        subject_root = os.path.join(root_dir, label_key)
        subject_folders = glob.glob(os.path.join(subject_root, "RJPD_*"))

        for subject_folder in subject_folders:
            image_dir = os.path.join(subject_folder, f"T1_{label_key}")
            image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png")))

            if len(image_paths) != 70:
                print(f"⚠️ {label_key} - {os.path.basename(subject_folder)}: 슬라이스 수 = {len(image_paths)} (70장 아님)")

            for img_path in image_paths:
                text_prompt = random.choice(text_prompts[label_key])  # 라벨에 맞는 텍스트
                imagetext_dataset.append((img_path, text_prompt))

    print(f"✅ 총 수집된 이미지-텍스트 쌍: {len(imagetext_dataset)}")
    return imagetext_dataset

# 3. CLIPDataset 정의
class CLIPDataset(Dataset):
    def __init__(self, imagetext_dataset):
        self.imagetext_dataset = imagetext_dataset  # 리스트: (img_path, text, label)
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14", use_fast=False)

    def __len__(self):
        return len(self.imagetext_dataset)

    def __getitem__(self, idx):
        img_path, text, label = self.imagetext_dataset[idx]

        image = Image.open(img_path).convert("RGB")
        processed = self.processor(text=text, images=image, return_tensors="pt", padding="max_length")

        # 배치 차원 제거
        processed = {k: v.squeeze(0) for k, v in processed.items()}
        processed["label"] = torch.tensor(label, dtype=torch.long)

        return processed


In [6]:
root_dir = r"C:\visual code\MRI\dataset\MRI_train"
imagetext_dataset = collect_image_text_pairs(root_dir)

dataset = CLIPDataset(imagetext_dataset)
print(dataset[0])  # 하나만 예시로 출력


✅ 총 수집된 이미지-텍스트 쌍: 14000
{'input_ids': tensor([49406,  4812, 24773,   593, 28129,   568,  5336, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
        49407, 49407, 49407, 49407, 49407, 49407, 49407]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0]), 'pixel_values': tensor([[[-1.7923, -1.7923, -1.7923,  ..., -1.7923, -1.79

### AttentionPoolingAdapter
- multi-head attention 기준

In [7]:
import torch
import torch.nn as nn

class AttentionPoolingAdapter(nn.Module):
    def __init__(self, embed_dim=1024, num_heads=8, num_slices=70):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_slices = num_slices

        # Learnable positional encoding: (num_slices, embed_dim)
        self.positional_encoding = nn.Parameter(torch.randn(num_slices, embed_dim))

        # Multi-head self-attention
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, cls_sequence):
        """
        cls_sequence: Tensor of shape (70, 1024) ← 70 slices의 CLS 벡터
        Returns: pooled_output (1, 1024)
        """
        if cls_sequence.dim() == 2:
            cls_sequence = cls_sequence.unsqueeze(0)  # (1, 70, 1024)

        # Positional encoding 추가
        cls_sequence = cls_sequence + self.positional_encoding.unsqueeze(0)  # (1, 70, 1024)

        # Self-attention
        attn_output, _ = self.attn(cls_sequence, cls_sequence, cls_sequence, need_weights=False)  # (1, 70, 1024)

        # Mean pooling across slices
        pooled = attn_output.mean(dim=1)  # (1, 1024)
        pooled = self.norm(pooled)

        return pooled  # shape: (1, 1024)



In [8]:
adapter = AttentionPoolingAdapter(embed_dim=1024, num_heads=8, num_slices=70)

# 예시 입력 (70개의 CLS 임베딩)
cls_sequence = torch.randn(70, 1024)  # (num_slices, embed_dim)

# forward 실행
pooled = adapter(cls_sequence)

# 출력
print("🟢 Final pooled embedding:", pooled)
print("📐 Shape:", pooled.shape)



🟢 Final pooled embedding: tensor([[1.0010, 0.3381, 0.3126,  ..., 1.3538, 0.2206, 1.5222]],
       grad_fn=<NativeLayerNormBackward0>)
📐 Shape: torch.Size([1, 1024])


### Contrastive Loss

In [9]:
import torch
import torch.nn.functional as F

def contrastive_loss(image_embeddings, text_embeddings, temperature=0.07):

    # 1. Normalize embeddings
    image_embeddings = F.normalize(image_embeddings, dim=-1)  # (B, D)
    text_embeddings = F.normalize(text_embeddings, dim=-1)    # (B, D)

    # Cosine similarity as logits
    logits_per_image = torch.matmul(image_embeddings, text_embeddings.T) / temperature

    logits_per_text = logits_per_image.T  # transpose for text-to-image

    # 3. Ground truth labels: diagonal (i == j)
    batch_size = image_embeddings.size(0)
    labels = torch.arange(batch_size, device=image_embeddings.device)

    # 4. Cross-entropy loss (image → text + text → image)
    loss_i2t = F.cross_entropy(logits_per_image, labels)
    loss_t2i = F.cross_entropy(logits_per_text, labels)
    loss = (loss_i2t + loss_t2i) / 2

    return loss, logits_per_image, logits_per_text


### CLIP Model에서 image 부분 수정

In [10]:
from transformers.models.clip.modeling_clip import CLIPModel, CLIPTextModel, CLIPVisionModel, CLIPOutput, _get_vector_norm
from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.utils import can_return_tuple
from transformers import CLIPPreTrainedModel
from typing import Optional

class CLIPModel(CLIPPreTrainedModel):
    config: CLIPConfig
    _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]

    def __init__(self, config: CLIPConfig):
        super().__init__(config)

        if not isinstance(config.text_config, CLIPTextConfig):
            raise TypeError(
                "config.text_config is expected to be of type CLIPTextConfig but is of type"
                f" {type(config.text_config)}."
            )

        if not isinstance(config.vision_config, CLIPVisionConfig):
            raise TypeError(
                "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
                f" {type(config.vision_config)}."
            )

        text_config = config.text_config
        vision_config = config.vision_config

        self.projection_dim = config.projection_dim
        self.text_embed_dim = text_config.hidden_size
        self.vision_embed_dim = vision_config.hidden_size

        text_model = CLIPTextModel._from_config(text_config)
        self.text_model = text_model.text_model

        vision_model = CLIPVisionModel._from_config(vision_config)
        self.vision_model = vision_model.vision_model

        self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
        self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
        self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))

        self.post_init()

    def get_text_features(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
    ) -> torch.FloatTensor:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        text_outputs: BaseModelOutputWithPooling = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        pooled_output = text_outputs.pooler_output
        text_features = self.text_projection(pooled_output)

        return text_features

    def get_image_features(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: bool = False,
    ) -> torch.FloatTensor:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        vision_outputs: BaseModelOutputWithPooling = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )

        pooled_output = vision_outputs.pooler_output
        image_features = self.visual_projection(pooled_output)

        return image_features

    @can_return_tuple
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        return_loss: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: bool = False,
    ) -> CLIPOutput:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        vision_outputs: BaseModelOutputWithPooling = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )

        text_outputs: BaseModelOutputWithPooling = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        image_embeds = vision_outputs.pooler_output

        text_embeds = text_outputs.pooler_output

        return CLIPOutput(
            text_embeds=text_embeds,
            image_embeds=image_embeds,
            text_model_output=text_outputs,
            vision_model_output=vision_outputs,
        )

### 가중치 불러오기(hugging)

In [11]:
from transformers import logging
logging.set_verbosity_error()  # Suppress warnings from transformers

batch_size = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

root_dir = r"C:\visual code\MRI\dataset\MRI_train" # 실제 경로로 교체 필요
imagetext_dataset = collect_image_text_pairs(root_dir)

dataset = CLIPDataset(imagetext_dataset)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
# print(model)
print(model.text_model.encoder.layers[0].mlp)

weights=torch.load('C:/visual code/MRI/RadCLIP.pth', map_location=device)
model.load_state_dict(weights)

✅ 총 수집된 이미지-텍스트 쌍: 14000
CLIPMLP(
  (activation_fn): QuickGELUActivation()
  (fc1): Linear(in_features=768, out_features=3072, bias=True)
  (fc2): Linear(in_features=3072, out_features=768, bias=True)
)


<All keys matched successfully>

### 루프

In [None]:
def evaluate_accuracy(model, dataloader, batch_size, device):
    projection_dim = 768
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14", use_fast=False)
    visual_projection = nn.Linear(1024, projection_dim, bias=False).to(device)

    # PD = 0, PDX = 1
    text_prompts = {
        "PD": [
            "Brain MRI with Parkinson's disease",
            "T1-weighted axial brain MRI showing Parkinsonian features",
            "MRI scan of the brain affected by Parkinson's disease",
            "Neuroimaging of a patient diagnosed with Parkinson's",
            "Axial brain MRI with signs of neurodegeneration"
        ],
        "PDX": [
            "Brain MRI of a healthy subject",
            "T1-weighted axial brain MRI without abnormalities",
            "Normal brain scan with no pathological findings",
            "Control subject brain MRI image",
            "Axial brain MRI with typical anatomy and no disease"
        ]
    }

    # 1) 텍스트 특징 벡터 평균 계산
    all_texts = text_prompts["PDX"] + text_prompts["PD"]
    text_inputs = processor(text=all_texts, padding=True, return_tensors="pt").to(device)

    with torch.no_grad():
        text_embeds = model.get_text_features(**text_inputs)  # (10, 768)
        healthy_text = text_embeds[:5].mean(dim=0, keepdim=True)     # (1, 768)
        parkinson_text = text_embeds[5:].mean(dim=0, keepdim=True)   # (1, 768)
        text_feats = torch.cat([parkinson_text, healthy_text], dim=0)  # (2, 768)

    # 2) 평가 루프
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["label"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
            image_embeds = outputs.image_embeds  # (B, 1024)
            image_proj = visual_projection(image_embeds)  # (B, 768)

            # normalize
            image_proj = image_proj / image_proj.norm(dim=-1, keepdim=True)
            text_feats_norm = text_feats / text_feats.norm(dim=-1, keepdim=True)

            logits = image_proj @ text_feats_norm.T  # (B, 2)
            preds = logits.argmax(dim=1)  # (B,)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

            print(f"Logits: {logits}, Preds: {preds}, Labels: {labels}, Correct: {correct}")

    accuracy = correct / total if total > 0 else 0
    return accuracy

acc = evaluate_accuracy(model, dataloader, batch_size, device)
print(f"Validation Accuracy: {acc * 100:.2f}%")


Logits: tensor([[-0.0345, -0.0010],
        [-0.0039,  0.0290],
        [-0.0234,  0.0102],
        [ 0.0040,  0.0414]]), Preds: tensor([1, 1, 1, 1]), Labels: tensor([0, 0, 0, 0]), Correct: 0
Logits: tensor([[-0.0179,  0.0129],
        [-0.0118,  0.0149],
        [-0.0198,  0.0054],
        [ 0.0036,  0.0320]]), Preds: tensor([1, 1, 1, 1]), Labels: tensor([0, 0, 0, 0]), Correct: 0
Logits: tensor([[0.0252, 0.0518],
        [0.0244, 0.0471],
        [0.0405, 0.0649],
        [0.0375, 0.0540]]), Preds: tensor([1, 1, 1, 1]), Labels: tensor([0, 0, 0, 0]), Correct: 0
Logits: tensor([[0.0291, 0.0507],
        [0.0369, 0.0556],
        [0.0256, 0.0595],
        [0.0365, 0.0655]]), Preds: tensor([1, 1, 1, 1]), Labels: tensor([0, 0, 0, 0]), Correct: 0
Logits: tensor([[0.0440, 0.0461],
        [0.0357, 0.0493],
        [0.0302, 0.0425],
        [0.0289, 0.0467]]), Preds: tensor([1, 1, 1, 1]), Labels: tensor([0, 0, 0, 0]), Correct: 0
Logits: tensor([[0.0413, 0.0462],
        [0.0178, 0.0303],
    

### 수정 필요

In [None]:
from tqdm import tqdm

projection_dim = 768  # Assuming projection dimension is 768
visual_projection = nn.Linear(1024, projection_dim, bias=False).to(device)
text_projection = nn.Linear(768, projection_dim, bias=False).to(device)
adapter = AttentionPoolingAdapter(embed_dim=1024, num_heads=8, num_slices=70).to(device)
optimizer = torch.optim.AdamW(list(adapter.parameters())+list(visual_projection.parameters()), lr=1e-5)
epochs = 10

#freeze the text encoder of the model
for param in model.text_model.parameters():
    param.requires_grad = False

#freeze the vision encoder of the model
for param in model.vision_model.parameters():
    param.requires_grad = False

# freeze text projection
for params in text_projection.parameters():
    params.requires_grad = False

# unfreeze visual projection
for params in visual_projection.parameters():
    params.requires_grad = True

for epoch in range(epochs):
    train_loss=0
    step=0
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()
        text_features = torch.empty(batch_size, 70, 768).to(device)  # Assuming 70 slices and 768 features
        image_features = torch.empty(batch_size, 70, 1024).to(device)  # Assuming 70 slices and 1024 features

        for idx in range(len(batch)):
            input_ids = batch[idx]['input_ids'].to(device)
            attention_mask = batch[idx]['attention_mask'].to(device)
            pixel_values = batch[idx]['pixel_values'].to(device)
            # print(f"[Step {step}] input_ids: {input_ids.shape}, attention_mask: {attention_mask.shape}, pixel_values: {pixel_values.shape}")
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
            text_features[:, idx, :] = outputs.text_embeds
            image_features[:, idx, :] = outputs.image_embeds
            #print(f"[Step {step}] Text Features: {text_features.shape}, Image Features: {image_features.shape}")

        #print(image_features[0, :, 0], "\n", image_features[1, :, 0], "\n", image_features[2, :, 0], "\n", image_features[3, :, 0])
        #break

        pooled_text_features = text_features[:, 0, :].squeeze(1)  # Assuming we want to pool the first slice
        pooled_image_features = adapter(image_features) # Apply the SlicePoolingAdapter to image features
        # print(f"[Step {step}] Pooled Text Features: {pooled_text_features.shape}, Pooled Image Features: {pooled_image_features.shape}")

        # Project features to the same dimension
        projected_texts  = text_projection(pooled_text_features)  # (4, projection_dim)
        projected_images = visual_projection(pooled_image_features)  # (4, projection_dim)

        # normalized features
        projected_texts = projected_texts / _get_vector_norm(projected_texts)
        projected_images = projected_images / _get_vector_norm(projected_images)

        loss = contrastive_loss(projected_texts, projected_images)
        train_loss += loss.item()

        # Backward pass and optimizer step
        loss.backward()
        optimizer.step()

        print(f"Epoch [{epoch+1}/{epochs}], Step [{step+1}/{len(dataloader)}], Loss: {loss.item()}")
        step+=1
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {train_loss/len(dataloader)}")
torch.save(model.state_dict(), "clip_model_epoch_"+str(epoch+1)+".pth")