In [7]:
# Noise.py

import numpy as np
import torch

# https://huggingface.co/blog/annotated-diffusion

# Cosine schedule function

def cosine_schedule(timesteps, s=0.008):
    # Cosine schedule from: `Improved Denoising Diffusion Probabilistic Models`
    # Based on code from: https://huggingface.co/blog/annotated-diffusion
    time = torch.linspace(0, timesteps, timesteps+1)
    alphas = torch.cos((time / timesteps + s) / (1 + s) * (np.pi / 2))**2
    alphas = alphas / alphas[0]
    alphas = alphas[1:] / alphas[:-1]
    betas = 1 - alphas
    return torch.clip(betas, 0, 0.999)


class NoiseSampler():
    def __init__(self, timesteps, noise_schedule, device):
        self.timesteps = timesteps
        self.noise_schedule = noise_schedule
        self.betas = noise_schedule(timesteps)

    def __call__(self, x):
        return self.sample_noise(x)
    
    def sample_noise(self, x):
        pass


class CosineScheduler():
    """Scheduler for discrete temporal values in Diffusion Model training.
    """

    def __init__(self, timesteps, s=0.008, device=None):
        self.device = device
        self.timesteps = timesteps
        time = torch.linspace(0, timesteps, timesteps + 1)
        alphas = alphas = torch.cos((time / timesteps + s) / (1 + s) * np.pi / 2)**2
        alphas = alphas / alphas[0]
        alphas = alphas[1:] / alphas[:-1]
        self.betas = torch.clip(1 - alphas, 0.0, 0.9999)

        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, 0)
        self.alphas_inv_sqrt = torch.sqrt(1 / self.alphas)
        self.alphas_cumprod_sqrt = torch.sqrt(self.alphas_cumprod)
        self.alphas_cumprod_min_sqrt = torch.sqrt(1 - self.alphas_cumprod)
        
        alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]])
        
        self.posterior_variance = self.betas * (1 - alphas_cumprod_prev) / (1 - self.alphas_cumprod)
    
    def __call__(self, img, timestep):
        """Noise image based on schedule."""
        return self.apply_noise(img, timestep)
    
    def apply_noise(self, img, timestep):
        """Apply noise to image.
        
        Returns:
            Image tensor with noise
            Noise tensor for loss calculation"""
        noise = self._gaussian_noise(img.shape)
        sqrt_alpha = self._obtain(self.alphas_cumprod_sqrt, timestep, img.shape)
        sqrt_one_minus_alpha = self._obtain(self.alphas_cumprod_min_sqrt, timestep, img.shape)
        return sqrt_alpha * img + sqrt_one_minus_alpha * noise, noise

    
    def _obtain(self, source, timestep, target_shape):
        """Obtain values from target in timestep index for batches.
        Based on the extract function from: https://huggingface.co/blog/annotated-diffusion.
        """
        batch_size = timestep.shape[0]
        values = source.gather(-1, timestep.cpu())
        return values.reshape(batch_size, *((1,) * (len(target_shape) - 1))).to(timestep.device)
    
    def _gaussian_noise(self, shape):
        """Gaussian noise for sampling."""
        return torch.randn(shape)

In [2]:
# dataloader.py
import torch
import torch.utils.data as tud
import numpy as np
import torchvision as tv
import torchvision.transforms.v2 as tfv2
import os
import PIL

def get_data_loaders(train_dir, val_dir, test_dir, batch_size, timesteps=1000, shuffle=True):
    train_dataset = CocoDataset(train_dir, timesteps)
    val_dataset = CocoDataset(val_dir, timesteps)
    test_dataset = CocoDataset(test_dir, timesteps)

    train_loader = tud.DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
    val_loader = tud.DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle)
    test_loader = tud.DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle)

    return train_loader, val_loader, test_loader

def get_data_loader(directory, batch_size, timesteps=1000, shuffle=True):
    dataset = CocoDataset(directory, timesteps)
    loader = tud.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return loader

class CocoDataset(tud.Dataset):
    def __init__(self, directory, timesteps, img_size=256, labels=None):
        self.labels = labels
        self.dir = directory
        self.imgs = os.listdir(directory)
        self.timesteps = timesteps

        self.transform = tv.transforms.Compose([
            tfv2.Resize(img_size),
            tfv2.RandomCrop(img_size),
            tfv2.ToTensor(),
            tfv2.Lambda(lambda x: x * 2 - 1)
        ])
 
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        img = PIL.Image.open(self.dir + '/' + self.imgs[index])
        img = img.convert('RGB')
        # Transform the image and apply noise
        time = np.random.randint(0, self.timesteps)
        img = self.transform(img)
        return img, time

    def __iter__(self):
        return self



In [3]:
# Unet.py
import torch
import torch.nn as nn
import torchvision.transforms.v2 as tfv2

"""
Still need to add the following:
-Conditional
-Time Stamp

Possible Additions:
-BatchNorm
-Self-Attention
-Different Activation Functions
"""

# class ConvBlock(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(ConvBlock, self).__init__()
#         self.conv = nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size=3),
#             nn.LeakyReLU(inplace=True, negative_slope=0.1),
#             nn.Conv2d(out_channels, out_channels, kernel_size=3),
#             nn.LeakyReLU(inplace=True, negative_slope=0.1),
#         )

#     def forward(self, x):
#         return self.conv(x)
    

# class Downscale(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(Downscale, self).__init__()
#         self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.conv = DoubleConv(in_channels, out_channels)

#     def forward(self, x):
#         x = self.pool(x)
#         x = self.conv(x)
#         return x
    
    
# class Upscale(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(Upscale, self).__init__()
#         self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
#         self.conv = DoubleConv(in_channels, out_channels)
    
#     def forward(self, x1, x2, img_size):
#         x1 = tfv2.Resize(img_size)(x1)
#         x2 = tfv2.Resize(img_size*2)(x2)
#         x1 = self.up(x1)
#         x = torch.cat([x2, x1], dim=1)
#         x = self.conv(x)
#         x = tfv2.Resize(img_size*2)(x)
#         return x

    

# class DoubleConv(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(DoubleConv, self).__init__()
    
#         self.conv1 = ConvBlock(in_channels, out_channels)
#         self.conv2 = ConvBlock(out_channels, out_channels)

#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.conv2(x)
#         return x

# class UNet(nn.Module):
#     def __init__(self, out_channels, img_size=256):
#         super(UNet, self).__init__()
#         self.img_size = img_size
#         # Encoder
#         self.layer1 = DoubleConv(3, 64)
#         self.down1 = Downscale(64, 128)
#         self.down2 = Downscale(128, 256)
#         self.down3 = Downscale(256, 512)
#         #self.down4 = Downscale(512, 1024)

#         # Bottleneck
#         #self.bottleneck = DoubleConv(512, 512)
        
#         # Decoder
#         #self.up1 = Upscale(1024, 512)
#         self.up2 = Upscale(512, 256)
#         self.up3 = Upscale(256, 128)
#         self.up4 = Upscale(128, 64)

#         # Output
#         self.out = nn.Conv2d(64, out_channels, kernel_size=1)

#     def forward(self, x, timestep):
#         # Encoder
#         x1 = self.layer1(x)
#         x2 = self.down1(x1)
#         self.img_size //= 2
#         x3 = self.down2(x2)
#         self.img_size //= 2
#         x = self.down3(x3)
#         self.img_size //= 2
#         #x5 = self.down4(x4)

#         # Bottleneck
#         #x = self.bottleneck(x4)

#         # Decoder
#         #x = self.up1(x, x4)
#         x = self.up2(x, x3, self.img_size)
#         self.img_size *= 2
#         x = self.up3(x, x2, self.img_size)
#         self.img_size *= 2
#         x = self.up4(x, x1, self.img_size)

#         # Output
#         x = self.out(x)
#         return x

import torch
import torch.nn as nn

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        self.dconv_down1 = DoubleConv(in_channels, 64)
        self.dconv_down2 = DoubleConv(64, 128)
        self.dconv_down3 = DoubleConv(128, 256)
        self.dconv_down4 = DoubleConv(256, 512)
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dconv_up3 = DoubleConv(256 + 512, 256)
        self.dconv_up2 = DoubleConv(128 + 256, 128)
        self.dconv_up1 = DoubleConv(128 + 64, 64)
        self.conv_last = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)
        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)
        x = self.dconv_down4(x)
        x = self.upsample(x)
        x = torch.cat([x, conv3], dim=1)
        x = self.dconv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)
        x = self.dconv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)
        x = self.dconv_up1(x)
        out = self.conv_last(x)
        return out

# Example usage
if __name__ == "__main__":
    model = UNet(in_channels=3, out_channels=1)  # Adjust input and output channels as per your task
    input_tensor = torch.randn(1, 3, 256, 256)  # Batch size 1, 3 channels, 256x256 input image
    output_tensor = model(input_tensor)
    print("Output tensor shape:", output_tensor.shape)

Output tensor shape: torch.Size([1, 1, 256, 256])


In [10]:
# train.py
import torch
from torch.nn.functional import mse_loss
from torch.optim import Adam

# from dataloader import get_data_loaders
# from noise import CosineScheduler

# def train_diffusion(
#         model,
#         scheduler,
#         train_loader,
#         val_loader,
#         test_loader=None,
#         epochs=100,
#         early_stopping=10,
#         optimizer=Adam,
#         learning_rate=1e-3,
#         weight_decay=0,
#         device="mps",
#         log_path=None,
#         save_path=None,
# ):
#     model.to(device)
#     optimizer = optimizer(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
#     best_val_loss = float("inf")
#     early_stopping_counter = 0

#     if log_path is not None:
#         with open(log_path, "w") as log_file:
#             log_file.write("epoch,train_loss,val_loss\n")
    
#     for epoch in range(epochs):
#         model.train()
#         train_loss = 0
#         for img, time in train_loader:
#             img, noise = scheduler(img, time)
            
#             img = img.to(device)
#             time = time.to(device)
#             noise = noise.to(device)
#             optimizer.zero_grad()
#             loss = mse_loss(model(img), img)
#             loss.backward()
#             optimizer.step()
#             train_loss += loss.item()

#         train_loss /= len(train_loader)

#         with torch.no_grad():
#             model.eval()
#             val_loss = 0
#             for img, time in val_loader:
#                 img, noise = scheduler(img, time)
                
#                 img = img.to(device)
#                 time = time.to(device)
#                 noise = noise.to(device)

#                 loss = mse_loss(model(img, time), img)
#                 val_loss += loss.item()

#             val_loss /= len(val_loader)

#             if val_loss < best_val_loss:
#                 best_val_loss = val_loss
#                 early_stopping_counter = 0
#                 if save_path is not None:
#                     torch.save(model.state_dict(), save_path)
#             else:
#                 early_stopping_counter += 1
#                 if early_stopping_counter >= early_stopping:
#                     print(f'--- Early Stop @ {epoch} ---')
#                     break

#         if log_path is not None:
#             with open(log_path, "a") as log_file:
#                 log_file.write(f"{epoch},{train_loss},{val_loss}\n")
        
#         print(f'Epoch: {epoch}')
#         print(f'Train Loss: {train_loss}')
#         print(f'Validation Loss: {val_loss}', end='\n\n')

#     if test_loader is not None:
#         with torch.no_grad():
#             model.eval()
#             test_loss = 0
#             for img, time in test_loader:
#                 img, noise = scheduler(img, time)
                
#                 img = img.to(device)
#                 time = time.to(device)
#                 noise = noise.to(device)
#                 loss = mse_loss(model(img, time), img)
#                 test_loss += loss.item()
            
#             test_loss /= len(test_loader)
#             print(f'Test Loss: {test_loss}')

from tqdm import tqdm

def train_diffusion(
        model,
        scheduler,
        train_loader,
        val_loader,
        test_loader=None,
        epochs=100,
        early_stopping=10,
        optimizer=Adam,
        learning_rate=1e-3,
        weight_decay=0,
        device="cuda",
        log_path=None,
        save_path=None,
):
    model.to(device)
    optimizer = optimizer(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    best_val_loss = float("inf")
    early_stopping_counter = 0

    if log_path is not None:
        with open(log_path, "w") as log_file:
            log_file.write("epoch,train_loss,val_loss\n")
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for img, time in tqdm(train_loader, desc=f'Epoch {epoch}/{epochs}', leave=False):
            img, noise = scheduler(img, time)
            
            img = img.to(device)
            time = time.to(device)
            noise = noise.to(device)
            optimizer.zero_grad()
            loss = mse_loss(model(img), img)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        with torch.no_grad():
            model.eval()
            val_loss = 0
            for img, time in val_loader:
                img, noise = scheduler(img, time)
                
                img = img.to(device)
                time = time.to(device)
                noise = noise.to(device)

                loss = mse_loss(model(img, time), img)
                val_loss += loss.item()

            val_loss /= len(val_loader)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                early_stopping_counter = 0
                if save_path is not None:
                    torch.save(model.state_dict(), save_path)
            else:
                early_stopping_counter += 1
                if early_stopping_counter >= early_stopping:
                    print(f'--- Early Stop @ {epoch} ---')
                    break

        if log_path is not None:
            with open(log_path, "a") as log_file:
                log_file.write(f"{epoch},{train_loss},{val_loss}\n")
        
        print(f'Epoch: {epoch}')
        print(f'Train Loss: {train_loss}')
        print(f'Validation Loss: {val_loss}', end='\n\n')

    if test_loader is not None:
        with torch.no_grad():
            model.eval()
            test_loss = 0
            for img, time in test_loader:
                img, noise = scheduler(img, time)
                
                img = img.to(device)
                time = time.to(device)
                noise = noise.to(device)
                loss = mse_loss(model(img, time), img)
                test_loss += loss.item()
            
            test_loss /= len(test_loader)
            print(f'Test Loss: {test_loss}')

    

In [11]:
def main():
    model = UNet(in_channels=3, out_channels=3)
    # model = UNet(in_channels=3)
    train_loader, val_loader, test_loader = get_data_loaders(
        "D:\\datasets\\coco_2014\\train\\data",
        "D:\\datasets\\coco_2014\\validation\\data",
        "D:\\datasets\\coco_2014\\test\\data",
        batch_size=32,
        )
    train_diffusion(
        model=model,
        scheduler=CosineScheduler(timesteps=1000),
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        log_path="logs/log.csv",
        save_path="models/model.pt",
    )
    
if __name__ == "__main__":
    main()

FileNotFoundError: [Errno 2] No such file or directory: 'logs/log.csv'