# Denoising Diffusion Probabilistic Models

In this notebook, we'll implement the Denoising Diffusion Probabilistic Model (DDPM) proposed in the paper [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) from scratch using PyTorch. We'll implement the U-Net model, the forward and reverse diffusion processes, the training loop and the sampling process.

In [1]:
import torch 
import torch.nn as nn
import matplotlib.pyplot as plt
import math

## UNet

The U-net is implemented as in the original paper. It's been divided into blocks to make it more modular and easier to work with. The "downblock" consists of a series of convolutional layers, each followed by a maxpooling layer. The "upblock" consists of a convolutional layer followed by a transpose convolutional layer. The bottom-most conv block is implemented separately and is different from the rest of the blocks.

In [2]:
from src.resnet import ResBlock

## TODO: all resblocks should be able to accept timestep embedding as an input 

class ConvBlock(nn.Module):
    def __init__(
            self, 
            in_channels: int, 
            out_channels: int, 
            num_layers: int, 
            num_groups: int = 1, 
            dropout: float = 0.2, 
            activation: nn.Module = nn.ReLU
            ):
        super(ConvBlock, self).__init__()
        convs = []
        convs.append(
            ResBlock(
            in_channels, 
            out_channels, 
            num_groups=num_groups, 
            dropout=dropout, 
            activation=activation
            )
        )
        
        for _ in range(num_layers-1):
            convs.append(
                ResBlock(
                out_channels,
                out_channels, 
                num_groups=num_groups, 
                dropout=dropout, 
                activation=activation
                )
            )

        self.convs = nn.ModuleList(convs)

    def forward(self, x):
        for conv in self.convs:
            x = conv(x)
        
        return x
    
class DownBlock(nn.Module):
    """
    Downampling (left) side of the UNet.
    Excludes the bottom-most conv block.
    """
    def __init__(self, in_channels: int, filters: list[int], num_layers: int):
        super(DownBlock, self).__init__()
        conv_blocks = [ConvBlock(in_channels, filters[0], num_layers)]
        for i in range(1, len(filters)):
            conv_blocks.append(ConvBlock(filters[i-1], filters[i], num_layers))

        self.conv_blocks = nn.Sequential(*conv_blocks)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        residual_outputs = []
        for conv_block in self.conv_blocks:
            x = conv_block(x)
            residual_outputs.append(x)
            x = self.maxpool(x)

        return residual_outputs, x

class UpBlock(nn.Module):
    """
    Upsampling (right) side of the UNet.
    """
    def __init__(self, filters: list[int], num_layers: int):
        super(UpBlock, self).__init__()
        layers = []
        for i in range(len(filters) - 2):
            layers.append(
                nn.Sequential(
                    ConvBlock(filters[i], filters[i+1], num_layers), 
                    nn.ConvTranspose2d(filters[i+1], filters[i+1]//2, 2, stride=2)
                )
            )
        
        layers.append(ConvBlock(filters[-2], filters[-1], num_layers))
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x, residual_outputs):
        for i in range(len(self.layers)):
            residual = residual_outputs[-(i+1)]
            _, _, h, w = x.shape
            residual = residual[:, :, :h, :w]
            x = torch.cat([x, residual], dim=1)
            x = self.layers[i](x)
        
        return x


In [3]:
db = DownBlock(in_channels=3, filters=[32, 64, 128], num_layers=2)

In [4]:
db

DownBlock(
  (conv_blocks): Sequential(
    (0): ConvBlock(
      (convs): ModuleList(
        (0): ResBlock(
          (norm1): Sequential(
            (0): GroupNorm(1, 3, eps=1e-05, affine=True)
            (1): ReLU()
          )
          (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (conv2): Sequential(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): GroupNorm(1, 32, eps=1e-05, affine=True)
          )
          (activation): ReLU()
          (idconv): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1))
          (avgpool): Identity()
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (1): ResBlock(
          (norm1): Sequential(
            (0): GroupNorm(1, 32, eps=1e-05, affine=True)
            (1): ReLU()
          )
          (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (conv2): Sequential(
            (0): Conv2d(32, 32, kernel_si

In [5]:
a = db(torch.randn(1, 3, 128, 128))
a[1].shape

torch.Size([1, 128, 16, 16])

In [6]:
for residual_output in a[0]:
    print(residual_output.shape)
    print("-"*10)

print(a[1].shape)

torch.Size([1, 32, 128, 128])
----------
torch.Size([1, 64, 64, 64])
----------
torch.Size([1, 128, 32, 32])
----------
torch.Size([1, 128, 16, 16])


In [7]:
bottom_conv = nn.Sequential(ResBlock(128, 256), nn.ConvTranspose2d(256, 128, 2, stride=2))
bottom_conv(a[1]).shape

torch.Size([1, 128, 32, 32])

In [8]:
up = UpBlock([256, 128, 64, 32], 2)



In [9]:
for layer in up.layers:
    print(layer)
    print("-"*10)

Sequential(
  (0): ConvBlock(
    (convs): ModuleList(
      (0): ResBlock(
        (norm1): Sequential(
          (0): GroupNorm(1, 256, eps=1e-05, affine=True)
          (1): ReLU()
        )
        (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Sequential(
          (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): GroupNorm(1, 128, eps=1e-05, affine=True)
        )
        (activation): ReLU()
        (idconv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
        (avgpool): Identity()
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (1): ResBlock(
        (norm1): Sequential(
          (0): GroupNorm(1, 128, eps=1e-05, affine=True)
          (1): ReLU()
        )
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Sequential(
          (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): Group

In [17]:
class UNet(nn.Module):
    def __init__(self, down_filters, in_channels, num_layers):
        super(UNet, self).__init__()
        self.down_filters = down_filters
        self.down_block = DownBlock(filters=down_filters, num_layers=num_layers, in_channels=in_channels)
        
        # the bottom-most conv block is different in that it doesn't have a maxpool or a residual connection
        self.bottom_conv = nn.Sequential(
            ConvBlock(down_filters[-1], down_filters[-1]*2, num_layers=num_layers), 
            nn.ConvTranspose2d(down_filters[-1]*2, down_filters[-1], 2, stride=2)
        )
        
        self.up_filters = [down_filters[-1]*2]
        self.up_filters.extend(reversed(down_filters))
        self.up_block = UpBlock(self.up_filters, num_layers)

    def forward(self, x):
        residual_outputs, down_output = self.down_block(x)
        bottom_output = self.bottom_conv(down_output)
        return self.up_block(bottom_output, residual_outputs)

In [19]:
u_net = UNet([32, 64, 128], 3, 2)
a = u_net(torch.randn(1, 3, 128, 128))
a.shape

torch.Size([1, 32, 128, 128])

In [20]:
def alpha_bar(beta):
    alpha = 1. - beta
    return alpha.cumprod(dim=0)

def prepare_batch(x: torch.Tensor, T: int, alpha_bar: torch.Tensor):
    """
    Prepare a batch for training by generating the noise and the noisy image.
    """
    batch_size = x.shape[0]
    t = torch.randint(0, T, (batch_size,))
    e = torch.randn_like(x)
    alpha_bar_t = alpha_bar(t)
    alpha_bar_t = alpha_bar_t.view(-1, 1, 1, 1)
    noisy_images = alpha_bar_t.sqrt() * x + (1 - alpha_bar_t).sqrt() * e
    
    return (noisy_images, t), e


Changes that need to be made to the U-Net for DDPM:
1. Swap batch norm with group norm 
2. Introduce an attention mechanism at each conv block in the down and up blocks
3. Create an embedding for the timestep

Next, we'll implement these missing features

In [24]:

def scaled_dot_product_attention(q, k, d_k, mask):
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is True:
        mask = torch.tril(torch.ones(scores.shape)).to(q.device)
        scores = scores.masked_fill(mask == 0, float('-inf'))
    return nn.Softmax(-1)(scores)

class Attention(nn.Module):
    """
    Multihead attention.

    Differences from multihead attention for text:
    
    - we no longer need a d_model, the internal hidden size 
    is determined by the number of channels which is determined 
    by the convolutional layers leading up to the attention layer.

    - swap batch norm with group norm

    - we resize the image to one of shape: (batch_size, num_channels, height * width)
    so that we can perform multihead attention across the image. This closely 
    mirrors (batch_size, embed_dim, seq_len
    """
    def __init__(self,
                 d_k: int, 
                 dropout: float, 
                 num_heads: int, 
                 num_channels: int,
                 num_groups: int = 8,
                 mask: bool = False
                 ):
        super(Attention, self).__init__()
        self.d_k, self.num_heads = d_k, num_heads
        self.query_projection, self.key_projection, self.value_projection = (
            nn.Linear(num_channels, num_heads* d_k),
            nn.Linear(num_channels, num_heads* d_k), 
            nn.Linear(num_channels, num_heads*d_k)
        )
        self.layer_norm = nn.LayerNorm(num_channels)
        self.output_layer = nn.Linear(num_heads*d_k, num_channels)
        self.dropout = nn.Dropout(dropout)
        self.mask = mask
        self.num_channels = num_channels

    def forward(self, x, y = None):
        batch_size, n_channels, height, width = x.shape
        x = x.view(batch_size, n_channels, height * width).permute(0, 2, 1)
        residual = x
        x = self.layer_norm(x)
        
        if y is not None:
            k, q, v = y, x, y
        else:
            k, q, v = x, x, x
        
        k_len, q_len, v_len, batch_size = k.size(1), q.size(1), v.size(1),  q.size(0)
        
        k = self.key_projection(k).view(batch_size, k_len,  self.num_heads, self.d_k)
        q = self.query_projection(q).view(batch_size, q_len,  self.num_heads, self.d_k)
        v = self.value_projection(v).view(batch_size, v_len,  self.num_heads, self.d_k)
        
        attention = scaled_dot_product_attention(
            q.transpose(1, 2), 
            k.transpose(1, 2), 
            self.d_k, 
            self.mask
        )
        output = torch.matmul(attention, v.transpose(1, 2))
        output = self.output_layer(output.transpose(1, 2).contiguous().view(batch_size, q_len, -1))
        
        return self.dropout(output) + residual


U-Net left and right blocks with attention


In [60]:
   
import pdb

class ConvBlockAttn(nn.Module):
    def __init__(
            self, 
            in_channels: int, 
            out_channels: int, 
            num_layers: int, 
            num_groups: int = 1, 
            dropout: float = 0.2, 
            activation: nn.Module = nn.ReLU,
            has_attention: bool = False,
            num_heads: int = 8
            ):
        super(ConvBlockAttn, self).__init__()
        convs = []
        convs.append(
            ResBlock(
            in_channels, 
            out_channels, 
            num_groups=num_groups, 
            dropout=dropout, 
            activation=activation,
            )
        )
        
        for _ in range(num_layers-1):
            convs.append(
                ResBlock(
                out_channels,
                out_channels, 
                num_groups=num_groups, 
                dropout=dropout, 
                activation=activation,
                )
            )

        self.convs = nn.ModuleList(convs)

    def forward(self, x):
        pdb.set_trace()
        for conv in self.convs:
            x = conv(x)
        
        return x

class DownBlockAttn(nn.Module):
    """
    Downampling (left) side of the UNet.
    Excludes the bottom-most conv block.
    """
    def __init__(
            self, 
            in_channels: int, 
            filters: list[int], 
            num_layers: int, 
            has_attention: bool = False, 
            num_heads: int = 8, 
            dropout: float = 0.2, 
            ):
        super(DownBlockAttn, self).__init__()
        
        self.has_attention = has_attention
        conv_blocks = [ConvBlockAttn(in_channels, filters[0], num_layers, has_attention, num_heads)]
        attention_blocks = [
            Attention(
                d_k=64, 
                dropout=0.1, 
                num_heads=num_heads, 
                num_channels=filters[0], 
            )
        ] if has_attention else []
        
        for i in range(1, len(filters)):
            conv_blocks.append(ConvBlockAttn(filters[i-1], filters[i], num_layers, has_attention, num_heads))
            if has_attention:
                attention_blocks.append(
                    Attention(
                        d_k=64, 
                        dropout=0.1, 
                        num_heads=num_heads, 
                        num_channels=filters[i], 
                    )
                )
        
        self.conv_blocks = nn.ModuleList(conv_blocks)
        self.attention_blocks = nn.ModuleList(attention_blocks)
        
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        residual_outputs = []
        for conv_block, attention_block in zip(self.conv_blocks, self.attention_blocks):
            x = conv_block(x)
            if self.has_attention:
                x = attention_block(x)
            
            residual_outputs.append(x)
            x = self.maxpool(x)

        return residual_outputs, x

class UpBlockAttn(nn.Module):
    """
    Upsampling (right) side of the UNet.
    """
    def __init__(
            self, 
            filters: list[int], 
            num_layers: int, 
            has_attention: bool = False, 
            num_heads: int = 8, 
            dropout: float = 0.2
            ):
        super(UpBlockAttn, self).__init__()
        
        conv_layers = []
        attention_layers = []
        
        for i in range(len(filters) - 2):
            conv_layers.append(
                nn.Sequential(
                    ConvBlock(filters[i], filters[i+1], num_layers), 
                    nn.ConvTranspose2d(filters[i+1], filters[i+1]//2, 2, stride=2)
                )
            )
            if has_attention:
                attention_layers.append(
                    Attention(
                        d_k=64, 
                        dropout=0.1, 
                        num_heads=num_heads, 
                        num_channels=filters[i+1]//2, 
                    )
                )
        conv_layers.append(ConvBlock(filters[-2], filters[-1], num_layers))
        self.conv_layers = nn.Sequential(*conv_layers)
        self.attention_layers = nn.Sequential(*attention_layers)
    
    def forward(self, x, residual_outputs):
        for i in range(len(self.conv_layers)):
            residual = residual_outputs[-(i+1)]
            _, _, h, w = x.shape
            residual = residual[:, :, :h, :w]
            x = torch.cat([x, residual], dim=1)
            
            x = self.conv_layers[i](x)
            if self.has_attention:
                x = self.attention_layers[i](x)
        
        return x


In [61]:
def timestep_encoding(curr_t: torch.Tensor, T: torch.Tensor, embedding_dim: int, n=10000):
    """
    Naive sin/cosin positional embedding adapted for timestep embedding in DDPM
    """
    curr_t = curr_t / T # normalize the timestep to be between 0 and 1
    p = torch.zeros((curr_t.shape[-1], embedding_dim)) # initialize the positional embedding tensor

    m = torch.arange(int(embedding_dim/2)) # this is divided by two because we alternate between sin and cos
    denominators = torch.float_power(n, 2*m/embedding_dim) # compute the denominators for the sin and cos functions
    
    p[:, 0::2] = torch.sin(curr_t.unsqueeze(1) / denominators.unsqueeze(0))
    p[:, 1::2] = torch.cos(curr_t.unsqueeze(1) / denominators.unsqueeze(0))
    return p



class TimestepEmbedding(nn.Module):
    """
    Embeds the timestep into a higher dimensional space using a 2 layer MLP.
    """
    def __init__(self,
                in_channels: int, 
                embedding_dim: int, 
                activation: nn.Module = nn.ReLU, 
                post_activation: bool = True
                ):
        super(TimestepEmbedding, self).__init__()
        self.linear1 = nn.Linear(in_channels, embedding_dim)
        self.linear2 = nn.Linear(embedding_dim, embedding_dim)
        self.activation = activation()
        self.post_activation = post_activation

    def forward(self, curr_t: torch.Tensor, T: torch.Tensor):
        x = self.linear1(curr_t)
        x = self.activation(x)
        x = self.linear2(x)
        
        if self.post_activation:
            x = self.activation(x)
        
        return x

In [62]:
class UNetAttn(nn.Module):
    """
    UNet model for DDPM.
    """
    def __init__(self,
                 down_filters, 
                 in_channels, 
                 num_layers, 
                 has_attention: bool = False, 
                 num_heads: int = 8,
                 diffusion_steps: int = None,
                 num_groups: int = 8,
                 activation: nn.Module = nn.ReLU
                ):
        super(UNetAttn, self).__init__()

        self.T = diffusion_steps
        self.down_filters = down_filters

        self.time_embed_dim = down_filters[0] * 4 
        
        if self.T is not None:
            self.timestep_embedding = TimestepEmbedding(
                in_channels=in_channels, 
                embedding_dim=in_channels, 
                activation=activation, 
                post_activation=True
            )
        
        
        self.down_block = DownBlock(
            filters=down_filters, 
            num_layers=num_layers, 
            in_channels=in_channels, 
            has_attention=has_attention, 
            num_heads=num_heads
        )
        
        # the bottom-most (middle) conv block 
        
        # The timestep embedding is fed to the bottom-most conv block because it 
        # provides temporal context at the deepest, most abstract level of features,
        # allowing the model to adjust its predictions based on the diffusion step.
        if has_attention:
            self.bottom_conv = nn.Sequential(
                ConvBlock(down_filters[-1], down_filters[-1]*2, num_layers=num_layers), 
                Attention(
                    d_k=64, 
                    dropout=0.1, 
                    num_heads=num_heads, 
                    num_channels=down_filters[-1]*2, 
                ),
                nn.ConvTranspose2d(down_filters[-1]*2, down_filters[-1], 2, stride=2)
            )
        else:
            self.bottom_conv = nn.Sequential(
                ConvBlock(down_filters[-1], down_filters[-1]*2, num_layers=num_layers), 
                nn.ConvTranspose2d(down_filters[-1]*2, down_filters[-1], 2, stride=2)
            )
        
        self.up_filters = [down_filters[-1]*2]
        self.up_filters.extend(reversed(down_filters))
        self.up_block = UpBlock(self.up_filters, num_layers)

        self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
        self.conv_out = nn.Conv2d(down_filters[-1], in_channels, 3, padding=1)
    
    def forward(self, x, t):
        if self.T is not None:
            t_emb = self.timestep_embedding(curr_t=t, T=self.T)
            t_emb = t_emb.view(-1, self.time_embed_dim, 1, 1)

        x = self.group_norm(x)
        residual_outputs, down_output = self.down_block(x)
        bottom_output = self.bottom_conv(down_output)
        return self.up_block(bottom_output, residual_outputs)

In [63]:
bottom_out = bottom_conv(a[1])
bottom_out.shape


torch.Size([1, 128, 32, 32])

In [64]:
up_out = up(bottom_out, a[0])

In [65]:
db = DownBlockAttn(in_channels=3, filters=[32, 64, 128], num_layers=2, has_attention=True, num_heads=3)
db

ValueError: dropout probability has to be between 0 and 1, but got 3

In [66]:
a = db(torch.randn(1, 3, 128, 128))
a[1].shape

RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [32] and input of shape [1, 8192, 16]

In [None]:
for residual_output in a[0]:
    print(residual_output.shape)
    print("-"*10)

print(a[1].shape)

torch.Size([1, 32, 128, 128])
----------
torch.Size([1, 64, 64, 64])
----------
torch.Size([1, 128, 32, 32])
----------
torch.Size([1, 128, 16, 16])


In [None]:
bottom_conv = nn.Sequential(ResBlock(128, 256), nn.ConvTranspose2d(256, 128, 2, stride=2))
bottom_conv(a[1]).shape

torch.Size([1, 128, 32, 32])

In [None]:
up = UpBlock([256, 128, 64, 32], 2)



In [None]:
for layer in up.layers:
    print(layer)
    print("-"*10)

Sequential(
  (0): ConvBlock(
    (convs): ModuleList(
      (0): ResBlock(
        (norm1): Sequential(
          (0): GroupNorm(1, 256, eps=1e-05, affine=True)
          (1): ReLU()
        )
        (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Sequential(
          (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): GroupNorm(1, 128, eps=1e-05, affine=True)
        )
        (activation): ReLU()
        (idconv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
        (avgpool): Identity()
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (1): ResBlock(
        (norm1): Sequential(
          (0): GroupNorm(1, 128, eps=1e-05, affine=True)
          (1): ReLU()
        )
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Sequential(
          (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): Group

In [None]:
bottom_out = bottom_conv(a[1])
bottom_out.shape


torch.Size([1, 128, 32, 32])

In [None]:
up_out = up(bottom_out, a[0])

In [70]:
example_image = torch.randn(2, 3, 128, 128)
conv_block = ConvBlock(3, 32, 2)
h_example_image = conv_block(example_image)

In [72]:
attention_block = Attention(d_k=64, dropout=0.1, num_heads=3, num_channels=32, num_groups=8)
attn_output = attention_block(h_example_image)
attn_output.shape

torch.Size([2, 16384, 32])

In [21]:
x = torch.randn(15, 3, 128, 128)
T = 1000
batch_size = x.shape[0]
t = torch.randint(0, T, (batch_size,))
e = torch.randn_like(x)
alpha_bar_t = alpha_bar(t)
print(alpha_bar_t.shape)
alpha_bar_t = alpha_bar_t.view(-1, 1, 1, 1)

torch.Size([15])


In [22]:
alpha_bar_t.shape

torch.Size([15, 1, 1, 1])

In [23]:
def validate_model(model, valid_loader, loss_fn, all_valid_loss):
    model.eval()
    valid_loss = []
    
    with torch.no_grad():
        for batch in valid_loader:
            x, y = batch
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            valid_loss.append(loss.item())
    
    all_valid_loss.append(sum(valid_loss) / len(valid_loss))

def plot_loss(all_train_loss, all_valid_loss):
    plt.figure(figsize=(10, 5))
    plt.plot(all_train_loss, label='Training Loss')
    plt.plot(all_valid_loss, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()


def train_model(model: nn.Module,
                optim: torch.optim.Optimizer,
                loss_fn,
                train_loader,
                valid_loader,
                scheduler,
                epochs=10,
                valid_every=1
                ):

    all_train_loss = []
    all_valid_loss = []
    
    for epoch in range(epochs):
        model.train()
        train_loss = []
        
        for batch in train_loader:
            x, _ = batch
            (noisy_images, t), e = prepare_batch(x, T, alpha_bar_t)
            optim.zero_grad()
            y_pred = model(noisy_images)
            loss = loss_fn(y_pred, e)
            loss.backward()
            optim.step()
            train_loss.append(loss.item())
        
        all_train_loss.append(sum(train_loss) / len(train_loss))
        scheduler.step()
        
        if epoch % valid_every == 0:
            validate_model(model, valid_loader, loss_fn, all_valid_loss)
            print(
                f"Epoch {epoch}, Train Loss: {sum(train_loss) / len(train_loss)}, "
                f"Valid Loss: {all_valid_loss[-1]}"
            )
    
    plot_loss(all_train_loss, all_valid_loss)
