In [None]:
import numpy as np

npz = dict(np.load('./data/HCP/npz/100206/HCP_visual_voxel.npz'))
# print()
v1 = npz['V1']
v2 = npz['V2']
v3 = npz['V3']
v4 = npz['V4']
a = len(v1)
b = np.array_split(v1,3,axis=0)

在单一subject下，四个区域的fMRI数据尺寸为 (1200,1618),(1200,2220),(1200,684),(1200,661)

In [None]:
import os
from config import Config_MBM_fMRI
from dataset import hcp_dataset
from stageA1_mbm_pretrain import fmri_transform
import cv2 as cv
import numpy as np

config = Config_MBM_fMRI()
dataset_pretrain = hcp_dataset(path=os.path.join(config.root_path, 'data/HCP/npz'), roi=config.roi, patch_size=config.patch_size,
                transform=fmri_transform, aug_times=config.aug_times, num_sub_limit=config.num_sub_limit, 
                include_kam=config.include_kam, include_hcp=config.include_hcp)
sample_data = dataset_pretrain[6000]
img = sample_data['image']
image_array = np.array(img)


In [None]:
import torch
torch.cuda.get_device_properties(0)

In [None]:
from config import Config_MBM_fMRI

config = Config_MBM_fMRI()

config.num_voxels = 0
print(config.__dict__)

In [None]:
import datetime
print(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))

In [None]:
from omegaconf import OmegaConf

config = OmegaConf.load('./pretrains/ldm/label2img/config.yaml')
a = config['model'].get("params", dict())


In [None]:
import torch
import numpy as np
import copy
from dc_ldm.ldm_for_fmri import fLDM
from config import Config_Generative_Model
from dataset import create_Kamitani_dataset, create_BOLD5000_dataset
from einops import rearrange
import torchvision.transforms as transforms


class random_crop:
    def __init__(self, size, p):
        self.size = size
        self.p = p
    def __call__(self, img):
        if torch.rand(1) < self.p:
            return transforms.RandomCrop(size=(self.size, self.size))(img)
        return img


def normalize(img):
    if img.shape[-1] == 3:
        img = rearrange(img, 'h w c -> c h w')
    img = torch.tensor(img)
    img = img * 2.0 - 1.0 # to -1 ~ 1
    return img

def channel_last(img):
        if img.shape[-1] == 3:
            return img
        return rearrange(img, 'c h w -> h w c')

def fmri_transform(x, sparse_rate=0.2):
    # x: 1, num_voxels
    x_aug = copy.deepcopy(x)
    idx = np.random.choice(x.shape[0], int(x.shape[0]*sparse_rate), replace=False)
    x_aug[idx] = 0
    return torch.FloatTensor(x_aug)



device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

config = Config_Generative_Model()
config.checkpoint_path = './results/generation/07-09-2023-01-50-29/checkpoint_best.pth'
if config.checkpoint_path is not None:
    model_meta = torch.load(config.checkpoint_path, map_location='cpu')
    ckp = config.checkpoint_path
    config = model_meta['config']
    config.checkpoint_path = ckp  # 更新config中ckp为读取的ckp
    print('Resuming from checkpoint: {}'.format(config.checkpoint_path))

config.logger = None # FIXME:???
pretrain_mbm_metafile = torch.load(config.pretrain_mbm_path, map_location='cpu')
crop_pix = int(config.crop_ratio*config.img_size)

img_transform_train = transforms.Compose([
        normalize,
        random_crop(config.img_size-crop_pix, p=0.5),
        transforms.Resize((256, 256)), 
        channel_last
    ])
img_transform_test = transforms.Compose([
        normalize, transforms.Resize((256, 256)), 
        channel_last
    ])

if config.dataset == 'GOD':
    fmri_latents_dataset_train, fmri_latents_dataset_test = create_Kamitani_dataset(config.kam_path, config.roi, config.patch_size, 
            fmri_transform=fmri_transform, image_transform=[img_transform_train, img_transform_test], 
            subjects=config.kam_subs)
    num_voxels = fmri_latents_dataset_train.num_voxels
elif config.dataset == 'BOLD5000':
    fmri_latents_dataset_train, fmri_latents_dataset_test = create_BOLD5000_dataset(config.bold5000_path, config.patch_size, 
            fmri_transform=fmri_transform, image_transform=[img_transform_train, img_transform_test], 
            subjects=config.bold5000_subs)
    num_voxels = fmri_latents_dataset_train.num_voxels
else:
    raise NotImplementedError


def print_trainable_params(model):
    for name, param in model.named_parameters():
        if 'norm' in name:
            print(name, param.requires_grad)


generative_model = fLDM(pretrain_mbm_metafile, num_voxels,
                device=device, pretrain_root=config.pretrain_gm_path, logger=config.logger, 
                ddim_steps=config.ddim_steps, global_pool=config.global_pool, use_time_cond=config.use_time_cond)

if config.checkpoint_path is not None:
        model_meta = torch.load(config.checkpoint_path, map_location='cpu')
        generative_model.model.load_state_dict(model_meta['model_state_dict'])
        print('model resumed')


# print_trainable_params(generative_model.model)

In [None]:
print(generative_model.model)

In [None]:
x = torch.randn(5, 3, 256, 256)
c = torch.randn(5,1,16)
y = generative_model.model(x, c)
print(y)

1. 打印命名中含有‘norm’字段的实例变量
2. 遍历所有子模块后搜索“Norm”字段，确保模型中只有（nn.GroupNorm, nn.LayerNorm）
3. 打印模型中所有Norm对象的变量名，发现变量名中的确有不含‘norm’的

In [None]:
import torch.nn as nn
import copy

stageC_model = copy.deepcopy(generative_model.model)

for name, param in stageC_model.named_parameters():
    if 'norm' in name: 
        print(f"Parameter name: {name}")
        print(f"Parameter shape: {param.shape}")
        print()

In [None]:
for name, module in stageC_model.named_modules():
    if isinstance(module,(nn.GroupNorm, nn.LayerNorm)):
        print(name)
        print(module)
        for para_name, param in module.named_parameters():
            # print(para_name,'with shape:',param.shape)
            print('++++++++++++++++++++++++++++',name) if not (para_name=='weight' or para_name=='bias') else None
        print()

In [None]:
for name, module in stageC_model.named_modules():
    if isinstance(module,(nn.GroupNorm, nn.LayerNorm)):
        print('module name++++++',name)
        # print(module.weight,module.bias)
        stageC_model.add_module()
        weight = module.weight
        bias = module.bias
        # for para_name, param in module.named_parameters():
        #     print(para_name,'with shape:',param.shape)
        print()

In [None]:
import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(MLP, self).__init__()
        self.ln = nn.Linear(in_dim, out_dim)
        # self.ln1 = nn.Linear(in_dim, hidden_dim)
        # self.relu1 = nn.ReLU()
        # self.ln2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = torch.relu(self.ln(x))
        return x

def replace_norm_layers_with_prompts(model):
    for name, module in model.named_modules():
        if isinstance(module, nn.ModuleNorm):
            num_features = module.num_features

            # 提取 Norm 层的参数
            weight = module.weight
            bias = module.bias

            # 创建 MLP 模型
            mlp = MLP(num_features, num_features)

            # 将 MLP 模型的参数设置为外部训练得到的参数
            mlp.linear.weight.data.copy_(weight.view(num_features, 1))
            mlp.linear.bias.data.copy_(bias.view(num_features, 1))

            # 使用 MLP 输出的参数替换原始的 Norm 层参数
            module.weight = nn.Parameter(mlp.linear.weight)
            module.bias = nn.Parameter(mlp.linear.bias)

# 替换
stageC_model = generative_model.model

# 替换 Norm 层的可训练参数为 MLP 训练得到的参数
replace_norm_layers_with_prompts(stageC_model)

# 在外部训练 MLP 模型...
# 将 MLP 输出的参数赋值给替换后的 Norm 层参数...

# 模型训练过程...

In [None]:
def count_parameters(model):
    total_params = 0
    frozen_params = 0

    for name, param in model.named_parameters():
        total_params += param.numel()
        if not param.requires_grad:
            print(name)
            frozen_params += param.numel()

    percentage_frozen = (frozen_params / total_params) * 100
    return total_params, percentage_frozen

total_params, frozen_percentage = count_parameters(generative_model.model)
print("Total parameters: {:.2,}M".format(total_params/(1024*1024)))
print("Frozen parameters percentage: {:.2f}%".format(frozen_percentage))

在所有的可训练参数中，将norm参数