Train MNIST
-------------------------

In [None]:
%matplotlib inline
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
import os
cuda = torch.cuda.is_available()
cuda =False
device = 'cuda' if cuda else 'cpu'

In [None]:
if not os.path.exists("./Representation_Learning"):
    !git clone https://github.com/Sibylse/Representation_Learning.git
%cd Representation_Learning

In [None]:
!git pull

In [None]:
from models import *
from layers import *
from loss import *
from utils import *
from optimization import Optimizer

Load the data

In [None]:
classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
c=10
classes = ('0', '1', '2', '3')
c=4
#classes = ('0', '1', '2', '3', '4')
#c=5

# Data
print('==> Preparing data..')
trans = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.5,), (1.0,))])

train_data = datasets.MNIST(root='./data', train=True,
                            download=True, transform=trans)

# Select only some classes for motivating picture
idx = train_data.targets < c
train_data.targets = train_data.targets[idx]
train_data.data = train_data.data[idx]
trainloader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)


testset = datasets.MNIST(root='./data', train=False,
                         download=True, transform=trans)
# Select only some classes for motivating picture
idx = testset.targets < c
testset.targets = testset.targets[idx]
testset.data = testset.data[idx]

test_size = int(0.65 * len(testset))
val_size = len(testset) - test_size
testset, valset = torch.utils.data.random_split(testset, [test_size, val_size])

testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
valloader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False, num_workers=2)


In [None]:
d=2 #d is the embedding dimension
name="MNISTd2"

# Train Deep Softmax Regression

In [None]:
# Model
print('==> Building model..')
classifier = nn.Linear(d, c,bias=True)
net = LeNet(embedding_dim=d, classifier=classifier)
net = net.to(device)

criterion = CE_Loss(c, device)

sgd = optim.SGD([{'params': net.parameters()},],
                lr=0.1, momentum=0.9, weight_decay=5e-4)
optimizer = Optimizer(sgd, trainloader, device)

In [None]:
epoch_offset =0
for lr,max_epochs in [(0.05,10),(0.01,10)]:
    optimizer.optimizer.param_groups[0]['lr'] = lr
    print("===== Optimize with step size ",lr)
    for epoch in range(epoch_offset, epoch_offset+max_epochs):
        print('\nEpoch: %d' % epoch)
        optimizer.train_epoch(net, criterion)
        (acc,conf) = optimizer.test_acc(net,criterion, testloader)
        if epoch%5==0:
            with torch.no_grad():
                plot_epoch(net, testloader, device, figsize=(10,5))
    epoch_offset+= max_epochs

print('Saving..')
state = {'net': net.state_dict(),'acc': acc}
if not os.path.isdir('checkpoint'):
    os.mkdir('checkpoint')
torch.save(state, './checkpoint/%s%s%s.t7'%(name,net.__class__.__name__,net.classifier.__class__.__name__))
                

In [None]:
inputs, targets = next(iter(trainloader))
inputs.shape,targets.shape

# Load and Inspect Models

In [None]:
!ls checkpoint

In [None]:
classifier = nn.Linear(d, c,bias=True)
net_sm = LeNet(embedding_dim=d, classifier=classifier)
net_sm = load_net('MNISTd2LeNetLinear.t7',net_sm).to(device)

In [None]:
criterion_sm = CE_Loss(net_sm.classifier, c).to(device)

In [None]:
(acc,conf) = optimizer.test_acc(net_sm,criterion_sm, testloader)

# Plot the test data

In [None]:
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
def scatter_pictures(inputs, outputs, samples=30):
    zoom = 0.7
        
    for j in range(min(inputs.shape[0],samples)):
        image = inputs[j,:,:,:].squeeze()
        im = OffsetImage(image, cmap="gray",zoom=zoom)
        ab = AnnotationBbox(im, (outputs[j,0], outputs[j,1]),
                            xycoords='data', frameon=False, alpha=0.5)
        ax.add_artist(ab)

In [None]:
inputs, targets = next(iter(testloader))
outputs = net_sm.embed(inputs).detach()
fig, ax = plt.subplots(figsize=(8,5))
plot_conf((lambda x: torch.softmax(net_sm.classifier(x),dim=1)),
          x_max =max(outputs[:,0])+5, y_max =max(outputs[:,1])+5,
          x_min =min(outputs[:,0])-3, y_min =min(outputs[:,1])-3)
scatter_pictures(inputs, outputs,samples=100)

# Apply QDA Layer (known from DDU model)

In [None]:
import copy
classifier = nn.Linear(d, c,bias=True)
net_sm = LeNet(embedding_dim=d, classifier=classifier)
net_sm = load_net('MNISTd2LeNetLinear.t7',net_sm).to(device)
net_ddu = copy.deepcopy(net_sm)
embeddings, labels = gather_embeddings(net_ddu, d, trainloader, device=device, storage_device=device)

classifier = Gauss_DDU(d, c, gamma =1)
classifier.fit(embeddings,labels)
net_ddu.classifier = classifier
with torch.no_grad():
    min_dist = classifier(embeddings).abs().min().item()
    classifier.gamma.data = classifier.gamma.data * (-np.log(0.99995)/min_dist)

sgd = optim.SGD([{'params': net.parameters()},],
                lr=0.1, momentum=0.9, weight_decay=5e-4)
optimizer = Optimizer(sgd, trainloader, device)
(acc,conf) = optimizer.test_acc(net_ddu,CE_Loss(c, device), testloader)

state = {'net': net.state_dict(),'acc': acc}
f= 'checkpoint/%s%s%s.t7'%(name,net.__class__.__name__,net.classifier.__class__.__name__)
torch.save(state, './'+f)
print('Saved as '+f)

In [None]:
net_ddu.classifier.gamma

In [None]:
inputs, targets = next(iter(testloader))
outputs = net_ddu.embed(inputs).detach()
fig, ax = plt.subplots(figsize=(8,5))
plot_conf(net_ddu.classifier.conf, x_max =max(outputs[:,0])+5, y_max =max(outputs[:,1])+5, x_min =min(outputs[:,0])-3, y_min =min(outputs[:,1])-3)
scatter_pictures(inputs, outputs,samples=100)

# Where do random noise pictures land in the embedding?

In [None]:
noise_inputs = (torch.rand_like(inputs)-0.5)
noise_outputs = net_ddu.embed(noise_inputs).detach()
fig, ax = plt.subplots(figsize=(8,5))
plot_conf(net_ddu.classifier.conf, x_max =max(noise_outputs[:,0])+5, y_max =max(noise_outputs[:,1])+5, x_min =min(noise_outputs[:,0])-3, y_min =min(noise_outputs[:,1])-3)
scatter_pictures(noise_inputs, noise_outputs,samples=100)

# What about more structured data?

In [None]:
classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
c=10

# Data
print('==> Preparing data..')
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])

train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=trans)
trainloader_fashion = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=2)


testset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=trans)
testloader_fashion = torch.utils.data.DataLoader(testset, batch_size=512, shuffle=False, num_workers=2)

In [None]:
inputs, targets = next(iter(testloader_fashion))
outputs = net_ddu.embed(inputs).detach()
fig, ax = plt.subplots(figsize=(8,5))
plot_conf(net_ddu.classifier.conf, x_max =max(outputs[:,0])+5,
          y_max =max(outputs[:,1])+5, x_min =min(outputs[:,0])-3,
          y_min =min(outputs[:,1])-3)
scatter_pictures(inputs, outputs,samples=100)

# Visualizing Higher Dimensional Embeddings 

In [None]:
d=8 #d is the embedding dimension
name="MNISTd%i"%d

In [None]:
# Model
print('==> Building model..')
classifier = nn.Linear(d, c,bias=True)
net = LeNet(embedding_dim=d, classifier=classifier)
net = net.to(device)

criterion = CE_Loss(c, device)

sgd = optim.SGD([{'params': net.parameters()},],
                lr=0.1, momentum=0.9, weight_decay=5e-4)
optimizer = Optimizer(sgd, trainloader, device)

In [None]:
epoch_offset =0
for lr,max_epochs in [(0.001,20),(0.001,10)]:
    optimizer.optimizer.param_groups[0]['lr'] = lr
    print("===== Optimize with step size ",lr)
    for epoch in range(epoch_offset, epoch_offset+max_epochs):
        print('\nEpoch: %d' % epoch)
        optimizer.train_epoch(net, criterion)
        (acc,conf) = optimizer.test_acc(net,criterion, testloader)
        if epoch%5==0:
            with torch.no_grad():
                plot_epoch(net, testloader, device, figsize=(10,5))
    epoch_offset+= max_epochs

print('Saving..')
state = {'net': net.state_dict(),'acc': acc}
if not os.path.isdir('checkpoint'):
    os.mkdir('checkpoint')
torch.save(state, './checkpoint/%s%s%s.t7'%(name,net.__class__.__name__,net.classifier.__class__.__name__))
                