<a href="https://colab.research.google.com/github/blufzzz/Introspective-Neural-Networks/blob/master/training_wasserstein.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import matplotlib.pyplot as plt
%matplotlib inline

from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch.utils.data import DataLoader
from torch.cuda import set_device
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.datasets import MNIST
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import inspect
import time
import random
from torch import nn, optim
import torch
from tqdm import tnrange, tqdm_notebook

In [0]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        #self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256 * block.expansion, num_classes)
        
        self.X = torch.nn.Parameter(torch.empty((10, 1, 28, 28)).normal_(mean=0, std=0.3));
        self.X.requires_grad = False;
        
        self.wass_fc = torch.nn.Linear(10, 1, bias=True)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        #x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x
    
    def to_synth(self):
        return self.forward(self.X);


def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs);
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model



MnistResNet = resnet18();
MnistResNet.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(2, 2), padding=(3, 3), bias=False);
MnistResNet.fc = torch.nn.Linear(256, 10, bias=True);


In [0]:
class MnistDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir='data/', training=True, train_samples='all'):
        self.real_imgs = [];
        self.fake_imgs = [];
        self.real_labels = [];
        self.fake_labels = [];
        self.sample_real = True;
        self.sampled_fakes = 0;
        
        self.fake_samples={}
        for i in range(10):
            self.fake_samples[-i] = [];
        #self.fake = [];
        
        #mnist = MNIST(download=False, train=True, root=".").train_data.float()
        
        if training:
            x, y = torch.load(root_dir + 'processed/training.pt');
        else:
            x, y = torch.load(root_dir + 'processed/test.pt');
        if train_samples == 'all':
            for i in range(y.shape[0]):
                self.real_imgs.append(x[i, ...].float()/255);
                self.real_labels.append(y[i].long());
        else:
            for i in range(train_samples):
                self.real_imgs.append(x[i, ...].float()/255);
                self.real_labels.append(y[i].long());
    
    def __len__(self):
        return 2*len(self.real_labels);

    def __getitem__(self, idx):
        if self.sample_real or len(self.fake_samples[0]) == 0:
            x = self.real_imgs[idx%len(self.real_labels)];
            y = self.real_labels[idx%len(self.real_labels)];
            fake = torch.ByteTensor([0])[0];
            self.sample_real = False;
            self.y_real=y
        else:
            y = -self.y_real;
            x = random.choice(self.fake_samples[int(y)]);
            fake = torch.ByteTensor([1])[0];
            self.sample_real = True;
        if len(self.fake_samples[0]) == 0:
            self.sample_real = True;
        return (x.unsqueeze(0)-0.5)/(0.5)*0.6, y.float().unsqueeze(0), fake;
    
    def add_artificial(self, X):
        for i in range(X.shape[0]):
            self.fake_samples[-i].append(X[i, 0, ...].detach().cpu());
            
        pass;


def get_data_loaders(train_batch_size, val_batch_size, train_size):
    mnist = MNIST(download=False, train=True, root=".").train_data.float()

    train_loader = DataLoader(MnistDataset(root_dir='', train_samples=train_size),
                              batch_size=train_batch_size, shuffle=False)

    val_loader = DataLoader(MnistDataset(root_dir='', training=False),
                            batch_size=val_batch_size, shuffle=False)
    return train_loader, val_loader


    
def calculate_metric(metric_fn, true_y, pred_y):
    if "average" in inspect.getfullargspec(metric_fn).args:
        return metric_fn(true_y, pred_y, average="macro")
    else:
        return metric_fn(true_y, pred_y)
    
    
    
def print_scores(p, r, f1, a, batch_size):
    for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
        print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")   

# Wasserstein loss

In [0]:
def wasserstein_loss(model, outputs, X, fakes, scale = 10):
        X_wass = model.wass_fc(outputs);
        X_1 = X_wass[1-fakes]
        X_2 = X_wass[fakes]
        
        d_real_loss = torch.mean(X_1)
        d_fake_loss = torch.mean(X_2)
        d_neg_wass_dist = d_fake_loss - d_real_loss;
        eps = X.new_empty(X.shape[0]//2, 1, 1, 1).uniform_();
        d_inter_images = eps*X[1-fakes] + (1-eps)*X[fakes];
        d_inter_images.requires_grad = True;
        
        d_inter_logits = model(d_inter_images);
        d_inter_wass = model.wass_fc(d_inter_logits);
        
        
        d_inter_grad = torch.autograd.grad(d_inter_wass, d_inter_images, grad_outputs=torch.ones_like(d_inter_wass), create_graph=True, only_inputs=True)[0]
        
        d_inter_grad_norm = torch.norm(d_inter_grad.view(d_inter_grad.shape[0],-1), dim=-1);
        d_inter_grad_penalty = torch.mean((d_inter_grad_norm - 1)**2);
        
        d_wass_loss = d_neg_wass_dist + scale*d_inter_grad_penalty;
        
        return d_wass_loss;

In [0]:
class Synthesis():
    def __init__(self, init_std=0.3):
        self.init_std = init_std;
        
    def sample(self, module, num_iter=10, learning_rate=0.01, add_noise=True):
        assert isinstance(module.X, torch.nn.Parameter), 'Expected X to be an instance of torch.nn.Parameter';
        
        module.train(False);
        
        # we do not want to create a graph and do backprop on net parameters, since we need only gradient of X
        for name, param in module.named_parameters():
            if name != 'X':
                param.requires_grad = False;
            else:
                param.requires_grad = True;
        
        module.X.data = module.X.data.normal_(mean=0, std=self.init_std);
        
        #module.X.data[0, ...] = x[1, ...];
        opt = torch.optim.Adam([module.X], lr=learning_rate, amsgrad=True, betas=(0.9, 0.9));
        #opt = torch.optim.ASGD(module.parameters(), lr=learning_rate);
        std_noise = learning_rate;
        loss_function = nn.CrossEntropyLoss()
        for i in range(num_iter):
            opt.zero_grad();
            classes = -torch.sum(torch.diag(module.to_synth()));
            classes.backward();
                
            if add_noise:
                module.X.grad += torch.empty_like(module.X.data).normal_(mean=0, std=2*opt.param_groups[0]['lr']);
                opt.param_groups[0]['lr'] *= 0.92;
            
            #module.X.grad = 0.1*torch.sign(module.X.grad);
            opt.step()
            
            for j in range(10):
                a = module.X.data[j, ...].min();
                b = module.X.data[j, ...].max();
                module.X.data[j, ...] -= a;
                module.X.data[j, ...] /= (b-a);
                module.X.data[j, ...] *= 1.2;
                module.X.data[j, ...] -= 0.6;
            
        module.train(True);
        
        for name, param in module.named_parameters():
            if name != 'X':
                param.requires_grad = True;
            else:
                param.requires_grad = False;
        
        return module.X.data;

In [0]:
def train(model, train_loader, val_loader, epochs=1, alpha=0.9, lrate=1e-3):

    start_ts = time.time()

    losses = []
    loss_function = nn.CrossEntropyLoss()
    nll_loss_function = nn.NLLLoss()
    #optimizer = optim.Adam(model.parameters(), amsgrad=True, lr=lrate, betas=(0.0, 0.9))
    optimizer = optim.Adam(model.parameters(), amsgrad=True, lr=lrate, betas=(0.0, 0.9));

    batches = len(train_loader)
    val_batches = len(val_loader)

    # training loop + eval loop
    for epoch in range(epochs):
        total_loss = 0
        progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)
        model.train()

        for i, data in progress:
            X, y, fakes = data[0].cuda(), data[1].cuda().squeeze(-1).long(), data[2].cuda()
            
            model.zero_grad()
            outputs = model(X)
            loss = loss_function(outputs[1-fakes], y[1-fakes]);
            
            if not torch.all(fakes == 0):
                loss += 0.01*wasserstein_loss(model,outputs, X, fakes)
                
                
            loss.backward()
            optimizer.step()
            progress.set_description("Loss: {:.4f}".format(loss.item()))

        torch.cuda.empty_cache()

        val_losses = 0
        precision, recall, f1, accuracy = [], [], [], []

    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X, y = data[0].cuda(), data[1].cuda().squeeze(-1)
            outputs = torch.sigmoid(model(X))
            val_losses += loss_function(outputs, y.long())

            predicted_classes = torch.max(outputs, 1)[1]

            for acc, metric in zip((precision, recall, f1, accuracy), 
                                   (precision_score, recall_score, f1_score, accuracy_score)):
                acc.append(
                    calculate_metric(metric, y.cpu(), predicted_classes.cpu())
                )


    print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}")
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches)
    pass


In [0]:
l = 200;
tmax = 200;
model = MnistResNet.cuda();
train_loader, val_loader = get_data_loaders(100, 256, 25000);
for t in range(tmax):
    print('t = ', t);
    #model = MnistResNet.cuda();
    model.train(True)
    train(model, train_loader, val_loader, epochs=1, alpha=0.5, lrate=0.001);
    
    model.train(False)
    s = Synthesis(init_std=0.3)
    for i in range(l):
        aug = s.sample(model, num_iter=100, learning_rate=0.1, add_noise=False);
        train_loader.dataset.add_artificial(aug);
    
    plt.figure();
    fig, ax = plt.subplots(nrows=1, ncols=10, figsize=(15,3));
    for i in range(10):
        ax[i].imshow(aug[i,0,...].cpu());
        ax[i].axis('off');
    plt.show();

In [0]:
fig,ax = plt.subplots(nrows=5, ncols=10, figsize=(17,10));
for j in range(5):
    s = Synthesis(init_std=0.3)
    aug = s.sample(model, num_iter=50, learning_rate=1e-1, add_noise=True);
    for i in range(aug.shape[0]):
            aug[i,...] -= aug[i, ...].min();
            aug[i, ...] /= aug[i, ...].max();
    for i in range(10):
        ax[j][i].imshow(aug[i,0,...].cpu());
        ax[j][i].axis('off');