# load the student model

In [2]:
import torch
from celeba_u import make_model, make_dataset
from moving_average import init_ema_model
from train_utils import *
from celeba_dataset import CelebaWrapper

class WeightedTrainingDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, L, weight):
        self.dataset = dataset
        self.L = L
        self.weight = weight
    def __getitem__(self, item):
        idx = random.randint(0, len(self.dataset) - 1)
        r = self.weight[idx] * self.dataset[idx][0]
        return r, 0
    def __len__(self):
        return self.L

device = torch.device("cuda")

train_set = make_dataset()
train_len = len(train_set.dataset)
weight = torch.ones(train_len)
train_dataset = WeightedTrainingDataset(train_set, train_len, weight)
len(train_dataset)

29999

In [4]:
teacher_ema = make_model().to(device)
image_size = teacher_ema.image_size
ckpt = torch.load('./checkpoints/celeba/original/checkpoint.pt')  # base checkpoint
teacher_ema.load_state_dict(ckpt["G"])
n_timesteps = ckpt["n_timesteps"]
time_scale = ckpt["time_scale"]
del ckpt
print(f"Num timesteps: {n_timesteps}, time scale: {time_scale}.")

Num timesteps: 1024, time scale: 1.


In [8]:
from v_diffusion import *

scheduler = StrategyConstantLR()
distillation_model = DiffusionDistillation(scheduler)

def make_diffusion(model, n_timestep, time_scale, device):
    betas = make_beta_schedule("cosine", cosine_s=8e-3, n_timestep=n_timestep).to(device)
    sampler = "ddpm"
    return GaussianDiffusion(model, betas, time_scale=time_scale, sampler=sampler)

teacher_ema_diffusion = make_diffusion(teacher_ema, n_timesteps, time_scale, device)


In [9]:
def make_dataset_val():
    return CelebaWrapper(dataset_dir="./data/celeba_val/", resolution=256)
val_set = make_dataset_val()
val_len = len(val_set.dataset)
val_dataset = InfinityDataset(val_set, val_len)

In [10]:
from torch.utils.tensorboard import SummaryWriter

student = make_model().to(device)  # student也是unet的backbone
student_ema = make_model().to(device)
ckpt = torch.load('./checkpoints/celeba/base_0/checkpoint.pt')  # base checkpoint
student.load_state_dict(ckpt["G"])
student_ema.load_state_dict(ckpt["G"])
del ckpt
distill_train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)  # 用的是Trainingset

checkpoints_dir = os.path.join("checkpoints", "celeba", "base_0")
tensorboard = SummaryWriter(os.path.join(checkpoints_dir, "tensorboard"))

In [11]:
student_diffusion = make_diffusion(student, teacher_ema_diffusion.num_timesteps // 2, teacher_ema_diffusion.time_scale * 2, device)  # 实例化student diffusion
student_ema_diffusion = make_diffusion(student_ema, teacher_ema_diffusion.num_timesteps // 2, teacher_ema_diffusion.time_scale * 2, device)
on_iter = make_iter_callback(student_ema_diffusion, device, checkpoints_dir, image_size, tensorboard, 15, 30, False)

In [None]:
distillation_model.reweight_val_student(distill_train_loader, teacher_ema_diffusion, student_diffusion, student_ema, student_lr=0.3 * 5e-5, device=device, make_extra_args=make_condition, on_iter=on_iter)


# Prepare the validation dataset

### The val data is saved at ./data/celeba_val

In [3]:


def make_dataset_val():
    return CelebaWrapper(dataset_dir="./data/celeba_val/", resolution=256)



val_set = make_dataset_val()
val_len = len(val_set.dataset)
weight = torch.ones(val_len)
val_dataset = InfinityDataset(val_set, val_len, weight)

In [4]:
from torch.utils.tensorboard import SummaryWriter

checkpoints_dir = os.path.join("checkpoints", "celeba", "base_0")
image_size = unet_model.image_size
tensorboard = SummaryWriter(os.path.join(checkpoints_dir, "tensorboard"))

val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=True)
on_iter = make_iter_callback(student_diffusion_ema, device, checkpoints_dir, image_size, tensorboard, 15, 30, False)

In [34]:
scheduler = StrategyConstantLR()
diffusion_train = DiffusionTrain(scheduler)



In [None]:
diffusion_train.train(val_loader, student_diffusion, unet_model_ema, model_lr=5e-5, device=device, make_extra_args=make_condition, on_iter=on_iter)

# Test for reweight framework

In [12]:
# test
import torch
import torch.nn as nn

# 读取数据集逻辑
weight = [1, 1, 1, 1, 1, 1, 1, 1]  # load weight
sample, label = torch.randn(4, 10), torch.randn(4, 5)

weight = torch.tensor(weight, dtype=torch.float).unsqueeze(1)
print(weight)

# 转化成可训练参数
weight = nn.Parameter(weight)
weight

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])


Parameter containing:
tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]], requires_grad=True)

In [13]:
# 读取模型
class DemoModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Linear(10, 5)

    def forward(self, x):
        return self.backbone(x)

model = DemoModel()
model.requires_grad_(False)  # 关闭梯度

DemoModel(
  (backbone): Linear(in_features=10, out_features=5, bias=True)
)

In [14]:
# 前向传播
val_sample, val_label = torch.randn(4, 10), torch.randn(4, 5)
out = model(val_sample*weight[4:8])
out

tensor([[ 0.3503, -0.8262,  0.4586,  0.3807, -0.5822],
        [ 0.0338,  0.5614,  1.1887, -0.1049, -0.0895],
        [-0.4504,  0.5381,  0.8031, -0.7693,  0.6945],
        [ 0.5707, -0.0521,  0.2479, -0.2745, -0.1978]],
       grad_fn=<AddmmBackward0>)

In [15]:
# 反向传播
loss = torch.sum(torch.square(out-label))
loss.backward()

In [16]:
# 查看权重的梯度
weight.grad

tensor([[ 0.0000],
        [ 0.0000],
        [ 0.0000],
        [ 0.0000],
        [ 3.9352],
        [ 3.7999],
        [-0.1606],
        [ 0.5773]])

In [42]:
# 更新权重
lr = 0.001
new_weight = (weight - lr * weight.grad).detach()
new_weight

tensor([[1.0025],
        [0.9975],
        [0.9961],
        [0.9981]])