# Omniglot

## 1.Gaussian weights

In [1]:
import torch
import torch.nn as nn 
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import Omniglot
import torch.nn.functional as F
import torchvision.transforms as transforms

In [2]:
# Load Omniglot dataset
transform = transforms.Compose([transforms.Resize((28, 28)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])
train_dataset = Omniglot(root="./data", download=True, transform=transform)
test_dataset = Omniglot(root="./data", download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip to ./data/omniglot-py/images_background.zip


100%|██████████| 9464212/9464212 [00:00<00:00, 179087334.65it/s]

Extracting ./data/omniglot-py/images_background.zip to ./data/omniglot-py





Files already downloaded and verified


### VI

In [3]:
# Model with Gaussian prior 
class BNN(nn.Module):
      def __init__(self, n_in, n_out,n_hide):
        super().__init__()
        self.w = nn.Parameter(torch.randn(n_in, n_out))
        self.fc1 = nn.Linear(n_in, n_hide) 
        self.fc2 = nn.Linear(n_hide, 964)
        self.dropout = nn.Dropout(0.5) 
    
      def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)  
  
      def kl_divergence(self): 
        mean, std = self.w.mean(), self.w.std()
        return ((mean ** 2 + std ** 2 - torch.log(std **2)-1)/2).sum()

In [4]:
model = BNN(784, 400, 964)

In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(50):
    for x, y in train_loader:
        x = x.view(-1, 784)
        loss = F.nll_loss(model(x), y)
        loss += model.kl_divergence() * 0.1   #添加KL散度
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch: {epoch+1}, Train Loss: {loss.item():.3f}')

Epoch: 1, Train Loss: 6.866
Epoch: 2, Train Loss: 6.843
Epoch: 3, Train Loss: 6.742
Epoch: 4, Train Loss: 6.665
Epoch: 5, Train Loss: 6.639
Epoch: 6, Train Loss: 6.562
Epoch: 7, Train Loss: 6.403
Epoch: 8, Train Loss: 6.547
Epoch: 9, Train Loss: 6.433
Epoch: 10, Train Loss: 6.417
Epoch: 11, Train Loss: 6.493
Epoch: 12, Train Loss: 6.566
Epoch: 13, Train Loss: 6.400
Epoch: 14, Train Loss: 6.507
Epoch: 15, Train Loss: 6.554
Epoch: 16, Train Loss: 6.411
Epoch: 17, Train Loss: 6.398
Epoch: 18, Train Loss: 6.534
Epoch: 19, Train Loss: 6.349
Epoch: 20, Train Loss: 6.468
Epoch: 21, Train Loss: 6.275
Epoch: 22, Train Loss: 6.400
Epoch: 23, Train Loss: 6.287
Epoch: 24, Train Loss: 6.455
Epoch: 25, Train Loss: 6.369
Epoch: 26, Train Loss: 6.195
Epoch: 27, Train Loss: 6.382
Epoch: 28, Train Loss: 6.357
Epoch: 29, Train Loss: 6.341
Epoch: 30, Train Loss: 6.421
Epoch: 31, Train Loss: 6.248
Epoch: 32, Train Loss: 6.390
Epoch: 33, Train Loss: 6.198
Epoch: 34, Train Loss: 6.422
Epoch: 35, Train Loss: 

In [8]:
# Test accuracy
with torch.no_grad():
    correct = 0
    for x, y in test_loader:
        x = x.view(-1, 784)
        y_pred = model(x).argmax(dim=1)
        correct += (y_pred == y).sum().item()

print(f'Test accuracy: {correct/len(test_dataset):.5f}')

Test accuracy: 0.40716


### 1.2MCMC

In [9]:
# Get train data
train_x = []  
train_y = []
for x, y in train_loader:
    train_x.append(x.view(-1, 784))
    train_y.append(y)
train_x = torch.cat(train_x)  
train_y = torch.cat(train_y)

In [10]:
# MCMC Sampling
model = BNN(784, 964, 400)
w = model.w.data  # Get weight tensor 

In [11]:
for i in range(1500):  # 1500 iterations
    # Proposal distribution
    w_proposal = w + torch.randn(w.size())*0.1 

    # Acceptance ratio
    outputs = model(train_x)  # log_softmax输出
    ap = torch.exp(outputs.gather(1, train_y.unsqueeze(1)).sum()) 

    # Metropolis acceptance
    u = torch.rand(1)
    if u < ap:
        w = w_proposal  # Accept proposal 

    model.w.data = w  # Set weight to sampled value

In [13]:
# Test accuracy
with torch.no_grad():
    correct = 0
    for x, y in test_loader:
        x = x.view(-1, 784)
        y_pred = model(x).argmax(dim=1) 
        correct += (y_pred == y).sum().item()
    print(f'Test accuracy: {correct/len(test_dataset):.5f}') 

Test accuracy: 0.10270


## 2.Laplace weights

### 2.1Variable Inference

In [43]:
# Model with Laplace prior 
class BNN(nn.Module):
    def __init__(self, n_in, n_out,n_hide):
        super().__init__()
        self.w = nn.Parameter(torch.zeros(n_in, n_out))
        self.fc1 = nn.Linear(n_in, n_hide)  
        self.fc2 = nn.Linear(n_hide, 964)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)  

    def kl_divergence(self):
        return (self.w.abs().sum() / 2).sum()  # Laplace prior  

In [44]:
model = BNN(784, 400, 964) 

In [45]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(50):
    for x, y in train_loader:
        x = x.view(-1, 784)
        loss = F.nll_loss(model(x), y)
        loss += model.kl_divergence() * 0.1   
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch: {epoch+1}, Train Loss: {loss.item():.3f}')

Epoch: 1, Train Loss: 6.878
Epoch: 2, Train Loss: 6.821
Epoch: 3, Train Loss: 6.807
Epoch: 4, Train Loss: 6.828
Epoch: 5, Train Loss: 6.752
Epoch: 6, Train Loss: 6.799
Epoch: 7, Train Loss: 6.745
Epoch: 8, Train Loss: 6.727
Epoch: 9, Train Loss: 6.727
Epoch: 10, Train Loss: 6.764
Epoch: 11, Train Loss: 6.695
Epoch: 12, Train Loss: 6.603
Epoch: 13, Train Loss: 6.738
Epoch: 14, Train Loss: 6.570
Epoch: 15, Train Loss: 6.450
Epoch: 16, Train Loss: 6.662
Epoch: 17, Train Loss: 6.550
Epoch: 18, Train Loss: 6.570
Epoch: 19, Train Loss: 6.544
Epoch: 20, Train Loss: 6.505
Epoch: 21, Train Loss: 6.434
Epoch: 22, Train Loss: 6.385
Epoch: 23, Train Loss: 6.295
Epoch: 24, Train Loss: 6.285
Epoch: 25, Train Loss: 6.450
Epoch: 26, Train Loss: 6.310
Epoch: 27, Train Loss: 6.313
Epoch: 28, Train Loss: 6.379
Epoch: 29, Train Loss: 6.403
Epoch: 30, Train Loss: 6.272
Epoch: 31, Train Loss: 6.232
Epoch: 32, Train Loss: 6.224
Epoch: 33, Train Loss: 6.370
Epoch: 34, Train Loss: 6.313
Epoch: 35, Train Loss: 

In [47]:
# Test accuracy  
with torch.no_grad():
    correct = 0
    for x, y in test_loader:
        x = x.view(-1, 784)
        y_pred = model(x).argmax(dim=1)
        correct += (y_pred == y).sum().item() 
print(f'Test accuracy: {correct/len(test_dataset):.5f}') 

Test accuracy: 0.42116


### 2.2MCMC

In [18]:
# Get train data
train_x = []  
train_y = []
for x, y in train_loader:
    train_x.append(x.view(-1, 784))
    train_y.append(y)
train_x = torch.cat(train_x)  
train_y = torch.cat(train_y)

In [19]:
# Model with Laplace prior
class BNN(nn.Module):
    def __init__(self, n_in, n_hide, n_out):
        super().__init__()
        self.fc1 = nn.Linear(n_in, n_hide)  
        self.fc2 = nn.Linear(n_hide, n_out)
        self.dropout = nn.Dropout(0.5)
        self.w = torch.zeros(n_in, n_out)  

    def forward(self, x, w_proposal=None):
        if w_proposal is None:  
              w = self.w
        else:
              w = w_proposal
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        pred = F.log_softmax(x, dim=1)
        dist = torch.distributions.Categorical(logits=pred)
        return dist

    def log_prior(self, w):
        return -w.abs().sum() 

In [20]:
# MCMC Sampling
model = BNN(784, 964, 400) 
w = model.w.data  # Get weight tensor 

In [21]:
# training
for i in range(1500): 
    # Proposal distribution
    w_proposal = w + torch.randn(w.size())*0.1

    # Prior ratio 
    lp = model.log_prior(w_proposal)  
    ap = torch.exp(lp - model.log_prior(w))

    # Metropolis acceptance 
    u = torch.rand(1)
    if u < ap:
        w = w_proposal  

    model.w.data = w 

In [22]:
# Test accuracy
with torch.no_grad():
    correct = 0
    for x, y in test_loader:
        x = x.view(-1, 784)
        y_pred = model(x).sample()
        correct += (y_pred == y).sum().item()
    print(f'Test accuracy: {correct/len(test_dataset):.5f}') 

Test accuracy: 0.09803


## 3.Equalisation weights

### 3.1Variable Inference

In [23]:
# Model with Uniform prior
class BNN(nn.Module):
    def __init__(self, n_in, n_out,n_hide):
        super().__init__()
        self.w = nn.Parameter(torch.zeros(n_in, n_out))
        self.fc1 = nn.Linear(n_in, n_hide)
        self.fc2 = nn.Linear(n_hide, 964)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)  

    def kl_divergence(self):
        return 0.5*(self.w**2).sum()  # Uniform prior  

In [24]:
model = BNN(784, 400, 964) 

In [25]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(50):
    for x, y in train_loader:
        x = x.view(-1, 784)
        loss = F.nll_loss(model(x), y)
        loss += model.kl_divergence() * 0.1   
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch: {epoch+1}, Train Loss: {loss.item():.3f}')


Epoch: 1, Train Loss: 6.856
Epoch: 2, Train Loss: 6.737
Epoch: 3, Train Loss: 6.583
Epoch: 4, Train Loss: 6.467
Epoch: 5, Train Loss: 6.348
Epoch: 6, Train Loss: 6.424
Epoch: 7, Train Loss: 6.168
Epoch: 8, Train Loss: 6.203
Epoch: 9, Train Loss: 6.385
Epoch: 10, Train Loss: 6.391
Epoch: 11, Train Loss: 6.072
Epoch: 12, Train Loss: 6.200
Epoch: 13, Train Loss: 6.325
Epoch: 14, Train Loss: 6.207
Epoch: 15, Train Loss: 6.199
Epoch: 16, Train Loss: 6.170
Epoch: 17, Train Loss: 6.367
Epoch: 18, Train Loss: 6.145
Epoch: 19, Train Loss: 6.220
Epoch: 20, Train Loss: 6.229
Epoch: 21, Train Loss: 6.201
Epoch: 22, Train Loss: 6.274
Epoch: 23, Train Loss: 6.358
Epoch: 24, Train Loss: 6.086
Epoch: 25, Train Loss: 6.204
Epoch: 26, Train Loss: 6.155
Epoch: 27, Train Loss: 6.425
Epoch: 28, Train Loss: 6.327
Epoch: 29, Train Loss: 6.311
Epoch: 30, Train Loss: 5.984
Epoch: 31, Train Loss: 6.325
Epoch: 32, Train Loss: 6.163
Epoch: 33, Train Loss: 6.284
Epoch: 34, Train Loss: 6.282
Epoch: 35, Train Loss: 

In [26]:
# Test accuracy  
with torch.no_grad():
    correct = 0
    for x, y in test_loader:
        x = x.view(-1, 784)
        y_pred = model(x).argmax(dim=1)
        correct += (y_pred == y).sum().item() 
print(f'Test accuracy: {correct/len(test_dataset):.5f}') 

Test accuracy: 0.45643


### 3.2MCMC

In [27]:
# Get train data
train_x = []  
train_y = []
for x, y in train_loader:
    train_x.append(x.view(-1, 784))
    train_y.append(y)
train_x = torch.cat(train_x)  
train_y = torch.cat(train_y)

In [28]:
# Model with Uniform prior
class BNN(nn.Module):
    def __init__(self, n_in, n_hide, n_out):
        super().__init__()
        self.fc1 = nn.Linear(n_in, n_hide)  
        self.fc2 = nn.Linear(n_hide, n_out)
        self.dropout = nn.Dropout(0.5)
        self.w = torch.zeros(n_in, n_out)  # Weight parameter

    def forward(self, x, w_proposal=None):
        if w_proposal is None:
              w = self.w
        else:
              w = w_proposal
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        pred = F.log_softmax(x, dim=1)
        dist = torch.distributions.Categorical(logits=pred)
        return dist

    def log_prior(self, w):
        return -torch.sum(torch.abs(w))

In [29]:
# MCMC Sampling
model = BNN(784, 964 ,400) 
w = model.w.data  # Get weight tensor 

In [36]:
for i in range(1500): 
    w_proposal = torch.rand(w.size())

    ap = torch.tensor(0.)  
    
    u = torch.rand(1)
    if u < ap:  
          w = w_proposal
            
    # update
    model.w.data = w  

    # loss
    logits = model(train_x, w_proposal).logits
    loss = F.cross_entropy(logits, train_y)

In [42]:
# Test accuracy
with torch.no_grad():
    correct = 0
    for x, y in test_loader:
        x = x.view(-1, 784)
        y_pred = model(x).sample()
        correct += (y_pred == y).sum().item()
    print(f'Test accuracy: {correct/len(test_dataset):.5f}') 

Test accuracy: 0.09129
