In [1]:
import argparse
from utils.load_params import load_params
from utils.sample import sample
from model.train import train
from dataset.mnist_dataset import MnistDataset
from torch.utils.data import DataLoader
from noise_scheduler.scheduler import NoiseScheduler
import torch
from model.unet import Unet

from noise_scheduler._archived_scheduler import NoiseScheduler as NoiseScheduler2

### Load Parameters
The parameters are available under `config.json`

In [2]:
#parser = argparse.ArgumentParser(description="Run models with optional training.")
#args = parser.parse_args()
params = load_params("config.json")

### Load Data

In [3]:
traindataset = MnistDataset(root_dir=params["root_dir"])
train_loader = DataLoader(traindataset, batch_size=params["batch_size"], shuffle=True, num_workers=2)

Found 60000 images


### Noise Scheduler

In [None]:
train_model = False
if train_model:
    device= 'cuda' if torch.cuda.is_available() else 'cpu'  
    scheduler = NoiseScheduler(device, params["num_timesteps"], params["beta_start"], params["beta_end"])
    train(
        epochs_num=params["epochs_num"],
        train_loader=train_loader,
        noise_scheduler=scheduler
    )
else:
    scheduler = NoiseScheduler2(params["num_timesteps"], params["beta_start"], params["beta_end"])
    model = Unet()
    state_dict = torch.load("model/model.pth", map_location=torch.device('cpu'))
    model.load_state_dict(state_dict)
    model.eval()
    sample(
        model=model,
        scheduler=scheduler,
        config=params
    )
    

1it [00:03,  3.10s/it]

tensor([[[[ 0.7091,  1.0742,  1.2403,  ...,  1.0385,  0.1881, -1.6217],
          [ 0.0032,  0.2035, -1.1948,  ...,  1.1383, -1.0080,  0.4361],
          [ 0.1956, -0.1845,  1.4496,  ...,  0.3318,  1.1738,  1.1153],
          ...,
          [ 0.1356, -0.1732, -0.5932,  ...,  0.4422,  0.4380, -0.7069],
          [ 0.7866,  0.6659, -1.7020,  ...,  2.2680, -0.5266,  1.5246],
          [-1.9091, -0.5354,  0.8463,  ...,  0.4438,  0.1641, -0.2859]]],


        [[[-1.0341,  1.2731,  0.5610,  ..., -0.1551,  1.7822,  0.3779],
          [-0.4225, -1.3237,  0.4533,  ...,  2.0477, -1.0226, -0.6864],
          [-0.2520,  0.9098,  0.0135,  ...,  1.1391, -1.7362,  0.4268],
          ...,
          [ 0.7363,  0.6258,  1.4630,  ..., -1.1713,  0.5192,  0.0451],
          [-0.8688, -0.5332, -0.0797,  ...,  0.4065, -1.3182,  0.1141],
          [-0.3480, -1.3200,  0.4356,  ..., -0.3924,  2.7445, -0.3550]]],


        [[[ 0.4869,  0.6204,  0.0240,  ..., -0.4432,  0.2591,  1.6308],
          [ 2.1831, -0.114

200it [00:08, 26.87it/s]

tensor([[[[ 0.7729,  1.2127,  1.5418,  ...,  1.0558,  0.3338, -1.6287],
          [ 0.0057,  0.3542, -1.0668,  ...,  1.1777, -1.0307,  0.2799],
          [ 0.3281, -0.1562,  1.5780,  ...,  0.1244,  1.0634,  1.0799],
          ...,
          [ 0.1081, -0.0327, -0.5219,  ...,  0.6679,  0.4619, -0.7063],
          [ 0.9880,  0.8429, -1.7611,  ...,  2.1871, -0.7532,  1.6544],
          [-1.9492, -0.5031,  0.9167,  ...,  0.3212,  0.2537, -0.3324]]],


        [[[-0.8948,  1.2580,  0.5097,  ..., -0.3719,  1.7746,  0.3284],
          [-0.5122, -1.4060,  0.1443,  ...,  1.9948, -0.9392, -0.6134],
          [-0.4980,  0.9544, -0.0028,  ...,  0.9743, -1.8176,  0.3728],
          ...,
          [ 0.6552,  1.0005,  1.8174,  ..., -1.0518,  0.5191,  0.2127],
          [-0.9166, -0.6678,  0.0151,  ...,  0.3860, -1.2774,  0.1135],
          [-0.5126, -1.1959,  0.3379,  ..., -0.2982,  2.5602, -0.3689]]],


        [[[ 0.4582,  0.7498,  0.0848,  ..., -0.5335,  0.4770,  1.6579],
          [ 2.1589,  0.076

700it [00:11, 74.01it/s]

tensor([[[[ 1.1611,  1.6040,  1.9456,  ...,  1.3421,  0.7112, -1.3894],
          [ 0.1237,  0.5564, -0.6237,  ...,  1.5196, -0.8982,  0.6184],
          [ 0.6782,  0.2191,  2.0270,  ...,  0.4853,  1.1434,  1.3932],
          ...,
          [ 0.3731, -0.1122, -0.2012,  ...,  0.9002,  0.7203, -0.5742],
          [ 1.3615,  1.4485, -1.4552,  ...,  2.5265, -0.4938,  1.7535],
          [-1.9019, -0.1321,  1.0882,  ...,  0.4782,  0.4536,  0.0413]]],


        [[[-0.6407,  1.6266,  0.7780,  ...,  0.0732,  1.9012,  0.7227],
          [-0.2527, -0.9426,  0.4597,  ...,  2.5559, -0.5602, -0.3998],
          [-0.0810,  1.2035,  0.4604,  ...,  1.2690, -1.4868,  0.4054],
          ...,
          [ 0.7943,  1.5745,  2.3014,  ..., -0.7546,  0.7239,  0.4300],
          [-0.6491, -0.5939,  0.3530,  ...,  0.7384, -1.0555,  0.3201],
          [-0.3338, -0.8145,  0.6119,  ..., -0.0636,  2.6792, -0.0924]]],


        [[[ 0.7908,  1.3282,  0.3064,  ..., -0.1714,  0.9390,  1.8570],
          [ 2.4846,  0.786

900it [00:15, 69.61it/s]

tensor([[[[ 2.0270,  3.0236,  2.7933,  ...,  1.7869,  1.0374, -1.3873],
          [ 0.3967,  0.8113, -0.4686,  ...,  1.6320, -1.0188,  0.7289],
          [ 1.1827,  0.4590,  2.0073,  ...,  0.5108,  1.2812,  2.2084],
          ...,
          [ 0.5602, -0.0828, -0.0477,  ...,  0.2564,  0.7019, -0.7538],
          [ 2.1536,  1.9830, -1.5536,  ...,  2.4683, -0.5003,  2.1836],
          [-1.9261,  0.1733,  1.6221,  ...,  0.6328,  0.9618,  0.3377]]],


        [[[-0.5437,  1.8441,  0.9613,  ...,  0.1345,  2.0605,  0.9533],
          [-0.0410, -1.1338,  0.1562,  ...,  2.6993, -0.8660, -0.2556],
          [ 0.1173,  1.1578, -0.1158,  ...,  0.7421, -1.6547,  0.5202],
          ...,
          [ 1.1759,  1.4927,  1.9794,  ..., -0.5866,  1.0468,  0.6246],
          [-0.5131, -0.7174,  0.4157,  ...,  0.9658, -1.0555,  0.7548],
          [-0.0768, -0.6989,  0.7662,  ...,  0.0873,  2.9632,  0.0790]]],


        [[[ 1.0332,  1.5903,  0.3199,  ..., -0.1964,  0.9513,  2.0210],
          [ 2.8090,  0.860

1000it [00:18, 54.27it/s]

tensor([[[[ 1.2560,  2.7557,  1.8401,  ...,  1.7941,  0.7100, -1.5873],
          [ 0.0289,  0.8865, -0.9783,  ...,  1.3973, -1.5135,  0.6998],
          [ 1.0568,  0.6041,  2.0980,  ...,  0.3221,  1.8177,  2.4812],
          ...,
          [ 0.1266, -0.3088, -0.2113,  ..., -0.0164,  0.6767, -0.8251],
          [ 2.1725,  1.8263, -1.9755,  ...,  3.0099, -0.7533,  2.1703],
          [-1.9233,  0.2218,  1.9540,  ...,  0.7000,  1.1133,  0.1793]]],


        [[[-0.7125,  1.7443,  0.7777,  ..., -0.2828,  1.8497,  0.3376],
          [-0.0089, -1.4559,  0.7170,  ...,  2.3485, -1.2870, -0.1089],
          [ 0.1416,  1.1859,  0.1715,  ...,  0.6820, -1.7638,  0.6378],
          ...,
          [ 1.2161,  1.6511,  1.7753,  ..., -0.8454,  0.8860,  0.6306],
          [-0.7794, -0.9227,  0.8146,  ...,  0.8829, -1.4247,  0.9242],
          [-0.0675, -0.7578,  0.9524,  ..., -0.0067,  2.3662, -0.1062]]],


        [[[ 0.4421,  1.2779,  0.1244,  ..., -0.6278,  0.7429,  1.0656],
          [ 2.0368,  0.897


