本module包含所有fLDM的子类module  
每个nn.Module都应该为原来nn.Module的子类,  
Norm(本例中只含nn.LayerNorm和nn.GroupNorm,以及一个自定的GroupNorm32)的父亲(即直接或间接调用了Norm模块的)节点必须重写  
- 直接调用的需要在init里初始化PDNorm类(替换Norm),并在forward里增加prompt到输入输出,training_step也需要修改,包括以上方法(__init__,forward,training_step)调用的方法。  
- 间接调用的除了不需要初始化, 其余一样  
- 不调用的就不用改了,可以用脚本(Norm not in module.named_children)快速找出  

最外层两种思路:  
- 重写fLDM  
- 重写LatentDiffuion,将generative_model.model 载入为 new_model,而generative_model的实例变量似乎都是为了初始化LatentDiffusion载入的一些config,因此直接load_state_dict没问题  
    
重写后,外部先载入老模型?再初始化一个新的LatentDiffusion对象直接load_state_dict  

---
init并载入一个原始模型

In [None]:
import copy
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import torchvision.transforms as transforms


from dc_ldm.ldm_for_fmri import fLDM
from config import Config_Generative_Model
from dataset import create_Kamitani_dataset, create_BOLD5000_dataset



class random_crop:
    def __init__(self, size, p):
        self.size = size
        self.p = p
    def __call__(self, img):
        if torch.rand(1) < self.p:
            return transforms.RandomCrop(size=(self.size, self.size))(img)
        return img


def normalize(img):
    if img.shape[-1] == 3:
        img = rearrange(img, 'h w c -> c h w')
    img = torch.tensor(img)
    img = img * 2.0 - 1.0 # to -1 ~ 1
    return img

def channel_last(img):
        if img.shape[-1] == 3:
            return img
        return rearrange(img, 'c h w -> h w c')

def fmri_transform(x, sparse_rate=0.2):
    # x: 1, num_voxels
    x_aug = copy.deepcopy(x)
    idx = np.random.choice(x.shape[0], int(x.shape[0]*sparse_rate), replace=False)
    x_aug[idx] = 0
    return torch.FloatTensor(x_aug)



device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

config = Config_Generative_Model()
config.checkpoint_path = './results/generation/07-09-2023-01-50-29/checkpoint_best.pth'
if config.checkpoint_path is not None:
    model_meta = torch.load(config.checkpoint_path, map_location='cpu')
    ckp = config.checkpoint_path
    config = model_meta['config']
    config.checkpoint_path = ckp  # 更新config中ckp为读取的ckp
    print('Resuming from checkpoint: {}'.format(config.checkpoint_path))

config.logger = None # FIXME:???
pretrain_mbm_metafile = torch.load(config.pretrain_mbm_path, map_location='cpu')
crop_pix = int(config.crop_ratio*config.img_size)

img_transform_train = transforms.Compose([
        normalize,
        random_crop(config.img_size-crop_pix, p=0.5),
        transforms.Resize((256, 256)), 
        channel_last
    ])
img_transform_test = transforms.Compose([
        normalize, transforms.Resize((256, 256)), 
        channel_last
    ])

if config.dataset == 'GOD':
    fmri_latents_dataset_train, fmri_latents_dataset_test = create_Kamitani_dataset(config.kam_path, config.roi, config.patch_size, 
            fmri_transform=fmri_transform, image_transform=[img_transform_train, img_transform_test], 
            subjects=config.kam_subs)
    num_voxels = fmri_latents_dataset_train.num_voxels
elif config.dataset == 'BOLD5000':
    fmri_latents_dataset_train, fmri_latents_dataset_test = create_BOLD5000_dataset(config.bold5000_path, config.patch_size, 
            fmri_transform=fmri_transform, image_transform=[img_transform_train, img_transform_test], 
            subjects=config.bold5000_subs)
    num_voxels = fmri_latents_dataset_train.num_voxels
else:
    raise NotImplementedError


generative_model = fLDM(pretrain_mbm_metafile, num_voxels,
                device=device, pretrain_root=config.pretrain_gm_path, logger=config.logger, 
                ddim_steps=config.ddim_steps, global_pool=config.global_pool, use_time_cond=config.use_time_cond)


def count_trainable_parameters(model):
    for name, module in model.named_children():
        trainable_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
        non_trainable_params = sum(p.numel() for p in module.parameters() if not p.requires_grad)
        print(f"Module: {name}")
        print(f"  Trainable parameters: {trainable_params/(1024*1024):.1f}M")
        print(f"  Non-trainable parameters: {non_trainable_params/(1024*1024):.1f}M")
        print('-' * 40)

# count_trainable_parameters(generative_model.model)
stageC_model = copy.deepcopy(generative_model.model)
if config.checkpoint_path is not None:  # 从自己fintune的checkpoint载入state
        model_meta = torch.load(config.checkpoint_path, map_location='cpu')
        stageC_model.load_state_dict(model_meta['model_state_dict'])
        print('model resumed')
count_trainable_parameters(stageC_model)


冻结first_stage_model后，查看参数

In [None]:
stageC_model.freeze_first_stage()
count_trainable_parameters(stageC_model)

新建一个自定义的LDM

统计不同Norm类的数量

In [None]:
def count_norm(model):
    norm_counts = {}
    for name, module in model.named_modules():
        class_name = module.__class__.__name__
        # print(class_name)
        if 'Norm' in class_name:
            norm_counts[class_name] = norm_counts.get(class_name, 0) + 1
            # print(module)   
    print(norm_counts)

def count_norm_i(model):
    norm_counts = {}
    for name, module in model.named_modules():
        class_name = module.__class__.__name__
        # print(class_name)
        if 'Norm' in class_name:
            norm_counts[class_name] = norm_counts.get(class_name, 0) + 1
            # print(module)   
    print(norm_counts)
count_norm(stageC_model)

按照树的结构打印各个module中Norm类的数量，作为重写各module的对照

In [None]:
from collections import defaultdict

class Node:
    def __init__(self, class_name, instance_name):
        self.class_name = class_name
        self.instance_name = instance_name
        self.children = defaultdict(Node)
        self.norm_count = 0
        
    def __getitem__(self, key):
        return self.children[key]

    def display(self, indent=0):
        prefix = '  ' * indent
        if self.norm_count > 0:
            print(f"{prefix}({self.instance_name}){self.class_name}: {self.norm_count}")
        for child in self.children.values():
            child.display(indent + 1)

def build_tree(module, path=[]):
    abs_instance_name = '.'.join(path)
    instance_name = path[-1]
    node = Node(module.__class__.__name__, instance_name)  # 将"类名"传入初始化1个Node
    # children为空，norm_count = 0
    for child_instance_name, submodule in module.named_children():  # 遍历子模块进行递归调用
        # abs_instance_name = '.'.join(path)
        child_node = build_tree(submodule, path + [child_instance_name])  # 将子模块的name (instanc) 存到path中，按照类名建立子节点
        # path 记录了子模块的变量名
        # node.children[child_node.class_name] = child_node  
        # # 以子节点类名为key(会导致同名类后来替代的问题)，子节点Node实例为value存入实例变量children字典中
        node.children[child_node.instance_name] = child_node 
        # 按照instance建立树则不会重复，display时在括号中显示class_name即可

        if 'Norm' in submodule.__class__.__name__:
            child_node.norm_count += 1

        node.norm_count += child_node.norm_count

    return node

root = build_tree(generative_model.model,['ldm_model'])
root.display()


---
修改LatentDiffusion，包括：  
**属性**：  
- self.model
- self.first_stage_model
- self.cond_stage_model


**方法**： 
- training_step (DDPM,直接在LatentDiffusion中重写即可)
    - shared_step (LatentDiffusion)
        - get_input (LatentDiffusion)
        - forward (LatentDiffusion)

In [None]:
from pdnorm_model import PDfLDM
from omegaconf import OmegaConf

config = OmegaConf.load('stageC_config.yaml')
pd_model = PDfLDM(**config.model.params)

修改DiffusionWrapper及其子模块，一个instance共包含了109个Norm (instance)  
自下而上重写  

In [None]:
count_trainable_parameters(pd_model)

In [None]:
count_norm(pd_model)

In [None]:
print(pd_model.cond_stage_model.mae.pos_embed)

In [None]:
m,u = pd_model.load_state_dict(stageC_model.state_dict(),strict=False)
print('missing keys:',m)
print('-'*40)
print('unexpected keys:',u)

In [None]:
for missing_key in m:
    if 'mlp_' not in missing_key:
        print(missing_key)

In [None]:
pd_model2 = PDfLDM(**config.model.params)
mm,uu = pd_model2.load_state_dict(model_meta['model_state_dict'],strict=False)
print('missing keys:',mm)
print('-'*40)
print('unexpected keys:',uu)

In [None]:
count_trainable_parameters(pd_model2)

In [None]:
for missing_key in mm:
    if 'mlp_' not in missing_key:
        print(missing_key)

In [None]:
from pdnorm_model import PDNorm
def findout_norm(model):
    for name, module in model.named_children():
        if isinstance(module, PDNorm):
            print(name, module)
        else:
            findout_norm(module)
findout_norm(pd_model)

In [None]:
list(pd_model.cond_stage_model.mae.blocks[22].norm2.named_parameters())

In [134]:
a = ['a','b','c']
b = [1,2,3]
c = {k:v for k,v in zip(a,b)}
print(c)

{'a': 1, 'b': 2, 'c': 3}


In [137]:
d = {f'val/{k}':v for k,v in c.items()}
print(d)

{'val/a': 1, 'val/b': 2, 'val/c': 3}


In [2]:
import torch
checkpoint_path = '/data/xiaozhaoliu/stageC1/mind-vis/bjcgbjql/checkpoints/epoch=294-step=23600.ckpt'
model_meta = torch.load(checkpoint_path, map_location='cpu')
print('👀')

👀


In [2]:
import time
start_time_generating = time.time()
time.sleep(61)
end_time_generating = time.time()
print(f"Execution Time: {(end_time_generating - start_time_generating)/60:.2f} mins")

Execution Time: 1.02 mins


In [8]:
grouped_fmri = [i for i in range(0, 50, 3)]
print(grouped_fmri)

[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48]


In [4]:
import torch
from torchmetrics.image import StructuralSimilarityIndexMeasure
preds = torch.rand([3, 3, 256, 256])
target = preds * 0.75
ssim = StructuralSimilarityIndexMeasure(data_range=1.0)
ssim(preds, target)

tensor(0.9219)

In [3]:
a = preds[1:,:,:,:]
type(a)

torch.Tensor

In [7]:
import torch
x = torch.rand([3, 3])
print(x)
y = (x*255).to(torch.uint8)
print(y)
z = y/255.
print(z.dtype)

tensor([[0.7572, 0.0961, 0.7131],
        [0.4517, 0.5330, 0.5896],
        [0.5377, 0.6059, 0.1126]])
tensor([[193,  24, 181],
        [115, 135, 150],
        [137, 154,  28]], dtype=torch.uint8)
torch.float32


In [9]:
print('\U0001F4C8')

📈


In [1]:
import torch
from einops import rearrange, repeat
x = torch.rand([2, 3]).to(torch.device('cuda:0')).detach()
print(x.device)
y = rearrange(x, 'h w -> w h').cpu().numpy()
print(type(y))
print(x.device)

cuda:0
<class 'numpy.ndarray'>
cuda:0


In [11]:
grouped_images = torch.randint(0,255,(5,6,3,64,64)).to(torch.uint8).to(torch.device('cuda:0'))
print(grouped_images.dtype)
print(grouped_images.device)

all_imgs = grouped_images/255.
print(all_imgs.dtype)
print(grouped_images.dtype)


torch.uint8
cuda:0
torch.float32
torch.uint8


In [6]:
from torchvision.utils import make_grid
from PIL import Image
images = rearrange(grouped_images, 'b n c h w -> (b n) c h w')
grid = make_grid(images, nrow=grouped_images.shape[2])
grid_rgb = rearrange(grid, 'c h w -> h w c')
grid_PIL= Image.fromarray(grid_rgb.cpu().numpy())

In [11]:
ep=1
print(f'recon_epoch_{ep:04}')

recon_epoch_0001


In [8]:
from pytz import timezone
from datetime import datetime
print(datetime.utcnow().astimezone(timezone('Asia/Shanghai')).strftime("%y-%m-%d-%H-%M-%S"))
print(datetime.utcnow().astimezone(timezone('Asia/Shanghai')).strftime("%y%m%d-%H%M%S"))

23-11-13-11-36-26
231113-113626


In [13]:
run_dir = datetime.utcnow().astimezone(timezone('Asia/Shanghai')).strftime("%y-%m-%d-%H-%M-%S")
print(type(run_dir))
print(run_dir)
ckpt_save_dir = run_dir + '/checkpoint'
print(ckpt_save_dir)

<class 'str'>
23-11-13-11-37-24
23-11-13-11-37-24/checkpoint


In [2]:
import time
from rich.progress import track

for i in track(range(20), description="Processing..."):
    time.sleep(0.1)  # Simulate work being done

Output()

In [4]:
print(f'[bold][red]Done!')

[bold][red]Done!


In [5]:
from rich.console import Console
from rich.table import Table

table = Table(title="Todo List")

table.add_column("S. No.", style="cyan", no_wrap=True)
table.add_column("Task", style="magenta")
table.add_column("Status", justify="right", style="green")

table.add_row("1", "Buy Milk", "✅")
table.add_row("2", "Buy Bread", "✅")
table.add_row("3", "Buy Jam", "❌")

console = Console()
console.print(table)

In [2]:
# initializing lists 
name = [ "Manjeet", "Nikhil", "Shambhavi", "Astha" ] 
roll_no = [ 4, 1, 3, 2 ] 
marks = [ 40, 50, 60, 70 ] 
  
# using zip() to map values 
mapped = zip(name, roll_no, marks) 
  
# converting values to print as set 
mapped = set(mapped) 
  
# printing resultant values  
print ("The zipped result is : ",end="") 
print (mapped) 

The zipped result is : {('Nikhil', 1, 50), ('Manjeet', 4, 40), ('Shambhavi', 3, 60), ('Astha', 2, 70)}


In [16]:
from einops import rearrange,repeat
import torch

a = torch.arange(0,5)
print('a:\n',a)
aa = repeat(a, 'b -> n b', n=3)
print('aa\n:',aa)
b = repeat(a, 'b -> (n b)', n=3)
print('b:\n',b)

c = b+10
print('c:\n',c)
d = rearrange(c,'(n b) -> n b',n=3,b=5)
print('d:\n',d)

a:
 tensor([0, 1, 2, 3, 4])
aa
: tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]])
b:
 tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
c:
 tensor([10, 11, 12, 13, 14, 10, 11, 12, 13, 14, 10, 11, 12, 13, 14])
d:
 tensor([[10, 11, 12, 13, 14],
        [10, 11, 12, 13, 14],
        [10, 11, 12, 13, 14]])
