In [1]:
!nvidia-smi
!which python | grep DYY

Wed Sep 18 21:28:45 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000004:04:00.0 Off |                    0 |
| N/A   47C    P0              55W / 300W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# Model

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class SinousEmbedding(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        assert dim%2==0,NotImplementedError()
        self.angles = (10000.**(-2/dim))**torch.arange(1,dim//2+1,1,dtype=torch.float).cuda()
        self.angles.requires_grad_(False)
    def forward(self,x):
        angles = torch.einsum('m,i->im',self.angles,x.float())
        return torch.cat((torch.sin(angles),torch.cos(angles)),dim=1)
    
class Attn(nn.Module):
    def __init__(self,dim,head):
        super().__init__()
        assert dim%head==0,NotImplementedError()
        self.head = head
        self.head_dim = dim // head
        self.Q = nn.Linear(dim,dim)
        self.K = nn.Linear(dim,dim)
        self.V = nn.Linear(dim,dim)
        self.apply_init()

    def apply_init(self):
        self.Q.weight.data.normal_()
        self.K.weight.data.normal_()
        self.V.weight.data.normal_()
        # self.V.weight.data.zero_()
        # self.V.bias.data.zero_()
        pass
    
    def forward(self,query,context):
        # query: [B, H, head * head_dim]
        # print('query',query)
        # print('context',context)
        q = self.Q(query).reshape(*query.shape[:2],self.head,self.head_dim) # [B, H, head, head_dim]
        k = self.K(context).reshape(*context.shape[:2],self.head,self.head_dim) # [B, H, head, head_dim]
        v = self.V(context).reshape(*context.shape[:2],self.head,self.head_dim) # [B, H_c, head, head_dim]
        score = torch.einsum('bihd,bjhd->bijh',q,k) / self.head_dim**0.5 # [B, H_q, H_c, head]
        # print('score',score * (self.head_dim ** 0.5))
        score = F.softmax(score,dim=2) # [B, H_q, H_c, head]
        # print('attention max:',score.max(),'attention min:',score.min(),'attention to time:',score[:,-1])
        return torch.einsum('bijh,bjhd->bihd',score,v).reshape_as(query) # [B, H_q, head * head_dim]

class GELU(nn.Module):
    def forward(self,x):
        return 0.5 * x * (1 + torch.tanh(0.7978845608 * (x + 0.044715 * x**3)))
    
class SiLU(nn.Module):
    def forward(self,x):
        return x * torch.sigmoid(x)

class Layer(nn.Module):

    def __init__(self,dim,head):
        super().__init__()
        self.attn = Attn(dim,head)
        self.mlp = nn.Sequential(
            nn.Linear(dim,dim),
            GELU(),
            nn.Linear(dim,dim)
        )
        self.condition_mlp = nn.Sequential(
            nn.Linear(dim,2*dim),
            SiLU(),
            nn.Linear(2*dim,6*dim)
        )
        self.norm1 = nn.LayerNorm(dim,elementwise_affine=False)
        self.norm2 = nn.LayerNorm(dim,elementwise_affine=False)
        self.apply_init()

    def apply_init(self):
        self.condition_mlp[0].weight.data.zero_()
        self.condition_mlp[0].bias.data.zero_()
        self.condition_mlp[2].weight.data.zero_()
        self.condition_mlp[2].bias.data.zero_()
        self.mlp[0].weight.data.zero_()
        self.mlp[0].bias.data.zero_()
        self.mlp[2].weight.data.zero_()
        self.mlp[2].bias.data.zero_()

    def forward(self,x,condition):
        alpha1,beta1,gamma1,alpha2,beta2,gamma2 = self.condition_mlp(condition).unsqueeze(1).chunk(6,dim=-1)

        # first half
        xc = x.clone()
        x = self.norm1(x)
        x = x * gamma1 + beta1
        x = self.attn(x,x)
        x = x * alpha1
        x = x + xc

        # second half
        xc = x.clone()
        x = self.norm2(x)
        x = x * gamma2 + beta2
        x = self.mlp(x)
        x = x * alpha2
        x = x + xc

        return x

class DiT(nn.Module):

    def __init__(self,
                 patch_size=4,
                 hidden_dim=128,
                 num_layers=3,
                 image_size = 28*28,
                 num_heads=4
        ):
        super().__init__()
        self.num_patches = image_size // (patch_size * patch_size)
        self.embedding = nn.Linear(patch_size * patch_size, hidden_dim)
        self.pos_embedding = SinousEmbedding(hidden_dim)
        self.t_embedding = nn.Sequential(
            SinousEmbedding(hidden_dim),
            nn.Linear(hidden_dim,hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim,hidden_dim)
        )
        self.patch_size = patch_size
        self.num_layers = num_layers
        self.layers = nn.ModuleList([Layer(hidden_dim,num_heads) for _ in range(num_layers)])
        self.out_norm = nn.LayerNorm(hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, patch_size * patch_size)
        self.apply_init()

    def apply_init(self):
        self.out_proj.weight.data.zero_()
        self.out_proj.bias.data.zero_()
        # pass

    def first(self,x,t):
        t_embed = self.t_embedding(t)
        # patchify the image x
        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size) # [B, 1, H/4, W/4, 4, 4]
        x_embed = self.embedding(x.reshape(x.shape[0],-1,self.patch_size*self.patch_size))
        x_embed += self.pos_embedding(torch.arange(x_embed.shape[1],device=device))
        data = torch.cat((x_embed,t_embed.unsqueeze(1)),dim=1)
        # print('position embedding range:',pos_embed.min(),pos_embed.max(),pos_embed.std())
        # print('x range:',data.min(),data.max(),data.std())
        return data, t_embed
    def forward(self,x,t):
        x = x.reshape(x.shape[0],1,28,28)
        inputs, conditioned = self.first(x,t)
        for i,ly in enumerate(self.layers):
            inputs = ly(inputs,conditioned)
            # print('layer',i,'input.range:',inputs.min(),inputs.max())
        # remove t token
        inputs = inputs[:,:-1,:]
        inputs = self.out_proj(self.out_norm(inputs))
        # patchify the image x
        length = 28 // self.patch_size
        inputs = inputs.reshape(inputs.shape[0],length,length,self.patch_size,self.patch_size).permute(0,3,4,1,2).reshape(inputs.shape[0],self.patch_size*self.patch_size,length*length)
        x = F.fold(inputs, (28,28), self.patch_size, stride=self.patch_size)
        return x.reshape(x.shape[0],-1)

# model = DiT().to(device)
# x = torch.randn(7,1,28,28).to(device)
# t = torch.randint(0,10,(7,)).to(device)
# model(x,t).shape

# img = torch.arange(12*12).reshape(1,1,12,12).float()
# img
# # separate to 7x7 4x4 patches
# patch_size = 4
# p = img.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
# good = p.squeeze(1).permute(0,3,4,1,2)
# good.shape
# # change patch back to 28x28

# back = F.fold(good.reshape(img.shape[0],16,9), (12,12), 4, stride=4)
# back

# Train

In [3]:
import sys
import os

# parent_dir = os.path.abspath('/root/DeepLearning')
parent_dir = os.path.abspath('/home/zhh24/DeepLearning')

sys.path.append(parent_dir)
print('appended',parent_dir)

import utils

from tqdm import tqdm
import torch
import torch.nn.functional as F
import torchvision.utils

import matplotlib.pyplot as plt
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.autograd.set_detect_anomaly(True)
mnist = utils.MNIST(batch_size=512)
train_loader = mnist.train_dataloader
valid_loader = mnist.valid_dataloader

T=200
beta1=1e-4 # variance of lowest temperature
betaT=5e-2 # variance of highest temperature

# step = torch.log(torch.tensor(betaT/beta1))/(T-1)
# betas = beta1 * torch.exp(step*torch.arange(T,dtype=torch.float).to(device))
step = (betaT-beta1)/(T-1)
betas = torch.arange(T,dtype=torch.float,device=device) * step + beta1


alphas = 1-betas
alpha_bars = alphas.clone()
for i in range(1,T):
    alpha_bars[i] *= alpha_bars[i-1]

print(alpha_bars)
print('range of bars',alpha_bars.min(),alpha_bars.max())
# print(alphas)
# assert False

sqrt = torch.sqrt
sigmas = sqrt(betas * (1-alpha_bars / alphas)/(1-alpha_bars))
sigmas[0] = 1
print('range of sigmas,',sigmas.min(),sigmas.max())
alphas = alphas.to(device)
alpha_bars = alpha_bars.to(device)
betas = betas.to(device)
sigmas = sigmas.to(device)
weights = torch.ones(T,dtype=torch.float,device=device)

@torch.no_grad()
def sample(model:DiT,save_dir):
    x = torch.randn([100,784]).to(device)
    for t in range(T-1,-1,-1):
        sigmaz = torch.randn_like(x)*sigmas[t]
        if t==0:
            sigmaz = 0
        x = (x-(1-alphas[t])/(sqrt(1-alpha_bars[t]))*model(x,t*torch.ones(x.shape[0],dtype=torch.long,device=device)))/(sqrt(alphas[t]))+sigmaz
        # x = torch.clamp(x,0,1)
    grid = torchvision.utils.make_grid(post_process(x).reshape(-1,1,28,28).cpu(), nrow=10)
    torchvision.utils.save_image(grid, save_dir)

@torch.no_grad()
def visualize(model,save_dir):
    interval = (T-1) // 20
    x = torch.randn([10,784]).to(device)
    x_history = []
    for t in range(T-1,-1,-1):
        sigmaz = torch.randn_like(x)*((betas[t])**0.5).to(device)
        if t==0:
            sigmaz = 0
        x = (x-(1-alphas[t])/(sqrt(1-alpha_bars[t]))*model(x,t*torch.ones(x.shape[0],dtype=torch.long,device=device)))/(sqrt(alphas[t]))+sigmaz
        # x = torch.clamp(x,0,1)
        x_history.append(x)
    # print('cat.shape',torch.cat(x_history,dim=0).shape)
    grid = torchvision.utils.make_grid(post_process(torch.stack(x_history,dim=0)[::interval,...]).reshape(-1,1,28,28).cpu(), nrow=10)
    torchvision.utils.save_image(grid, save_dir)
    print('Saved visualize to',os.path.abspath(save_dir))

@torch.no_grad()
def visualize_denoise(model,save_dir):
    # get 10 images from the dataset
    x,_ = next(iter(valid_loader))
    x = x[:20,...].reshape(20,784).to(device)
    x = pre_process(x)
    t = torch.tensor([i * T // 20 for i in range(20)],dtype=torch.long,device=device)
    noise = torch.randn_like(x).reshape(-1,784)
    v1 = (sqrt(alpha_bars[t]).reshape(-1,1)*x).reshape(-1,784)
    v2 = sqrt(1-alpha_bars[t]).reshape(-1,1)*noise
    x_corr = v1+v2
    est = model(x_corr,t)
    x_rec = (x_corr - sqrt(1-alpha_bars[t]).reshape(-1,1)*est)/(sqrt(alpha_bars[t])).reshape(-1,1)
    grid_orig = torchvision.utils.make_grid(post_process(x).reshape(-1,1,28,28).cpu(), nrow=10)
    grid_corr = torchvision.utils.make_grid(post_process(x_corr).reshape(-1,1,28,28).cpu(), nrow=10)
    grid_rec = torchvision.utils.make_grid(post_process(x_rec).reshape(-1,1,28,28).cpu(), nrow=10)
    # add noise level infomation to the image
    noise_level = (1-alpha_bars[t]).reshape(-1).tolist()
    ori_mse = noise.pow(2).mean(dim=1).reshape(-1).tolist()
    mse = ((est-noise)**2).mean(dim=1).reshape(-1).tolist()
    print(noise_level)
    print(ori_mse)
    print(mse)
    grid = torch.cat([grid_orig,grid_corr,grid_rec],dim=1)
    torchvision.utils.save_image(grid, save_dir)
    print('Saved denoise to',os.path.abspath(save_dir))

def plot_loss(losses,save_dir):
    losses_vals, t_vals = zip(*losses)
    losses_vals = torch.cat(losses_vals,dim=0)
    t_vals = torch.cat(t_vals,dim=0)
    # print('t_vals',t_vals)
    # print('losses_vals',losses_vals)

    results = []
    for t in range(T):
        this_t = abs(t_vals.float()-float(t))<0.5
        results.append(torch.sum(torch.where(this_t,losses_vals,torch.tensor(0.,device=device))).item() / (torch.sum(this_t.float())+1e-3).item())
    plt.plot(results)
    plt.ylim(0,max(results)* 1.2)
    plt.savefig(save_dir)
    plt.close()
    # weights = (torch.tensor(results,device=device)) # weights
    weights = torch.ones(T,dtype=torch.float,device=device)
    # weights[:10]=0
    # weights[10:80] /= 100
    return weights

def pre_process(x):
    # do the logit transform
    # return (torch.log(x+1e-3)-torch.log(1-x+1e-3))
    return x*2-1 #MODIFIED
    return (x+1)/2

def post_process(x):
    # return torch.sigmoid(x)
    return (x+1)/2 #MODIFIED
    return x*2-1

def train(epochs,model:DiT,optimizer,eval_interval=5):
    global weights
    for epoch in range(epochs):
        # print('weights normalized:',weights/weights.sum())
        all_ts = torch.distributions.Categorical(weights).sample((50000,))
        cnt = 0
        model.train()
        with tqdm(train_loader) as bar:
            losses = []
            for x,_ in bar:
                cnt += x.shape[0]
                x = pre_process(x.to(device))
                epss = torch.randn_like(x).reshape(-1,784).to(device)
                # ts = torch.randint(0,T,(x.shape[0],),device=device,dtype=torch.long)
                ts = all_ts[cnt-x.shape[0]:cnt]
                alpha_tbars = alpha_bars[ts]
                value = (sqrt(alpha_tbars).reshape(-1,1,1,1)*x).reshape(-1,784)+sqrt(1-alpha_tbars).reshape(-1,1)*epss
                out = model(value,ts) # [batch,784]

                # loss = ((epss-out).pow(2).mean(dim=-1) * (betas[ts])/(2*alphas[ts]*(1-alpha_tbars))).sum(dim=0)
                loss = ((epss-out).pow(2).mean(dim=-1)).mean(dim=0)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses.append(loss.item())
                bar.set_description('epoch {}, loss {:.4f}'.format(epoch,sum(losses)/len(losses)))

        model.eval()
        with torch.no_grad():
            with tqdm(valid_loader) as bar:
                mses = []
                losses = []
                losses_for_t = []
                for x,_ in bar:
                    x = pre_process(x.to(device))
                    epss = torch.randn_like(x).reshape(-1,784).to(device)
                    ts = torch.randint(0,T,(x.shape[0],),device=device,dtype=torch.long)
                    # print(ts)
                    alpha_tbars = alpha_bars[ts]
                    value = (sqrt(alpha_tbars).reshape(-1,1,1,1)*x).reshape(-1,784)+sqrt(1-alpha_tbars).reshape(-1,1)*epss
                    out = model(value,ts)
                    mse = F.mse_loss(epss,out)
                    mses.append(mse.item())
                    loss = ((epss-out).pow(2).mean(dim=-1))
                    # loss = (epss-out).pow(2).mean(dim=-1)
                    losses_for_t.append((loss.clone().detach(),ts))
                    loss = (loss).mean(dim=0)
                    losses.append(loss.item())
                    bar.set_description('epoch {}, MSE {:.4f}, [Valid] {:.4f}'.format(epoch,sum(mses)/len(mses),sum(losses)/len(losses)))
                    
        if epoch % eval_interval == 0:
            visualize(model,save_dir=os.path.join('./samples',f'diffuse_epoch_{epoch}.png'))
            sample(model,save_dir=os.path.join('./samples',f'sample_epoch_{epoch}.png'))
            # visualize_denoise(model,save_dir=os.path.join('./samples',f'denoise_epoch_{epoch}.png'))
            weights = plot_loss(losses_for_t,save_dir=os.path.join('./samples',f'loss_epoch_{epoch}.png'))
            torch.save(model,os.path.join('./samples',f'epoch_{epoch}.pt'))

if __name__ == '__main__':
    model = DiT(
        num_layers=3,
        hidden_dim=128,
        num_heads=8
    ).to(device)
    print('Number parameters of the model:', sum(p.numel() for p in model.parameters()))
    print('Model strcuture:',model)
    optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
    os.makedirs('./samples',exist_ok=True)
    # sample(model,save_dir=os.path.join('./samples',f'init.png'))
    # visualize(model,save_dir=os.path.join('./samples',f'init_visualize.png'))
    train(200,model,optimizer,eval_interval=5)

appended /home/zhh24/DeepLearning


  0%|                                                                                                                                     | 0/94 [00:00<?, ?it/s]

tensor([0.9999, 0.9995, 0.9989, 0.9981, 0.9970, 0.9956, 0.9940, 0.9922, 0.9901,
        0.9878, 0.9852, 0.9824, 0.9793, 0.9760, 0.9725, 0.9688, 0.9648, 0.9606,
        0.9561, 0.9515, 0.9466, 0.9415, 0.9363, 0.9308, 0.9251, 0.9192, 0.9131,
        0.9068, 0.9004, 0.8937, 0.8869, 0.8799, 0.8728, 0.8655, 0.8580, 0.8504,
        0.8426, 0.8347, 0.8267, 0.8185, 0.8102, 0.8018, 0.7933, 0.7847, 0.7759,
        0.7671, 0.7582, 0.7492, 0.7401, 0.7309, 0.7217, 0.7124, 0.7030, 0.6936,
        0.6841, 0.6746, 0.6651, 0.6555, 0.6459, 0.6363, 0.6267, 0.6170, 0.6073,
        0.5977, 0.5880, 0.5784, 0.5688, 0.5592, 0.5496, 0.5400, 0.5305, 0.5210,
        0.5115, 0.5021, 0.4927, 0.4834, 0.4742, 0.4650, 0.4558, 0.4467, 0.4377,
        0.4288, 0.4199, 0.4112, 0.4025, 0.3938, 0.3853, 0.3769, 0.3685, 0.3602,
        0.3521, 0.3440, 0.3360, 0.3282, 0.3204, 0.3127, 0.3052, 0.2977, 0.2904,
        0.2831, 0.2760, 0.2690, 0.2621, 0.2553, 0.2486, 0.2420, 0.2356, 0.2292,
        0.2230, 0.2169, 0.2109, 0.2050, 

epoch 0, loss 0.3839: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:32<00:00,  2.98it/s]
epoch 0, MSE 0.2044, [Valid] 0.2044: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.68it/s]


Saved visualize to /home/zhh24/samples/diffuse_epoch_0.png


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
epoch 1, loss 0.1772: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:32<00:00,  2.99it/s]
epoch 1, MSE 0.1513, [Valid] 0.1513: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.65it/s]
epoch 2, loss 0.1400: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:32<00:00,  2.99it/s]
epoch 2, MSE 0.1293, [Valid] 0.1293: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.66it/s]
epoch 3, loss 0.1244: 100%|█████████████

Saved visualize to /home/zhh24/samples/diffuse_epoch_5.png


epoch 6, loss 0.1060: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:32<00:00,  2.98it/s]
epoch 6, MSE 0.1041, [Valid] 0.1041: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.67it/s]
epoch 7, loss 0.1034: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:32<00:00,  2.98it/s]
epoch 7, MSE 0.1017, [Valid] 0.1017: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.66it/s]
epoch 8, loss 0.1004: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:32<00:00,  2.99it/s]
epoch 8, MSE 0.0985, [Valid] 0.0985: 100%|███████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.68it/s]
epoch 9, loss 0.0981: 100%|█

Saved visualize to /home/zhh24/samples/diffuse_epoch_10.png


epoch 11, loss 0.0947: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:32<00:00,  2.98it/s]
epoch 11, MSE 0.0942, [Valid] 0.0942: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.56it/s]
epoch 12, loss 0.0937: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:32<00:00,  2.97it/s]
epoch 12, MSE 0.0899, [Valid] 0.0899: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.68it/s]
epoch 13, loss 0.0922: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 94/94 [00:32<00:00,  2.98it/s]
epoch 13, MSE 0.0904, [Valid] 0.0904: 100%|██████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  8.67it/s]
epoch 14, loss 0.0906:  36%|

: 