In [1]:
!nvidia-smi
!which python | grep DYY

Sat Sep 14 21:25:29 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000004:04:00.0 Off |                    0 |
| N/A   43C    P0              54W / 184W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# Model

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# class Res(nn.Module):
#     def __init__(self, channel, kernel_size=3,x_size = 28) -> None:
#         super().__init__()
#         self.channel = channel
#         self.conv1=nn.Sequential(
#             nn.Conv2d(channel,channel,(kernel_size,kernel_size),padding=(kernel_size-1)//2),
#             nn.BatchNorm2d(channel),
#         )
#         self.conv2= nn.Sequential(
#             nn.ReLU(),
#             nn.Conv2d(channel,channel,(kernel_size,kernel_size),padding=(kernel_size-1)//2),
#         )
#         self.t_net = nn.Linear(256,self.channel*x_size*x_size)
#     def forward(self,x,t):
#         res = x.clone()
#         x = self.conv1(x)
#         x = x + self.t_net(t).reshape(x.shape)
#         x = self.conv2(x) + res
#         return x

class SinousEmbedding(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        assert dim%2==0,NotImplementedError()
        self.angles = (100.**(-2/dim))**torch.arange(1,dim//2+1,1,dtype=torch.float).cuda()
        self.angles.requires_grad_(False)
    def forward(self,x):
        angles = torch.einsum('m,i->im',self.angles,x.float())
        return torch.cat((torch.sin(angles),torch.cos(angles)),dim=1)

# class Attention(nn.Module):
#     def __init__(self, channel,hidden_size=512) -> None:
#         super().__init__()
#         self.hidden_size = hidden_size
#         self.q_proj = nn.Linear(channel,hidden_size)
#         self.k_proj = nn.Linear(channel,hidden_size)
#         self.v_proj = nn.Linear(channel,hidden_size)
#         self.out_proj = nn.Linear(hidden_size,channel)
#     def forward(self,x):
#         res = x.clone()
#         batch,channel = x.shape[:2]
#         seq_len = x.shape[-1]
#         x = x.reshape(batch,channel,seq_len*seq_len).transpose(1,2)
#         v = self.v_proj(x)
#         q = self.q_proj(x)
#         k = self.k_proj(x)
#         att_sc = torch.einsum('bic,bjc->bij',q,k)*((self.hidden_size)**-0.5)
#         att_sc = torch.softmax(att_sc,dim=-1)
#         att_out = torch.einsum('bij,bjc->bic',att_sc,v)
#         ans = self.out_proj(att_out).transpose(1,2).reshape(batch,channel,seq_len,seq_len)
#         return ans+res


# class ResBlockWithAttention(nn.Module):
#     def __init__(self, in_channel, out_channel,x_size,with_attention=True,kernel_size=3) -> None:
#         super().__init__()
#         self.conv=nn.Sequential(
#             nn.Conv2d(in_channel,out_channel,(kernel_size,kernel_size),padding=(kernel_size-1)//2),
#             nn.BatchNorm2d(out_channel),
#             nn.ReLU(),
#         )
#         self.reses = nn.ModuleList(
#             [Res(out_channel,kernel_size,x_size) for _ in range(4)]
#         )
#         if with_attention:
#             self.attentions = nn.ModuleList(
#                 [Attention(out_channel) for _ in range(4)]
#             )
        
#     def forward(self,x,t):
#         x = self.conv(x)
#         for i,ly in enumerate(self.reses):
#             x = ly(x,t)
#             if hasattr(self,'attentions'):
#                 x = self.attentions[i](x)
#         return x
class F_x_t(nn.Module):

    def __init__(self,in_channels,out_channels,out_size,kernel_size=5,padding=2,t_shape=64) -> None:
        super().__init__()
        self.t_channels = out_channels // 2
        self.conv_channels = out_channels - self.t_channels
        self.conv = nn.Conv2d(in_channels, self.conv_channels, kernel_size=kernel_size, padding=padding)
        self.out_size = out_size
        self.fc = nn.Linear(t_shape, self.t_channels*out_size*out_size)

    def forward(self, x, t):
        # return self.conv(x) + self.fc(t).reshape(-1, self.out_channels, self.out_size, self.out_size)
        if self.t_channels == 0:
            return self.conv(x)
        return torch.cat([self.conv(x),self.fc(t).reshape(-1, self.t_channels, self.out_size, self.out_size)],dim=1)

class DDPM(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.in_size = 28 * 28
        self.t_embedding = SinousEmbedding(dim=64)
        self.up= nn.ModuleList([
            # ResBlockWithAttention(1,64,x_size=28,with_attention=False), # 28 28
            # nn.MaxPool2d(kernel_size=(2,2)), # 14 14
            # ResBlockWithAttention(64,128,x_size=14), # 14 14
            # nn.MaxPool2d(kernel_size=(2,2)), # 7 7
            # ResBlockWithAttention(128,256,x_size=7), # 7 7
            F_x_t(in_channels=1,out_channels=32,out_size=28),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(2,2)), # 14 14
            F_x_t(in_channels=32,out_channels=64,out_size=14),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(2,2)), # 7 7
        ])
        self.middle = nn.ModuleList([
            # nn.Conv2d(64,64,kernel_size=(5,5),padding=2),
            # nn.ReLU(),
            # nn.Conv2d(64,64,kernel_size=(5,5),padding=2),
            # nn.Conv2d(256,256,kernel_size=(5,5),padding=2),
            # nn.ReLU(),
            # Attention(256)
            nn.Identity()
        ])
        self.down= nn.ModuleList([
            nn.Upsample(scale_factor=2),
            F_x_t(in_channels=64,out_channels=32,out_size=14),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            F_x_t(in_channels=32,out_channels=1,out_size=28),
            # ResBlockWithAttention(512,128,x_size=7), # 7 7
            # nn.Upsample(scale_factor=2),
            # ResBlockWithAttention(256,64,x_size=14),
            # nn.Upsample(scale_factor=2),
            # ResBlockWithAttention(128,1,x_size=28,with_attention=False),       
        ])

    def forward(self,x,t):
        x = x.reshape(-1,1,28,28)
        ttensor = self.t_embedding(t) # [batch, 256]
        batch = x.shape[0]
        ups = []
        for i,ly in enumerate(self.up):
            if isinstance(ly,F_x_t):
                cl = x.clone()
                x = ly(x,ttensor)
                ups.append(cl)
            else:
                x = ly(x)
        for ly in self.middle:
            x = ly(x)

        for i,ly in enumerate(self.down):
            if isinstance(ly,F_x_t):
                old = ups.pop()
                x = ly(x,ttensor) + old
            else:
                x = ly(x)
        x = x.reshape(batch,-1)
        return x

# Train

In [4]:
import sys
import os

parent_dir = os.path.abspath('..')

sys.path.append(parent_dir)
print('appended',parent_dir)

import utils

from tqdm import tqdm
import torch
import torch.nn.functional as F
import torchvision.utils

device = 'cuda' if torch.cuda.is_available() else 'cpu'

mnist = utils.MNIST(batch_size=512)
train_loader = mnist.train_dataloader
valid_loader = mnist.valid_dataloader
T=64
beta1=1e-2 # variance of lowest temperature
betaT=1 # variance of highest temperature
# step = torch.log(torch.tensor(betaT/beta1))/(T-1)
# betas = beta1 * torch.exp(step*torch.arange(T,dtype=torch.float).to(device))
step = (betaT-beta1)/(T-1)
betas = torch.arange(beta1,betaT+step,step).to(device)
alphas = 1-betas
alpha_bars = alphas.clone()
for i in range(1,T):
    alpha_bars[i] *= alpha_bars[i-1]

sqrt = torch.sqrt
sigmas = sqrt(betas)

@torch.no_grad()
def sample(model:DDPM,save_dir):
    x = torch.randn([100,784]).to(device)
    for t in range(T-1,-1,-1):
        sigmaz = torch.randn_like(x)*sigmas[t]
        if t==0:
            sigmaz = 0
        x = (x-(1-alphas[t])/(sqrt(1-alpha_bars[t]))*model(x,t*torch.ones(x.shape[0],dtype=torch.long,device=device)))/(sqrt(alphas[t]))+sigmaz
        x = torch.clamp(x,0,1)
    grid = torchvision.utils.make_grid(x.reshape(-1,1,28,28).cpu(), nrow=10)
    torchvision.utils.save_image(grid, save_dir)

@torch.no_grad()
def visualize(model,save_dir):
    x = torch.randn([10,784]).to(device)
    x_history = []
    for t in range(T-1,-1,-1):
        sigmaz = torch.randn_like(x)*((betas[t])**0.5).to(device)
        if t==0:
            sigmaz = 0
        x = (x-(1-alphas[t])/(sqrt(1-alpha_bars[t]))*model(x,t*torch.ones(x.shape[0],dtype=torch.long,device=device)))/(sqrt(alphas[t]))+sigmaz
        x = torch.clamp(x,0,1)
        x_history.append(x)
    grid = torchvision.utils.make_grid(torch.cat(x_history,dim=0)[::2,...].reshape(-1,1,28,28).cpu(), nrow=10)
    torchvision.utils.save_image(grid, save_dir)
    print('Saved visualize to',save_dir)

def train(epochs,model:DDPM,optimizer,eval_interval=5):
    for epoch in range(epochs):
        model.train()
        with tqdm(train_loader) as bar:
            losses = []
            for x,_ in bar:
                x = x.to(device)
                epss = torch.randn_like(x).reshape(-1,784).to(device)
                ts = torch.randint(0,T,(x.shape[0],),device=device,dtype=torch.long)
                alpha_tbars = alpha_bars[ts]
                value = (sqrt(alpha_tbars).reshape(-1,1,1,1)*x).reshape(-1,784)+sqrt(1-alpha_tbars).reshape(-1,1)*epss
                out = model(value,ts) # [batch,784]
                # loss = ((epss-out).pow(2).mean(dim=-1) * (betas[ts])/(2*alphas[ts]*(1-alpha_tbars))).sum(dim=0)
                loss = (epss-out).pow(2).mean(dim=-1).sum(dim=0)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses.append(loss.item())
                bar.set_description('epoch {}, loss {:.4f}'.format(epoch,sum(losses)/len(losses)))
        model.eval()
        with torch.no_grad():
            with tqdm(valid_loader) as bar:
                mses = []
                losses = []
                for x,_ in bar:
                    x = x.to(device)
                    epss = torch.randn_like(x).reshape(-1,784).to(device)
                    ts = torch.randint(0,T,(x.shape[0],),device=device,dtype=torch.long)
                    alpha_tbars = alpha_bars[ts]
                    value = (sqrt(alpha_tbars).reshape(-1,1,1,1)*x).reshape(-1,784)+sqrt(1-alpha_tbars).reshape(-1,1)*epss
                    out = model(value,ts)
                    mse = F.mse_loss(epss,out)
                    mses.append(mse.item())
                    # loss = ((epss-out).pow(2).mean(dim=-1) * (betas[ts])/(2*alphas[ts]*(1-alpha_tbars))).sum(dim=0)
                    loss = (epss-out).pow(2).mean(dim=-1).sum(dim=0)
                    losses.append(loss.item())
                    bar.set_description('epoch {}, MSE {:.4f}, [Valid] {:.4f}'.format(epoch,sum(mses)/len(mses),sum(losses)/len(losses)))
                    
        if epoch % eval_interval == 0:
            visualize(model,save_dir=os.path.join('./samples',f'diffuse_epoch_{epoch}.png'))
            sample(model,save_dir=os.path.join('./samples',f'sample_epoch_{epoch}.png'))
            torch.save(model,os.path.join('./samples',f'epoch_{epoch}.pt'))

if __name__ == '__main__':
    model = DDPM().to(device)
    print('Number parameters of the model:', sum(p.numel() for p in model.parameters()))
    print('Model strcuture:',model)
    optimizer = torch.optim.Adam(model.parameters(),lr=3e-4)
    os.makedirs('./samples',exist_ok=True)
    sample(model,save_dir=os.path.join('./samples',f'init.png'))
    train(100,model,optimizer,eval_interval=5)

appended /home/zhh24/DeepLearning


  0%|                                                                                                                                     | 0/94 [00:00<?, ?it/s]

Number parameters of the model: 1479345
Model strcuture: DDPM(
  (t_embedding): SinousEmbedding()
  (up): ModuleList(
    (0): F_x_t(
      (conv): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (fc): Linear(in_features=64, out_features=12544, bias=True)
    )
    (1): ReLU()
    (2): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
    (3): F_x_t(
      (conv): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (fc): Linear(in_features=64, out_features=6272, bias=True)
    )
    (4): ReLU()
    (5): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
  )
  (middle): ModuleList(
    (0): Identity()
  )
  (down): ModuleList(
    (0): Upsample(scale_factor=2.0, mode=nearest)
    (1): F_x_t(
      (conv): Conv2d(64, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (fc): Linear(in_features=64, out_features=3136, bias=True)
    )
    (2): ReLU()
    (3): Upsample(scale_factor=2.0, mode=nearest)
    (4): F_x_t(
      (conv):

epoch 0, loss 34.9476: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:13<00:00,  7.40it/s]
epoch 0, MSE 0.0617, [Valid] 30.9762: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.08it/s]


Saved visualize to ./samples/diffuse_epoch_0.png


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
epoch 1, loss 31.2828: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:13<00:00,  7.36it/s]
epoch 1, MSE 0.0592, [Valid] 29.6298: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.08it/s]
epoch 2, loss 29.8576: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:13<00:00,  7.25it/s]
epoch 2, MSE 0.0568, [Valid] 28.3667: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.08it/s]
epoch 3, loss 29.4801: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:13<00:00,  7.37it/s]
epoch 3, MSE 0.0579, [Valid] 29.08

Saved visualize to ./samples/diffuse_epoch_5.png


epoch 6, loss 26.5727: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:13<00:00,  7.37it/s]
epoch 6, MSE 0.0513, [Valid] 25.6103: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.07it/s]
epoch 7, loss 25.6989: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:13<00:00,  7.38it/s]
epoch 7, MSE 0.0479, [Valid] 23.9537: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.09it/s]
epoch 8, loss 25.1130: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:13<00:00,  7.34it/s]
epoch 8, MSE 0.0485, [Valid] 24.1230: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.09it/s]
epoch 9, loss 24.1537: 100%|

Saved visualize to ./samples/diffuse_epoch_10.png


epoch 11, loss 22.8539: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:13<00:00,  7.40it/s]
epoch 11, MSE 0.0454, [Valid] 22.7209: 100%|█████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.06it/s]
epoch 12, loss 23.0506: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:13<00:00,  7.35it/s]
epoch 12, MSE 0.0447, [Valid] 22.3688: 100%|█████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.09it/s]
epoch 13, loss 22.3320: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:13<00:00,  7.35it/s]
epoch 13, MSE 0.0432, [Valid] 21.4805: 100%|█████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.06it/s]
epoch 14, loss 22.1643: 100%

Saved visualize to ./samples/diffuse_epoch_15.png


epoch 16, loss 21.4069: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:13<00:00,  7.39it/s]
epoch 16, MSE 0.0425, [Valid] 21.7634:  46%|██████████████████████████████████████▉                                              | 11/24 [00:01<00:01,  7.78it/s]


KeyboardInterrupt: 

# Next step: reduce the minimal variance, how to make it still work?