------
0. 𝒲 ~ MVN(mean, cov)
* parameter 𝒲는 100개라고 가정
* μ<- random한 1x100 벡터 
* ∑ <- random한 100x100 행렬
------
1. mean perturbation
* $w_{\mu}$ ~ MVN(μ+Δμ, ∑)


------
2. cov perturbation은
* $w_{\sum}$ ~ MVN(μ, ∑+Δ∑)


------
3. mean&cov perturbation은
* $w_{\mu\sum}$ ~ MVN(μ+Δμ, ∑+Δ∑)

# 0. Import

In [1]:
import torch
import random
import numpy as np
# from scipy.special import rel_entr
# from scipy.stats import multivariate_normal
from torch.distributions import multivariate_normal, normal, kl_divergence
import torch.nn.functional as F
from sklearn.datasets import make_spd_matrix

torch.manual_seed(0)
random.seed(0)

# 1. Univariate Gaussian에서 증명 (완료)

In [None]:
# mean / delta_mean 생성
mean = torch.randn(1, requires_grad=True)
delta_mean = torch.Tensor(np.random.randint(low=1, high=10, size=1)) / 1000 ; delta_mean = delta_mean.clone().detach().requires_grad_(True)

# std / delta_std 생성
std = torch.Tensor(np.random.randint(low=1, high=5, size=1)) ; std = std.clone().detach().requires_grad_(True)
delta_std = torch.Tensor(np.random.randint(low=1, high=10, size=1)) / 1000 ; delta_std = delta_std.clone().detach().requires_grad_(True)

In [None]:
print(f"mean : {mean} \n delta_mean : {delta_mean} \n std : {std} \n delta_std : {delta_std}")

mean : tensor([1.5410], requires_grad=True) 
 delta_mean : tensor([0.0060], requires_grad=True) 
 std : tensor([2.], requires_grad=True) 
 delta_std : tensor([0.0070], requires_grad=True)


## 1.1 Mean Perturbation만

In [None]:
# W, W_mean 분포 생성
W = normal.Normal(mean, std)
W_sample = W.sample((1000,)) ; W_sample = W_sample.clone().detach().requires_grad_(True)

W_pert = normal.Normal(mean+delta_mean, std)
W_pert_sample = W_pert.sample((1000,)) ; W_pert_sample = W_pert_sample.clone().detach().requires_grad_(True)

* KLD between two univariate gaussian : 

$\log \frac{{\sigma}_{2}}{{\sigma}_{1}} + \frac{ {\sigma}_{1}^2 + (\mu_{1} - \mu_{2})^{2}}{ 2\sigma_{2}^{2}} - \frac{1}{2}$




In [None]:
# (내가 맹든) KLD 
torch.log(std/std) + (std**2 + (delta_mean)**2)/(2 * (std**2)) - 1/2

tensor([4.4703e-06], grad_fn=<SubBackward0>)

* Emprical Fisher : 
$ F(\mu) = \frac{1}{N}∑^N_{i=1} ∇_\mu log p(w_i | \mu, L) ∇_\mu log p(w_i | \mu, L)^T$

In [None]:
# empirical Fisher 계산
score = 0
for i in range(W_sample.shape[0]):
  tmp = (W.log_prob(W_sample))[i]
  tmp.backward()
  score = score + mean.grad **2
  mean.grad.data.zero_()

empirical_fisher = score / W_sample.shape[0]

In [None]:
empirical_fisher

tensor([0.2608])

$ D_{KL} [p(w|\mu+\Delta\mu, L) || p(w|\mu, L)] ≈ \Delta\mu^T F(\mu) \Delta\mu$       with small $\Delta\mu$

In [None]:
delta_mean * empirical_fisher * delta_mean

tensor([9.3877e-06], grad_fn=<MulBackward0>)

## 1.2 Variance Perturbation만

In [None]:
W_pert = normal.Normal(mean, std + delta_std)
W_pert_sample = W_pert.sample((1000,)) ; W_pert_sample = W_pert_sample.clone().detach().requires_grad_(True)

* KLD between two univariate gaussian : 

$\log \frac{{\sigma}_{2}}{{\sigma}_{1}} + \frac{ {\sigma}_{1}^2 + (\mu_{1} - \mu_{2})^{2}}{ 2\sigma_{2}^{2}} - \frac{1}{2}$

In [None]:
torch.log(std/(std+delta_std)) + ((std+delta_std)**2)/(2 * (std**2)) - 1/2

tensor([1.2279e-05], grad_fn=<SubBackward0>)

* Emprical Fisher : 
$ F(\sigma) = \frac{1}{N}∑^N_{i=1} ∇_\sigma log p(w_i | \mu, L) ∇_\sigma log p(w_i | \mu, L)^T$

In [None]:
# empirical Fisher 계산
score = 0
for i in range(W_sample.shape[0]):
  tmp = (W.log_prob(W_sample))[i]
  tmp.backward()
  score = score + std.grad **2
  std.grad.data.zero_()

empirical_fisher = score / W_sample.shape[0]

In [None]:
empirical_fisher

tensor([1.0126])

$ D_{KL} [p(w|\mu+\Delta\mu, L) || p(w|\mu, L)] ≈ \Delta\sigma^T F(\sigma) \Delta\sigma$       with small $\Delta\sigma$

In [None]:
delta_std * empirical_fisher * delta_std

tensor([4.9616e-05], grad_fn=<MulBackward0>)

## 1.3 Mean and Covariance Perturbation

In [None]:
W_pert = normal.Normal(mean + delta_mean, std + delta_std)
W_pert_sample = W_pert.sample((1000,)) ; W_pert_sample = W_pert_sample.clone().detach().requires_grad_(True)

* KLD between two univariate gaussian : 

$\log \frac{{\sigma}_{2}}{{\sigma}_{1}} + \frac{ {\sigma}_{1}^2 + (\mu_{1} - \mu_{2})^{2}}{ 2\sigma_{2}^{2}} - \frac{1}{2}$

In [None]:
torch.log(std/(std+delta_std)) + ((std+delta_std)**2 + delta_mean**2)/(2 * (std**2)) - 1/2

tensor([1.6749e-05], grad_fn=<SubBackward0>)

* $\theta = [\mu, \sigma]$

* Emprical Fisher : 
$ F(\theta) = \frac{1}{N}∑^N_{i=1} ∇_\theta log p(w_i | \mu, \sigma) ∇_\theta log p(w_i | \mu, \sigma)^T$

https://www.statlect.com/glossary/information-matrix

In [None]:
mean_mean_score = 0; mean_std_score = 0; std_std_score = 0
for i in range(W_sample.shape[0]):
  tmp = (W.log_prob(W_sample))[i]
  tmp.backward()
  mean_mean_score = mean_mean_score + mean.grad **2
  mean_std_score = mean_std_score + mean.grad * std.grad
  std_std_score = std_std_score + std.grad**2
  mean.grad.data.zero_()
  std.grad.data.zero_()

# | mean_mean_score   mean_std_score |
# | mean_std_score    std_std_score  |
# 행렬 생성
fisher = 1/W_sample.shape[0] * torch.tensor([[mean_mean_score, mean_std_score],
              [mean_std_score, std_std_score]], requires_grad=True)

In [None]:
delta_theta = torch.tensor([delta_mean,
              delta_std])

$ D_{KL} [p(w|\mu+\Delta\mu, \sigma) || p(w|\mu, \sigma)] ≈ \Delta\theta^T F(\theta) \Delta\theta$       with small $\Delta\theta$

In [None]:
tmp = torch.mm(delta_theta.reshape(1,-1), fisher)
torch.mm(tmp, delta_theta.reshape(-1,1))

tensor([[3.9940e-05]], grad_fn=<MmBackward0>)

# 2. MVN에서의 증명 (Univariate Gaussian에서 되면 넘어가는 걸로..)

## 2.1. 필요한 변수 생성

*   μ
*   ∑
*   Δμ
*   Δ∑


In [100]:
# Mean
mean = torch.randn(10)
mean = mean.clone().detach().requires_grad_(True) # tensor 만들면서 requires_grad=True로 주는 옵션이 있다면 그걸로 바꾸자

# Delta_Mean
delta_mean = torch.randn(10) / 100
delta_mean = delta_mean.clone().detach().requires_grad_(True)

In [101]:
# Covariance
# PSD 성질 유지하도록 
cov = torch.Tensor(make_spd_matrix(10))
cov = cov.clone().detach().requires_grad_(True)
L = torch.linalg.cholesky(cov)
L = L.clone().detach().requires_grad_(True)

# Delta_Cov
delta_cov = torch.Tensor(make_spd_matrix(10) / 1000)
delta_cov = delta_cov.clone().detach().requires_grad_(True)
delta_L = torch.linalg.cholesky(delta_cov)
delta_L = delta_L.clone().detach().requires_grad_(True)

In [102]:
W = multivariate_normal.MultivariateNormal(loc=mean, scale_tril=L)
# W = multivariate_normal.MultivariateNormal(mean, cov)
n_sample = 10000

W_sample = W.sample((n_sample,))
W_sample = W_sample.clone().detach().requires_grad_(True)

## 2.2. mean perturbation 실험

In [None]:
W_pert = multivariate_normal.MultivariateNormal(loc=(mean + delta_mean), scale_tril=L)
W_pert_sample = W_pert.sample((n_sample,))
W_pert_sample = W_pert_sample.clone().detach().requires_grad_(True)

In [None]:
# KLD 계산
kl_divergence(W_pert, W)

tensor(0.0002, grad_fn=<AddBackward0>)

* Empirical Fisher : $ F(μ) = \frac{1}{N}\sum_{i=1}^{N}[\Delta_{\mu} \log P(w_i | μ, L) \Delta_{\mu} log P(w_i | μ, L)] $




In [None]:
# empirical Fisher 계산
tmp = W.log_prob(W_sample).sum()
tmp.backward()

score = torch.mm(mean.grad.reshape(-1,1), mean.grad.reshape(1,-1)) # 100 x 100

empirical_fisher = score / W_sample.shape[0] # 100 x 100

* $ D_{KL} [p(w|\mu+\Delta\mu, L) || p(w|\mu, L)] ≈ \Delta\mu^T F(\mu) \Delta\mu$       with small $\Delta\mu$

In [None]:
tmp = torch.mm(delta_mean.reshape(1,-1), empirical_fisher)
torch.mm(tmp, delta_mean.reshape(-1,1))

tensor([[0.0002]], grad_fn=<MmBackward0>)

## 2.3. Covariance Perturbation 실험

In [None]:
W_pert = multivariate_normal.MultivariateNormal(loc=mean, scale_tril=(L+delta_L))
# W_pert = multivariate_normal.MultivariateNormal(mean, cov+delta_cov)
W_pert_sample = W_pert.sample((n_sample,))
W_pert_sample = W_pert_sample.clone().detach().requires_grad_(True)

In [None]:
# KLD 계산
kl_divergence(W_pert, W)

tensor(15.8182, grad_fn=<AddBackward0>)

* Empirical Fisher : $ F(L) = \frac{1}{N}\sum_{i=1}^{N}[\nabla_{L} \log P(w_i | μ, L) \nabla_{L} log P(w_i | μ, L)] $




In [None]:
tmp = W.log_prob(W_sample).sum()
tmp.backward()

score = torch.mm(L.grad.reshape(-1,1), L.grad.reshape(1,-1)) # 10000 x 10000
# score = torch.mm(cov.grad.reshape(-1,1), cov.grad.reshape(1,-1)) # 10000 x 10000
empirical_fisher = score / W_sample.shape[0] # 10000 x 10000

* $ D_{KL} [p(w|\mu+\Delta\mu, L) || p(w|\mu, L)] ≈ \Delta L^T F(L) \Delta L$       with small $\Delta L$

In [None]:
tmp = torch.mm(torch.flatten(delta_L).reshape(1,-1), empirical_fisher)
# tmp = torch.mm(torch.flatten(delta_cov).reshape(1,-1), empirical_fisher)
torch.mm(tmp, torch.flatten(delta_L).reshape(-1,1))
# torch.mm(tmp, torch.flatten(delta_cov).reshape(-1,1))

* ΔL은 Δμ에 비해서 많이 작게 해야될 것 같다!!!

* Δ∑으로 할 때보다 ΔL로 할 때 더 작은 perturbation(ΔL)을 줘야한다. (근데 왜????)

## 2.4. Mean and Covariance Perturbation 실험

In [103]:
W_pert = multivariate_normal.MultivariateNormal(loc=(mean+delta_mean), scale_tril=(L+delta_L))
# W_pert = multivariate_normal.MultivariateNormal(mean+delta_mean, cov+delta_cov)
W_pert_sample = W_pert.sample((n_sample,))
W_pert_sample = W_pert_sample.clone().detach().requires_grad_(True)

In [104]:
# KLD 계산
kl_divergence(W_pert, W)

tensor(0.0413, grad_fn=<AddBackward0>)

* Empirical Fisher : $ F(w) = \frac{1}{N}\sum_{i=1}^{N}[\nabla_{θ} \log P(w_i | μ, L) \nabla_{θ} log P(w_i | μ, L)] $
* $θ = [μ \ \  L]^T $ (계산 시에는 flatten해서 이용)




In [105]:
tmp = W.log_prob(W_sample).sum()
tmp.backward()

nabla_theta = torch.cat([torch.flatten(mean.grad), torch.flatten(L.grad)])
# nabla_theta = torch.cat([torch.flatten(mean.grad), torch.flatten(cov.grad)])

score = torch.mm(nabla_theta.reshape(-1,1), nabla_theta.reshape(1,-1)) # 10100 x 10100
empirical_fisher = score / W_sample.shape[0] # 10100 x 10100

In [106]:
delta_theta = torch.cat([torch.flatten(delta_mean), torch.flatten(delta_L)])
# delta_theta = torch.cat([torch.flatten(delta_mean), torch.flatten(delta_cov)])

tmp = torch.mm(delta_theta.reshape(1,-1), empirical_fisher)
torch.mm(tmp, delta_theta.reshape(-1,1))

tensor([[0.0459]], grad_fn=<MmBackward0>)

# 3. FisherSAM 수식 검증 (진행 중)

* $\mathbb{E}_x[KL(p(y|x,\theta+\epsilon) \ || \ p(y|x,\theta))] = \epsilon^T F
(\theta) \epsilon$
* $\theta$ : model paramter
* $\epsilon$ : perturbation

In [None]:
# output의 KLD를 측정한다....
# Fisher는 또 model parameter로 계산한다...
 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transfroms
 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
if device == 'cuda':
    torch.cuda.manual_seed_all(0)
print(device + " is available")
 
learning_rate = 0.001
batch_size = 100
num_classes = 10
epochs = 5
 
# MNIST 데이터셋 로드
train_set = torchvision.datasets.MNIST(
    root = './data/MNIST',
    train = True,
    download = True,
    transform = transfroms.Compose([
        transfroms.ToTensor() # 데이터를 0에서 255까지 있는 값을 0에서 1사이 값으로 변환
    ])
)
test_set = torchvision.datasets.MNIST(
    root = './data/MNIST',
    train = False,
    download = True,
    transform = transfroms.Compose([
        transfroms.ToTensor() # 데이터를 0에서 255까지 있는 값을 0에서 1사이 값으로 변환
    ])
)
 
# train_loader, test_loader 생성
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)
 
# input size를 알기 위해서
examples = enumerate(train_set)
batch_idx, (example_data, example_targets) = next(examples)
example_data.shape
 
class ConvNet(nn.Module):
  def __init__(self): # layer 정의
        super(ConvNet, self).__init__()

        # input size = 28x28 
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5) # input channel = 1, filter = 10, kernel size = 5, zero padding = 0, stribe = 1
        # ((W-K+2P)/S)+1 공식으로 인해 ((28-5+0)/1)+1=24 -> 24x24로 변환
        # maxpooling하면 12x12
  
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5) # input channel = 1, filter = 10, kernel size = 5, zero padding = 0, stribe = 1
        # ((12-5+0)/1)+1=8 -> 8x8로 변환
        # maxpooling하면 4x4

        self.drop2D = nn.Dropout2d(p=0.25, inplace=False) # 랜덤하게 뉴런을 종료해서 학습을 방해해 학습이 학습용 데이터에 치우치는 현상을 막기 위해 사용
        self.mp = nn.MaxPool2d(2)  # 오버피팅을 방지하고, 연산에 들어가는 자원을 줄이기 위해 maxpolling
        self.fc1 = nn.Linear(320,100) # 4x4x20 vector로 flat한 것을 100개의 출력으로 변경
        self.fc2 = nn.Linear(100,10) # 100개의 출력을 10개의 출력으로 변경

  def forward(self, x):
        x = F.relu(self.mp(self.conv1(x))) # convolution layer 1번에 relu를 씌우고 maxpool, 결과값은 12x12x10
        x = F.relu(self.mp(self.conv2(x))) # convolution layer 2번에 relu를 씌우고 maxpool, 결과값은 4x4x20
        x = self.drop2D(x)
        x = x.view(x.size(0), -1) # flat
        x = self.fc1(x) # fc1 레이어에 삽입
        x = self.fc2(x) # fc2 레이어에 삽입
        return F.log_softmax(x) # fully-connected layer에 넣고 logsoftmax 적용
 
model = ConvNet().to(device) # CNN instance 생성
# Cost Function과 Optimizer 선택
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
 

for epoch in range(epochs): # epochs수만큼 반복
    avg_cost = 0

    for data, target in train_loader:
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad() # 모든 model의 gradient 값을 0으로 설정
        pred = model(data) # 모델을 forward pass해 결과값 저장 
        cost = criterion(pred, target) # output과 target의 loss 계산
        cost.backward() # backward 함수를 호출해 gradient 계산
        optimizer.step() # 모델의 학습 파라미터 갱신
        avg_cost += cost / len(train_loader) # loss 값을 변수에 누적하고 train_loader의 개수로 나눔 = 평균
    print('[Epoch: {:>4}] cost = {:>.9}'.format(epoch + 1, avg_cost))

cuda is available
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw





[Epoch:    1] cost = 0.336576045
[Epoch:    2] cost = 0.108807892
[Epoch:    3] cost = 0.0845165476
[Epoch:    4] cost = 0.0707272887
[Epoch:    5] cost = 0.063354589


In [3]:
pred.shape

torch.Size([100, 10])

In [None]:
# output 저장
output = pred

#perturbation 지정
