# Demo - Freeze Bayesian Neural Network

In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

import torchbnn as bnn
from torchbnn.utils import freeze, unfreeze

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

## 2. Define Model

In [3]:
model = nn.Sequential(
    bnn.BayesLinear(prior_mu=0, prior_sigma=0.05, in_features=2, out_features=2),
    nn.ReLU(),
    bnn.BayesLinear(prior_mu=0, prior_sigma=0.05, in_features=2, out_features=1),
)

## 3. Forward Model

In [4]:
model(torch.ones(1, 2))

tensor([[-0.4672]], grad_fn=<AddmmBackward>)

In [5]:
model(torch.ones(1, 2))

tensor([[-0.3220]], grad_fn=<AddmmBackward>)

## 3. Freeze Model

In [6]:
freeze(model)

In [7]:
model(torch.ones(1, 2))

tensor([[-0.4340]], grad_fn=<AddmmBackward>)

In [8]:
model(torch.ones(1, 2))

tensor([[-0.4340]], grad_fn=<AddmmBackward>)

In [9]:
freeze(model)

In [10]:
model(torch.ones(1, 2))

tensor([[-0.2875]], grad_fn=<AddmmBackward>)

In [11]:
model(torch.ones(1, 2))

tensor([[-0.2875]], grad_fn=<AddmmBackward>)

## 4. Unfreeze Model

In [12]:
unfreeze(model)

In [13]:
model(torch.ones(1, 2))

tensor([[-0.4530]], grad_fn=<AddmmBackward>)

In [14]:
model(torch.ones(1, 2))

tensor([[-0.4920]], grad_fn=<AddmmBackward>)