<a href="https://colab.research.google.com/github/Yang-star-source/Latent_Diffusion_From_Scratch/blob/main/LDM_T2I_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResNetBlock(nn.Module):
    def __init__(self,in_channels,out_channels,dropout_prob = 0.0):
        super(ResNetBlock, self).__init__()
        self.norm1 = nn.GroupNorm(32,in_channels)
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)
        self.norm2 = nn.GroupNorm(32,out_channels)
        self.drop = nn.Dropout(dropout_prob)
        self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
        self.silu = nn.SiLU()

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels,out_channels,kernel_size=1)
        else:
            self.shortcut = nn.Identity()

    def forward(self,x):
        x1 = x

        x = self.norm1(x)
        x = self.silu(x)
        x = self.conv1(x)

        x = self.norm2(x)
        x = self.silu(x)
        x = self.drop(x)
        x = self.conv2(x)
        return x + self.shortcut(x1)

class AttentionBlock(nn.Module):
    def __init__(self,in_channels):
        super(AttentionBlock, self).__init__()
        self.norm = nn.GroupNorm(32,in_channels)

        self.to_q = nn.Linear(in_channels,in_channels)
        self.to_k = nn.Linear(in_channels,in_channels)
        self.to_v = nn.Linear(in_channels,in_channels)

        self.to_out = nn.Linear(in_channels,in_channels)

    def forward(self,x):
        residual = x
        B,C,H,W = x.shape
        x = self.norm(x)
        x = x.view(B,C,-1).permute(0,2,1)

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        attn = torch.bmm(q,k.permute(0,2,1)) # batch matrix multiplication
        attn = attn * (C**(-0.5))  # sqrt(dk)
        attn = attn.softmax(dim=-1)
        attn = torch.bmm(attn,v)

        out = self.to_out(attn)
        out = out.permute(0,2,1).view(B,C,H,W)

        return out + residual

class MidBlock(nn.Module):
    def __init__(self,in_channels):
        super(MidBlock, self).__init__()
        self.res1 = ResNetBlock(in_channels,in_channels)
        self.attn1 = AttentionBlock(in_channels)
        self.res2 = ResNetBlock(in_channels,in_channels)

    def forward(self,x):
        x = self.res1(x)
        x = self.attn1(x)
        x = self.res2(x)
        return x

class DownBlock(nn.Module):
    def __init__(self,in_channels,out_channels,has_attn=False):
        super(DownBlock, self).__init__()
        self.res1 = ResNetBlock(in_channels,out_channels)
        self.res2 = ResNetBlock(out_channels,out_channels)
        self.down = nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=2,padding=1)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x):
        x = self.res1(x)
        x = self.res2(x)
        x = self.attn(x)
        x = self.down(x)
        return x

class Encoder(nn.Module):
    def __init__(self,in_channels=3,out_channels=4):
        super(Encoder, self).__init__()
        self.inp = nn.Conv2d(in_channels,64,kernel_size=3,padding=1)
        self.down_block = nn.Sequential(
            DownBlock(64,128),
            DownBlock(128,256),
            DownBlock(256,512,has_attn = True)
        )
        self.bottle = MidBlock(512)
        self.out = nn.Sequential(
            nn.GroupNorm(32,512),
            nn.SiLU(),
            nn.Conv2d(512,out_channels*2,kernel_size=1)
        )

    def reparameterize(self,x):
        mean , log_var = torch.chunk(x,2,dim=1)
        log_var = torch.clamp(log_var, -30.0, 20.0)
        D_kl = 0.5 *(torch.exp(log_var) + mean**2 - log_var - 1)
        D_kl = torch.sum(D_kl,dim=[1,2,3]).mean() # Mean is for batch dimention
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        z = mean + eps*std
        return z,D_kl

    def forward(self,x):
        x = self.inp(x)
        x = self.down_block(x)
        x = self.bottle(x)
        x = self.out(x)
        z,D_kl = self.reparameterize(x)
        return z,D_kl

class UpBlock(nn.Module):
    def __init__(self,in_channels,out_channels,has_attn=False):
        super().__init__()
        self.res1 = ResNetBlock(in_channels,out_channels)
        self.res2 = ResNetBlock(out_channels,out_channels)
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2), # nearest mode by default
            nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
        )
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x):
        x = self.res1(x)
        x = self.res2(x)
        x = self.attn(x)
        x = self.up(x)
        return x

class Decoder(nn.Module):
    def __init__(self,in_channels=4,out_channels=3):
        super(Decoder, self).__init__()
        self.inp = nn.Conv2d(in_channels,512,kernel_size=3,padding=1)
        self.bottle = MidBlock(512)
        self.up_block = nn.Sequential(
            UpBlock(512,256,has_attn=True),
            UpBlock(256,128),
            UpBlock(128,64)
        )
        self.out = nn.Sequential(
            nn.GroupNorm(32,64),
            nn.SiLU(),
            nn.Conv2d(64,out_channels,kernel_size=3,padding=1)

        )

    def forward(self,x):
        x = self.inp(x)
        x = self.bottle(x)
        x = self.up_block(x)
        x = self.out(x)
        return x


class VAE(nn.Module):
    def __init__(self,in_channels=3,out_channels=4):
        super(VAE, self).__init__()
        self.encoder = Encoder(in_channels,out_channels)
        self.decoder = Decoder(out_channels,in_channels)

    def forward(self,x):
        z,D_kl = self.encoder(x)
        x = self.decoder(z)
        return x,D_kl




In [None]:
def raw_time_embedding(time , dim):
    if not torch.is_tensor(time):
        time = torch.tensor(time)

    device=time.device
    if time.ndim == 0:
        time = time.unsqueeze(0).unsqueeze(1)
    else: # This will be execute in training since t shape is (B)
        time = time.unsqueeze(1)
        # (B) -> (B,1)

    # important to specify device
    i=torch.arange(dim//2,device=device).float()
    obj = (time)/(10000**(2*i/dim))
    return torch.cat([torch.sin(obj),torch.cos(obj)],dim=1)

class TextConditionTimeEmbedding(nn.Module):
    def __init__(self,dim):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(dim,dim),
            nn.SiLU(),
            nn.Linear(dim,dim)
        )

    def forward(self,X):
        return self.net(X)

class TextConditionDiffusionResNetBlock(nn.Module):
    def __init__(self,in_channels,out_channels,time_emb_dim,dropout_prob = 0.0 ):
        super().__init__()
        self.norm1 = nn.GroupNorm(32,in_channels)
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)
        self.norm2 = nn.GroupNorm(32,out_channels)
        self.drop = nn.Dropout(dropout_prob)
        self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
        self.silu = nn.SiLU()

        # purpose of this projection is to match channel dim , before adding to x
        self.time_proj = nn.Linear(time_emb_dim,out_channels)

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels,out_channels,kernel_size=1)
        else:
            self.shortcut = nn.Identity()

    def forward(self,x,time_emb):
        x1 = x

        x = self.norm1(x)
        x = self.silu(x)
        x = self.conv1(x)

        # (B,C) -> (B,C,1,1)
        # why silu ?
        emb = self.time_proj(self.silu(time_emb))
        x = x + emb[:, :, None, None]

        x = self.norm2(x)
        x = self.silu(x)
        x = self.drop(x)
        x = self.conv2(x)
        return x + self.shortcut(x1)

class FeedForwardBlock(nn.Module):
    def __init__(self,channel_dim,multiplier = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(channel_dim,channel_dim*multiplier,1),
            nn.SiLU(),
            nn.Conv2d(channel_dim*multiplier,channel_dim,1)
        )

    def forward(self,x):
        return self.net(x)


class CrossAttentionBlock(nn.Module):
    def __init__(self,channel_dim , context_dim = 768 ,heads=8):
        super().__init__()
        self.heads = heads
        self.head_dim = channel_dim // heads
        self.scale = self.head_dim ** -0.5

        self.to_q = nn.Linear(channel_dim,channel_dim,bias=False)
        self.to_k = nn.Linear(context_dim,channel_dim,bias=False)
        self.to_v = nn.Linear(context_dim,channel_dim,bias=False)

        self.to_out = nn.Linear(channel_dim,channel_dim)

    def forward(self,x,context):
        # x shape : [B,C,H,W]
        # context shape : [B,77,768]
        B,C,H,W = x.shape

        # [B,H*W,C]
        x_flat = x.permute(0,2,3,1).reshape(B,H*W,C)

        # [B,H*W,C]
        q = self.to_q(x_flat)

        # [B,77,C]
        k = self.to_k(context)
        v = self.to_v(context)

        # [Batch , head , pixels , head_dim]
        q = q.reshape(B,H*W,self.heads,self.head_dim).permute(0,2,1,3)

        # [B,8,77,head_dim]
        k = k.reshape(B,-1,self.heads,self.head_dim).permute(0,2,1,3)
        v = v.reshape(B,-1,self.heads,self.head_dim).permute(0,2,1,3)

        # Q shape: [B, 8, H*W ,  head_dim]
        # K shape: [B, 8, 77  ,  head_dim] 77 is from CLIP
        # Result1: [B, 8, H*W ,  77      ]
        # V shape: [B, 8, 77  ,  head_dim]
        # Result : [B, 8, H*W , head_dim ]
        out = F.scaled_dot_product_attention(q,k,v,dropout_p=0.0,is_causal=False)

        # [B,H*W,C] where C = head*head_dim
        out = out.permute(0,2,1,3).reshape(B,H*W,C)

        # [B,H*W,C]
        out = self.to_out(out)

        # [B,C,H,W]
        out = out.reshape(B,H,W,C).permute(0,3,1,2)
        return out

class TransformerBlock(nn.Module):
    def __init__(self,channel_dim,context_dim=768):
        super().__init__()

        self.norm1 = nn.GroupNorm(32,channel_dim)
        self.self_attn = CrossAttentionBlock(channel_dim,context_dim=channel_dim)

        self.norm2 = nn.GroupNorm(32,channel_dim)
        self.cross_attn = CrossAttentionBlock(channel_dim,context_dim=context_dim)

        self.norm3 = nn.GroupNorm(32,channel_dim)
        self.ff = FeedForwardBlock(channel_dim)

    def forward(self,x,context):
        B,C,H,W = x.shape

        res1 = x
        x = self.norm1(x)
        flat_x = x.permute(0,2,3,1).reshape(B,H*W,C)
        x = res1 + self.self_attn(x,flat_x)

        res2 = x
        x = self.norm2(x)

        x = res2 + self.cross_attn(x,context)

        res3 = x
        x = self.norm3(x)
        x = res3 + self.ff(x)
        return x

class TextConditionDiffusionDownBlock(nn.Module):
    def __init__(self,in_channels,out_channels,time_emb_dim,has_attn=False):

        self.need_attn = has_attn
        super().__init__()
        self.res1 = TextConditionDiffusionResNetBlock(in_channels,out_channels,time_emb_dim)
        self.res2 = TextConditionDiffusionResNetBlock(out_channels,out_channels,time_emb_dim)
        self.down = nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=2,padding=1)
        if has_attn:
            self.attn = TransformerBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x,time_emb,context):
        x = self.res1(x,time_emb)
        x = self.res2(x,time_emb)
        if self.need_attn:
            x = self.attn(x,context)
        else:
            x = self.attn(x)

        skip_connection = x

        x = self.down(x)
        return x , skip_connection

class TextConditionDiffusionMidBlock(nn.Module):
    def __init__(self,in_channels,time_emb_dim):
        super().__init__()

        self.res1 = TextConditionDiffusionResNetBlock(in_channels,in_channels,time_emb_dim)
        self.attn = TransformerBlock(in_channels)
        self.res2 = TextConditionDiffusionResNetBlock(in_channels,in_channels,time_emb_dim)

    def forward(self,x,time_emb,context):
        x = self.res1(x,time_emb)
        x = self.attn(x,context)
        x = self.res2(x,time_emb)
        return x

class TextConditionDiffusionUpBlock(nn.Module):
    def __init__(self,in_channels,out_channels,time_emb_dim,has_attn=False):
        super().__init__()

        self.need_attn = has_attn

        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels,in_channels,kernel_size=3,padding=1)
        )
        # Why input *2 ? Because we have to concatenate the channels
        self.res1 = TextConditionDiffusionResNetBlock(in_channels*2,out_channels,time_emb_dim)
        self.res2 = TextConditionDiffusionResNetBlock(out_channels,out_channels,time_emb_dim)
        if has_attn:
            self.attn = TransformerBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x,skip_connection,time_emb,context):
        x = self.up(x)

        # cancatenate at channels dimension
        x = torch.cat([x,skip_connection],dim=1)
        x = self.res1(x,time_emb)
        x = self.res2(x,time_emb)

        if self.need_attn:
            x = self.attn(x,context)
        else:
            x = self.attn(x)
        return x

class TextConditionDiffusionUnet(nn.Module):
    def __init__(self,in_channels=4,out_channels=4,time_dim=256):
        super().__init__()

        self.time_dim = time_dim
        self.time_embedding = TextConditionTimeEmbedding(time_dim)

        self.init_conv = nn.Conv2d(in_channels,64,kernel_size=3,padding=1)

        self.down1 = TextConditionDiffusionDownBlock(64,64,time_dim)
        self.down2 = TextConditionDiffusionDownBlock(64,128,time_dim)
        self.down3 = TextConditionDiffusionDownBlock(128,128,time_dim,has_attn=True)
        self.down4 = TextConditionDiffusionDownBlock(128,256,time_dim,has_attn=True)

        self.mid = TextConditionDiffusionMidBlock(256,time_dim)

        self.up1 = TextConditionDiffusionUpBlock(256,128,time_dim,has_attn=True)
        self.up2 = TextConditionDiffusionUpBlock(128,128,time_dim,has_attn=True)
        self.up3 = TextConditionDiffusionUpBlock(128,64,time_dim)
        self.up4 = TextConditionDiffusionUpBlock(64,64,time_dim)

        self.out = nn.Sequential(
            nn.GroupNorm(32,64),
            nn.SiLU(),
            nn.Conv2d(64,out_channels,kernel_size=3,padding=1)
        )

    def forward(self,x,t,context):
        t = raw_time_embedding(t,self.time_dim)
        emb = self.time_embedding(t)

        x = self.init_conv(x)

        x1 , skip1 = self.down1(x,emb,context)
        x2 , skip2 = self.down2(x1,emb,context)
        x3 , skip3 = self.down3(x2,emb,context)
        x4 , skip4 = self.down4(x3,emb,context)

        x = self.mid(x4,emb,context)

        x = self.up1(x,skip4,emb,context)
        x = self.up2(x,skip3,emb,context)
        x = self.up3(x,skip2,emb,context)
        x = self.up4(x,skip1,emb,context)

        x = self.out(x)
        return x




# Data Preprocessing

In [None]:
from huggingface_hub import hf_hub_download
import os

latent_dir = "/content/dataset/LATENTS"
caption_dir = "/content/dataset/captions_embedding/captions_embedding"
os.makedirs(latent_dir,exist_ok=True)
os.makedirs(caption_dir,exist_ok=True)

repo_id = "ziyang06315/cats_images_dataset"
LATENTS_ZIP = hf_hub_download(repo_id=repo_id,filename="LATENTS.zip",repo_type="dataset",local_dir=latent_dir)
TEXT_EMB_ZIP = hf_hub_download(repo_id=repo_id,filename="captions_embedding.zip",repo_type="dataset",local_dir=caption_dir)

os.system(f"unzip -q {LATENTS_ZIP} -d {latent_dir}")
os.system(f"unzip -q {TEXT_EMB_ZIP} -d {caption_dir}")
print("Finish unzip")

In [None]:
import os
import shutil
latent_dir  = "/content/dataset/LATENTS/LATENTS"
caption_dir = "/content/dataset/captions_embedding/captions_embedding"
print(len(os.listdir(latent_dir)))
print(len(os.listdir(caption_dir)))

junk_files = [file for file in os.listdir(caption_dir) if not file.endswith(".pt")]
if len(junk_files) > 0:
    print(f"Found {len(junk_files)} junk files")
    for file in junk_files:
        full_path = os.path.join(caption_dir,file)
        if os.path.isfile(full_path):
            os.remove(full_path) # delete files
        else:
            shutil.rmtree(full_path) # delete folder
else:
    print("No junk files found")

print(len(os.listdir(latent_dir)))
print(len(os.listdir(caption_dir)))

# Load Dataset
- **Uncomment** if need to check file matching or use for debugging

In [None]:
from  torch.utils.data import Dataset,DataLoader
from natsort import natsorted
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class MyDataset(Dataset):
    def __init__(self,latent_dir,text_emb_dir):
        self.latent_dir = latent_dir
        self.text_emb_dir = text_emb_dir
        self.latent_files = natsorted(os.listdir(latent_dir))
        self.text_emb_files = natsorted(os.listdir(text_emb_dir))

    def __len__(self):
        return len(self.latent_files)

    def __getitem__(self,idx):
        latent_path = os.path.join(self.latent_dir,self.latent_files[idx])
        text_emb_path = os.path.join(self.text_emb_dir,self.text_emb_files[idx])
        latent = torch.load(latent_path)
        text_emb = torch.load(text_emb_path)

        # Uncomment if check file matching
        #return latent_path,text_emb_path
        return latent,text_emb

'''
# Check file matching
dataset = MyDataset(latent_dir,caption_dir)
for i in [6,9,100]:
    print(dataset[i])
'''

# Get 3 Prompt Embedding from CLIP for Sampling Purpose

In [None]:
from transformers import CLIPTextModel , CLIPTokenizer

model_id = "openai/clip-vit-large-patch14"
tokenizer = CLIPTokenizer.from_pretrained(model_id)
text_encoder = CLIPTextModel.from_pretrained(model_id).to(DEVICE)
text_encoder.eval()

# Use Repeat 8 when Batch size = 8
def get_text_embedding(text,batch_size:int):
    with torch.no_grad():
        inputs = tokenizer(text, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE)
        embedding = text_encoder(**inputs).last_hidden_state.cpu()[0]
        embedding = embedding.repeat(batch_size,1,1)
    return embedding

text = "A cute orange cat with green eyes"
sampling_text_emb = get_text_embedding(text,8)
print(sampling_text_emb.shape)




# Traning and Sampling Loop

In [None]:
from huggingface_hub import hf_hub_download
from torch.optim import AdamW
from tqdm import tqdm
import torch.nn.functional as F
from torchvision.utils import save_image

MODEL_SAVE_PATH = "/content/drive/MyDrive/VAE_Training/TextConditionLDM/UnetWeight"
SAMPLING_PATH = "/content/drive/MyDrive/VAE_Training/TextConditionLDM/SamplingWhileTraining"
os.makedirs(MODEL_SAVE_PATH,exist_ok=True)
os.makedirs(SAMPLING_PATH,exist_ok=True)

repo_id = "ziyang06315/latent_diffusion_from_scratch"
VAE_CHECKPOINT = hf_hub_download(repo_id, filename="checkpoint.pth")

vae=VAE().to(DEVICE)
checkpoint = torch.load(VAE_CHECKPOINT,map_location=DEVICE)
vae.load_state_dict(checkpoint['vae_state_dict'])
vae.eval()

unet = TextConditionDiffusionUnet().to(DEVICE)
unet = torch.compile(unet)

LEARNING_RATE = 1e-4
EPOCH = 800
BATCH_SIZE = 64

dataset = MyDataset(latent_dir,caption_dir)
dataloader = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=2,pin_memory=True)

beta_start = 0.0001
beta_end = 0.02
num_timesteps = 1000

betas = torch.linspace(beta_start, beta_end, num_timesteps,device=DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)

optimizer = AdamW(unet.parameters(),lr = LEARNING_RATE)

scaler = torch.cuda.amp.GradScaler()

unet.train()
for epoch in range(EPOCH):
    unet.train()
    epoch_loss = 0.0
    progress_bar = tqdm(dataloader,desc = f"Epoch {epoch}")

    for i ,(latent,text_emb) in enumerate(progress_bar):
        latents = latent.to(DEVICE)
        text_emb =text_emb.to(DEVICE)

        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, 1000, (latents.shape[0],), device=DEVICE)

        sqrt_alpha_cumprod = sqrt_alphas_cumprod[timesteps][:,None,None,None]
        sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alphas_cumprod[timesteps][:,None,None,None]

        noisy_latents = sqrt_alpha_cumprod * latents + sqrt_one_minus_alpha_cumprod * noise

        with torch.cuda.amp.autocast():
            noise_pred = unet(noisy_latents, timesteps , text_emb)

            loss = F.mse_loss(noise_pred, noise)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        epoch_loss += loss.item()
        progress_bar.set_postfix({"Loss": loss.item()})

    print(f"Epoch {epoch+1} | Loss: {epoch_loss / len(dataloader)}")

    if (epoch+1)%50 == 0:
        torch_path = os.path.join(MODEL_SAVE_PATH,f"TextConditionUnetEpoch_{epoch+1}.pth")
        torch.save(unet.state_dict(), torch_path)
        print(f"Saved model to epoch_{epoch+1}")

        unet.eval()
        with torch.no_grad():
            latents = torch.randn(
                (8, 4, 32, 32),
                device=DEVICE
            )

            sampling_text_emb = sampling_text_emb.to(DEVICE)

            for t in reversed(range(1000)):
                t_tensor = torch.ones(8,device=DEVICE).long() * t

                noise_pred = unet(latents, t_tensor,sampling_text_emb)

                alpha = alphas[t]
                alpha_cumprod = alphas_cumprod[t]
                beta = betas[t]
                sqrt_alpha_cumprod = sqrt_alphas_cumprod[t]
                sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alphas_cumprod[t]

                if t >0:
                    noise = torch.randn_like(latents)
                else:
                    noise = torch.zeros_like(latents)

                # Subtract noise
                latents = (1 / torch.sqrt(alpha)) * (latents - ((1 - alpha) / (torch.sqrt(1 - alpha_cumprod))) * noise_pred) + torch.sqrt(beta) * noise

            reconstructed = vae.decoder(latents)

            # Un-normalize (-1,1 -> 0,1)
            reconstructed = reconstructed * 0.5 + 0.5
            reconstructed = torch.clamp(reconstructed, 0, 1)

            #Save
            save_image(reconstructed, f"{SAMPLING_PATH}/ImageSampling{epoch+1}.png", nrow=4)
            print(f"Saved to {SAMPLING_PATH}/ImageSampling{epoch+1}.png")

        unet.train()



