In [1]:
import os
import gc
import pandas as pd

import torch
from SwiFT.project.module.models.swin4d_transformer_ver7 import SwinTransformer4D
from nilearn.image import load_img, resample_img
import numpy as np
from tqdm import tqdm
import nibabel as nib
# 사용자 설정
data_root = "./data"
save_root = "./fmri_pooled_features"
metadata_csv = "fmri_volumes_with_events.csv"
TR = 1.5
WINDOW_SIZE = 20
target_shape = (96, 96, 96)

In [2]:
# 모델 (사용자 정의 encoder 준비)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt_path = "SwiFT/pretrained_models/contrastive_pretrained.ckpt"
#ckpt_path = "SwiFT/pretrained_models/hcp_sex_classification.ckpt"
ckpt = torch.load(ckpt_path, map_location=device)

# state_dict만 추출 (PyTorch Lightning 형식일 경우 'state_dict' 키 있음)
if "state_dict" in ckpt:
    state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()}
else:
    state_dict = ckpt

model = SwinTransformer4D(
    img_size=(96, 96, 96, 20),
    in_chans=1,
    embed_dim=36,
    window_size=(4, 4, 4, 4),
    first_window_size=(2, 2, 2, 2),
    patch_size=(6, 6, 6, 1),
    depths=(2, 2, 6, 2),
    num_heads=(3, 6, 12, 24),
    downsample="mergingv2",
)

model.eval()
model.to(device)

img_size:  (96, 96, 96, 20)
patch_size:  (6, 6, 6, 1)
patch_dim:  (16, 16, 16, 20)


SwinTransformer4D(
  (patch_embed): PatchEmbed(
    (fc): Linear(in_features=216, out_features=36, bias=True)
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (pos_embeds): ModuleList(
    (0-3): 4 x PositionalEmbedding()
  )
  (layers): ModuleList(
    (0): BasicLayer(
      (blocks): ModuleList(
        (0-1): 2 x SwinTransformerBlock4D(
          (norm1): LayerNorm((36,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention4D(
            (qkv): Linear(in_features=36, out_features=108, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=36, out_features=36, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((36,), eps=1e-05, elementwise_affine=True)
          (mlp): MLPBlock(
            (linear1): Linear(in_features=36, out_features=144, bias=True)
            (l

In [3]:
print(model(torch.randn(1, 1, 96, 96, 96, 20).to(device)).shape)

torch.Size([1, 288])


In [4]:
res = model.load_state_dict(state_dict, strict=False)
print(res)


_IncompatibleKeys(missing_keys=[], unexpected_keys=['head.weight', 'head.bias', 'emb_mlp.fc1.weight', 'emb_mlp.bn1.weight', 'emb_mlp.bn1.bias', 'emb_mlp.bn1.running_mean', 'emb_mlp.bn1.running_var', 'emb_mlp.bn1.num_batches_tracked'])


In [5]:
# 메타데이터 불러오기
df = pd.read_csv(metadata_csv)
df = df.iloc[630:].reset_index(drop=True)

df = df[df['trial_type'].str.contains("story")].copy()

In [6]:
df.head()

Unnamed: 0,subject,file,task,n_volumes,TR_nii,TR_json,duration_sec,onset,duration,trial_type,stim_file
0,sub-264,sub-264_task-21styear_bold.nii.gz,21styear,2249,1.5,1.5,3373.5,"0.0,21.0","18.0,3338.0","music,story","21styear_audio.wav,21styear_audio.wav"
1,sub-265,sub-265_task-piemanpni_bold.nii.gz,piemanpni,294,1.5,1.5,441.0,12.0,400.0,story,piemanpni_audio.wav
2,sub-265,sub-265_task-forgot_bold.nii.gz,forgot,574,1.5,1.5,861.0,12.0,837.0,story,forgot_audio.wav
3,sub-265,sub-265_task-21styear_bold.nii.gz,21styear,2249,1.5,1.5,3373.5,"0.0,21.0","18.0,3338.0","music,story","21styear_audio.wav,21styear_audio.wav"
4,sub-265,sub-265_task-bronx_bold.nii.gz,bronx,390,1.5,1.5,585.0,12.0,536.0,story,bronx_audio.wav


In [7]:
def extract_story_onset_duration(row):
    types = row['trial_type'].split(',')
    onsets = list(map(float, row['onset'].split(',')))
    durations = list(map(float, row['duration'].split(',')))
    story_idx = [i for i, t in enumerate(types) if t.strip() == "story"]
    return [onsets[i] for i in story_idx], [durations[i] for i in story_idx]

def resample_to_target(func_path, anat_path, target_shape=(96, 96, 96)):
    anat_img = load_img(anat_path)
    func_img = load_img(func_path)
    zoom_factors = np.array(anat_img.shape) / np.array(target_shape)
    target_affine = anat_img.affine.copy()
    target_affine[:3, :3] = anat_img.affine[:3, :3] @ np.diag(zoom_factors)
    resampled = resample_img(func_img, target_affine=target_affine,
                             target_shape=target_shape, interpolation="nearest")
    return resampled.get_fdata()

for _, row in tqdm(df.iterrows(), total=len(df)):
    subject = row['subject']
    task = row['task']
    file_name = row['file']
    story_onsets, story_durations = extract_story_onset_duration(row)

    func_path = os.path.join(data_root, subject, "func", file_name)
    scan_tsv = os.path.join(data_root, subject, f"{subject}_scans.tsv")
    
    if not os.path.exists(func_path) or not os.path.exists(scan_tsv):
        print(f"[!] Missing file for {subject}")
        continue

    try:
        scan_df = pd.read_csv(scan_tsv, sep="\t")
        anat_row = scan_df[scan_df["filename"].str.startswith("anat/")].iloc[0]
        anat_path = os.path.join(data_root, subject, anat_row["filename"])
    except Exception as e:
        print(f"[!] Failed to find anat for {subject}: {e}")
        continue

    try:
        resampled_data = resample_to_target(func_path, anat_path, target_shape)
    except Exception as e:
        print(f"[!] Failed to resample {func_path}: {e}")
        continue
    stride = 20

    for onset, duration in zip(story_onsets, story_durations):
        start_vol = int(onset // TR)
        end_vol = int((onset + duration) // TR)

        for i in range(start_vol, end_vol - WINDOW_SIZE + 1, stride):
            vol = resampled_data[..., i:i + WINDOW_SIZE]
            vol_tensor = torch.tensor(vol).unsqueeze(0).unsqueeze(0).float().to(device)

            with torch.no_grad():
                pooled = model(vol_tensor)  # (1, 96)
                pooled = pooled.squeeze(0).cpu()

            run_name = file_name.replace(".nii.gz", "")
            save_path = os.path.join(save_root, subject, task, run_name, f"frame_{i}.pt")
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save(pooled, save_path)

    del resampled_data, pooled, vol, vol_tensor
    resampled_data = None
    pooled = None
    
    gc.collect()



100%|█████████████████████████████████████████| 243/243 [46:26<00:00, 11.47s/it]
