In [1]:
import torch
from Unet import *
from typing import *
from train_utils import *
import matplotlib as plt
from tqdm import tqdm
from dataset import QuickDrawDataset
from GRU import HistoryEncoder
import os

In [2]:
data_path = '/home/yujunwei/CS280-_Project_autoregressive_diffusuon/quickdraw_1k5_apple_cat.npz'
val_dataset = QuickDrawDataset(data_path, split='valid')
train_dataset = QuickDrawDataset(data_path, split='train')
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=1,  # 在这里设置批处理大小
    shuffle=True,
    num_workers=10,
    pin_memory=True if torch.cuda.is_available() else False
)
# val_loader = torch.utils.data.DataLoader(
#     val_dataset,
#     batch_size=8,  # 同上
#     shuffle=False,
#     num_workers=2,
#     pin_memory=True if torch.cuda.is_available() else False
# )

In [17]:
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
model = FlowMatching(UNet(1, 32, 32), GRU=HistoryEncoder(in_frames=12), num_ts=100)
checkpoint = torch.load("/home/yujunwei/CS280-_Project_autoregressive_diffusuon/checkpoints_0505_64_resolution_32_unet_1k/model_20.pth", map_location=device)
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()
label = torch.nn.functional.one_hot(torch.tensor(2), num_classes=10).float().unsqueeze(0).to(device)

model.unet.to(device)
model.GRU.to(device)

# x_p = torch.zeros(1, 12, 1, 256, 256).to(device)
# video = model.autoregressive_sample(c=label, autoregressive_steps=5, img_wh=[256, 256], x_p=None)

HistoryEncoder(
  (cnn): ConvEncoder(
    (backbone): Sequential(
      (0): Conv2d(12, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): GroupNorm(32, 64, eps=1e-05, affine=True)
      (2): SiLU()
      (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (4): GroupNorm(32, 128, eps=1e-05, affine=True)
      (5): SiLU()
      (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (7): GroupNorm(32, 256, eps=1e-05, affine=True)
      (8): SiLU()
      (9): Conv2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (10): GroupNorm(32, 128, eps=1e-05, affine=True)
      (11): SiLU()
    )
    (pool): AdaptiveAvgPool2d(output_size=1)
  )
  (proj): Linear(in_features=128, out_features=128, bias=True)
  (gru): GRU(128, 128, num_layers=2, batch_first=True)
)

In [4]:
def count_parameters_by_layer(model):
    """Count parameters for each layer in the model."""
    param_counts = {}
    total_params = 0
    
    # 遍历所有命名参数
    for name, param in model.named_parameters():
        if param.requires_grad:
            # 提取层名
            layer_name = name.split('.')[0] if '.' in name else name
            
            # 计算参数数量
            param_count = param.numel()
            
            # 更新层参数计数
            if layer_name in param_counts:
                param_counts[layer_name] += param_count
            else:
                param_counts[layer_name] = param_count
                
            total_params += param_count
    
    # 打印每一层的参数数量
    print("Parameters by layer:")
    for layer_name, count in param_counts.items():
        print(f"{layer_name}: {count:,} parameters")
    
    print(f"\nTotal trainable parameters: {total_params:,}")
    
    return param_counts

# 使用方法
count_parameters_by_layer(model.unet)

# by reducing the resolution to 64 x 64: curr_param = 123437132
# with hidden_dim = 32, we have 36713068 parameters
# by removing the after_gru MLP: 32510572

Parameters by layer:
conv_1: 12,992 parameters
down_1: 28,128 parameters
down_2: 93,120 parameters
down_3: 111,552 parameters
down_4: 370,560 parameters
after_GRU: 1,056,768 parameters
up_1: 894,848 parameters
up_2: 632,704 parameters
up_3: 224,192 parameters
up_4: 158,656 parameters
final_conv: 21,932 parameters
timeembed_1: 16,768 parameters
timeembed_2: 16,768 parameters
classembed_1: 17,920 parameters
classembed_2: 17,920 parameters

Total trainable parameters: 3,674,828


{'conv_1': 12992,
 'down_1': 28128,
 'down_2': 93120,
 'down_3': 111552,
 'down_4': 370560,
 'after_GRU': 1056768,
 'up_1': 894848,
 'up_2': 632704,
 'up_3': 224192,
 'up_4': 158656,
 'final_conv': 21932,
 'timeembed_1': 16768,
 'timeembed_2': 16768,
 'classembed_1': 17920,
 'classembed_2': 17920}

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
model_2 = FlowMatching(UNet(1, 64, 64), GRU=HistoryEncoder(in_frames=12))
checkpoint = torch.load("/home/yujunwei/CS280-_Project_autoregressive_diffusuon/checkpoints_0503_apple_update_gru/model_5.pth", map_location=device)
model_2.load_state_dict(checkpoint)
model_2 = model_2.to(device)
model_2.eval()
label = torch.nn.functional.one_hot(torch.tensor(2), num_classes=10).float().unsqueeze(0).to(device)

model_2.unet.to(device)
model_2.GRU.to(device)

# x_p = torch.zeros(1, 12, 1, 256, 256).to(device)
# video = model.autoregressive_sample(c=label, autoregressive_steps=5, img_wh=[256, 256], x_p=None)

HistoryEncoder(
  (cnn): ConvEncoder(
    (backbone): Sequential(
      (0): Conv2d(12, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): GroupNorm(32, 64, eps=1e-05, affine=True)
      (2): SiLU()
      (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (4): GroupNorm(32, 128, eps=1e-05, affine=True)
      (5): SiLU()
      (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (7): GroupNorm(32, 256, eps=1e-05, affine=True)
      (8): SiLU()
      (9): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (10): GroupNorm(32, 512, eps=1e-05, affine=True)
      (11): SiLU()
    )
    (pool): AdaptiveAvgPool2d(output_size=1)
  )
  (proj): Linear(in_features=512, out_features=512, bias=True)
  (gru): GRU(512, 512, num_layers=2, batch_first=True)
)

In [None]:
for k, v in model_2.GRU.named_aprameters():
    print(k)
    print(v)
    break

cnn.backbone.0.weight
Parameter containing:
tensor([[[[ 0.0447,  0.0104,  0.0721],
          [-0.0283, -0.0678, -0.0100],
          [-0.0279,  0.0455, -0.0571]],

         [[ 0.0681, -0.0854, -0.0304],
          [-0.0376, -0.0608, -0.0927],
          [-0.0465, -0.0049,  0.0798]],

         [[ 0.0043, -0.0599,  0.0896],
          [-0.0052,  0.0885,  0.0523],
          [ 0.0194,  0.0474, -0.0065]],

         ...,

         [[-0.0657,  0.0028,  0.0623],
          [ 0.0245,  0.0492, -0.0036],
          [-0.0580, -0.0280,  0.0738]],

         [[ 0.0411, -0.0069,  0.0381],
          [ 0.0864,  0.0450,  0.0551],
          [-0.0751,  0.0383, -0.0339]],

         [[-0.0180,  0.0423,  0.0873],
          [-0.0030,  0.0411, -0.0252],
          [ 0.0903, -0.0617,  0.0810]]],


        [[[-0.0446, -0.0735, -0.0365],
          [ 0.0509, -0.0259,  0.0773],
          [-0.0188,  0.0720,  0.0897]],

         [[ 0.0633, -0.0219,  0.0059],
          [-0.0816, -0.0466,  0.0381],
          [ 0.0571, -0.0208,

In [6]:
for k, v in model.GRU.named_parameters():
    print(k)
    print(v)
    break

cnn.backbone.0.weight
Parameter containing:
tensor([[[[ 0.0486,  0.0118,  0.0707],
          [-0.0283, -0.0692, -0.0120],
          [-0.0295,  0.0435, -0.0595]],

         [[ 0.0745, -0.0813, -0.0288],
          [-0.0341, -0.0593, -0.0925],
          [-0.0449, -0.0045,  0.0797]],

         [[ 0.0111, -0.0546,  0.0928],
          [-0.0002,  0.0918,  0.0546],
          [ 0.0231,  0.0503, -0.0050]],

         ...,

         [[-0.0634,  0.0024,  0.0605],
          [ 0.0268,  0.0479, -0.0075],
          [-0.0567, -0.0272,  0.0724]],

         [[ 0.0393, -0.0100,  0.0330],
          [ 0.0842,  0.0405,  0.0494],
          [-0.0776,  0.0354, -0.0375]],

         [[-0.0200,  0.0391,  0.0824],
          [-0.0059,  0.0363, -0.0311],
          [ 0.0870, -0.0655,  0.0763]]],


        [[[-0.0397, -0.0686, -0.0313],
          [ 0.0567, -0.0200,  0.0830],
          [-0.0130,  0.0776,  0.0951]],

         [[ 0.0681, -0.0171,  0.0111],
          [-0.0760, -0.0411,  0.0436],
          [ 0.0628, -0.0152,

In [None]:
out = (torch.sigmoid(logits) > 0.5).float()

In [18]:
cnt = 0

In [19]:
for video, labels in tqdm(train_loader):
    video = video.to(device)
    labels = labels.to(device)
    cnt += 1
    print(labels.shape)
    print(labels[0][0])
    if int(labels[0][1]) == 0:
        continue
    print(labels)
    video = model.autoregressive_sample(c=labels, autoregressive_steps=5, img_wh=[64, 64], x_p=video[:,:12])
    break
    # video = model.autoregressive_sample(c=labels, autoregressive_steps=5, img_wh=[256, 256], x_p=None)

  0%|          | 1/3000 [00:00<32:43,  1.53it/s]

torch.Size([1, 10])
tensor(1., device='cuda:0')
torch.Size([1, 10])
tensor(0., device='cuda:0')
tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0')
torch.Size([1, 12, 64, 64])


  0%|          | 1/3000 [00:03<3:11:19,  3.83s/it]


In [20]:
video.shape

torch.Size([1, 72, 1, 64, 64])

In [21]:
copy_video = video.clone()

In [22]:
for t in range(copy_video.shape[1]):  # 遍历每一帧
    # 对每帧应用softmax使像素互相竞争
    frame = copy_video[:, t].reshape(1, -1)  # 将形状变为 (1, 256*256)
    softmax_frame = torch.softmax(frame, dim=1)
    
    # 归一化处理，让值分布更均匀
    # 将最大值缩放到接近1.0，使二值化有意义
    min_val = softmax_frame.min()
    max_val = softmax_frame.max()
    if max_val > min_val:  # 避免除零错误
        normalized = (softmax_frame - min_val) / (max_val - min_val)
        copy_video[:, t] = normalized.reshape(copy_video[:, t].shape)
    else:
        copy_video[:, t] = softmax_frame.reshape(copy_video[:, t].shape)
copy_video = (copy_video > 0.5).float()

In [23]:
import imageio
import numpy as np

# 将视频数据转移到 CPU 并转换为 numpy 数组
video_np = copy_video[:,12:].cpu().detach().numpy()

# 假设视频形状为 [batch, autoregressive_steps, time_step, channels, height, width]
# 展平多个维度以获得总共 60 帧
flat_video = video_np.reshape(-1, *video_np.shape[-3:])  # [total_frames, channels, height, width]

# 如果帧数不足 60，可以循环播放以达到 60 帧
frames = []
total_frames = flat_video.shape[0]
repeats = max(1, int(np.ceil(60 / total_frames)))

for _ in range(repeats):
    for i in range(total_frames):
        if len(frames) >= 60:
            break
            
        # 获取当前帧并去掉通道维度（如果为1）
        frame = flat_video[i, 0] if flat_video.shape[1] == 1 else flat_video[i].transpose(1, 2, 0)
        
        # 缩放到 [0, 255] 范围并转换为 uint8
        frame = (frame * 255).astype(np.uint8)
        frames.append(frame)

# 创建 GIF
imageio.mimsave('animation_20_0505_400k_32_unet_cat_1k.gif', frames, fps=10)
print("GIF 已保存为 'animation_5_add_1st_chunk.gif'")

GIF 已保存为 'animation_5_add_1st_chunk.gif'
