In [1]:
import torch
import torchvision
from torch import nn
from torch import optim
from torchvision import datasets
from torchvision.transforms import v2 as T
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output

In [2]:
T = 1000
schedule = torch.linspace(1, 0.001, T)
def forward_diffuse(x, timestep, schedule=schedule):
    noise = torch.randn_like(x)
    alpha = schedule[timestep]
    return torch.sqrt(alpha) * x + torch.sqrt(1 - alpha) * noise, noise

In [3]:
class BackwardDiffuse(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 128, kernel_size=3, padding=1)
        # self.conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        # self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        # self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        # self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        # self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        # self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv8 = nn.Conv2d(128, 3, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        # x = self.relu(self.conv2(x))
        # x = self.relu(self.conv3(x))
        # x = self.relu(self.conv4(x))
        # x = self.relu(self.conv5(x))
        # x = self.relu(self.conv6(x))
        # x = self.relu(self.conv7(x))
        x = self.conv8(x)
        return x

In [4]:
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
transform = T.Compose([
    T.ToImage(),
    T.Resize(size=(32, 32), antialias=True),
    T.ToDtype(torch.float32, scale=True),
    # T.Normalize(mean=mean, std=std),
])

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform):
        super().__init__()
        self.root_dir = root_dir
        self.transform = transform

train_data = datasets.CIFAR10(
    root="data",
    download=True,
    transform=transform
)

train_loader = DataLoader(dataset=train_data, batch_size=512, shuffle=True)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BackwardDiffuse()
model = model.to(device)
criterion = nn.MSELoss()
optimizer = optim.AdamW(params=model.parameters(), lr=1e-4, weight_decay=1e-3)

In [None]:
num_epochs = 250
losses = []
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')

for epoch in tqdm(range(num_epochs)):
    plt.ylim(bottom=0.0)
    plt.show()
    model.train()
    avg_loss = 0
    for images, _ in train_loader:
        t = torch.randint(0, T, (len(images),))
        noisy_images, noise = forward_diffuse(images, t)
        noisy_images, noise = noisy_images.to(device), noise.to(device)
            
        output = model(noisy_images)
        loss = criterion(output, noise)
        avg_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    avg_loss = avg_loss / len(train_loader)
    losses.append(avg_loss)
    plt.plot(losses, marker='o', linestyle='-', color='b', label='Loss')
    clear_output(wait=True)

In [None]:
def plot_image(tensor):
    plt.figure(figsize=(1.25, 1.25))
    plt.axis('off')
    plt.imshow(tensor.permute(1, 2, 0).numpy())

In [None]:
with torch.no_grad():
    test_image = torch.randn((1, 3, 32, 32)).to(device)
    for _ in range(5):
        test_image = model(test_image)
        plot_image(test_image[0].cpu())