In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from tqdm import tqdm

  Referenced from: <AA076682-B999-31A8-AEA6-2BF697B54762> /opt/homebrew/Caskroom/miniforge/base/envs/mlp/lib/python3.8/site-packages/torchvision/image.so
  warn(


In [2]:
import matplotlib.pyplot as plt
import numpy as np

def imshow(tensor):
    tensor = tensor.permute(1, 2, 0)

    # Step 2: Convert the tensor to a NumPy array
    array = tensor.cpu().numpy()

    # Step 3: Scale the pixel values to the range [0, 1]
    array = (array - np.min(array)) / (np.max(array) - np.min(array))

    # Step 4: Clip the pixel values to the valid range [0, 1]
    array = np.clip(array, 0, 1)

    # Step 5: Display the RGB image using matplotlib
    plt.imshow(array)
    plt.axis('off')
    plt.show()

def imshow_grayscale(img):
    img = (img + 1) * 0.5
    # clipped_output = np.clip(img, 0, 1)
    plt.imshow(img, cmap='gray')
    plt.show()

In [3]:
# hyperparameters
batch_size = 128 
n_channels = 3
num_timesteps = 1000
img_dim = 32

device = "mps"

In [4]:

# transformMNIST = transforms.Compose([
#     transforms.Pad(padding=2, fill=0, padding_mode='constant'), 
#     transforms.ToTensor(),
#     transforms.Lambda(lambda x: (x - 0.5) * 2.0)
# ])

transformCIFAR = transforms.Compose([
    transforms.ToTensor(),
   transforms.Normalize([0.5], [0.5]),

])

# mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transformMNIST)
# trainloader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)

cifar_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transformCIFAR)
train_airplanes = [(img, label) for img, label in cifar_train if label == 7]
# only pick airplanes
trainloader = torch.utils.data.DataLoader(train_airplanes, batch_size=batch_size, shuffle=True)

# # sample from the dataloader
for batch in (trainloader):
    # imshow(batch[0][0])
    break

# # sample from the dataloader



Files already downloaded and verified


In [5]:

beta_schedule = torch.linspace(1e-4, 0.02, num_timesteps)
alpha_schedule = 1. - beta_schedule

def get_alphas(num_timesteps):
    alphas = [alpha_schedule[t] for t in range(num_timesteps)]
    return torch.cumprod(torch.tensor(alphas), 0).to(device)

# alphas_cumprod = get_alphas(num_timesteps)
alphas_cumprod = torch.tensor([torch.prod(alpha_schedule[:i + 1]) for i in range(len(alpha_schedule))]).to(device)

# returns tuple (noised image, noise)
def get_noised_image_at(t, image):
    alpha = alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(device)
    noise = torch.randn(batch_size, n_channels, img_dim, img_dim, device=device)
    noise_factor = torch.sqrt(1 - alpha)
    return ((torch.sqrt(alpha) * image) + (noise * noise_factor), noise)


In [6]:
def get_sinusoidal_embedding(n, d):
    # Returns the standard positional embedding
    embedding = torch.zeros(n, d)
    wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
    wk = wk.reshape((1, d))
    t = torch.arange(n).reshape((n, 1))
    embedding[:,::2] = torch.sin(t * wk[:,::2])
    embedding[:,1::2] = torch.cos(t * wk[:,::2])
    return embedding

In [7]:
# class MultiHeadAttention(nn.Module):
#     def __init__(self, dim):
#         super(MultiHeadAttention, self).__init__()
#         self.dim = dim
#         self.num_heads = 2
#         self.hidden_dim = 16
#         self.qkv = nn.Linear(dim, self.hidden_dim * 3, bias=False) 
#         self.map = nn.Linear(self.hidden_dim, dim)

#     def forward(self, x):
#         # apply along channel dim (1)
#         shaped_x = x.permute(0, 2, 3, 1) # (B, H, W, C)
#         q,k,v = self.qkv(shaped_x).chunk(3, dim=-1) # (B, H, W, 3 * hidden_dim)
        
#         q = q.reshape(x.shape[0] * self.num_heads, x.shape[-1] * x.shape[-2], self.hidden_dim // (self.num_heads))
#         k = k.reshape(x.shape[0] * self.num_heads, self.hidden_dim // (self.num_heads), x.shape[-1] * x.shape[-2])
#         v = v.reshape(x.shape[0] * self.num_heads, x.shape[-1] * x.shape[-2], self.hidden_dim // (self.num_heads)) 

#         out = (1/np.sqrt(self.hidden_dim)) * (q @ k) # (B*Hds, H*W, H*W) -> each pixel "sees" the other
#         out = F.softmax(out, dim=-1) @ v # (B*Hds, H*W, dv)

#         # target shape = (B, H, W, C)
#         out = out.view(x.shape[0], x.shape[-2], x.shape[-1], self.hidden_dim)
#         out = self.map(out)
#         out = out.permute(0, 3, 1, 2)
#         # residual connection
#         out += x

#         return out

# copy linear attention for now from HuggingFace
from einops import rearrange

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=20):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out) + x


In [8]:
class WideResidualBlock(nn.Module):
    def __init__(self, channels_in, channels_out, im_dim):
        super(WideResidualBlock, self).__init__()
        # two convolutional layers and a skip connection for residual
        self.conv1 = nn.Conv2d(channels_in, channels_out, 3, padding=1).to(device) 
        self.conv2 = nn.Conv2d(channels_out, channels_out, 3, padding=1).to(device)
        self.ln1 = nn.LayerNorm([channels_in, im_dim, im_dim]).to(device)
        # self.ln2 = nn.LayerNorm([channels_out, im_dim, im_dim]).to(device)
        self.skip_connection = nn.Conv2d(channels_in, channels_out, 1).to(device)

    def forward(self, x, norm=True):
        out = self.ln1(x) if norm else x
        out = F.silu((self.conv1(out)))
        out = F.silu((self.conv2(out)))
        out = out + self.skip_connection(x)
        return out
        
class DownsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, im_dim):
        super(DownsampleBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        # two residual blocks and a downsampling operation
        self.attention = LinearAttention(out_channels).to(device)
        self.wide_residual_block1 = WideResidualBlock(self.in_channels, self.out_channels, im_dim).to(device)
        self.wide_residual_block2 = WideResidualBlock(self.out_channels, self.out_channels, im_dim).to(device)
        self.wide_residual_block3 = WideResidualBlock(self.out_channels, self.out_channels, im_dim).to(device)
        self.max_pool = nn.MaxPool2d(2, 2).to(device)

    def forward(self, x, t):
        x = x + t.reshape(x.shape[0], -1, 1, 1)
        x = self.wide_residual_block1(x)
        x = self.attention(x)
        x = self.wide_residual_block2(x)
        x = self.wide_residual_block3(x)
        x = self.max_pool(x)
        return x

class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, im_dim, concat=True):
        super(UpsampleBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.concat = concat
        self.attention = LinearAttention(out_channels).to(device)
        self.upsample = nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1).to(device)
        self.upsample_channels = nn.Conv2d(in_channels, out_channels, 1).to(device) 
        self.cat_map = nn.Conv2d(out_channels*2, out_channels, 1).to(device) if concat else None
        self.wide_residual_block1 = WideResidualBlock(self.out_channels, self.out_channels, im_dim).to(device)
        self.wide_residual_block2 = WideResidualBlock(self.out_channels, self.out_channels, im_dim).to(device)
        self.wide_residual_block3 = WideResidualBlock(self.out_channels, self.out_channels, im_dim).to(device)

    def forward(self, x, skip_conn, time_embed, norm=True):
        x = F.silu(self.upsample_channels(self.upsample(x)))
        # skip connection
        x = torch.cat((x, skip_conn), dim=1) if self.concat else x
        x = F.silu(self.cat_map(x)) if self.concat else x
        x = x + time_embed.reshape(x.shape[0], -1, 1, 1)
        x = self.wide_residual_block1(x)
        x = self.attention(x) 
        x = self.wide_residual_block2(x)
        x = self.wide_residual_block3(x, norm)
        return x

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.time_emb_dim = 100
        self.max_dim = 150

        self.time_embed = nn.Embedding(num_timesteps, self.time_emb_dim)
        self.time_embed.weight.data = get_sinusoidal_embedding(num_timesteps, self.time_emb_dim)
        self.time_embed.requires_grad_(False)

        self.time_embed_mlps = nn.ModuleList([nn.Sequential(
            nn.Linear(self.time_emb_dim, dim),
            nn.SiLU(),
            nn.Linear(dim, dim)
        ) for dim in [n_channels,32,64]])
    

        self.downsamplers = nn.ModuleList([DownsampleBlock(n_channels, 32, img_dim), DownsampleBlock(32, 64, 16), DownsampleBlock(64, self.max_dim, 8)])
        self.upsamplers = nn.ModuleList([UpsampleBlock(self.max_dim, 64, 8), UpsampleBlock(64, 32, 16), UpsampleBlock(32, 32, 32, concat=False)])
        

        self.latent_resnet1 = WideResidualBlock(self.max_dim, self.max_dim, 4).to(device)
        self.latent_attention = LinearAttention(self.max_dim).to(device)
        self.latent_resnet2 = WideResidualBlock(self.max_dim, self.max_dim, 4).to(device)

        self.final_conv = nn.Sequential(
            nn.Conv2d(32 , n_channels , 3, padding=1).to(device),
            )

    def forward(self, x, t):
        t = self.time_embed(t)
        time_embeds = [self.time_embed_mlps[i](t) for i in range(len(self.downsamplers))]
        residuals = []
        for i in range(len(self.downsamplers)):
            residuals.append(x)
            x = self.downsamplers[i](x, time_embeds[i])

        x = self.latent_resnet2(self.latent_attention(self.latent_resnet1(x)))
        # x = self.latent_resnet2((self.latent_resnet1(x)))

        for i in range(len(self.upsamplers)):
            normalise_ = False if i == len(self.upsamplers) - 1 else True
            t_embed = time_embeds[-(i+1)] if i < len(self.upsamplers) - 1 else time_embeds[1]
            x = self.upsamplers[i](x, residuals.pop(),t_embed, normalise_)
            

        return self.final_conv(x)

    

In [29]:
unet = UNet()
unet.to(device)
PATH = "tinyhorse.pth"

checkpoint = torch.load(PATH)
unet.load_state_dict(checkpoint['model_state_dict'])

unet.train()

print(sum([p.numel() for p in unet.parameters() if p.requires_grad]))

num_epochs = 300
# l1 loss
loss_fn = F.huber_loss
optimizer = torch.optim.Adam(unet.parameters(), lr=2e-3)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


for i in range(num_epochs):
    for j, (batch, _) in tqdm(enumerate(trainloader)):
        optimizer.zero_grad()
        images = batch.to(device)

        if(images.shape[0] != batch_size):
            continue
        # adding noise
        time_vals = torch.randint(1, num_timesteps+1, (batch_size,), device=device)
        batch_, noise = get_noised_image_at(time_vals - 1, images)
        # forward pass, get predicted noise
        predicted_noise = unet(batch_, time_vals)
        loss_val = loss_fn(predicted_noise, noise)
        # if j % 100 == 0:
        #     print(loss_val.item())
        loss_val.backward()
        optimizer.step()
    print(f"Epoch {i+1}, Loss: {loss_val.item()}")

3821366


40it [00:14,  2.70it/s]


Epoch 1, Loss: 0.017704768106341362


40it [00:11,  3.38it/s]


Epoch 2, Loss: 0.01778866909444332


40it [00:11,  3.39it/s]


Epoch 3, Loss: 0.018920211121439934


40it [00:11,  3.40it/s]


Epoch 4, Loss: 0.015475473366677761


40it [00:11,  3.40it/s]


Epoch 5, Loss: 0.016049077734351158


40it [00:11,  3.46it/s]


Epoch 6, Loss: 0.015537510626018047


40it [00:12,  3.23it/s]


Epoch 7, Loss: 0.017676200717687607


40it [00:11,  3.45it/s]


Epoch 8, Loss: 0.016586661338806152


40it [00:11,  3.41it/s]


Epoch 9, Loss: 0.012949854135513306


40it [00:11,  3.43it/s]


Epoch 10, Loss: 0.016779666766524315


40it [00:11,  3.44it/s]


Epoch 11, Loss: 0.018444888293743134


40it [00:11,  3.41it/s]


Epoch 12, Loss: 0.02133018523454666


40it [00:11,  3.42it/s]


Epoch 13, Loss: 0.024512512609362602


40it [00:11,  3.43it/s]


Epoch 14, Loss: 0.01659471169114113


40it [00:11,  3.45it/s]


Epoch 15, Loss: 0.016908466815948486


40it [00:11,  3.38it/s]


Epoch 16, Loss: 0.01884816214442253


40it [00:11,  3.43it/s]


Epoch 17, Loss: 0.015645578503608704


40it [00:11,  3.42it/s]


Epoch 18, Loss: 0.014798969030380249


40it [00:11,  3.40it/s]


Epoch 19, Loss: 0.01789822429418564


40it [00:11,  3.47it/s]


Epoch 20, Loss: 0.01929916813969612


40it [00:11,  3.48it/s]


Epoch 21, Loss: 0.019395429641008377


40it [00:11,  3.48it/s]


Epoch 22, Loss: 0.01874835044145584


40it [00:11,  3.46it/s]


Epoch 23, Loss: 0.016510983929038048


40it [00:11,  3.35it/s]


Epoch 24, Loss: 0.01743200793862343


40it [00:12,  3.27it/s]


Epoch 25, Loss: 0.02043038234114647


40it [00:12,  3.12it/s]


Epoch 26, Loss: 0.018045689910650253


40it [00:12,  3.13it/s]


Epoch 27, Loss: 0.01346028596162796


40it [00:12,  3.28it/s]


Epoch 28, Loss: 0.02607794851064682


40it [00:12,  3.17it/s]


Epoch 29, Loss: 0.01498749852180481


40it [00:12,  3.12it/s]


Epoch 30, Loss: 0.02328816056251526


40it [00:12,  3.13it/s]


Epoch 31, Loss: 0.02112426422536373


40it [00:12,  3.14it/s]


Epoch 32, Loss: 0.013742441311478615


40it [00:12,  3.16it/s]


Epoch 33, Loss: 0.019044868648052216


40it [00:12,  3.16it/s]


Epoch 34, Loss: 0.012811332941055298


40it [00:12,  3.33it/s]


Epoch 35, Loss: 0.021610077470541


40it [00:11,  3.47it/s]


Epoch 36, Loss: 0.020416878163814545


40it [00:11,  3.48it/s]


Epoch 37, Loss: 0.025122148916125298


40it [00:11,  3.49it/s]


Epoch 38, Loss: 0.016252532601356506


40it [00:11,  3.49it/s]


Epoch 39, Loss: 0.01628122851252556


40it [00:11,  3.49it/s]


Epoch 40, Loss: 0.017756909132003784


40it [00:11,  3.48it/s]


Epoch 41, Loss: 0.01542632095515728


40it [00:11,  3.49it/s]


Epoch 42, Loss: 0.020208533853292465


40it [00:11,  3.49it/s]


Epoch 43, Loss: 0.022682681679725647


40it [00:11,  3.49it/s]


Epoch 44, Loss: 0.02153506875038147


40it [00:11,  3.49it/s]


Epoch 45, Loss: 0.020506368950009346


40it [00:11,  3.49it/s]


Epoch 46, Loss: 0.016127225011587143


40it [00:11,  3.47it/s]


Epoch 47, Loss: 0.017471443861722946


40it [00:11,  3.46it/s]


Epoch 48, Loss: 0.01495419442653656


40it [00:11,  3.45it/s]


Epoch 49, Loss: 0.014137251302599907


40it [00:11,  3.47it/s]


Epoch 50, Loss: 0.01994810812175274


40it [00:11,  3.40it/s]


Epoch 51, Loss: 0.020529016852378845


40it [00:11,  3.42it/s]


Epoch 52, Loss: 0.01566666178405285


40it [00:11,  3.39it/s]


Epoch 53, Loss: 0.023067452013492584


40it [00:11,  3.36it/s]


Epoch 54, Loss: 0.019311869516968727


40it [00:11,  3.36it/s]


Epoch 55, Loss: 0.017743799835443497


40it [00:11,  3.38it/s]


Epoch 56, Loss: 0.015784455463290215


40it [00:11,  3.43it/s]


Epoch 57, Loss: 0.012441938742995262


40it [00:11,  3.42it/s]


Epoch 58, Loss: 0.0193868987262249


40it [00:11,  3.41it/s]


Epoch 59, Loss: 0.015552561730146408


40it [00:11,  3.43it/s]


Epoch 60, Loss: 0.019995499402284622


40it [00:11,  3.43it/s]


Epoch 61, Loss: 0.015101316384971142


40it [00:12,  3.24it/s]


Epoch 62, Loss: 0.01842600479722023


40it [00:11,  3.42it/s]


Epoch 63, Loss: 0.015973538160324097


40it [00:12,  3.22it/s]


Epoch 64, Loss: 0.02024497464299202


40it [00:12,  3.22it/s]


Epoch 65, Loss: 0.01900976337492466


40it [00:12,  3.31it/s]


Epoch 66, Loss: 0.012013490311801434


40it [00:11,  3.41it/s]


Epoch 67, Loss: 0.01899806410074234


40it [00:11,  3.41it/s]


Epoch 68, Loss: 0.014049327000975609


40it [00:11,  3.41it/s]


Epoch 69, Loss: 0.01577538624405861


40it [00:11,  3.42it/s]


Epoch 70, Loss: 0.01641817018389702


40it [00:11,  3.39it/s]


Epoch 71, Loss: 0.017485547810792923


40it [00:11,  3.41it/s]


Epoch 72, Loss: 0.010780349373817444


40it [00:11,  3.43it/s]


Epoch 73, Loss: 0.01896776258945465


40it [00:11,  3.39it/s]


Epoch 74, Loss: 0.015789151191711426


40it [00:11,  3.37it/s]


Epoch 75, Loss: 0.01839020475745201


40it [00:11,  3.48it/s]


Epoch 76, Loss: 0.015585492365062237


40it [00:11,  3.44it/s]


Epoch 77, Loss: 0.018538853153586388


40it [00:12,  3.33it/s]


Epoch 78, Loss: 0.0204886794090271


40it [00:12,  3.32it/s]


Epoch 79, Loss: 0.018409304320812225


40it [00:11,  3.38it/s]


Epoch 80, Loss: 0.012127958238124847


40it [00:11,  3.38it/s]


Epoch 81, Loss: 0.015906769782304764


40it [00:11,  3.44it/s]


Epoch 82, Loss: 0.012555897235870361


40it [00:11,  3.42it/s]


Epoch 83, Loss: 0.019128508865833282


40it [00:11,  3.43it/s]


Epoch 84, Loss: 0.015326067805290222


40it [00:11,  3.45it/s]


Epoch 85, Loss: 0.018540387973189354


40it [00:11,  3.46it/s]


Epoch 86, Loss: 0.020166348665952682


40it [00:11,  3.47it/s]


Epoch 87, Loss: 0.020003218203783035


40it [00:11,  3.47it/s]


Epoch 88, Loss: 0.012905269861221313


40it [00:11,  3.47it/s]


Epoch 89, Loss: 0.015704097226262093


40it [00:11,  3.47it/s]


Epoch 90, Loss: 0.014402851462364197


40it [00:11,  3.47it/s]


Epoch 91, Loss: 0.017618926241993904


40it [00:11,  3.46it/s]


Epoch 92, Loss: 0.015483282506465912


40it [00:11,  3.46it/s]


Epoch 93, Loss: 0.01770683005452156


40it [00:11,  3.47it/s]


Epoch 94, Loss: 0.020643899217247963


40it [00:11,  3.47it/s]


Epoch 95, Loss: 0.02227027341723442


40it [00:11,  3.47it/s]


Epoch 96, Loss: 0.015764830633997917


40it [00:11,  3.46it/s]


Epoch 97, Loss: 0.015995845198631287


40it [00:11,  3.46it/s]


Epoch 98, Loss: 0.014771837741136551


40it [00:11,  3.45it/s]


Epoch 99, Loss: 0.018324710428714752


40it [00:11,  3.46it/s]


Epoch 100, Loss: 0.017182067036628723


40it [00:11,  3.47it/s]


Epoch 101, Loss: 0.01611843705177307


40it [00:11,  3.47it/s]


Epoch 102, Loss: 0.017377108335494995


40it [00:11,  3.47it/s]


Epoch 103, Loss: 0.023426871746778488


40it [00:11,  3.45it/s]


Epoch 104, Loss: 0.014114531688392162


40it [00:11,  3.46it/s]


Epoch 105, Loss: 0.01591670699417591


40it [00:11,  3.44it/s]


Epoch 106, Loss: 0.01877433992922306


40it [00:11,  3.45it/s]


Epoch 107, Loss: 0.015673112124204636


40it [00:11,  3.42it/s]


Epoch 108, Loss: 0.02381782978773117


40it [00:11,  3.46it/s]


Epoch 109, Loss: 0.014684910885989666


40it [00:11,  3.47it/s]


Epoch 110, Loss: 0.014027832075953484


40it [00:11,  3.47it/s]


Epoch 111, Loss: 0.016300112009048462


40it [00:11,  3.47it/s]


Epoch 112, Loss: 0.01807093434035778


40it [00:11,  3.47it/s]


Epoch 113, Loss: 0.014237161725759506


40it [00:11,  3.47it/s]


Epoch 114, Loss: 0.02306341379880905


40it [00:11,  3.44it/s]


Epoch 115, Loss: 0.011490844190120697


40it [00:11,  3.45it/s]


Epoch 116, Loss: 0.013104848563671112


40it [00:11,  3.47it/s]


Epoch 117, Loss: 0.018986236304044724


40it [00:11,  3.47it/s]


Epoch 118, Loss: 0.01582496240735054


40it [00:11,  3.46it/s]


Epoch 119, Loss: 0.01757049188017845


40it [00:11,  3.47it/s]


Epoch 120, Loss: 0.017793020233511925


40it [00:11,  3.47it/s]


Epoch 121, Loss: 0.02492162212729454


40it [00:11,  3.47it/s]


Epoch 122, Loss: 0.019618459045886993


40it [00:11,  3.44it/s]


Epoch 123, Loss: 0.016979437321424484


40it [00:11,  3.39it/s]


Epoch 124, Loss: 0.016745854169130325


40it [00:11,  3.42it/s]


Epoch 125, Loss: 0.014770494773983955


40it [00:11,  3.46it/s]


Epoch 126, Loss: 0.019220072776079178


40it [00:11,  3.40it/s]


Epoch 127, Loss: 0.02232484146952629


40it [00:11,  3.36it/s]


Epoch 128, Loss: 0.021917706355452538


40it [00:11,  3.42it/s]


Epoch 129, Loss: 0.01643204502761364


40it [00:11,  3.44it/s]


Epoch 130, Loss: 0.020461291074752808


40it [00:11,  3.40it/s]


Epoch 131, Loss: 0.01701834611594677


40it [00:12,  3.30it/s]


Epoch 132, Loss: 0.017172858119010925


40it [00:11,  3.38it/s]


Epoch 133, Loss: 0.013185403309762478


40it [00:11,  3.39it/s]


Epoch 134, Loss: 0.013094313442707062


40it [00:11,  3.39it/s]


Epoch 135, Loss: 0.019444048404693604


40it [00:11,  3.44it/s]


Epoch 136, Loss: 0.016211196780204773


40it [00:11,  3.44it/s]


Epoch 137, Loss: 0.015204187482595444


40it [00:11,  3.43it/s]


Epoch 138, Loss: 0.019022732973098755


40it [00:11,  3.37it/s]


Epoch 139, Loss: 0.020510775968432426


40it [00:11,  3.39it/s]


Epoch 140, Loss: 0.016353804618120193


40it [00:12,  3.21it/s]


Epoch 141, Loss: 0.020191093906760216


40it [00:12,  3.16it/s]


Epoch 142, Loss: 0.02556145191192627


40it [00:11,  3.38it/s]


Epoch 143, Loss: 0.02352862060070038


40it [00:11,  3.45it/s]


Epoch 144, Loss: 0.015908174216747284


40it [00:12,  3.31it/s]


Epoch 145, Loss: 0.012663180939853191


40it [00:11,  3.41it/s]


Epoch 146, Loss: 0.013738544657826424


40it [00:11,  3.45it/s]


Epoch 147, Loss: 0.016043348237872124


40it [00:11,  3.46it/s]


Epoch 148, Loss: 0.01631811261177063


40it [00:11,  3.46it/s]


Epoch 149, Loss: 0.016356583684682846


40it [00:11,  3.46it/s]


Epoch 150, Loss: 0.015767693519592285


40it [00:11,  3.45it/s]


Epoch 151, Loss: 0.0156905185431242


40it [00:11,  3.39it/s]


Epoch 152, Loss: 0.015394976362586021


40it [00:11,  3.39it/s]


Epoch 153, Loss: 0.01975543610751629


40it [00:11,  3.38it/s]


Epoch 154, Loss: 0.018093114718794823


40it [00:11,  3.41it/s]


Epoch 155, Loss: 0.02542008087038994


40it [00:11,  3.43it/s]


Epoch 156, Loss: 0.015319744125008583


40it [00:11,  3.43it/s]


Epoch 157, Loss: 0.01791616529226303


40it [00:11,  3.43it/s]


Epoch 158, Loss: 0.011627430096268654


40it [00:11,  3.44it/s]


Epoch 159, Loss: 0.014428885653614998


40it [00:11,  3.44it/s]


Epoch 160, Loss: 0.011155456304550171


40it [00:11,  3.48it/s]


Epoch 161, Loss: 0.018257442861795425


40it [00:11,  3.49it/s]


Epoch 162, Loss: 0.018187178298830986


40it [00:11,  3.44it/s]


Epoch 163, Loss: 0.025686796754598618


40it [00:11,  3.44it/s]


Epoch 164, Loss: 0.014780012890696526


40it [00:11,  3.42it/s]


Epoch 165, Loss: 0.01908726803958416


40it [00:11,  3.42it/s]


Epoch 166, Loss: 0.015893176198005676


40it [00:11,  3.45it/s]


Epoch 167, Loss: 0.01456967368721962


40it [00:11,  3.42it/s]


Epoch 168, Loss: 0.020200088620185852


40it [00:11,  3.44it/s]


Epoch 169, Loss: 0.012158025056123734


40it [00:11,  3.41it/s]


Epoch 170, Loss: 0.016614748165011406


40it [00:11,  3.49it/s]


Epoch 171, Loss: 0.01913696900010109


40it [00:11,  3.48it/s]


Epoch 172, Loss: 0.019869714975357056


40it [00:11,  3.45it/s]


Epoch 173, Loss: 0.01227051392197609


40it [00:11,  3.43it/s]


Epoch 174, Loss: 0.01590888202190399


40it [00:11,  3.40it/s]


Epoch 175, Loss: 0.016090264543890953


40it [00:11,  3.37it/s]


Epoch 176, Loss: 0.021776553243398666


40it [00:11,  3.42it/s]


Epoch 177, Loss: 0.0149629395455122


40it [00:11,  3.43it/s]


Epoch 178, Loss: 0.011566773056983948


40it [00:11,  3.46it/s]


Epoch 179, Loss: 0.020178914070129395


40it [00:11,  3.43it/s]


Epoch 180, Loss: 0.01567319594323635


40it [00:11,  3.34it/s]


Epoch 181, Loss: 0.016706835478544235


40it [00:11,  3.41it/s]


Epoch 182, Loss: 0.017874963581562042


40it [00:11,  3.41it/s]


Epoch 183, Loss: 0.017197366803884506


40it [00:11,  3.40it/s]


Epoch 184, Loss: 0.018741320818662643


40it [00:11,  3.41it/s]


Epoch 185, Loss: 0.015456038527190685


40it [00:11,  3.47it/s]


Epoch 186, Loss: 0.016448525711894035


40it [00:11,  3.47it/s]


Epoch 187, Loss: 0.013959254138171673


40it [00:11,  3.49it/s]


Epoch 188, Loss: 0.014923656359314919


40it [00:11,  3.40it/s]


Epoch 189, Loss: 0.018490619957447052


40it [00:11,  3.41it/s]


Epoch 190, Loss: 0.013810491189360619


40it [00:11,  3.40it/s]


Epoch 191, Loss: 0.019246486946940422


40it [00:11,  3.45it/s]


Epoch 192, Loss: 0.01685379259288311


40it [00:11,  3.41it/s]


Epoch 193, Loss: 0.014871910214424133


40it [00:11,  3.43it/s]


Epoch 194, Loss: 0.015870019793510437


40it [00:11,  3.44it/s]


Epoch 195, Loss: 0.014393040910363197


40it [00:11,  3.44it/s]


Epoch 196, Loss: 0.022327570244669914


40it [00:11,  3.46it/s]


Epoch 197, Loss: 0.013607638888061047


40it [00:11,  3.45it/s]


Epoch 198, Loss: 0.01613878272473812


40it [00:11,  3.45it/s]


Epoch 199, Loss: 0.015386853367090225


40it [00:11,  3.43it/s]


Epoch 200, Loss: 0.016487160697579384


40it [00:11,  3.44it/s]


Epoch 201, Loss: 0.016118427738547325


40it [00:11,  3.43it/s]


Epoch 202, Loss: 0.014957597479224205


40it [00:11,  3.41it/s]


Epoch 203, Loss: 0.02109045349061489


40it [00:11,  3.40it/s]


Epoch 204, Loss: 0.018900126218795776


40it [00:11,  3.43it/s]


Epoch 205, Loss: 0.01538415439426899


40it [00:11,  3.47it/s]


Epoch 206, Loss: 0.015149356797337532


40it [00:11,  3.45it/s]


Epoch 207, Loss: 0.01626746729016304


40it [00:11,  3.40it/s]


Epoch 208, Loss: 0.0167365912348032


40it [00:11,  3.43it/s]


Epoch 209, Loss: 0.013340520672500134


40it [00:11,  3.40it/s]


Epoch 210, Loss: 0.012862331233918667


40it [00:11,  3.40it/s]


Epoch 211, Loss: 0.01927195116877556


40it [00:11,  3.36it/s]


Epoch 212, Loss: 0.015538638457655907


40it [00:11,  3.45it/s]


Epoch 213, Loss: 0.015481874346733093


40it [00:11,  3.48it/s]


Epoch 214, Loss: 0.01709652505815029


40it [00:11,  3.48it/s]


Epoch 215, Loss: 0.012875160202383995


40it [00:11,  3.48it/s]


Epoch 216, Loss: 0.015230458229780197


40it [00:11,  3.48it/s]


Epoch 217, Loss: 0.022626321762800217


40it [00:11,  3.47it/s]


Epoch 218, Loss: 0.017682328820228577


40it [00:11,  3.48it/s]


Epoch 219, Loss: 0.019470620900392532


40it [00:11,  3.44it/s]


Epoch 220, Loss: 0.015646960586309433


40it [00:11,  3.35it/s]


Epoch 221, Loss: 0.02170853316783905


40it [00:11,  3.40it/s]


Epoch 222, Loss: 0.018400968983769417


40it [00:11,  3.42it/s]


Epoch 223, Loss: 0.01670018583536148


40it [00:11,  3.39it/s]


Epoch 224, Loss: 0.014118100516498089


40it [00:12,  3.26it/s]


Epoch 225, Loss: 0.015140637755393982


40it [00:11,  3.36it/s]


Epoch 226, Loss: 0.017927654087543488


40it [00:12,  3.32it/s]


Epoch 227, Loss: 0.017842555418610573


40it [00:11,  3.42it/s]


Epoch 228, Loss: 0.013063906691968441


40it [00:11,  3.43it/s]


Epoch 229, Loss: 0.015252306126058102


40it [00:11,  3.43it/s]


Epoch 230, Loss: 0.021372109651565552


40it [00:11,  3.40it/s]


Epoch 231, Loss: 0.012836471199989319


40it [00:11,  3.42it/s]


Epoch 232, Loss: 0.016393370926380157


40it [00:11,  3.43it/s]


Epoch 233, Loss: 0.0186306219547987


40it [00:11,  3.40it/s]


Epoch 234, Loss: 0.019295211881399155


40it [00:11,  3.39it/s]


Epoch 235, Loss: 0.014847252517938614


40it [00:11,  3.41it/s]


Epoch 236, Loss: 0.021835193037986755


40it [00:11,  3.37it/s]


Epoch 237, Loss: 0.014618201181292534


40it [00:11,  3.42it/s]


Epoch 238, Loss: 0.013644257560372353


40it [00:11,  3.40it/s]


Epoch 239, Loss: 0.01581006869673729


40it [00:11,  3.46it/s]


Epoch 240, Loss: 0.017121128737926483


40it [00:11,  3.44it/s]


Epoch 241, Loss: 0.02000507526099682


40it [00:11,  3.46it/s]


Epoch 242, Loss: 0.021243687719106674


40it [00:11,  3.45it/s]


Epoch 243, Loss: 0.01797817274928093


40it [00:11,  3.43it/s]


Epoch 244, Loss: 0.019525133073329926


40it [00:11,  3.41it/s]


Epoch 245, Loss: 0.01876593753695488


40it [00:11,  3.39it/s]


Epoch 246, Loss: 0.01412968896329403


40it [00:11,  3.42it/s]


Epoch 247, Loss: 0.016070615500211716


40it [00:11,  3.39it/s]


Epoch 248, Loss: 0.014987394213676453


40it [00:11,  3.42it/s]


Epoch 249, Loss: 0.018086005002260208


40it [00:11,  3.36it/s]


Epoch 250, Loss: 0.014710514806210995


40it [00:11,  3.38it/s]


Epoch 251, Loss: 0.013423341326415539


40it [00:11,  3.36it/s]


Epoch 252, Loss: 0.018015308305621147


40it [00:11,  3.42it/s]


Epoch 253, Loss: 0.01645868830382824


40it [00:11,  3.45it/s]


Epoch 254, Loss: 0.02067844197154045


40it [00:11,  3.42it/s]


Epoch 255, Loss: 0.018263565376400948


40it [00:11,  3.42it/s]


Epoch 256, Loss: 0.017002740874886513


40it [00:11,  3.44it/s]


Epoch 257, Loss: 0.011358475312590599


40it [00:11,  3.38it/s]


Epoch 258, Loss: 0.018061654642224312


40it [00:11,  3.40it/s]


Epoch 259, Loss: 0.02195793017745018


40it [00:11,  3.43it/s]


Epoch 260, Loss: 0.02044004201889038


40it [00:11,  3.45it/s]


Epoch 261, Loss: 0.012072518467903137


40it [00:11,  3.45it/s]


Epoch 262, Loss: 0.01255890540778637


40it [00:11,  3.45it/s]


Epoch 263, Loss: 0.014095954596996307


40it [00:11,  3.44it/s]


Epoch 264, Loss: 0.023977046832442284


40it [00:11,  3.45it/s]


Epoch 265, Loss: 0.016780298203229904


40it [00:11,  3.45it/s]


Epoch 266, Loss: 0.019355840981006622


40it [00:11,  3.47it/s]


Epoch 267, Loss: 0.01713757961988449


40it [00:11,  3.48it/s]


Epoch 268, Loss: 0.018925029784440994


40it [00:11,  3.38it/s]


Epoch 269, Loss: 0.021185997873544693


40it [00:11,  3.41it/s]


Epoch 270, Loss: 0.017119307070970535


40it [00:11,  3.39it/s]


Epoch 271, Loss: 0.014737683348357677


40it [00:11,  3.37it/s]


Epoch 272, Loss: 0.021760765463113785


40it [00:11,  3.39it/s]


Epoch 273, Loss: 0.015229831449687481


40it [00:11,  3.44it/s]


Epoch 274, Loss: 0.01479967962950468


40it [00:11,  3.44it/s]


Epoch 275, Loss: 0.017360804602503777


40it [00:11,  3.44it/s]


Epoch 276, Loss: 0.020932069048285484


40it [00:11,  3.44it/s]


Epoch 277, Loss: 0.01855628937482834


40it [00:11,  3.43it/s]


Epoch 278, Loss: 0.014535119757056236


40it [00:11,  3.44it/s]


Epoch 279, Loss: 0.019962789490818977


40it [00:11,  3.44it/s]


Epoch 280, Loss: 0.01588747836649418


40it [00:11,  3.46it/s]


Epoch 281, Loss: 0.01223771832883358


40it [00:11,  3.45it/s]


Epoch 282, Loss: 0.012835266068577766


40it [00:11,  3.42it/s]


Epoch 283, Loss: 0.012496141716837883


40it [00:11,  3.45it/s]


Epoch 284, Loss: 0.014742578379809856


40it [00:11,  3.46it/s]


Epoch 285, Loss: 0.017511921003460884


40it [00:11,  3.46it/s]


Epoch 286, Loss: 0.012583231553435326


40it [00:11,  3.46it/s]


Epoch 287, Loss: 0.01651732437312603


40it [00:11,  3.46it/s]


Epoch 288, Loss: 0.012995805591344833


40it [00:11,  3.44it/s]


Epoch 289, Loss: 0.014851473271846771


40it [00:11,  3.42it/s]


Epoch 290, Loss: 0.01760990545153618


40it [00:11,  3.44it/s]


Epoch 291, Loss: 0.01572663150727749


40it [00:11,  3.44it/s]


Epoch 292, Loss: 0.016085611656308174


40it [00:11,  3.47it/s]


Epoch 293, Loss: 0.012134438380599022


40it [00:11,  3.47it/s]


Epoch 294, Loss: 0.017449868842959404


40it [00:11,  3.48it/s]


Epoch 295, Loss: 0.01423533447086811


40it [00:11,  3.47it/s]


Epoch 296, Loss: 0.01568271964788437


40it [00:11,  3.47it/s]


Epoch 297, Loss: 0.012763766571879387


40it [00:11,  3.41it/s]


Epoch 298, Loss: 0.016624119132757187


40it [00:11,  3.34it/s]


Epoch 299, Loss: 0.01707272045314312


40it [00:11,  3.38it/s]

Epoch 300, Loss: 0.01865219511091709





In [31]:
# diffusion model sampling code:
unet = UNet()
unet.to(device)

#unet.load_state_dict(torch.load("tinyhorse.pth"))

# load the model checkpoint
unet.load_state_dict(torch.load("tinyhorse.pth")["model_state_dict"])


with torch.no_grad():
    unet.eval()

    # sample from the model
    # time_vals = torch.randint(1, num_timesteps+1, (batch_size,), device=device)
    test_batch_size = 10
    x = torch.randn(test_batch_size, n_channels, img_dim, img_dim, device=device)
    # print(x.shape)

    for t in range(num_timesteps, 0, -1):
        time_tensor = (torch.ones(1, 1) * t).to(device).long()
        # repeat the time tensor along the batch dimension
        time_tensor = time_tensor.repeat(test_batch_size, 1)
    
        
        eta_theta = unet(x, time_tensor)
        alpha_t = alpha_schedule[t-1]
        alpha_t_bar = alphas_cumprod[t-1]

        x = (1 / alpha_t.sqrt()) * (x - ((1 - alpha_t) / (1 - alpha_t_bar).sqrt()) * eta_theta)
        
        z = torch.randn(test_batch_size, n_channels, img_dim, img_dim).to(device)

        beta_t = beta_schedule[t-1]
        sigma_t = beta_t.sqrt()

        x = x + (sigma_t * z)

        if( t == 1):
            for i in range(test_batch_size):
                imshow(x[i])
            # imshow(x[0])


    

In [30]:
# # checkpoint model
# torch.save({
#             'model_state_dict': unet.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             }, "tinyhorse.pth")


In [14]:
# torch.load("tinyhorse.pth")['model_state_dict']

OrderedDict([('time_embed.weight',
              tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
                        0.0000e+00,  1.0000e+00],
                      [ 8.4147e-01,  5.4030e-01,  6.3795e-01,  ...,  1.0000e+00,
                        1.4454e-08,  1.0000e+00],
                      [ 9.0930e-01, -4.1615e-01,  9.8254e-01,  ...,  1.0000e+00,
                        2.8909e-08,  1.0000e+00],
                      ...,
                      [-8.9797e-01, -4.4006e-01, -9.8457e-01,  ...,  1.0000e+00,
                        1.4411e-05,  1.0000e+00],
                      [-8.5547e-01,  5.1785e-01, -6.4655e-01,  ...,  1.0000e+00,
                        1.4425e-05,  1.0000e+00],
                      [-2.6461e-02,  9.9965e-01, -1.1223e-02,  ...,  1.0000e+00,
                        1.4440e-05,  1.0000e+00]], device='mps:0')),
             ('time_embed_mlps.0.0.weight',
              tensor([[-2.4899e-03, -9.4468e-03,  6.2953e-04,  2.3557e-03, -6.1437e-04,
 