In [10]:
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid
from tqdm import trange
import torch

from main import Model, NoiseScheduler, DATASETS, denormalize, CONFIGS, parse_args, ACTIVATION_FUNCTIONS

In [None]:
def _test(device, noise_scheduler, model, file_path="img.png", progress=False, dataset='cifar10', resolution=None,
          conditional=True):
    # Use seed
    torch.manual_seed(0)
    dataset = DATASETS[dataset]
    n, nr = 256, 16
    classes = None
    if conditional and dataset.name in ['cifar10', 'mnist']:
        n, nr = 100, 10
        classes = torch.arange(10).repeat_interleave(10).to(device)
    elif conditional and dataset.name == 'cifar100':
        n, nr = 100, 10
        classes = torch.arange(100).to(device)
    elif conditional:
        raise ValueError(f"Conditional model is not supported for {dataset.name}")

    x = torch.randn(n, dataset.image_channels, resolution, resolution, device=device)

    if progress:
        steps = trange(noise_scheduler.steps - 1, -1, -1)
    else:
        steps = range(noise_scheduler.steps - 1, -1, -1)

    for step in steps:
        with torch.no_grad():
            t = torch.tensor(step, device=device).expand(x.size(0), )
            if conditional:
                pred_noise = model(x, t, classes)
            else:
                pred_noise = model(x, t)
            x = noise_scheduler.sample_prev_step(x, t, pred_noise)

    x = denormalize(x).clamp(0, 1)

    # Create an image grid
    grid = make_grid(x, nrow=nr, padding=2)
    grid = to_pil_image(grid)
    grid.save(file_path)
    torch.seed()  # Reset seed


def test(model_channels=32,
         activation_fn=torch.nn.SiLU,
         num_res_blocks=2,
         channel_mult=(1, 2, 2, 2),
         dropout=0.1,
         attention_resolutions=(2,),
         gpu=None,
         model_path='model.pth',
         file_path='img.png',
         dataset='cifar10',
         conditional=True):
    dataset = DATASETS[dataset]
    device = torch.device(f'cuda:{gpu}' if gpu is not None else 'cpu')
    noise_scheduler = NoiseScheduler().to(device)
    model = Model(image_channels=dataset.image_channels,
                  model_channels=model_channels,
                  activation_fn=activation_fn,
                  num_res_blocks=num_res_blocks,
                  channel_mult=channel_mult,
                  dropout=dropout,
                  attention_resolutions=attention_resolutions,
                  num_classes=dataset.num_classes)
    model.load_state_dict(torch.load(model_path, weights_only=True))
    model.to(device)
    model.eval()
    _test(device, noise_scheduler, model, file_path, progress=True, dataset=dataset.name, conditional=conditional)

In [2]:
class NoiseSchedulerParameterized(NoiseScheduler):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def sample_prev_step_z(self, xt, t, pred_noise, z):
        t = t.view(-1, 1, 1, 1)
        z[t.expand_as(z) == 0] = 0

        mean = (1 / torch.sqrt(self.alpha[t])) * (xt - (self.beta[t] / torch.sqrt(1 - self.alpha_bar[t])) * pred_noise)
        var = ((1 - self.alpha_bar[t - 1]) / (1 - self.alpha_bar[t])) * self.beta[t]
        sigma = torch.sqrt(var)

        x = mean + sigma * z
        return x

In [2]:
config = CONFIGS['mnist']
args = parse_args(args=config)

In [3]:
args

Namespace(command='train', dataset='mnist', batch_size=256, epochs=2000, steps=200000, val_interval=4000, lr=0.001, grad_clip=1.0, grad_accum=1, warmup=5000, ema_decay=0.9999, model_channels=32, activation='silu', num_res_blocks=1, channel_mult=[1, 1, 2], hflip=False, dropout=0.1, attention_resolutions=[2], gpu=None, model='model.pth', save_checkpoints=True, log_interval=1, output_dir='output', file_path='img.png', conditional=False, resolution=None, progress=True, config=None)

In [11]:
dataset = DATASETS[args.dataset]
activation_fn = ACTIVATION_FUNCTIONS[args.activation]

In [12]:
model = Model(image_channels=dataset.image_channels,
              model_channels=args.model_channels,
              activation_fn=activation_fn,
              num_res_blocks=args.num_res_blocks,
              channel_mult=args.channel_mult,
              dropout=args.dropout,
              attention_resolutions=args.attention_resolutions,
              num_classes=dataset.num_classes)