In [1]:
import os, json, wandb, torch
import numpy as np
import torch.nn.functional as F
from monai import transforms, data
from monai.data import DataLoader, DistributedSampler
from monai.utils import set_determinism
from tqdm import tqdm
import torch.distributed as dist
import torch
from torch import nn
from diffusers import AutoencoderKLWan, WanTransformer3DModel
from peft import LoraConfig, inject_adapter_in_model
from flow_match import FlowMatchScheduler

In [2]:
train_json_path = './json/train.json'
with open(train_json_path) as f:
    train_files = json.load(f)
val_json_path = './json/val.json'
with open(val_json_path) as f:
    val_files = json.load(f)
train_batchsize  = 1

transforms_1mm = transforms.Compose(
    [transforms.Spacingd(keys=["image"], pixdim=(1, 1, 1), mode=("bilinear")),
    transforms.Spacingd(keys=["brainmask"], pixdim=(1, 1, 1), mode=("nearest")),
    transforms.SpatialPadd(keys=["image","brainmask"], spatial_size=(160, 160, 128)),
    transforms.CropForegroundd(keys=["image"], source_key="brainmask",allow_smaller=False),
    transforms.DeleteItemsd(keys=["brainmask"]),
    transforms.RandSpatialCropd(keys=["image"], roi_size=(80, 80, 64),max_roi_size = (100, 100, 80), random_size=True),
    transforms.Resized(keys=["image"], spatial_size=(80, 80, 64), size_mode = 'all', mode='bilinear'),
    ]
)
transforms_2mm = transforms.Compose(
    [transforms.Spacingd(keys=["image"], pixdim=(2, 2, 2), mode=("bilinear")),
    transforms.Spacingd(keys=["brainmask"], pixdim=(2, 2, 2), mode=("nearest")),
    transforms.SpatialPadd(keys=["image","brainmask"], spatial_size=(80, 80, 64)),
    transforms.CropForegroundd(keys=["image"], source_key="brainmask",allow_smaller=False),
    transforms.DeleteItemsd(keys=["brainmask"]),
    transforms.Resized(keys=["image"], spatial_size=80, size_mode = 'longest', mode='bilinear'),
    transforms.CenterSpatialCropd(keys=["image"], roi_size=(80, 80, 64)),
    transforms.SpatialPadd(keys=["image"], spatial_size=(80, 80, 64)),
    ]
)
train_transforms = transforms.Compose(
    [
        transforms.CopyItemsd(keys=["image"], names=["path"]),
        transforms.LoadImaged(keys=["image","brainmask"]),
        transforms.EnsureChannelFirstd(keys=["image","brainmask"]),
        transforms.EnsureTyped(keys=["image","brainmask"]),
        transforms.Orientationd(keys=["image","brainmask"], axcodes="RAS"),
        transforms.RandAffined(
            keys=["image","brainmask"],
            rotate_range=(-np.pi / 36, np.pi / 36),
            translate_range=(-1, 1),
            scale_range=(-0.05, 0.05),
            padding_mode="zeros",
            prob=0.5,
        ),
        transforms.CropForegroundd(
            keys=["image", "brainmask"],
            source_key="brainmask",
            allow_smaller=False,
        ),
        transforms.ResizeWithPadOrCropd(keys=["image","brainmask"], spatial_size=(192, 192, 141)),
        transforms.ScaleIntensityRangePercentilesd(keys="image", lower=0.5, upper=99.5, b_min=0, b_max=1,clip=True ),
    ]
)
train_ds = data.Dataset(data=train_files, transform=train_transforms)
# sampler_train = DistributedSampler(train_ds, num_replicas=4, rank=rank)
train_loader = DataLoader(train_ds, batch_size=train_batchsize, shuffle=False, num_workers=8, persistent_workers=True, drop_last=True, sampler=None)

val_transforms = transforms.Compose(
    [
        transforms.CopyItemsd(keys=["image"], names=["path"]),
        transforms.LoadImaged(keys=["image","brainmask"]),
        transforms.EnsureChannelFirstd(keys=["image","brainmask"]),
        transforms.EnsureTyped(keys=["image","brainmask"]),
        transforms.Orientationd(keys=["image","brainmask"], axcodes="RAS"),
        transforms.CropForegroundd(
            keys=["image", "brainmask"],
            source_key="brainmask",
            allow_smaller=False,
        ),
        transforms.ResizeWithPadOrCropd(keys=["image","brainmask"], spatial_size=(192, 192, 141)),
        transforms.ScaleIntensityRangePercentilesd(keys="image", lower=0.5, upper=99.5, b_min=0, b_max=1),
    ]
)
val_ds = data.Dataset(data=val_files, transform=train_transforms)
# sampler_val = DistributedSampler(val_ds, num_replicas=4, rank=rank)
val_loader = DataLoader(val_ds, batch_size=train_batchsize, shuffle=False, num_workers=8, persistent_workers=True, drop_last=True, sampler=None)

In [3]:

model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir="/working/cache/huggingface/hub")
vae = vae.to("cuda")
vae.eval()

AutoencoderKLWan(
  (encoder): WanEncoder3d(
    (nonlinearity): SiLU()
    (conv_in): WanCausalConv3d(3, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1))
    (down_blocks): ModuleList(
      (0-1): 2 x WanResidualBlock(
        (nonlinearity): SiLU()
        (norm1): WanRMS_norm()
        (conv1): WanCausalConv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1))
        (norm2): WanRMS_norm()
        (dropout): Dropout(p=0.0, inplace=False)
        (conv2): WanCausalConv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1))
        (conv_shortcut): Identity()
      )
      (2): WanResample(
        (resample): Sequential(
          (0): ZeroPad2d((0, 1, 0, 1))
          (1): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2))
        )
      )
      (3): WanResidualBlock(
        (nonlinearity): SiLU()
        (norm1): WanRMS_norm()
        (conv1): WanCausalConv3d(96, 192, kernel_size=(3, 3, 3), stride=(1, 1, 1))
        (norm2): WanRMS_norm()
        (dropout): Dropout(p=0.0, inplace=False)
 

In [12]:
class UnifiedFlowNet(nn.Module):
    def __init__(self, latent_dim, input_channels = 4,  drop_prob=0.1, lora = True):
        super().__init__()
        self.latent_dim = latent_dim
        self.input_channels = input_channels
        self.model = self.init_model(lora)
        self.drop_prob = drop_prob

    def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
        # Add LoRA to UNet
        
        lora_alpha = lora_alpha
        if init_lora_weights == "kaiming":
            init_lora_weights = True
            
        lora_config = LoraConfig(
            r=lora_rank,
            lora_alpha=lora_alpha,
            init_lora_weights=init_lora_weights,
            target_modules=lora_target_modules.split(","),
        )
        model = inject_adapter_in_model(lora_config, model)
        for param in model.parameters():
            # Upcast LoRA parameters into fp32
            if param.requires_grad:
                param.data = param.to(torch.float32)
        return model
        
    def init_model(self, lora):
        transformer = WanTransformer3DModel.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16, cache_dir="/working/cache/huggingface/hub")
        old_patch_embed = transformer.patch_embedding
        new_patch_embed = nn.Conv3d(
            in_channels=old_patch_embed.in_channels*8,               # 修改为新输入通道
            out_channels=old_patch_embed.out_channels,
            kernel_size=old_patch_embed.kernel_size,
            stride=old_patch_embed.stride,
            padding=old_patch_embed.padding
        )
        transformer.patch_embedding = new_patch_embed
        old_proj_out = transformer.proj_out
        new_proj_out = nn.Linear(
            in_features=old_proj_out.in_features,
            out_features=old_proj_out.out_features*4,            # 修改为新输出通道
            bias=True
        )
        transformer.proj_out = new_proj_out
        if lora:
            transformer = self.add_lora_to_model(transformer, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out.0,linear_1,linear_2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None)
        return transformer
    
    def forward(self, z_t, timestep, encoder_hidden_states=None, z_c=None):
        B, C, D, H, W = z_t.shape
        if encoder_hidden_states is None:
            encoder_hidden_states = torch.zeros([1,256,4096], device=z_t.device)
        if z_c is None:
            z_c_in = torch.zeros_like(z_t, device=z_t.device, dtype=z_t.dtype)
        else:
            # 以 drop_prob 随机丢弃
            mask = (torch.rand(B,self.input_channels, device=z_t.device, dtype=z_t.dtype) < self.drop_prob).float()
            mask = mask.unsqueeze(2)
            mask = mask.repeat(1,1,self.latent_dim)
            mask = mask.view(B,self.latent_dim*self.input_channels,1,1,1)
            z_c_keep = z_c
            z_c_zero = torch.zeros_like(z_c, device=z_t.device, dtype=z_t.dtype)
            z_c_in = z_c_keep * (1-mask) + z_c_zero * mask

        # 拼接输入
        inp = torch.cat([z_t, z_c_in], dim=1)
        v_pred = self.model(inp, timestep, encoder_hidden_states)
        return v_pred


In [13]:
scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
scheduler.set_timesteps(1000, training=True)

In [19]:
model = UnifiedFlowNet(latent_dim=16, input_channels=4, drop_prob=0.2, lora=True)
model = model.to("cuda")
model.train()

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

UnifiedFlowNet(
  (model): WanTransformer3DModel(
    (rope): WanRotaryPosEmbed()
    (patch_embedding): Conv3d(128, 1536, kernel_size=(1, 2, 2), stride=(1, 2, 2))
    (condition_embedder): WanTimeTextImageEmbedding(
      (timesteps_proj): Timesteps()
      (time_embedder): TimestepEmbedding(
        (linear_1): lora.Linear(
          (base_layer): Linear(in_features=256, out_features=1536, bias=True)
          (lora_dropout): ModuleDict(
            (default): Identity()
          )
          (lora_A): ModuleDict(
            (default): Linear(in_features=256, out_features=4, bias=False)
          )
          (lora_B): ModuleDict(
            (default): Linear(in_features=4, out_features=1536, bias=False)
          )
          (lora_embedding_A): ParameterDict()
          (lora_embedding_B): ParameterDict()
          (lora_magnitude_vector): ModuleDict()
        )
        (act): SiLU()
        (linear_2): lora.Linear(
          (base_layer): Linear(in_features=1536, out_features=1536

In [20]:
for batch in train_loader:
    image = batch['image']
    mask = batch['brainmask']
    path = batch['path']
    image = image.permute(0, 1, 4, 2, 3)  # [B, C, H, W, D] -> [B, D, H, W, C]
    B, C, D, H, W  = image.shape
    image = image.view(B*C, D, H, W)  # [4, 1, D, H, W]
    image = image.unsqueeze(1)
    image = image.repeat(1, 3, 1, 1, 1)  # [4, 3, D, H, W]
    image = image.to(torch.bfloat16).to("cuda")
    with torch.no_grad():
        encode = vae.encode(image, return_dict=True)
    latent = encode.latent_dist.sample()
    latents = latent.view(B, 64, -1, int(H/8), int(W/8))
    noise = torch.randn_like(latents)
    timestep_id = torch.randint(0, scheduler.num_train_timesteps, (1,))
    timestep = scheduler.timesteps[timestep_id].to(dtype=latents.dtype, device=latents.device)
    noisy_latents = scheduler.add_noise(latents, noise, timestep)
    training_target = scheduler.training_target(latents, noise, timestep)

    noise_pred = model(noisy_latents, timestep, encoder_hidden_states=None, z_c=latents)
    loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
    loss = loss * scheduler.training_weight(timestep)
    print(loss.item())
    break

RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16