In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim import Adam
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader
import numpy as np

DEVICE = 'cuda'
# DEVICE = 'cpu'

In [None]:
def MNIST_loaders(train_batch_size=20000, test_batch_size=10000):
    transform = Compose([
        ToTensor(),
        Normalize((0.1307,), (0.3081,)),
        Lambda(lambda x: torch.flatten(x))])

    train_loader = DataLoader(
        MNIST('./data/', train=True,
              download=True,
              transform=transform),
        batch_size=train_batch_size, shuffle=True)

    test_loader = DataLoader(
        MNIST('./data/', train=False,
              download=True,
              transform=transform),
        batch_size=test_batch_size, shuffle=False)

    return train_loader, test_loader

In [None]:
def overlay_y_on_x(x, y):
    x_ = x.clone()
    x_[:, :10] *= 0.0
    x_[range(x.shape[0]), y] = x.max()
    return x_

In [None]:
class Net(torch.nn.Module):
    def __init__(self, dims):
        super().__init__()
        self.layers = []
        for d in range(len(dims) - 1):
            self.layers += [Layer(dims[d], dims[d + 1]).to(DEVICE)]

    def predict(self, x):
        goodness_per_label = []
        for label in range(10):
            h = overlay_y_on_x(x, label)
            goodness = []
            
            for layer in self.layers:
                h = layer(h)
                goodness += [h.pow(2).mean(1)]
            goodness_per_label += [sum(goodness).unsqueeze(1)]
        goodness_per_label = torch.cat(goodness_per_label, 1)
        return goodness_per_label.argmax(1)

    def train(self, x_pos, x_neg):
        h_pos, h_neg = x_pos, x_neg
        for i, layer in enumerate(self.layers):
            h_pos, h_neg = layer.train(h_pos, h_neg)

In [None]:
class Layer(nn.Linear):
    def __init__(self, in_features, out_features,
                 bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.relu = torch.nn.ReLU()
        self.opt = Adam(self.parameters(), lr=0.03)
        self.threshold = 3.0
        self.num_epochs = 1000

    def forward(self, x):
        x_direction = x / (x.pow(2).sum(dim=1).sqrt().reshape((x.shape[0], 1)) + 1e-4)
        
        return self.relu(
            torch.mm(x_direction, self.weight.T) +
            self.bias.unsqueeze(0))

    def train(self, x_pos, x_neg):
        for i in range(self.num_epochs):
            g_pos = self.forward(x_pos).pow(2).mean(1)
            g_neg = self.forward(x_neg).pow(2).mean(1)
            
            pos_loss = -g_pos + self.threshold
            neg_loss = g_neg - self.threshold
            
            loss = torch.log(1 + torch.exp(torch.cat([
                                    pos_loss,
                                    neg_loss]))).mean()
            self.opt.zero_grad()

            loss.backward()
            self.opt.step()

        return self.forward(x_pos).detach(), self.forward(x_neg).detach()

# Training

In [None]:
def generate_negative_data(x, y):
    y_neg = y.clone()
    for idx, y_samp in enumerate(y):
        allowed_indices = [i for i in range(10)]
        allowed_indices.pop(y_samp.item())
        y_neg[idx] = torch.tensor(np.random.choice(allowed_indices)).to(DEVICE)
    
    return overlay_y_on_x(x, y_neg)

In [None]:
torch.manual_seed(1234)
train_loader, test_loader = MNIST_loaders()

net = Net([784, 512, 512])

for x, y in tqdm(train_loader):
    x, y = x.to(DEVICE), y.to(DEVICE)

    x_pos = overlay_y_on_x(x, y)
    x_neg = generate_negative_data(x, y)
    
    net.train(x_pos, x_neg)

In [None]:
# torch.save(net, 'net.pth')

In [None]:
# net = torch.load('net.pth')
# train_loader, test_loader = MNIST_loaders()
# torch.manual_seed(1234)

In [None]:
train_err = round((1.0 - net.predict(x).eq(y).float().mean().item()) * 100, 2)
print('train error:', train_err)

x_te, y_te = next(iter(test_loader))
x_te, y_te = x_te.to(DEVICE), y_te.to(DEVICE)

test_err = round((1.0 - net.predict(x_te).eq(y_te).float().mean().item()) * 100, 2)
print('test error:', test_err)

In [None]:
print(y_te[:1])
img = x_te[:1,:].reshape((28,28)).cpu()

In [None]:
torch.cuda.empty_cache()

In [None]:
net.layers[0].weight.shape

In [None]:
import matplotlib.pyplot as plt
plt.imshow(pos_img.reshape((28,28)).cpu())
plt.show()

In [None]:
pos_img = overlay_y_on_x(x_te[:1,:], 7)
neg_img = overlay_y_on_x(x_te[:1,:], 0)

In [None]:
def show_activations(img, title):
    activations = []

    activ = img

    for layer in net.layers:
        activ = layer(activ)
        activations.append(activ)

    fig = plt.figure(figsize=(6,4))
    fig.suptitle(title)
    columns = 2
    
    for i, image in enumerate(activations):
        plt.subplot(1, columns, i+1)
        plt.imshow(image.detach().cpu().reshape((32,-1)))
        
    return activations

In [None]:
pos_act = show_activations(pos_img, "Положительные данные")

In [None]:
neg_act = show_activations(neg_img, "Негативные данные")

In [None]:
neg_act[0].sum()

In [None]:
diffs = []

for i, act in enumerate(pos_act):
    diff = act - neg_act[i]
    diffs.append(diff)

In [None]:
plt.figure(figsize=(12,4))
columns = 3
for i, image in enumerate(diffs):
    plt.subplot(1, columns, i+1)
#     plt.imshow(image.detach().cpu().reshape((32,-1)))