In [29]:
import argparse
import collections
import copy
import gc
import logging
import math
import os
import time
import torchsde
import numpy as np

import models
import utils
import tqdm
import diffeq_layers

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch import optim
from torch.utils.tensorboard import SummaryWriter



def make_y_net(input_size,
               blocks=(2, 2, 2),
               activation="softplus",
               verbose=False,
               explicit_params=True,
               hidden_width=128,
               aug_dim=0):
    
    """This is the bayesian neural network"""

    _input_size = (input_size[0] + aug_dim,) + input_size[1:]
    layers = []
    print(f"input_size ynet: {_input_size}")

    for i, num_blocks in enumerate(blocks, 1):
        for j in range(1, num_blocks + 1):

            print(_input_size, 'forloop')
            layers.extend(diffeq_layers.make_ode_k3_block_layers(input_size=_input_size,
                                                                 activation=activation,
                                                                 last_activation=i < len(blocks) or j < num_blocks,
                                                                 hidden_width=hidden_width, mode=1))

            if verbose:
                if i == 1:
                    print(f"y_net (augmented) input size: {_input_size}")
                layers.append(diffeq_layers.Print(name=f"group: {i}, block: {j}"))

        if i < len(blocks):
            layers.append(diffeq_layers.ConvDownsample(_input_size))
            _input_size = _input_size[0] * 4, _input_size[1] // 2, _input_size[2] // 2
            print(_input_size, 'convdownsample')

    y_net = diffeq_layers.DiffEqSequential(*layers, explicit_params=explicit_params)

    # return augmented input size b/c y net should have same input / output
    return y_net, _input_size


# TODO: add STL
class SDENet(torchsde.SDEStratonovich):
    def __init__(self,
                 input_size=(3, 32, 32),
                 blocks=(2, 2, 2),
                 weight_network_sizes=(1, 64, 1),
                 num_classes=10,
                 activation="softplus",
                 verbose=False,
                 inhomogeneous=True,
                 sigma=0.1,
                 hidden_width=128,
                 aug_dim=0):
        super(SDENet, self).__init__(noise_type="diagonal")
        self.input_size = input_size
        self.aug_input_size = (aug_dim + input_size[0], *input_size[1:])
        self.aug_zeros_size = (aug_dim, *input_size[1:])
        self.register_buffer('aug_zeros', torch.zeros(size=(1, *self.aug_zeros_size)))

        # Create network evolving state.
        self.y_net, self.output_size = make_y_net(
            input_size=input_size,
            blocks=blocks,
            activation=activation,
            verbose=verbose,
            hidden_width=hidden_width,
            aug_dim=aug_dim
        )
        # Create network evolving weights.
        initial_params = self.y_net.make_initial_params()  # w0.
        flat_initial_params, unravel_params = utils.ravel_pytree(initial_params)
        self.flat_initial_params = nn.Parameter(flat_initial_params, requires_grad=True)
        self.params_size = flat_initial_params.numel()
        print(f"initial_params ({self.params_size}): {flat_initial_params.shape}")
        self.unravel_params = unravel_params
        self.w_net = models.make_w_net(
            in_features=self.params_size,
            hidden_sizes=weight_network_sizes,
            activation="tanh",
            inhomogeneous=inhomogeneous
        )

        # Final projection layer.
        self.projection = nn.Sequential(
            nn.Flatten(),
            # nn.Linear(int(np.prod(self.output_size)), num_classes), # option: projection w/o ReLU
            nn.Linear(int(np.prod(self.output_size)), 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, num_classes),
        )

        self.register_buffer('ts', torch.tensor([0., 1.]))
        self.sigma = sigma
        self.nfe = 0

    def f(self, t, y: torch.Tensor):
        input_y = y
        self.nfe += 1
        y, w, _ = y.split(split_size=(y.numel() - self.params_size - 1, self.params_size, 1), dim=1) # params_size: 606408

        fy = self.y_net(t, y.reshape((-1, *self.aug_input_size)), self.unravel_params(w.squeeze(0))).reshape(-1).unsqueeze(0)
        nn = self.w_net(t, w)
        fw = nn - w  # hardcoded OU prior on weights w
        fl = (nn ** 2).sum(dim=1, keepdim=True) / (self.sigma ** 2)

        assert input_y.shape == torch.cat([fy, fw, fl], dim=1).shape, f"Want: {input_y.shape} Got: {torch.cat((fy, fw, fl)).shape}. Check nblocks for dataset divisibility.\n"
        return torch.cat([fy, fw, fl], dim=1)#.squeeze(0)

    def g(self, t, y):
        self.nfe += 1
        gy = torch.zeros(size=(y.numel() - self.params_size - 1,), device=y.device)
        gw = torch.full(size=(self.params_size,), fill_value=self.sigma, device=y.device)
        gl = torch.tensor([0.], device=y.device)
        return torch.cat([gy, gw, gl], dim=0).unsqueeze(0)

    def make_initial_params(self):
        return self.y_net.make_initial_params()

    def forward(self, y, adjoint=False, dt=0.02, adaptive=False, adjoint_adaptive=False, method="midpoint", rtol=1e-4, atol=1e-3):
        # Note: This works correctly, as long as we are requesting the nfe after each gradient update.
        #  There are obviously cleaner ways to achieve this.
        self.nfe = 0    
        sdeint = torchsde.sdeint_adjoint if adjoint else torchsde.sdeint
        if self.aug_zeros.numel() > 0:  # Add zero channels.
            aug_zeros = self.aug_zeros.expand(y.shape[0], *self.aug_zeros_size)
            y = torch.cat([y, aug_zeros], dim=1) # 235200
        aug_y = torch.cat((y.reshape(-1), self.flat_initial_params, torch.tensor([0.], device=y.device))) # 841609: (235200, 606408, 1)
        aug_y = aug_y[None]
        bm = torchsde.BrownianInterval(
            t0=self.ts[0], t1=self.ts[-1], size=aug_y.shape, dtype=aug_y.dtype, device=aug_y.device,
            cache_size=45 if adjoint else 30  # If not adjoint, don't really need to cache.
        )
        if adjoint_adaptive:
            _, aug_y1 = sdeint(self, aug_y, self.ts, bm=bm, method=method, dt=dt, adaptive=adaptive, adjoint_adaptive=adjoint_adaptive, rtol=rtol, atol=atol)
        else:
            _, aug_y1 = sdeint(self, aug_y, self.ts, bm=bm, method=method, dt=dt, adaptive=adaptive, rtol=rtol, atol=atol)
        
        print(aug_y1.shape, 'ww, aug_y1')

        y1 = aug_y1[:,:y.numel()].reshape(y.size())
        logits = self.projection(y1)
        #logits = nn.functional.softmax(logits, dim=1)
        logqp = .5 * aug_y1[-1]
        return logits, logqp

    def zero_grad(self) -> None:
        for p in self.parameters(): p.grad = None

print(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))

def train(model, ema_model, optimizer, scheduler, epochs, global_step=0, output_dir=None, start_epoch=0,
          best_test=0, best_val=0, info=collections.defaultdict(dict), tb_writer=None):
    
    def add_scalar_(name: str, arg1, arg2):
        if torch.is_tensor(arg1):
            tb_writer.add_scalar(name, arg1[0], arg2)
        else:
            tb_writer.add_scalar(name, arg1, arg2)
    
    train_xent = utils.EMAMeter()
    train_accuracy = utils.EMAMeter()
    test_accuracy, best_test_acc = 0, best_test
    val_xent, val_xent_ema, val_accuracy, best_val_acc = 0, 0, 0, best_val
    obj, kl, ece, nfe = 0, 0, utils.AverageMeter(), utils.AverageMeter()
    epoch_start = time.time()
    for epoch in range(start_epoch, epochs):
        itr_per_epoch = 0
        for i, (x, y) in tqdm.tqdm(enumerate(train_loader)):
            model.train()
            model.zero_grad()
            x, y = x.to(device), y.to(device, non_blocking=True)
            logits, logqp = model(
                x, dt=args.dt, adjoint=args.adjoint, method=args.method, adaptive=args.adaptive, adjoint_adaptive=args.adjoint_adaptive, rtol=args.rtol, atol=args.atol
            )
            nfes = model.nfe
            xent = F.cross_entropy(logits, y, reduction="mean")
            loss = xent + args.kl_coeff * logqp
            obj, kl = loss, args.kl_coeff * logqp
            predictions = logits.detach().argmax(dim=1)
            accuracy = torch.eq(predictions, y).float().mean()
            train_ece = utils.score_model(logits.detach().cpu().numpy(), y.detach().cpu().numpy())[2]
            ece.step(train_ece)
            nfe.step(nfes)
            loss.mean().backward()  # retain_graph=True
            optimizer.step()
            scheduler.step()
            train_xent.step(loss.mean())
            train_accuracy.step(accuracy)
            utils.ema_update(model=model, ema_model=ema_model, gamma=args.gamma)
            global_step += 1
            itr_per_epoch += 1
            gc.collect()
            # per itr nfes: {global step: [train nfe, [for each pause_every:] val nfe, val nfe ema, test nfe,
            # test nfe ema]}
            info["nfes"] = {global_step: [nfes]}

            if global_step % args.pause_every == 0:
                # tb_writer.add_scalar(f'Grad Norm (pause/{args.pause_every})', torch.norm(x.grad), global_step)
                # TODO: magnitude of learned drift function
                # drift_y = model.f(0, aug_y)[:y.numel()]
                add_scalar_(f'Activation Norm (pause/{args.pause_every})', torch.norm(y.detach().cpu().float()).numpy().mean(), global_step)
                add_scalar_(f'NFE/train (pause/{args.pause_every})', nfes, global_step)
                val_xent, val_accuracy, val_ece, val_nfe = evaluate(model, validate=True)
                val_xent_ema, val_accuracy_ema, val_ece_ema, val_nfe_ema = evaluate(ema_model, validate=True)
                info['nfes'][global_step].extend([val_nfe, val_nfe_ema])
                if val_accuracy > best_val_acc:
                    best_val_acc = val_accuracy
                    utils.save_ckpt(model, ema_model, optimizer, os.path.join(output_dir, "best_val_acc.ckpt"),
                                    scheduler, epoch=epoch, global_step=global_step, best_acc=best_test_acc,
                                    best_val=best_val_acc, info=info)
                add_scalar_('Accuracy/val', val_accuracy, global_step)
                add_scalar_('Accuracy EMA/val', val_accuracy_ema, global_step)
                add_scalar_('NLL/val', val_xent, global_step)
                add_scalar_('NLL EMA/val', val_xent_ema, global_step)
                add_scalar_('ECE/val', val_ece, global_step)
                add_scalar_('ECE EMA/val', val_ece_ema, global_step)
                add_scalar_('NFE/val (total/inference)', val_nfe, global_step)
                add_scalar_('NFE EMA/val (total/inference)', val_nfe_ema, global_step)
                logging.warning(
                    f"global step: {global_step}, "
                    f"epoch: {epoch}, "
                    f"train_xent: {train_xent.val:.4f}, "
                    f"train_accuracy: {train_accuracy.val:.4f}, "
                    f"val_xent: {val_xent:.4f}, "
                    f"val_accuracy: {val_accuracy:.4f}, "
                    f"val_xent_ema: {val_xent_ema:.4f}, "
                    f"val_accuracy_ema: {val_accuracy_ema:.4f}"
                )

        epoch_time = epoch_start - time.time()
        utils.save_ckpt(model, ema_model, optimizer, os.path.join(output_dir, "state.ckpt"), scheduler, epoch=epoch,
                        global_step=global_step, best_val=best_val_acc, best_acc=best_test_acc, info=info)
        # import pdb; pdb.set_trace()
        add_scalar_('Accuracy/train', train_accuracy.val, epoch)
        add_scalar_('NLL/train', train_xent.val, epoch)
        add_scalar_('KL/train', kl, epoch)
        add_scalar_('Loss/train', obj, epoch)
        add_scalar_('ECE/train', ece.val, epoch)
        add_scalar_('NFE/train (avg/epoch)', nfe.val, epoch)
        nfe.__init__()  # reset for new epoch
        logging.warning("Wrote training scalars to tensorboard")

        test_xent, test_accuracy, test_ece, test_nfe = evaluate(model)
        test_xent_ema, test_accuracy_ema, test_ece_ema, test_nfe_ema = evaluate(ema_model)
        info['nfes'][global_step].extend([test_nfe, test_nfe_ema])
        add_scalar_('Accuracy/test', test_accuracy, epoch)
        add_scalar_('Accuracy EMA/test', test_accuracy_ema, epoch)
        add_scalar_('NLL/test', test_xent, epoch)
        add_scalar_('NLL EMA/test', test_xent_ema, global_step)
        add_scalar_('ECE/test', test_ece, epoch)
        add_scalar_('ECE EMA/test', test_ece_ema, epoch)
        add_scalar_('NFE/test (total/inference)', test_nfe, epoch)
        add_scalar_('NFE EMA/test (total/inference)', test_nfe_ema, epoch)
        logging.warning("Wrote test scalars to tensorboard")
        if test_accuracy > best_test_acc:
            best_test_acc = test_accuracy
            utils.save_ckpt(model, ema_model, optimizer, os.path.join(output_dir, "best_test_acc.ckpt"), scheduler,
                            epoch=epoch, global_step=global_step, best_val=best_val_acc, best_acc=best_test_acc,
                            info=info)
        with open(os.path.join(output_dir, "results.txt"), "a") as f:
            f.write(f"Epoch {epoch} (step {global_step}) in {epoch_time:.4f} sec | Train acc {train_accuracy.val}" + \
                    f" | Test accuracy {test_accuracy} | Test EMA accuracy {test_accuracy_ema}" + \
                    f" | Train NLL {train_xent.val} | Test NLL {test_xent} | Test EMA NLL {test_xent_ema} | Train "
                    f"Loss " + \
                    f" {obj.detach().cpu().numpy().tolist()} | Train KL {kl}" + \
                    f" | Train ECE {ece.val} | Test ECE {test_ece} | Test ECE EMA {test_ece_ema}" + \
                    f" | Train nfes {nfe.val} | Test NFE {test_nfe} | Test NFE EMA {test_nfe_ema}\n")
            logging.warning(f"Wrote epoch info to {os.path.join(output_dir, 'results.txt')}")
        info[global_step] = {'epoch': epoch, 'time': epoch_time, 'train_acc': train_accuracy.val,
                             'test_acc': test_accuracy,
                             'train_nll': train_xent.val, 'test_nll': test_xent, 'test_ema_nll': test_xent_ema,
                             'train_loss': obj.detach().cpu().numpy().tolist(),
                             'train_kl': kl.detach().cpu().numpy().tolist(), "val_acc": val_accuracy,
                             "val_xent": val_xent, "val_xent_ema": val_xent_ema,
                             "train_ece": ece.val, "test_ece": test_ece, "test_ece_ema": test_ece_ema,
                             "itr_per_epoch": itr_per_epoch, "avg_train_nfe": nfe.val,
                             "test_nfe": test_nfe, "test_nfe_ema": test_nfe_ema}
        utils.write_state_config(info, args.train_dir, file_name='state.json')


@torch.no_grad()
def _evaluate_with_loader(model, loader):
    xents = []
    accuracies = []
    eces = []
    nfes = 0
    model.eval()
    for i, (x, y) in enumerate(loader, 1):
        x, y = x.to(device), y.to(device, non_blocking=True)
        logits, _ = model(x, dt=args.dt, adjoint=args.adjoint, adjoint_adaptive=args.adjoint_adaptive, method=args.method)  # , rtol=args.rtol, atol=args.atol)
        loss = F.cross_entropy(logits, y, reduction="none")
        predictions = logits.detach().argmax(dim=1)
        accuracy = torch.eq(predictions, y).float()
        scores = utils.score_model(logits.detach().cpu().numpy(), y.detach().cpu().numpy())
        xents.append(loss)
        accuracies.append(accuracy)
        eces.append(torch.tensor([scores[2]]))
        nfes += model.nfe
        if i >= args.eval_batches: break
    return tuple(torch.cat(x, dim=0).mean(dim=0).cpu().item() for x in (xents, accuracies, eces)) + (nfes,)


def evaluate(model, validate=False):
    if validate:
        logging.warning("evaluating on validation set")
        test_xent, test_accuracy, test_ece, test_nfe = _evaluate_with_loader(model, val_loader)
    else:
        logging.warning("evaluating on test set")
        test_xent, test_accuracy, test_ece, test_nfe = _evaluate_with_loader(model, test_loader)
    return test_xent, test_accuracy, test_ece, test_nfe


def get_cosine_schedule_with_warmup(optimizer,
                                    num_training_steps,
                                    num_warmup_steps=0,
                                    num_cycles=7. / 16.,
                                    last_epoch=-1):
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        no_progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0., math.cos(math.pi * num_cycles * no_progress))

    return optim.lr_scheduler.LambdaLR(optimizer, _lr_lambda, last_epoch)


def main():
    input_chw = (3, 32, 32)
    if args.data == "mnist":
        input_chw = (1, 28, 28)
    if args.model == "baseline":
        model = models.BaselineYNet(
            input_size=input_chw,
            activation=args.activation,
            hidden_width=args.hidden_width
        )
    elif args.model == "sdebnn":
        model = SDENet(
            input_size=input_chw,
            inhomogeneous=args.inhomogeneous,
            activation=args.activation,
            verbose=args.verbose,
            hidden_width=args.hidden_width,
            weight_network_sizes=(1,128,1),
            blocks=(2,2,2),
            sigma=args.sigma,
            aug_dim=args.aug,
        )
    else:
        raise ValueError(f"Unknown model: {args.model}")
    ema_model = copy.deepcopy(model)
    model.to(device)
    ema_model.to(device)

    optimizer = optim.Adam(lr=args.lr, params=model.parameters())
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_training_steps=args.epochs * (50000 // args.batch_size))

    start_epoch, best_test_acc, best_val_acc, global_step = 0, 0, 0, 0
    info = collections.defaultdict(dict)
    if os.path.exists(os.path.join(args.train_dir, "state.ckpt")):
        # if os.path.exists(os.path.join(args.train_dir, "best_val_acc.ckpt")): # TODO: for debugging
        logging.warning("Loading checkpoints...")
        checkpoint = torch.load(os.path.join(args.train_dir, "state.ckpt"))
        # checkpoint = torch.load(os.path.join(args.train_dir, "best_val_acc.ckpt")) # TODO: for debugging
        start_epoch = checkpoint['epoch']
        best_test_acc = checkpoint['best_acc']
        best_val_acc = checkpoint['best_val_acc']
        info = checkpoint['info']
        global_step = checkpoint['global_step']
        model.load_state_dict(checkpoint["model"])
        ema_model.load_state_dict(checkpoint["ema_model"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        logging.warning(f"Successfully loaded checkpoints for epoch {start_epoch} | best acc {best_test_acc}")

    logging.warning(f'model: {model}')
    logging.warning(f'{utils.count_parameters(model) / 1e6:.4f} million parameters')

    tb_writer = SummaryWriter(os.path.join(args.train_dir, 'tb'))
    train(
        model, ema_model, optimizer, scheduler, args.epochs,
        output_dir=args.train_dir, global_step=global_step, start_epoch=start_epoch,
        best_test=best_test_acc, best_val=best_val_acc, info=info, tb_writer=tb_writer
    )
    tb_writer.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--train-dir', type=str, default='train', required=False)
    parser.add_argument('--seed', type=int, default=1000000)
    parser.add_argument('--no-gpu', action="store_true")
    parser.add_argument('--subset', type=int, default=None, help="Use subset of mnist data.")
    parser.add_argument('--data', type=str, default="mnist", choices=['mnist', 'cifar10', 'cifar100'])
    parser.add_argument('--pin-memory', type=utils.str2bool, default=True)
    parser.add_argument('--num-workers', type=int, default=1)
    parser.add_argument('--model', type=str, choices=['baseline', 'sdebnn'], default='sdebnn')
    parser.add_argument('--method', type=str, choices=['milstein', 'midpoint', "heun", "euler_heun"], default='midpoint')
    parser.add_argument('--gamma', type=float, default=0.999)

    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--aug', type=int, default=0)
    parser.add_argument('--epochs', type=int, default=2)
    parser.add_argument('--batch-size', type=int, default=128)
    parser.add_argument('--eval-batch-size', type=int, default=512)
    parser.add_argument('--pause-every', type=int, default=200) 
    parser.add_argument('--eval-batches', type=int, default=10000)

    # Model.
    parser.add_argument('--dt', type=float, default=0.1)
    parser.add_argument('--rtol', type=float, default=1e-5)
    parser.add_argument('--atol', type=float, default=1e-4)
    parser.add_argument('--steps', type=int, default=20)
    parser.add_argument('--adjoint', type=utils.str2bool, default=False)
    parser.add_argument('--adaptive', type=utils.str2bool, default=False)
    parser.add_argument('--adjoint_adaptive', type=utils.str2bool, default=False)
    parser.add_argument('--inhomogeneous', type=utils.str2bool, default=True)
    parser.add_argument('--activation', type=str, default="softplus",
                        choices=['swish', 'mish', 'softplus', 'tanh', 'relu', 'elu'])
    parser.add_argument('--verbose', type=utils.str2bool, default=False)
    parser.add_argument('--hidden-width', type=int, default=32)
    parser.add_argument('--fw-width', type=str, default="1-128-1")
    parser.add_argument('--nblocks', type=str, default="2-2-2")
    parser.add_argument('--sigma', type=float, default=0.1)

    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--nesterov', type=utils.str2bool, default=True)
    parser.add_argument('--kl-coeff', type=float, default=1e-3, help='Coefficient on the KL term.')
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_gpu else 'cpu')
    torch.backends.cudnn.benchmark = True  # noqa

    utils.manual_seed(args)
    utils.write_config(args)

    print(args.pin_memory, args.num_workers)

    train_loader, val_loader, test_loader = utils.get_loader(
        args.data,
        train_batch_size=args.batch_size,
        test_batch_size=args.eval_batch_size,
        pin_memory=args.pin_memory,
        num_workers=args.num_workers,
        subset=args.subset,
        task="classification"
    )

    logging.warning(
        f"Training set size: {utils.count_examples(train_loader)}, "
        f"Val set size: {utils.count_examples(val_loader)}, "
        f"test set size: {utils.count_examples(test_loader)}"
    )







cpu
True 1




In [30]:
main()

  (y_net): Sequential(
    (0): ConcatConv2d(2, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): DiffEqWrapper(
      (module): Softplus(beta=1, threshold=20)
    )
    (2): ConcatConv2d(33, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): DiffEqWrapper(
      (module): Softplus(beta=1, threshold=20)
    )
    (4): ConcatConv2d(33, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): DiffEqWrapper(
      (module): Softplus(beta=1, threshold=20)
    )
    (6): ConcatConv2d(2, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): DiffEqWrapper(
      (module): Softplus(beta=1, threshold=20)
    )
    (8): ConcatConv2d(33, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): DiffEqWrapper(
      (module): Softplus(beta=1, threshold=20)
    )
    (10): ConcatConv2d(33, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): DiffEqWrapper(
      (module): Softplus(beta=1, threshold=20)
    )
    (12): ConvDownsample(2, 4, 

input_size ynet: (1, 28, 28)
(1, 28, 28) forloop
(1, 28, 28) forloop
(4, 14, 14) convdownsample
(4, 14, 14) forloop
(4, 14, 14) forloop
(16, 7, 7) convdownsample
(16, 7, 7) forloop
(16, 7, 7) forloop
initial_params (84560): torch.Size([84560])


0it [00:00, ?it/s]

torch.Size([1, 184913]) ww, aug_y1


1it [00:13, 13.62s/it]

torch.Size([1, 184913]) ww, aug_y1


1it [00:24, 24.74s/it]


KeyboardInterrupt: 

In [None]:
131072/128

1024.0

In [None]:
model = SDENet(
input_size=(1,28,28),
inhomogeneous=args.inhomogeneous,
activation=args.activation,
verbose=args.verbose,
hidden_width=args.hidden_width,
weight_network_sizes=(1,128,1),
blocks=(2,2,2),
sigma=args.sigma,
aug_dim=args.aug,
)

initial_params (141776): torch.Size([141776])


In [None]:
model.output_size

(16, 7, 7)

In [None]:
184913

784

In [None]:
inputsize = (5, 1, 28, 28)
sample = torch.randn(inputsize)

time_sample = torch.tensor([0])
y, out1= make_y_net(inputsize[1:],
               blocks=(2, 2, 2),
               activation="softplus",
               verbose=False,
               explicit_params=True,
               hidden_width=128,
               aug_dim=0)
print(out1)
ee, ff = utils.ravel_pytree(y.make_initial_params())
print(ee.shape)
fyyy = y(time_sample, sample.reshape((-1, 1, 28,28)), ff(ee.squeeze(0)))
print(fyyy.shape)

input_size ynet: (1, 28, 28)
(1, 28, 28) forloop
(1, 28, 28) forloop
(4, 14, 14) convdownsample
(4, 14, 14) forloop
(4, 14, 14) forloop
(16, 7, 7) convdownsample
(16, 7, 7) forloop
(16, 7, 7) forloop
(16, 7, 7)
torch.Size([998096])
torch.Size([5, 16, 7, 7])


In [None]:
sample.reshape((-1, 1, 28,28)).shape

torch.Size([5, 1, 28, 28])

In [None]:
fyyy.flatten(start_dim=1).shape

torch.Size([5, 1024])

In [None]:
fyyy.shape

torch.Size([5, 16, 8, 8])

In [None]:
sample.flatten(start_dim=1).shape

torch.Size([5, 784])

In [None]:
ee, ff = utils.ravel_pytree(y.make_initial_params())

In [None]:
ee.shape

torch.Size([1890512])

In [None]:
32*32

1024

In [None]:
fyyy.shape

torch.Size([5, 16, 8, 8])

In [None]:
ee

tensor([ 0.0716, -0.1845, -0.1163,  ...,  0.0176, -0.0211, -0.0175])

In [None]:
blocks = (2,2,2)
input_size = (3,28,28)
aug_dim=0

_input_size = (input_size[0] + aug_dim,) + input_size[1:]
layers = []
print(f"input_size ynet: {_input_size}")

for i, num_blocks in enumerate(blocks, 1):

    if i < len(blocks):
        
        _input_size = _input_size[0] * 4, _input_size[1] // 2, _input_size[2] // 2
        print(_input_size)


input_size ynet: (3, 28, 28)
(12, 14, 14)
(48, 7, 7)


In [None]:
235200/784

300.0

In [None]:
[2,1,189397]
_, aug = sdeint
[1, 12432423], [1,12432423]