In [1]:
"""Utility classes for NICE.
"""
import argparse
import torch, torchvision
import numpy as np
import nice, utils
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from tqdm import tqdm
import numpy as np
from torch.utils.data import TensorDataset

def save_list(data, path):
    with open(path, "wb") as fp:   #Pickling
        pickle.dump(data, fp)
        
def load_list(path):
    with open(path, "rb") as fp:   # Unpickling
        b = pickle.load(fp)
        return b

"""Additive coupling layer.
"""
class Coupling(nn.Module):
    def __init__(self, in_out_dim, mid_dim, hidden, mask_config):
        """Initialize a coupling layer.

        Args:
            in_out_dim: input/output dimensions.
            mid_dim: number of units in a hidden layer.
            hidden: number of hidden layers.
            mask_config: 1 if transform odd units, 0 if transform even units.
        """
        super(Coupling, self).__init__()
        self.mask_config = mask_config

        self.in_block = nn.Sequential(
            nn.Linear(in_out_dim//2, mid_dim),
            nn.ReLU())
        self.mid_block = nn.ModuleList([
            nn.Sequential(
                nn.Linear(mid_dim, mid_dim),
                nn.ReLU()) for _ in range(hidden - 1)])
        self.out_block = nn.Linear(mid_dim, in_out_dim//2)

    def forward(self, x, reverse=False):
        """Forward pass.

        Args:
            x: input tensor.
            reverse: True in inference mode, False in sampling mode.
        Returns:
            transformed tensor.
        """
        [B, W] = list(x.size())
        x = x.reshape((B, W//2, 2))
        if self.mask_config:
            on, off = x[:, :, 0], x[:, :, 1]
        else:
            off, on = x[:, :, 0], x[:, :, 1]

        off_ = self.in_block(off)
        for i in range(len(self.mid_block)):
            off_ = self.mid_block[i](off_)
        shift = self.out_block(off_)
        if reverse:
            on = on - shift
        else:
            on = on + shift

        if self.mask_config:
            x = torch.stack((on, off), dim=2)
        else:
            x = torch.stack((off, on), dim=2)
        return x.reshape((B, W))

"""Log-scaling layer.
"""
class Scaling(nn.Module):
    def __init__(self, dim):
        """Initialize a (log-)scaling layer.

        Args:
            dim: input/output dimensions.
        """
        super(Scaling, self).__init__()
        self.scale = nn.Parameter(
            torch.zeros((1, dim)), requires_grad=True)

    def forward(self, x, reverse=False):
        """Forward pass.

        Args:
            x: input tensor.
            reverse: True in inference mode, False in sampling mode.
        Returns:
            transformed tensor and log-determinant of Jacobian.
        """
        log_det_J = torch.sum(self.scale)
        if reverse:
            x = x * torch.exp(-self.scale)
        else:
            x = x * torch.exp(self.scale)
        return x, log_det_J

"""NICE main model.
"""
class NICE(nn.Module):
    def __init__(self, prior, coupling, 
        fc_in_dim, in_out_dim, mid_dim, hidden, mask_config):
        """Initialize a NICE.

        Args:
            prior: prior distribution over latent space Z.
            coupling: number of coupling layers.
            in_out_dim: input/output dimensions.
            mid_dim: number of units in a hidden layer.
            hidden: number of hidden layers.
            mask_config: 1 if transform odd units, 0 if transform even units.
        """
        super(NICE, self).__init__()
        self.prior = prior
        self.in_out_dim = in_out_dim
        self.fc_in_dim = fc_in_dim
        self.coupling = nn.ModuleList([
            Coupling(in_out_dim=in_out_dim, 
                     mid_dim=mid_dim, 
                     hidden=hidden, 
                     mask_config=(mask_config+i)%2) \
            for i in range(coupling)])
        self.scaling = Scaling(in_out_dim)
        self.fc = nn.Sequential(
            nn.Linear(fc_in_dim, 128),
            nn.LeakyReLU(),
            nn.Linear(128, in_out_dim)
        )

    def g(self, z):
        """Transformation g: Z -> X (inverse of f).

        Args:
            z: tensor in latent space Z.
        Returns:
            transformed tensor in data space X.
        """
        x, _ = self.scaling(z, reverse=True)
        for i in reversed(range(len(self.coupling))):
            x = self.coupling[i](x, reverse=True)
        return x

    def f(self, x):
        """Transformation f: X -> Z (inverse of g).

        Args:
            x: tensor in data space X.
        Returns:
            transformed tensor in latent space Z.
        """
        for i in range(len(self.coupling)):
            x = self.coupling[i](x)
        return self.scaling(x)

    def log_prob(self, x):
        """Computes data log-likelihood.

        (See Section 3.3 in the NICE paper.)

        Args:
            x: input minibatch.
        Returns:
            log-likelihood of input.
        """
#         semantic_z = self.fc(x)
        z, log_det_J = self.f(x)
        log_ll = torch.sum(self.prior.log_prob(z), dim=1)
        return log_ll + log_det_J

    def sample(self, size):
        """Generates samples.

        Args:
            size: number of samples to generate.
        Returns:
            samples from the data space X.
        """
        z = self.prior.sample((size, self.in_out_dim)).cuda()
        return self.g(z)

    def forward(self, x):
        """Forward pass.

        Args:
            x: input minibatch.
        Returns:
            log-likelihood of input.
        """
        return self.log_prob(x)
    
class Net(nn.Module):
    def __init__(self):
        """Initialize a coupling layer.

        Args:
            in_out_dim: input/output dimensions.
            mid_dim: number of units in a hidden layer.
            hidden: number of hidden layers.
            mask_config: 1 if transform odd units, 0 if transform even units.
        """
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 2, stride=3)
        self.conv2 = nn.Conv2d(16, 32, 2, stride=3)
        self.fc1 = nn.Linear(32 * 3 *3, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x, reverse=False):
        x = self.conv1(x)
        x = F.leaky_relu(x)
        x = self.conv2(x)
        x = F.leaky_relu(x)
        x = x.reshape(-1, 32 * 3 *3)
        x = self.fc1(x)
        x = F.leaky_relu(x)
        embd = x
        x = self.fc2(x)
        x = F.softmax(x)
        return x, embd

In [2]:
device = torch.device("cuda:6")

# model hyperparameters
dataset = 'mnist'
batch_size = 200
latent = 'logistic'
max_iter = 15000
fc_in_dim = 128
sample_size = 64
coupling = 4
mask_config = 1.

# optimization hyperparameters
lr = 1e-3
momentum = 0.9
decay = 0.999

zca = None
mean = None
if dataset == 'mnist':
    mean = torch.load('./statistics/mnist_mean.pt')
    (full_dim, mid_dim, hidden) = (128, 500, 5)
    transform = torchvision.transforms.ToTensor()
    trainset = torchvision.datasets.MNIST(root='~/torch/data/MNIST',
        train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
        batch_size=batch_size, shuffle=True, num_workers=2)
elif dataset == 'emnist':
    mean = torch.load('./statistics/emnist_mean.pt')
    (full_dim, mid_dim, hidden) = (128, 500, 5)
    emnist_data = load_list('./emnist_test_data.pkl')
    emnist_label = load_list('./emnist_test_label.pkl')
    trainset = TensorDataset(torch.Tensor(emnist_data), torch.Tensor(emnist_label))
    trainloader = torch.utils.data.DataLoader(trainset,
                    batch_size=batch_size, shuffle=True, num_workers=2)
elif dataset == 'fashion-mnist':
    mean = torch.load('./statistics/fashion_mnist_mean.pt')
    (full_dim, mid_dim, hidden) = (128, 500, 5)
    transform = torchvision.transforms.ToTensor()
    trainset = torchvision.datasets.FashionMNIST(root='~/torch/data/FashionMNIST',
        train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
        batch_size=batch_size, shuffle=True, num_workers=2)

if latent == 'normal':
    prior = torch.distributions.Normal(
        torch.tensor(0.).to(device), torch.tensor(1.).to(device))
elif latent == 'logistic':
    prior = utils.StandardLogistic()

filename = '%s_' % dataset \
         + 'bs%d_' % batch_size \
         + '%s_' % latent \
         + 'cp%d_' % coupling \
         + 'md%d_' % mid_dim \
         + 'hd%d_' % hidden

flow = NICE(prior=prior, 
            coupling=coupling,
            fc_in_dim = fc_in_dim,
            in_out_dim=full_dim, 
            mid_dim=mid_dim, 
            hidden=hidden, 
            mask_config=mask_config).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    flow.parameters(), lr=lr, betas=(momentum, decay), eps=1e-4)

model = Net().to(device)
model.load_state_dict(torch.load('./%s_embd_model.pt'%dataset))
model = model.to(device)
model.eval()

total_iter = 0
train = True
running_loss = 0

while train:
    for _, data in enumerate(trainloader, 1):
        flow.train()    # set to training mode
        if total_iter == max_iter:
            train = False
            break

        total_iter += 1
        optimizer.zero_grad()    # clear gradient tensors

        inputs, target = data
        inputs = utils.prepare_data(
            inputs, dataset, zca=zca, mean=mean).to(device)
        inputs = inputs.reshape(-1, 1, 28, 28)
        _, inputs = model(inputs)
        target = target.long().to(device)

        # log-likelihood of input minibatch
        loss_output= flow(inputs)
#         if dataset == 'emnist':
#             loss = -loss_output.mean() + criterion(F.softmax(sematic), target-1)
#         else:
#             loss = -loss_output.mean() + criterion(F.softmax(sematic), target)
        loss = -loss_output.mean()
        running_loss += float(loss)

        # backprop and update parameters
        loss.backward()
        optimizer.step()

        if total_iter % 1000 == 0:
            mean_loss = running_loss / 1000
            bit_per_dim = (mean_loss + np.log(256.) * full_dim) \
                        / (full_dim * np.log(2.))
            print('iter %s:' % total_iter, 
                'loss = %.3f' % mean_loss, 
                'bits/dim = %.3f' % bit_per_dim)
            running_loss = 0.0

#             flow.eval()        # set to inference mode
#             with torch.no_grad():
#                 z, _ = flow.f(inputs)
#                 reconst = flow.g(z).cpu()
#                 reconst = utils.prepare_data(
#                     reconst, dataset, zca=zca, mean=mean, reverse=True)
#                 samples = flow.sample(sample_size).cpu()
#                 samples = utils.prepare_data(
#                     samples, dataset, zca=zca, mean=mean, reverse=True)
#                 torchvision.utils.save_image(torchvision.utils.make_grid(reconst),
#                     './reconstruction/' + filename +'iter%d.png' % total_iter)
#                 torchvision.utils.save_image(torchvision.utils.make_grid(samples),
#                     './samples/' + filename +'iter%d.png' % total_iter)

print('Finished training!')

torch.save({
    'total_iter': total_iter, 
    'model_state_dict': flow.state_dict(), 
    'optimizer_state_dict': optimizer.state_dict(), 
    'dataset': dataset, 
    'batch_size': batch_size, 
    'latent': latent, 
    'coupling': coupling, 
    'mid_dim': mid_dim, 
    'hidden': hidden, 
    'mask_config': mask_config}, 
    './models_aux/%s/'%dataset + filename +'iter%d.tar' % total_iter)

print('Checkpoint Saved')



iter 1000: loss = 216.652 bits/dim = 10.442
iter 2000: loss = 167.700 bits/dim = 9.890
iter 3000: loss = 148.281 bits/dim = 9.671
iter 4000: loss = 135.019 bits/dim = 9.522
iter 5000: loss = 125.547 bits/dim = 9.415
iter 6000: loss = 118.176 bits/dim = 9.332
iter 7000: loss = 111.626 bits/dim = 9.258
iter 8000: loss = 107.021 bits/dim = 9.206
iter 9000: loss = 103.130 bits/dim = 9.162
iter 10000: loss = 99.165 bits/dim = 9.118
iter 11000: loss = 96.443 bits/dim = 9.087
iter 12000: loss = 93.926 bits/dim = 9.059
iter 13000: loss = 91.408 bits/dim = 9.030
iter 14000: loss = 89.516 bits/dim = 9.009
iter 15000: loss = 87.820 bits/dim = 8.990
Finished training!
Checkpoint Saved


# Likelihood Test

In [9]:
device = torch.device("cuda:0")

# model hyperparameters
dataset = 'emnist'
trained_dataset = 'mnist'
batch_size = 200
latent = 'logistic'
max_iter = 15000
fc_in_dim = 128
sample_size = 64
coupling = 4
mask_config = 1.

# optimization hyperparameters
lr = 1e-3
momentum = 0.9
decay = 0.999

zca = None
mean = None
if dataset == 'mnist':
    mean = torch.load('./statistics/mnist_mean.pt')
    (full_dim, mid_dim, hidden) = (128, 500, 5)
    transform = torchvision.transforms.ToTensor()
    trainset = torchvision.datasets.MNIST(root='~/torch/data/MNIST',
        train=False, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
        batch_size=batch_size, shuffle=False, num_workers=2)
elif dataset == 'emnist':
    mean = torch.load('./statistics/emnist_mean.pt')
    (full_dim, mid_dim, hidden) = (128, 500, 5)
    emnist_data = load_list('./emnist_test_data.pkl')
    emnist_label = load_list('./emnist_test_label.pkl')
    trainset = TensorDataset(torch.Tensor(emnist_data), torch.Tensor(emnist_label))
    trainloader = torch.utils.data.DataLoader(trainset,
                    batch_size=batch_size, shuffle=False, num_workers=2)
elif dataset == 'fashion-mnist':
    mean = torch.load('./statistics/fashion_mnist_mean.pt')
    (full_dim, mid_dim, hidden) = (128, 500, 5)
    transform = torchvision.transforms.ToTensor()
    trainset = torchvision.datasets.FashionMNIST(root='~/torch/data/FashionMNIST',
        train=False, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
        batch_size=batch_size, shuffle=False, num_workers=2)

if latent == 'normal':
    prior = torch.distributions.Normal(
        torch.tensor(0.).to(device), torch.tensor(1.).to(device))
elif latent == 'logistic':
    prior = utils.StandardLogistic()

filename = '%s_' % trained_dataset \
         + 'bs%d_' % batch_size \
         + '%s_' % latent \
         + 'cp%d_' % coupling \
         + 'md%d_' % mid_dim \
         + 'hd%d_' % hidden

flow = NICE(prior=prior, 
            coupling=coupling,
            fc_in_dim = fc_in_dim,
            in_out_dim=full_dim, 
            mid_dim=mid_dim, 
            hidden=hidden, 
            mask_config=mask_config).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    flow.parameters(), lr=lr, betas=(momentum, decay), eps=1e-4)

total_iter = 0
train = True
running_loss = 0
loss_list = []

model = Net().to(device)
model.load_state_dict(torch.load('./%s_embd_model.pt'%trained_dataset))
model = model.to(device)
model.eval()

print('Model loading complete!')
print('./models_aux/%s/%siter%d.tar'%(trained_dataset, filename, max_iter))
checkpoints = torch.load('./models_aux/%s/%siter%d.tar'%(trained_dataset, filename, max_iter))
flow.load_state_dict(checkpoints['model_state_dict'])
print('Calculating log_prob')

for _, data in tqdm(enumerate(trainloader, 1)):
    flow.eval()
    if total_iter == max_iter:
        train = False
        break

    total_iter += 1
    optimizer.zero_grad()    # clear gradient tensors

    inputs, target = data
    inputs = utils.prepare_data(
        inputs, dataset, zca=None, mean=mean).to(device)
    inputs = inputs.reshape(-1, 1, 28, 28)
    _, inputs = model(inputs)
    target = target.long().to(device)

    # log-likelihood of input minibatch
    loss_output= flow(inputs)
    
    loss_list.append(-loss_output.cpu().data.numpy())
    
loss_list = np.array(loss_list).reshape(-1)
print('Finished!')

Model loading complete!
./models_aux/mnist/mnist_bs200_logistic_cp4_md500_hd5_iter15000.tar
Calculating log_prob


40it [00:00, 99.29it/s]

Finished!





In [10]:
loss_list

array([432.8667 , 392.5661 , 339.13953, ..., 487.11993, 266.9073 ,
       229.79628], dtype=float32)

In [11]:
save_list(loss_list, './results/aux/train_%s_test_%s_loss.pkl'%(trained_dataset, dataset))
load_list('./results/aux/train_%s_test_%s_loss.pkl'%(trained_dataset, dataset))

array([432.8667 , 392.5661 , 339.13953, ..., 487.11993, 266.9073 ,
       229.79628], dtype=float32)

In [5]:
dataset = 'emnist'
batch_size = 100
device = 'cuda:4'

if dataset == 'mnist':
    mean = torch.load('./statistics/mnist_mean.pt')
    (full_dim, mid_dim, hidden) = (1 * 28 * 28, 1000, 5)
    transform = torchvision.transforms.ToTensor()
    trainset = torchvision.datasets.MNIST(root='~/torch/data/MNIST',
        train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
        batch_size=batch_size, shuffle=True, num_workers=2)
elif dataset == 'emnist':
    mean = torch.load('./statistics/emnist_mean.pt')
    (full_dim, mid_dim, hidden) = (1 * 28 * 28, 1000, 5)
    emnist_data = load_list('./emnist_train_data.pkl')
    emnist_label = load_list('./emnist_train_label.pkl')
    trainset = TensorDataset(torch.Tensor(emnist_data), torch.Tensor(emnist_label))
    trainloader = torch.utils.data.DataLoader(trainset,
                batch_size=batch_size, shuffle=True, num_workers=2)
elif dataset == 'fashion-mnist':
    mean = torch.load('./statistics/fashion_mnist_mean.pt')
    (full_dim, mid_dim, hidden) = (1 * 28 * 28, 1000, 5)
    transform = torchvision.transforms.ToTensor()
    trainset = torchvision.datasets.FashionMNIST(root='~/torch/data/FashionMNIST',
        train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
        batch_size=batch_size, shuffle=True, num_workers=2)

In [6]:
class Net(nn.Module):
    def __init__(self):
        """Initialize a coupling layer.

        Args:
            in_out_dim: input/output dimensions.
            mid_dim: number of units in a hidden layer.
            hidden: number of hidden layers.
            mask_config: 1 if transform odd units, 0 if transform even units.
        """
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 2, stride=3)
        self.conv2 = nn.Conv2d(16, 32, 2, stride=3)
        self.fc1 = nn.Linear(32 * 3 *3, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x, reverse=False):
        x = self.conv1(x)
        x = F.leaky_relu(x)
        x = self.conv2(x)
        x = F.leaky_relu(x)
        x = x.reshape(-1, 32 * 3 *3)
        x = self.fc1(x)
        x = F.leaky_relu(x)
        embd = x
        x = self.fc2(x)
        x = F.softmax(x)
        return x, embd

In [7]:
model = Net().to(device)
optimizer = torch.optim.Adam(
    model.parameters(), lr=0.001, eps=1e-4)
criterion = nn.CrossEntropyLoss()

In [8]:
for epoch in tqdm(range(50)):
    epoch_loss = 0
    for _, (inputs, target) in enumerate(trainloader):
        model.train()

        optimizer.zero_grad()    # clear gradient tensors
        inputs = utils.prepare_data(
            inputs, dataset, zca=None, mean=mean).to(device)
        inputs = inputs.reshape(-1, 1, 28, 28)
        target = target.long().to(device)

        # log-likelihood of input minibatch
        outputs, _ = model(inputs)

        loss = criterion(outputs, target-1)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print('epoch loss:', epoch_loss)
        
print('Finished!')

  2%|▏         | 1/50 [00:02<02:24,  2.95s/it]

epoch loss: 933.7865806818008


  4%|▍         | 2/50 [00:05<02:17,  2.86s/it]

epoch loss: 850.8915197849274


  6%|▌         | 3/50 [00:08<02:15,  2.88s/it]

epoch loss: 822.8712948560715


  8%|▊         | 4/50 [00:11<02:08,  2.80s/it]

epoch loss: 804.37983751297


 10%|█         | 5/50 [00:13<02:05,  2.79s/it]

epoch loss: 792.2115323543549


 12%|█▏        | 6/50 [00:16<01:57,  2.67s/it]

epoch loss: 783.8098410367966


 14%|█▍        | 7/50 [00:19<02:04,  2.89s/it]

epoch loss: 777.7055968046188


 16%|█▌        | 8/50 [00:22<01:57,  2.80s/it]

epoch loss: 773.0303585529327


 18%|█▊        | 9/50 [00:24<01:51,  2.73s/it]

epoch loss: 769.3047661781311


 20%|██        | 10/50 [00:27<01:48,  2.70s/it]

epoch loss: 766.1822824478149


 22%|██▏       | 11/50 [00:29<01:41,  2.59s/it]

epoch loss: 762.9142169952393


 24%|██▍       | 12/50 [00:32<01:37,  2.56s/it]

epoch loss: 760.3056217432022


 26%|██▌       | 13/50 [00:34<01:34,  2.55s/it]

epoch loss: 758.3159655332565


 28%|██▊       | 14/50 [00:37<01:34,  2.63s/it]

epoch loss: 756.2790558338165


 30%|███       | 15/50 [00:40<01:32,  2.64s/it]

epoch loss: 754.0227473974228


 32%|███▏      | 16/50 [00:42<01:29,  2.62s/it]

epoch loss: 752.7859129905701


 34%|███▍      | 17/50 [00:45<01:26,  2.63s/it]

epoch loss: 750.8165979385376


 36%|███▌      | 18/50 [00:48<01:24,  2.63s/it]

epoch loss: 749.4775792360306


 38%|███▊      | 19/50 [00:50<01:21,  2.64s/it]

epoch loss: 747.9158524274826


 40%|████      | 20/50 [00:53<01:21,  2.71s/it]

epoch loss: 746.5630462169647


 42%|████▏     | 21/50 [00:56<01:17,  2.66s/it]

epoch loss: 744.9250929355621


 44%|████▍     | 22/50 [00:58<01:14,  2.67s/it]

epoch loss: 743.7624547481537


 46%|████▌     | 23/50 [01:01<01:12,  2.68s/it]

epoch loss: 742.8824466466904


 48%|████▊     | 24/50 [01:04<01:09,  2.68s/it]

epoch loss: 741.8954775333405


 50%|█████     | 25/50 [01:07<01:07,  2.72s/it]

epoch loss: 740.6927067041397


 52%|█████▏    | 26/50 [01:09<01:05,  2.71s/it]

epoch loss: 739.6339681148529


 54%|█████▍    | 27/50 [01:12<01:02,  2.70s/it]

epoch loss: 738.7188880443573


 56%|█████▌    | 28/50 [01:15<01:00,  2.74s/it]

epoch loss: 737.5434273481369


 58%|█████▊    | 29/50 [01:18<00:57,  2.72s/it]

epoch loss: 736.8673624992371


 60%|██████    | 30/50 [01:20<00:54,  2.72s/it]

epoch loss: 736.4634869098663


 62%|██████▏   | 31/50 [01:23<00:51,  2.70s/it]

epoch loss: 735.9175111055374


 64%|██████▍   | 32/50 [01:25<00:48,  2.67s/it]

epoch loss: 735.0809383392334


 66%|██████▌   | 33/50 [01:28<00:44,  2.60s/it]

epoch loss: 733.9444080591202


 68%|██████▊   | 34/50 [01:31<00:42,  2.65s/it]

epoch loss: 733.4062818288803


 70%|███████   | 35/50 [01:33<00:39,  2.65s/it]

epoch loss: 732.9633536338806


 72%|███████▏  | 36/50 [01:36<00:37,  2.67s/it]

epoch loss: 732.2552486658096


 74%|███████▍  | 37/50 [01:39<00:34,  2.62s/it]

epoch loss: 731.9425407648087


 76%|███████▌  | 38/50 [01:41<00:31,  2.60s/it]

epoch loss: 731.7111604213715


 78%|███████▊  | 39/50 [01:44<00:28,  2.63s/it]

epoch loss: 731.0291463136673


 80%|████████  | 40/50 [01:46<00:25,  2.57s/it]

epoch loss: 730.5048418045044


 82%|████████▏ | 41/50 [01:49<00:24,  2.68s/it]

epoch loss: 729.976172208786


 84%|████████▍ | 42/50 [01:52<00:21,  2.64s/it]

epoch loss: 729.3653045892715


 86%|████████▌ | 43/50 [01:54<00:18,  2.64s/it]

epoch loss: 729.1404027938843


 88%|████████▊ | 44/50 [01:57<00:16,  2.68s/it]

epoch loss: 728.6273589134216


 90%|█████████ | 45/50 [02:00<00:14,  2.83s/it]

epoch loss: 728.3801941871643


 92%|█████████▏| 46/50 [02:03<00:11,  2.78s/it]

epoch loss: 728.1932671070099


 94%|█████████▍| 47/50 [02:06<00:08,  2.72s/it]

epoch loss: 727.8758239746094


 96%|█████████▌| 48/50 [02:08<00:05,  2.71s/it]

epoch loss: 727.5501205921173


 98%|█████████▊| 49/50 [02:11<00:02,  2.67s/it]

epoch loss: 727.2476817369461


100%|██████████| 50/50 [02:14<00:00,  2.68s/it]

epoch loss: 726.8566207885742
Finished!





In [9]:
torch.save(model.to('cpu').state_dict(), './emnist_embd_model.pt')

In [10]:
dataset = 'emnist'
batch_size = 100
device = 'cuda:6'

if dataset == 'mnist':
    mean = torch.load('./statistics/mnist_mean.pt')
    (full_dim, mid_dim, hidden) = (1 * 28 * 28, 1000, 5)
    transform = torchvision.transforms.ToTensor()
    trainset = torchvision.datasets.MNIST(root='~/torch/data/MNIST',
        train=False, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
        batch_size=batch_size, shuffle=False, num_workers=2)
elif dataset == 'emnist':
    mean = torch.load('./statistics/emnist_mean.pt')
    (full_dim, mid_dim, hidden) = (1 * 28 * 28, 1000, 5)
    emnist_data = load_list('./emnist_test_data.pkl')
    emnist_label = load_list('./emnist_test_label.pkl')
    trainset = TensorDataset(torch.Tensor(emnist_data), torch.Tensor(emnist_label))
    trainloader = torch.utils.data.DataLoader(trainset,
                batch_size=batch_size, shuffle=False, num_workers=2)
elif dataset == 'fashion-mnist':
    mean = torch.load('./statistics/fashion_mnist_mean.pt')
    (full_dim, mid_dim, hidden) = (1 * 28 * 28, 1000, 5)
    transform = torchvision.transforms.ToTensor()
    trainset = torchvision.datasets.FashionMNIST(root='~/torch/data/FashionMNIST',
        train=False, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
        batch_size=batch_size, shuffle=False, num_workers=2)

In [11]:
model = Net().to(device)
model.load_state_dict(torch.load('./emnist_embd_model.pt'))
model = model.to(device)
model.eval()

Net(
  (conv1): Conv2d(1, 16, kernel_size=(2, 2), stride=(3, 3))
  (conv2): Conv2d(16, 32, kernel_size=(2, 2), stride=(3, 3))
  (fc1): Linear(in_features=288, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [12]:
correct = 0
for _, (inputs, target) in enumerate(trainloader):
    model.train()

    optimizer.zero_grad()    # clear gradient tensors
    inputs = utils.prepare_data(
        inputs, dataset, zca=None, mean=mean).to(device)
    inputs = inputs.reshape(-1, 1, 28, 28)
    target = target.long().to(device) -1

    # log-likelihood of input minibatch
    outputs, embd= model(inputs)

    pred = torch.argmax(outputs, axis=1).detach().cpu().numpy()
    target = target.detach().cpu().numpy()
    correct += (pred == target).sum()
     
print(correct/len(trainloader.dataset))



0.891
