# Trains a diffusion model on CIFAR-10 (version 2).

By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings).

The model is a denoising diffusion probabilistic model (https://arxiv.org/abs/2006.11239), which is trained to reverse a gradual noising process, allowing the model to generate samples from the learned data distribution starting from random noise. DDIM-style deterministic sampling (https://arxiv.org/abs/2010.02502) is also supported. This model is also trained on continuous timesteps parameterized by the log SNR on each timestep (see Variational Diffusion Models, https://arxiv.org/abs/2107.00630), allowing different noise schedules than the one used during training to be easily used during sampling. It uses the 'v' objective from Progressive Distillation for Fast Sampling of Diffusion Models (https://openreview.net/forum?id=TIdIXIpzhoI) for better conditioned denoised images at high noise levels, but reweights the loss function so that it has the same relative weighting as the 'eps' objective.

In [None]:
# @title Licensed under the MIT License

# Copyright (c) 2021 Katherine Crowson

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

In [None]:
# Check the GPU type
!nvidia-smi

In [None]:
# Imports
from contextlib import contextmanager
from copy import deepcopy
import math
import numpy as np
from IPython import display
from matplotlib import pyplot as plt
import torch
from torch import optim, nn
from torch.nn import functional as F
from torch.utils import data
from torchvision import datasets, transforms, utils
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm, trange

In [None]:
# Utilities

@contextmanager
def train_mode(model, mode=True):
    """A context manager that places a model into training mode and restores
    the previous mode on exit."""
    modes = [module.training for module in model.modules()]
    try:
        yield model.train(mode)
    finally:
        for i, module in enumerate(model.modules()):
            module.training = modes[i]
def eval_mode(model):
    """A context manager that places a model into evaluation mode and restores
    the previous mode on exit."""
    return train_mode(model, False)


@torch.no_grad()
def ema_update(model, averaged_model, decay):
    """Incorporates updated model parameters into an exponential moving averaged
    version of a model. It should be called after each optimizer step."""
    model_params = dict(model.named_parameters())
    averaged_params = dict(averaged_model.named_parameters())
    assert model_params.keys() == averaged_params.keys()

    for name, param in model_params.items():
        averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)

    model_buffers = dict(model.named_buffers())
    averaged_buffers = dict(averaged_model.named_buffers())
    assert model_buffers.keys() == averaged_buffers.keys()

    for name, buf in model_buffers.items():
        averaged_buffers[name].copy_(buf)


In [None]:
# Define the model (a residual U-Net)

class ResidualBlock(nn.Module):
    def __init__(self, main, skip=None):
        super().__init__()
        self.main = nn.Sequential(*main)
        self.skip = skip if skip else nn.Identity()

    def forward(self, input):
        return self.main(input) + self.skip(input)


class ResConvBlock(ResidualBlock):
    def __init__(self, c_in, c_mid, c_out, dropout_last=True):
        skip = None if c_in == c_out else nn.Conv2d(c_in, c_out, 1, bias=False)
        super().__init__([
            nn.Conv2d(c_in, c_mid, 3, padding=1),
            nn.Dropout2d(0.1, inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(c_mid, c_out, 3, padding=1),
            nn.Dropout2d(0.1, inplace=True) if dropout_last else nn.Identity(),
            nn.ReLU(inplace=True),
        ], skip)


class SkipBlock(nn.Module):
    def __init__(self, main, skip=None):
        super().__init__()
        self.main = nn.Sequential(*main)
        self.skip = skip if skip else nn.Identity()

    def forward(self, input):
        return torch.cat([self.main(input), self.skip(input)], dim=1)


class FourierFeatures(nn.Module):
    def __init__(self, in_features, out_features, std=1.):
        super().__init__()
        assert out_features % 2 == 0
        self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)

    def forward(self, input):
        f = 2 * math.pi * input @ self.weight.T
        return torch.cat([f.cos(), f.sin()], dim=-1)


def expand_to_planes(input, shape):
    return input[..., None, None].repeat([1, 1, shape[2], shape[3]])


class Diffusion(nn.Module):
    def __init__(self):
        super().__init__()
        c = 64  # The base channel count

        # The inputs to timestep_embed will approximately fall into the range
        # -10 to 10, so use std 0.2 for the Fourier Features.
        self.timestep_embed = FourierFeatures(1, 16, std=0.2)
        #self.class_embed = nn.Embedding(10, 4)

        self.net = nn.Sequential(   # 32x32
            ResConvBlock(3 + 16 + 4, c, c),
            ResConvBlock(c, c, c),
            SkipBlock([
                nn.AvgPool2d(2),  # 32x32 -> 16x16
                ResConvBlock(c, c * 2, c * 2),
                ResConvBlock(c * 2, c * 2, c * 2),
                SkipBlock([
                    nn.AvgPool2d(2),  # 16x16 -> 8x8
                    ResConvBlock(c * 2, c * 4, c * 4),
                    ResConvBlock(c * 4, c * 4, c * 4),
                    SkipBlock([
                        nn.AvgPool2d(2),  # 8x8 -> 4x4
                        ResConvBlock(c * 4, c * 8, c * 8),
                        ResConvBlock(c * 8, c * 8, c * 8),
                        ResConvBlock(c * 8, c * 8, c * 8),
                        ResConvBlock(c * 8, c * 8, c * 4),
                        nn.Upsample(scale_factor=2),
                    ]),  # 4x4 -> 8x8
                    ResConvBlock(c * 8, c * 4, c * 4),
                    ResConvBlock(c * 4, c * 4, c * 2),
                    nn.Upsample(scale_factor=2),
                ]),  # 8x8 -> 16x16
                ResConvBlock(c * 4, c * 2, c * 2),
                ResConvBlock(c * 2, c * 2, c),
                nn.Upsample(scale_factor=2),
            ]),  # 16x16 -> 32x32
            ResConvBlock(c * 2, c, c),
            ResConvBlock(c, c, 3, dropout_last=False),
        )

    def forward(self, input, log_snrs, cond):
        timestep_embed = expand_to_planes(self.timestep_embed(log_snrs[:, None]), input.shape)
        #class_embed = expand_to_planes(self.class_embed(cond), input.shape)
        #print(class_embed.shape)
        b,c,h,w = input.shape
        class_embed = torch.zeros(b,4,h,w).to(device)
        #class_embed = torch.zeros_like(input)[:, :4].to(device)
        #print(class_embed.shape)
        return self.net(torch.cat([input, class_embed, timestep_embed], dim=1))
    def get_features(self, input, log_snrs, cond):
        timestep_embed = expand_to_planes(self.timestep_embed(log_snrs[:, None]), input.shape)
        #class_embed = expand_to_planes(self.class_embed(cond), input.shape)
        b,c,h,w = input.shape
        class_embed = torch.zeros(b,4,h,w).to(device)
        x = torch.cat([input, class_embed, timestep_embed], dim=1)

        features = []
        features_before_up = []
        res_bl_lvl1_num = 0
        for module in self.net:
            if isinstance(module, ResConvBlock):
                x = module(x)
                features.append(x)
                #print("Res detected on level1")
                res_bl_lvl1_num += 1
                if res_bl_lvl1_num == 3:
                    features_before_up.append(x)

            if isinstance(module, SkipBlock):
                #print("Skipblock level1")
                before_skip1 = x
                for module1 in module.main:
                    if isinstance(module1, nn.AvgPool2d):
                        #print("AvgPool2d detected level2")

                        x = module1(x)
                    if isinstance(module1, ResConvBlock):
                        x = module1(x)
                        features.append(x)
                        #print("Res detected level2")

                    if isinstance(module1, SkipBlock):
                        #print("Skipblock level2")
                        before_skip2 = x
                        for module2 in module1.main:
                            if isinstance(module2, nn.AvgPool2d):
                                x = module2(x)
                                #print("AvgPool2d detected level3")

                            if isinstance(module2, ResConvBlock):
                                x = module2(x)
                                features.append(x)
                                #print("Res detected level3")

                            if isinstance(module2, SkipBlock):
                                #print("Skipblock level3")
                                before_skip3 = x
                                for module3 in module2.main:
                                    if isinstance(module3, nn.AvgPool2d):
                                        #print("AvgPool2d detected level4")
                                        x = module3(x)
                                    if isinstance(module3, ResConvBlock):
                                        x = module3(x)
                                        features.append(x)
                                        #print("Res detected level4")

                                    if isinstance(module3, nn.Upsample):
                                        #print("upsample before: ", x.shape)
                                        features_before_up.append(x)

                                        #print("Up detected level4")
                                        x = module3(x)
                                        x = torch.cat([x, before_skip3], dim=1)

                                        #print("upsample after: ", x.shape)
                            if isinstance(module2, nn.Upsample):
                                features_before_up.append(x)
                                #print("Up detected level3")
                                x = module2(x)
                                x = torch.cat([x, before_skip2], dim=1)
                    if isinstance(module1, nn.Upsample):
                        features_before_up.append(x)
                        x = module1(x)
                        x = torch.cat([x,before_skip1], dim=1)
                        #print("Up detected level2")

            if isinstance(module, nn.Upsample): # не зайдет в него
                features_before_up.append(x)
                x = module(x)
                #print("Up detected level1")
        #return features_before_up
        features = [feature.squeeze().cpu().numpy() for feature in features_before_up]
        feature0 = np.array(features[0])
        feature1 = np.array(features[1])
        feature2 = np.array(features[2])
        feature3 = np.array(features[3])
        # print(feature0.shape)
        # print(feature1.shape)
        # print(feature2.shape)
        # print(feature3.shape)
        # Apply average pooling and reshape
        # print("avg_pool.shape: ", F.avg_pool2d(torch.tensor(feature0), kernel_size=4).shape)
        feature0_pooled = F.avg_pool2d(torch.tensor(feature0), kernel_size=4).squeeze()
        feature1_pooled = F.avg_pool2d(torch.tensor(feature1), kernel_size=8).squeeze()
        feature2_pooled = F.avg_pool2d(torch.tensor(feature2), kernel_size=16).squeeze()
        feature3_pooled = F.avg_pool2d(torch.tensor(feature3), kernel_size=32).squeeze()
        # print(feature0_pooled.shape)
        # print(feature1_pooled.shape)
        # print(feature2_pooled.shape)
        # print(feature3_pooled.shape)
        feature_map = torch.cat((feature0_pooled, feature1_pooled, feature2_pooled, feature3_pooled), axis =1)
        #print(feature_map.shape) # (batch_size, 512)
        return feature_map

In [None]:
# Define the noise schedule and sampling loop

def get_alphas_sigmas(log_snrs):
    """Returns the scaling factors for the clean image (alpha) and for the
    noise (sigma), given the log SNR for a timestep."""
    return log_snrs.sigmoid().sqrt(), log_snrs.neg().sigmoid().sqrt()

def get_ddpm_schedule(t):
    """Returns log SNRs for the noise schedule from the DDPM paper."""
    return -torch.special.expm1(1e-4 + 10 * t**2).log()

@torch.no_grad()
def sample(model, x, steps, eta, classes):
    """Draws samples from a model given starting noise."""
    ts = x.new_ones([x.shape[0]])

    # Create the noise schedule
    t = torch.linspace(1, 0, steps + 1)[:-1]
    log_snrs = get_ddpm_schedule(t)
    alphas, sigmas = get_alphas_sigmas(log_snrs)

    # The sampling loop
    for i in trange(steps):

        # Get the model output (v, the predicted velocity)
        with torch.cuda.amp.autocast():
            v = model(x, ts * log_snrs[i], classes).float()

        # Predict the noise and the denoised image
        pred = x * alphas[i] - v * sigmas[i]
        eps = x * sigmas[i] + v * alphas[i]

        # If we are not on the last timestep, compute the noisy image for the
        # next timestep.
        if i < steps - 1:
            # If eta > 0, adjust the scaling factor for the predicted noise
            # downward according to the amount of additional noise to add
            ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
                (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
            adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
            # Recombine the predicted noise and predicted denoised image in the
            # correct proportions for the next step
            x = pred * alphas[i + 1] + eps * adjusted_sigma
            # Add the correct amount of fresh noise
            if eta:
                x += torch.randn_like(x) * ddim_sigma
    # If we are on the last timestep, output the denoised image
    return pred

# Diff net training

In [None]:
# Visualize the noise schedule

%config InlineBackend.figure_format = 'retina'
plt.rcParams['figure.dpi'] = 100

t_vis = torch.linspace(0, 1, 1000)
log_snrs_vis = get_ddpm_schedule(t_vis)
alphas_vis, sigmas_vis = get_alphas_sigmas(log_snrs_vis)

print('The noise schedule:')

plt.plot(t_vis, alphas_vis, label='alpha (signal level)')
plt.plot(t_vis, sigmas_vis, label='sigma (noise level)')
plt.legend()
plt.xlabel('timestep')
plt.grid()
plt.show()

plt.plot(t_vis, log_snrs_vis, label='log SNR')
plt.legend()
plt.xlabel('timestep')
plt.grid()
plt.show()


In [None]:
# Prepare the dataset

batch_size = 100

tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])
train_set = datasets.CIFAR10('data', train=True, download=True, transform=tf)
train_dl = data.DataLoader(train_set, batch_size, shuffle=True,
                           num_workers=4, persistent_workers=True, pin_memory=True)
val_set = datasets.CIFAR10('data', train=False, download=True, transform=tf)
val_dl = data.DataLoader(val_set, batch_size,
                         num_workers=4, persistent_workers=True, pin_memory=True)


In [None]:
# Create the model and optimizer

seed = 0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
torch.manual_seed(0)

model = Diffusion().to(device)
model_ema = deepcopy(model)
print('Model parameters:', sum(p.numel() for p in model.parameters()))

opt = optim.Adam(model.parameters(), lr=2e-4)
scaler = torch.cuda.amp.GradScaler()
epoch = 0

# Use a low discrepancy quasi-random sequence to sample uniformly distributed
# timesteps. This considerably reduces the between-batch variance of the loss.
rng = torch.quasirandom.SobolEngine(1, scramble=True)

In [None]:
# Actually train the model

ema_decay = 0.998

# The number of timesteps to use when sampling
steps = 500

# The amount of noise to add each timestep when sampling
# 0 = no noise (DDIM)
# 1 = full noise (DDPM)
eta = 1.


def eval_loss(model, rng, reals, classes):
    # Draw uniformly distributed continuous timesteps
    t = rng.draw(reals.shape[0])[:, 0].to(device)

    # Calculate the noise schedule parameters for those timesteps
    log_snrs = get_ddpm_schedule(t)
    alphas, sigmas = get_alphas_sigmas(log_snrs)
    weights = log_snrs.exp() / log_snrs.exp().add(1)

    # Combine the ground truth images and the noise
    alphas = alphas[:, None, None, None]
    sigmas = sigmas[:, None, None, None]
    noise = torch.randn_like(reals)
    noised_reals = reals * alphas + noise * sigmas
    targets = noise * alphas - reals * sigmas

    # Compute the model output and the loss.
    with torch.cuda.amp.autocast():
        v = model(noised_reals, log_snrs, classes)
        return (v - targets).pow(2).mean([1, 2, 3]).mul(weights).mean()


def train():
    for i, (reals, classes) in enumerate(tqdm(train_dl)):
        opt.zero_grad()
        reals = reals.to(device)
        classes = classes.to(device)

        # Evaluate the loss
        loss = eval_loss(model, rng, reals, classes)

        # Do the optimizer step and EMA update
        scaler.scale(loss).backward()
        scaler.step(opt)
        ema_update(model, model_ema, 0.95 if epoch < 20 else ema_decay)
        scaler.update()

        if i % 50 == 0:
            tqdm.write(f'Epoch: {epoch}, iteration: {i}, loss: {loss.item():g}')


@torch.no_grad()
@torch.random.fork_rng()
@eval_mode(model_ema)
def val():
    tqdm.write('\nValidating...')
    torch.manual_seed(seed)
    rng = torch.quasirandom.SobolEngine(1, scramble=True)
    total_loss = 0
    count = 0
    for i, (reals, classes) in enumerate(tqdm(val_dl)):
        reals = reals.to(device)
        classes = classes.to(device)

        loss = eval_loss(model_ema, rng, reals, classes)

        total_loss += loss.item() * len(reals)
        count += len(reals)
    loss = total_loss / count
    tqdm.write(f'Validation: Epoch: {epoch}, loss: {loss:g}')


@torch.no_grad()
@torch.random.fork_rng()
@eval_mode(model_ema)
def demo():
    tqdm.write('\nSampling...')
    torch.manual_seed(seed)

    noise = torch.randn([100, 3, 32, 32], device=device)
    fakes_classes = torch.arange(10, device=device).repeat_interleave(10, 0)
    fakes = sample(model_ema, noise, steps, eta, fakes_classes)

    grid = utils.make_grid(fakes, 10).cpu()
    filename = f'demo_{epoch:05}.png'
    TF.to_pil_image(grid.add(1).div(2).clamp(0, 1)).save(filename)
    display.display(display.Image(filename))
    tqdm.write('')


def save():
    filename = 'cifar_diffusion.pth'
    obj = {
        'model': model.state_dict(),
        'model_ema': model_ema.state_dict(),
        'opt': opt.state_dict(),
        'scaler': scaler.state_dict(),
        'epoch': epoch,
    }
    torch.save(obj, filename)


try:
    val()
    demo()
    while True:
        print('Epoch', epoch)
        train()
        epoch += 1
        if epoch % 5 == 0:
            val()
            demo()
        save()
except KeyboardInterrupt:
    pass


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import shutil

source_path = '/content/cifar_diffusion.pth'
destination_path = '/content/drive/My Drive/without_labels_cifar_diffusion_30_epochs.pth'

shutil.copyfile(source_path, destination_path)


# Load model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
torch.manual_seed(0)

In [None]:
from google.colab import drive
import torch

drive.mount('/content/drive')
weights_path = '/content/drive/My Drive/without_labels_cifar_diffusion_30_epochs.pth' # Path to the weights file on Google Drive
# '/content/drive/My Drive/cifar_diffusion_30_epochs.pth'
#or '/content/drive/My Drive/cifar_diffusion.pth'

saved_obj = torch.load(weights_path, map_location=torch.device('cpu')) # Load the weights + some info about trainig
model_dif = Diffusion().to(device) # Create an instance of your Diffusion model
model_dif.load_state_dict(saved_obj['model']) # Load the weights into the model

# My 1 image demonstration

In [None]:
seed = 0
ema_decay = 0.998
steps = 500
eta = 1.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
torch.manual_seed(0)

model_ema = deepcopy(model_dif)
print('Model parameters:', sum(p.numel() for p in model.parameters()))

In [None]:
@torch.no_grad()
@torch.random.fork_rng()
@eval_mode(model_ema)
def demo_small():
    tqdm.write('\nSampling...')
    torch.manual_seed(seed+100)
    noise = torch.randn([1, 3, 32, 32], device=device)
    #fakes_classes = torch.arange(1, device=device).repeat_interleave(1, 0)
    fakes_classes = torch.tensor([1])
    print(fakes_classes)
    fakes = sample(model_ema, noise, steps, eta, fakes_classes)
    print(fakes.shape)
    fakes = fakes.squeeze().permute(1, 2, 0).cpu().numpy()
    print(fakes.shape)
    plt.imshow(fakes)
    plt.axis('off')
    plt.show()

In [None]:
demo_small()

# Feature extraction example

In [None]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import numpy as np
rng = torch.quasirandom.SobolEngine(1, scramble=True)

transform = transforms.Compose([
    transforms.ToTensor(),
    # need to add normalization
])


dataset = CIFAR10(root='./data', train=False, download=True, transform=transform) # Load the CIFAR10 dataset
dataloader_small = DataLoader(dataset, batch_size=10, shuffle=False)
model_dif.eval() # Set the model to evaluation mode

c = 0
for image, classes in dataloader_small:
    print(image.shape)
    print(classes)
    t = rng.draw(image.shape[0])[:, 0]
    print(t.shape)
    log_snrs = get_ddpm_schedule(t) # Calculate the noise schedule parameters for those timesteps
    # Compute features for one image
    with torch.no_grad():
        features = model_dif.get_features(image, log_snrs, classes)
        #features = [feature.squeeze().cpu().numpy() for feature in features]
        c+=1
        if c==1:
            break

In [None]:
t

In [None]:
features.shape

#Umap feature vizualization

In [None]:
pip install umap-learn

In [None]:
import numpy as np
from sklearn.preprocessing import StandardScaler
import umap
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.decomposition import PCA

import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import numpy as np
rng = torch.quasirandom.SobolEngine(1, scramble=True)

transform = transforms.Compose([
    transforms.ToTensor(),
    # need to add normalization
    transforms.Normalize([0.5], [0.5]),
])


dataset = CIFAR10(root='./data', train=False, download=True, transform=transform) # Load the CIFAR10 dataset
dataloader_small = DataLoader(dataset, batch_size=1000, shuffle=False)
model_dif.eval() # Set the model to evaluation mode

c = 0
for image, classes in dataloader_small:
    print(image.shape)
    #print(classes)
    #t = rng.draw(image.shape[0])[:, 0]
    t = torch.tensor([0.001] * 1000)

    print(t.shape)
    log_snrs = get_ddpm_schedule(t) # Calculate the noise schedule parameters for those timesteps
    # Compute features for one batch
    with torch.no_grad():
        features = model_dif.get_features(image, log_snrs, classes)
        # features shape: (num_samples, num_features)
        # labels shape: (num_samples,)
        scaler = StandardScaler()
        normalized_features = scaler.fit_transform(features)
        reducer = umap.UMAP()
        embedding = reducer.fit_transform(normalized_features)
        plt.figure(figsize=(10, 8))
        plt.scatter(embedding[:, 0], embedding[:, 1], c=classes, cmap='tab10', s=5)
        plt.title('UMAP Visualization of CIFAR-10 Features')
        plt.colorbar()
        plt.show()
        c+=1
        if c==1:
            break

In [None]:
c = 0
for image, classes in dataloader_small:
    print(image.shape)
    #print(classes)
    #t = rng.draw(image.shape[0])[:, 0]
    t = torch.tensor([0.001] * 1000)

    print(t.shape)
    log_snrs = get_ddpm_schedule(t) # Calculate the noise schedule parameters for those timesteps
    # Compute features for one batch
    with torch.no_grad():
        features = model_dif.get_features(image, log_snrs, classes)
        # features shape: (num_samples, num_features)
        # labels shape: (num_samples,)
        scaler = StandardScaler()
        normalized_features = scaler.fit_transform(features)
        # reducer = umap.UMAP()
        # embedding = reducer.fit_transform(normalized_features)
        umap_3d = umap.UMAP(n_components=3, random_state=42, n_neighbors=10, min_dist=0.3)
        projection = umap_3d.fit_transform(normalized_features)
        # plt.figure(figsize=(10, 8))
        # plt.scatter(embedding[:, 0], embedding[:, 1], c=classes, cmap='tab10', s=5)
        # plt.title('UMAP Visualization of CIFAR-10 Features')
        # plt.colorbar()
        # plt.show()

        # Plot the 3D graph
        fig = plt.figure(figsize=(10, 7))
        ax = fig.add_subplot(111, projection='3d')

        # Scatter plot with labels as color
        scatter = ax.scatter(projection[:, 0], projection[:, 1], projection[:, 2], c=classes, cmap='Spectral', s=5)
        legend1 = ax.legend(*scatter.legend_elements(), title="Digits")
        ax.add_artist(legend1)

        plt.show()

        c+=1
        if c==1:
            break

# Forward process noising

In [None]:
bs = 1

tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

train_dl2 = data.DataLoader(train_set, bs, shuffle=True,
                           num_workers=4, persistent_workers=True, pin_memory=True)
for i, (reals, classes) in enumerate(tqdm(train_dl2)):
    reals = reals.to(device)
    classes = classes.to(device)
    t_list = [0.001, 0.01, 0.1, 0.5, 0.9]
    noised_im = list
    for t in t_list:
        log_snrs = get_ddpm_schedule(torch.tensor([t]))
        alphas, sigmas = get_alphas_sigmas(log_snrs)
        weights = log_snrs.exp() / log_snrs.exp().add(1)

        # Combine the ground truth images and the noise
        alphas = alphas[:, None, None, None]
        sigmas = sigmas[:, None, None, None]
        noise = torch.randn_like(reals)
        noised_reals = reals * alphas + noise * sigmas
        print(noised_reals.shape)
        noised_reals = noised_reals.squeeze().permute(1, 2, 0).cpu().numpy()
        print(noised_reals.shape)
        #noised_im.append(noised_reals)
        plt.imshow(noised_reals)
        plt.axis('off')
        plt.show()
    break

# Small net train

In [None]:
def extract_features(images, labels, t_up, batch_size, model_dif):
    """
    Extracts features, namely concatenates averaged arrays from different layers of the diffusion model with the UNet architecture.
    :t_up: from which step of the forward diffusion process image is needed.
    :model_dif: a diffusion model.
    :forward_diffusion: forward diffusion process.
    :return: a tensor of shape [batch_size, 80].
    """
    with torch.no_grad():
        images = images.to(device)
        labels = labels.to(device)
        t = torch.tensor([t_up] * labels.shape[0])
        log_snrs = get_ddpm_schedule(t)
        log_snrs = log_snrs.to(device)
        alphas, sigmas = get_alphas_sigmas(log_snrs)
        weights = log_snrs.exp() / log_snrs.exp().add(1)
        alphas = alphas[:, None, None, None]
        sigmas = sigmas[:, None, None, None]
        noise = torch.randn_like(images)

        # print("images:", images.shape)
        # print("log_snrs:", log_snrs.shape)
        # print("labels.shape:", labels.shape)
        # print("alphas.shape:", alphas.shape)
        # print("sigmas.shape:", sigmas.shape)
        noised_reals = images * alphas + noise * sigmas
        # print(noised_reals.shape)
        features = model_dif.get_features(noised_reals.to(device), log_snrs, labels)
        #print(features.shape)
    return torch.tensor(features)

In [None]:
def train_model(train_loader, t_up, batch_size, model, criterion, optimizer, model_dif,  epochs=90, loss_list=[]):
    model.train()
    loss_list = []
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            features = extract_features(images, labels, t_up, batch_size, model_dif.to(device)).to(device)
            #print(features)
            outputs = model(features)
            optimizer.zero_grad()
            #print(labels)
            loss = criterion(outputs, labels)
            #print(loss)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        loss_list.append(running_loss / len(train_loader))
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader)}")
    return loss_list

In [None]:
def test_model(model, test_loader, t_up, batch_size, model_dif=model_dif):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            features = extract_features(images, labels, t_up, batch_size, model_dif).to(device)
            outputs = model(features)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f"Accuracy: {accuracy}%")
    return accuracy

In [None]:
batch_size = 256

tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

train_set = datasets.CIFAR10('data', train=True, download=True, transform=tf)
train_dl = data.DataLoader(train_set, batch_size, shuffle=True,
                           num_workers=4, persistent_workers=True, pin_memory=True)
val_set = datasets.CIFAR10('data', train=False, download=True, transform=tf)
val_dl = data.DataLoader(val_set, batch_size,
                         num_workers=4, persistent_workers=True, pin_memory=True)

In [None]:
from smallnet import LinearNet, Net, split_dataset
input_size = 512  # embedding size
num_classes = 10
epochs = 10
t_up = 0.1  # from 0 to 1!!!

model = Net(input_size, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

loss_list = train_model(train_dl, t_up, batch_size, model, criterion, optimizer, model_dif, epochs=epochs)
accuracy = test_model(model, val_dl, t_up, batch_size, model_dif=model_dif)
accuracy

In [None]:
accuracy = test_model(model, val_dl, t_up, batch_size, model_dif=model_dif)