<a href="https://colab.research.google.com/github/blufzzz/Introspective-Neural-Networks/blob/master/Introspective_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torchvision
from torch import nn
import torchvision
from torchvision.models import resnet18


In [3]:
model = resnet18(pretrained = True)
X = torch.randn(1,3,256,256)
model.forward(X).shape

torch.Size([1, 1000])

# INN Model

In [0]:
class Synthesis():
    def __init__(self, init_std=0.3, noise_grad_std=None):
        self.init_std = init_std;
        self.noise_grad_std = noise_grad_std;
        
    def sample(self, module, num_iter=10, learning_rate=1e-2):
        assert isinstance(module.X, torch.nn.Parameter), 'Expected X to be an instance of torch.nn.Parameter';
        
        # we do not want to create a graph and do backprop on net parameters, since we need only gradient of X
        for name, param in module.named_parameters():
            if name != 'X':
                param.requires_grad = False;
            else:
                param.requires_grad = True;
        
        module.X.data = module.X.data.normal_(mean=0, std=self.init_std);
        opt = torch.optim.SGD([module.X], lr=learning_rate);
        
        for i in range(num_iter):
            opt.zero_grad();
            classes = -module.to_synth();
            for j in range(classes.shape[0]):
                classes[j, ...].backward(retain_graph=True);
                
            if self.noise_grad_std is not None:
                module.X.data += torch.empty_like(module.X.data).normal_(mean=0, std=self.noise_grad_std);
            
            opt.step()
                
        return module.X.data;

In [0]:
class ICNN(nn.Module):
  
  def __init__(self, cnn_model, size, std=0.3, n_classes = False):
      super(ICNN, self).__init__()
      
      self.conv_features = [] # store conv_features for all passes
      self.model = cnn_model
      self.n_classes = n_classes
      self.sigmoid = nn.Sigmoid()
      self.model.fc = nn.Linear(in_features=512, out_features=1, bias=True)
      self.X = torch.nn.Parameter(torch.empty(size).normal_(mean=0, std=std));
      self.X.requires_grad = False;
      
  def forward(self, input, sigmoid = False):
      
      hook = lambda module, input, output: self.conv_features.append(output)
      self.model.avgpool.register_forward_hook(hook)
      output = self.model.forward(input)
      
      return self.sigmoid(output)
    
  def to_synth(self):
      
      hook = lambda module, input, output: self.conv_features.append(output)
      self.model.avgpool.register_forward_hook(hook)
      
      return self.model.forward(self.X)

In [0]:
X = torch.randn(2,3,256,256)
icnn_model = ICNN(model, X.size())
s = Synthesis(init_std=0.3);
result = s.sample(icnn_model, num_iter=5, learning_rate=1e-1)

In [19]:
result.shape

torch.Size([2, 3, 256, 256])

# Wasserstein loss

In [0]:
batch_size = 16
n=100
X = torch.randn(100, 3, 256, 256)
y = torch.LongTensor(n).random_(2)*2-1
X_1 = X[y==1]
X_2 = X[y==-1]

In [0]:
batch_1 = X_1[torch.LongTensor(batch_size).random_(len(X_1))]
batch_2 = X_2[torch.LongTensor(batch_size).random_(len(X_2))]
alphas = torch.FloatTensor(batch_size).uniform_(0,1)
batch_hat = torch.zeros_like(batch_1,requires_grad=True)
for i in range(batch_size):
  batch_hat[i]=batch_1[i]*alphas[i]+batch_2[i]*(1-alphas[i])

In [0]:
icnn_model=ICNN(model)
f_1 = icnn_model.forward(batch_1, sigmoid=False)
f_2 = icnn_model.forward(batch_2, sigmoid=False)
f_hat = icnn_model.forward(batch_hat, sigmoid=False)

In [0]:
grads_hat = torch.zeros_like(batch_hat)
for i in range(batch_size):
  grads_hat[i]=torch.autograd.grad(f_hat[i],batch_hat,retain_graph=True)[0][i]

In [0]:
loss = 0
for i in range(batch_size):
  loss+=f_2[i]-f_1[i]+torch.norm(grads_hat[i],p=2)

loss/=batch_size

In [0]:
loss.backward()