In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tqdm import tqdm,trange
from torch.optim import Adam
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='mps')

In [3]:
def MNIST_loaders(train_batch_size=50000, test_batch_size=10000):

    transform = Compose([
        ToTensor(),
        Normalize((0.1307,), (0.3081,)),
        Lambda(lambda x: torch.flatten(x))])

    train_loader = DataLoader(
        MNIST('./data/', train=True,
              download=True,
              transform=transform),
        batch_size=train_batch_size, shuffle=True)

    test_loader = DataLoader(
        MNIST('./data/', train=False,
              download=True,
              transform=transform),
        batch_size=test_batch_size, shuffle=False)

    return train_loader, test_loader


def overlay_y_on_x(x, y):
    """Replace the first 10 pixels of data [x] with one-hot-encoded label [y]
    """
    x_ = x.clone()
    x_[:, :10] *= 0.0
    x_[range(x.shape[0]), y] = x.max()
    return x_

In [None]:
torch.manual_seed(1234)
train_loader, test_loader = MNIST_loaders()
x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)
x_pos = overlay_y_on_x(x, y)
rnd = torch.randperm(x.size(0))
x_neg = overlay_y_on_x(x, y[rnd])


In [20]:
class Net(torch.nn.Module):

    def __init__(self, dims):
        super().__init__()
        self.layers = []
        for d in range(len(dims) - 1):
            self.layers += [Layer(dims[d], dims[d + 1]).to(device)]

    def predict_goodness(self,x):
        with torch.no_grad():
            g=0
            for layer in self.layers:
                x = layer(x)
                g+=x.pow(2).mean()
            g/=len(self.layers)
            return g
            
        
            
    def predict(self, x):
        with torch.no_grad():
            goodness_per_label = []
            for label in range(10):
                h = overlay_y_on_x(x, label)
                goodness = []
                for layer in self.layers:
                    h = layer(h)
                    goodness += [h.pow(2).mean(1)]
                goodness_per_label += [sum(goodness).unsqueeze(1)]
            goodness_per_label = torch.cat(goodness_per_label, 1)
            return goodness_per_label.argmax(1)

    def train(self, x,pos=True,verbose=False):
        for i, layer in enumerate(self.layers):
            x = layer.train(x, pos, verbose)


class Layer(nn.Linear):
    def __init__(self, in_features, out_features,
                 bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.relu = torch.nn.ReLU()
#         self.opt = torch.optim.SGD(self.parameters(), lr=0.05)
        self.threshold = 5
        self.num_epochs = 10
        self.goodness = None
        self.direction = None

    def forward(self, x):
        x_dir = x/x.norm(2,1,keepdim=True)
        return self.relu(super().forward(x_dir))

    def train(self, x, pos=True,verbose=False):
        for i in range(self.num_epochs):
            self.required_goodness=1e4 if pos==True else 1e-4
            
            g = self.forward(x).pow(2).mean(1)
            self.goodness = g.item()

            lr = (self.required_goodness/self.goodness)**0.5 - 1
            if lr>0:
                self.opt = torch.optim.SGD(self.parameters(),lr=lr,maximize=True)
            else:
                self.opt = torch.optim.SGD(self.parameters(),lr=-lr,maximize=False)
            
            
            
#             if pos==True:
#                 loss = torch.log(1 + torch.exp(-g + self.threshold)).mean()
            loss = torch.log(torch.sigmoid(g - self.threshold)).mean()
                

            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
            
            with torch.no_grad():
                self.goodness = self.forward(x).pow(2).mean(1).item()
                dir = self.forward(x)/self.forward(x).norm(2,1,keepdim=True).item()
                self.direction = (dir.mean().item(),dir.std().item())
            if verbose:
                print(f"Goodness : {self.goodness} , Direction : {self.direction}")
        
        return self.forward(x).detach()                   

    
# x_te, y_te = next(iter(test_loader))
# x_te, y_te = x_te.cuda(), y_te.cuda()
# preds = net.predict(x_te)
# print("Test Accuracy :",(preds==y_te).float().mean().item())

### #ISSUE 1 : Stagnate on single value

In [21]:
net = Net([784, 500, 500])
for i in range(4,10):
    net.train(x_pos[i:i+1],pos=True)
    net.train(x_neg[i:i+1],pos=False)
    pred = net.predict(x[i:i+1])
    print("Prediction:", pred.item(), "Actual Value", y[i].item())


# net.train(x_neg,pos=False)
# preds = net.predict(x)
# print("Train Accuracy :",(preds==y[:1]).float().mean().item())


Prediction: 3 Actual Value 3
Prediction: 3 Actual Value 2
Prediction: 3 Actual Value 8
Prediction: 3 Actual Value 7
Prediction: 3 Actual Value 0
Prediction: 3 Actual Value 9


### Issue 2: Does not recognize negative sample (after positive sample is passed)

In [25]:
#Uncomment positive sample training and show results
net = Net([784, 500, 500])
net.train(x_pos[0:1],pos=True,verbose=True)
print()
net.train(x_neg[0:1],pos=False,verbose=True)

Goodness : 0.6635534167289734 , Direction : (0.02447320893406868, 0.03746824711561203)
Goodness : 2.553041934967041 , Direction : (0.02447321265935898, 0.03746825084090233)
Goodness : 5.3932204246521 , Direction : (0.02447321079671383, 0.03746824711561203)
Goodness : 6.95481538772583 , Direction : (0.02447321265935898, 0.03746825084090233)
Goodness : 7.473690509796143 , Direction : (0.02447320893406868, 0.03746824711561203)
Goodness : 7.8080220222473145 , Direction : (0.02447320893406868, 0.03746824711561203)
Goodness : 8.057228088378906 , Direction : (0.02447321265935898, 0.03746824711561203)
Goodness : 8.256607055664062 , Direction : (0.02447320893406868, 0.03746824711561203)
Goodness : 8.42305850982666 , Direction : (0.02447321265935898, 0.03746825084090233)
Goodness : 8.566062927246094 , Direction : (0.02447321265935898, 0.03746824711561203)
Goodness : 0.6757698655128479 , Direction : (0.02679372765123844, 0.035842228680849075)
Goodness : 2.5763633251190186 , Direction : (0.0267937

# Advantage : Fast Adaptation

In [26]:
#Positive Sample
net = Net([784, 500, 500])
net.train(x_pos[0:1],pos=True,verbose=True)

Goodness : 0.6646897196769714 , Direction : (0.02463086135685444, 0.037364594638347626)
Goodness : 2.555216073989868 , Direction : (0.02463085949420929, 0.037364594638347626)
Goodness : 5.395772933959961 , Direction : (0.02463086135685444, 0.037364598363637924)
Goodness : 6.955174922943115 , Direction : (0.02463085949420929, 0.037364594638347626)
Goodness : 7.473896026611328 , Direction : (0.02463085949420929, 0.037364594638347626)
Goodness : 7.808168888092041 , Direction : (0.02463086135685444, 0.037364594638347626)
Goodness : 8.057342529296875 , Direction : (0.02463086135685444, 0.037364594638347626)
Goodness : 8.25670051574707 , Direction : (0.02463086135685444, 0.037364598363637924)
Goodness : 8.423139572143555 , Direction : (0.02463085949420929, 0.037364594638347626)
Goodness : 8.566130638122559 , Direction : (0.02463086135685444, 0.037364598363637924)
Goodness : 0.6734951734542847 , Direction : (0.026480956003069878, 0.03607439249753952)
Goodness : 2.5720293521881104 , Direction 

In [27]:
#Negative Sample
net = Net([784, 500, 500])
net.train(x_neg[0:1],pos=False,verbose=True)

Goodness : 0.00043666508281603456 , Direction : (0.025418071076273918, 0.036832522600889206)
Goodness : 0.00043305379222147167 , Direction : (0.02541806921362877, 0.03683251887559891)
Goodness : 0.00042948604095727205 , Direction : (0.025418074801564217, 0.036832526326179504)
Goodness : 0.0004259611596353352 , Direction : (0.025418072938919067, 0.036832526326179504)
Goodness : 0.00042247865349054337 , Direction : (0.025418072938919067, 0.036832526326179504)
Goodness : 0.00041903796955011785 , Direction : (0.025418072938919067, 0.036832522600889206)
Goodness : 0.0004156385257374495 , Direction : (0.025418071076273918, 0.036832522600889206)
Goodness : 0.0004122797981835902 , Direction : (0.025418074801564217, 0.036832522600889206)
Goodness : 0.00040896114660426974 , Direction : (0.025418074801564217, 0.036832526326179504)
Goodness : 0.0004056821344420314 , Direction : (0.02541806921362877, 0.036832522600889206)
Goodness : 0.0006938425358384848 , Direction : (0.02524782344698906, 0.036949