In [0]:
import numpy as np
import torch
import torch.nn as nn
import copy
import itertools
from IPython.display import clear_output
from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

import torchvision

torch.manual_seed(1337)
np.random.seed(1337)

if torch.cuda.is_available():
    DEVICE = 'cuda:0'
    torch.cuda.manual_seed(1337)
else:
    DEVICE = None

In [0]:
X_DIM = 28 * 28
Y_DIM = 10
Z_DIM = 50
HIDDEN_DIM1 = 600
HIDDEN_DIM2 = 500
INIT_VAR = 0.001


def ohe_convert(y):
    res = torch.zeros(len(y), Y_DIM)
    res[torch.arange(len(y)), y.squeeze()] = 1
    return res.to(y.device)


class M2(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.encoder_y = torch.nn.Sequential(
            nn.Linear(X_DIM, HIDDEN_DIM2),
            nn.Softplus(),
            nn.Linear(HIDDEN_DIM2, HIDDEN_DIM2),
            nn.Softplus(),
            nn.Linear(HIDDEN_DIM2, Y_DIM)
        )

        self.encoder_z = torch.nn.Sequential(
            nn.Linear(X_DIM + Y_DIM, HIDDEN_DIM2),
            nn.Softplus(),
            nn.Linear(HIDDEN_DIM2, HIDDEN_DIM2),
            nn.Softplus(),
            nn.Linear(HIDDEN_DIM2, Z_DIM * 2)
        )

        self.decoder = nn.Sequential(
            nn.Linear(Y_DIM + Z_DIM, HIDDEN_DIM2),
            nn.Softplus(),
            nn.Linear(HIDDEN_DIM2, HIDDEN_DIM2),
            nn.Softplus(),
            nn.Linear(HIDDEN_DIM2, X_DIM)
        )

        for p in self.parameters():
            if p.ndim == 1:
                p.data.fill_(0)
            else:
                p.data.normal_(0, INIT_VAR)

                
        self.p_z = torch.distributions.Normal(
            torch.zeros(1, device=device), torch.ones(1, device=device)
        )
        self.p_y = torch.distributions.OneHotCategorical(
            probs=torch.ones((1, Y_DIM), device=device) / Y_DIM)


    def forward(self, x):
        probs = self.encode_y(x).probs
        return probs.max(dim=1)[1]


    def encode_y(self, x):
        return torch.distributions.OneHotCategorical(logits=self.encoder_y(x))


    def encode_z(self, x, y):
        res = self.encoder_z(torch.cat([x, y], axis=1))
        means_z, logsigma_z = torch.chunk(res, 2, dim=-1)
        return torch.distributions.Normal(means_z, torch.exp(logsigma_z))


    def decode(self, y, z):
        return torch.distributions.Bernoulli(
            logits=self.decoder(torch.cat([y, z], axis=1)))

In [0]:
LR = 0.0003
BETA1 = 0.9
BETA2 = 0.999
ALPHA = 0.1

def train_M2(model: M2, dl_labeled, dl_unlabeled, dl_test, n_epochs, device):
    def loss_func(x, y, z, p_x_yz, q_z_xy):
        return p_x_yz.log_prob(x).sum(1) + \
            model.p_y.log_prob(y) + \
                model.p_z.log_prob(z).sum(1) - \
                    q_z_xy.log_prob(z).sum(1)

    opt = torch.optim.Adam(model.parameters(), lr=LR, betas=(BETA1, BETA2))
    n_batches = len(dl_labeled) + len(dl_unlabeled)


    unlabeled_per_labeled = len(dl_unlabeled) // len(dl_labeled) + 1
    train_loss_log = []
    val_acc_log = []

    for epoch in tqdm(range(n_epochs)):
        model.train()
        labeled_i = unlabeled_i = 0
        dl_labeled_iterable = iter(dl_labeled)
        dl_unlabeled_iterable = iter(dl_unlabeled)
        for batch_id in range(n_batches):
            unsupervised = bool(batch_id % unlabeled_per_labeled)
            if not unsupervised and labeled_i == len(dl_labeled):
                unsupervised = True
            if unsupervised:
                unlabeled_i += 1
                x, _ = next(dl_unlabeled_iterable)
                x = x.view(x.shape[0], -1).to(device)
                q_y = model.encode_y(x)
                loss = - q_y.entropy()
                for y in q_y.enumerate_support():
                    q_z_xy = model.encode_z(x, y)
                    z = q_z_xy.rsample()
                    p_x_yz = model.decode(y, z)
                    L = loss_func(x, y, z, p_x_yz, q_z_xy)
                    loss += q_y.log_prob(y).exp() * (-L)
                
            else:
                labeled_i += 1
                x, y = next(dl_labeled_iterable)
                x = x.view(x.shape[0], -1).to(device)
                y = ohe_convert(y).to(device)

                q_y = model.encode_y(x)

                q_z_xy = model.encode_z(x, y)
                z = q_z_xy.rsample()
                p_x_yz = model.decode(y, z)
                loss = -loss_func(x, y, z, p_x_yz, q_z_xy)
                loss -= ALPHA * len(dl_labeled) * q_y.log_prob(y)

            loss = loss.mean(0)
            opt.zero_grad()
            loss.backward()
            opt.step()

        train_loss_log.append(loss.data.item())
        val_acc_log.append(evaluate(model, dl_test, device))
                # print (f'Epoch [{epoch + 1}/{n_epochs}], Loss: {loss.data.item()}')
        plot_history(train_loss_log, val_acc_log, epoch + 1)
        # torch.save(model.state_dict(), '../log/')
    return train_loss_log, val_acc_log


def plot_history(train_history, val_history, epoch):
    clear_output()
    plt.figure()
    plt.title('Train loss')
    plt.plot(np.arange(len(train_history)), train_history, label='train', zorder=1)
    plt.xlabel('train steps')
    
    plt.legend(loc='best')
    plt.grid()

    plt.figure()
    plt.title('Val accuracy')
    plt.plot(np.arange(len(val_history)), val_history, label='val', c='orange', zorder=1)
    plt.xlabel('train steps')
    
    plt.legend(loc='best')
    plt.grid()

    plt.show()


@torch.no_grad()
def evaluate(model, dl, device):
    model.eval()
    accurate_preds = 0
    all_count = 0
    for x, y in dl:
        all_count += x.shape[0]
        x = x.to(device).view(x.shape[0], -1)
        y = y.to(device)
        preds = model(x)
        accurate_preds += (preds == y).sum().item()

    return accurate_preds / all_count


In [0]:
def load_data(n_labeled, batch_size=64):

    transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])

    train_labeled = torchvision.datasets.MNIST(PATH, download=True, train=True, transform=transforms)
    train_unlabeled = copy.deepcopy(train_labeled)

    n_classes = np.unique(train_labeled.train_labels).size
    n_labeled_class = n_labeled // n_classes

    x_labeled, y_labeled, x_unlabeled, y_unlabeled = map(lambda x: [], range(4))
    for i in range(n_classes):
        mask = train_labeled.train_labels == i
        x_masked = train_labeled.data[mask]
        y_masked = train_labeled.train_labels[mask]
        np.random.shuffle(x_masked)

        x_labeled.append(x_masked[:n_labeled_class])
        x_unlabeled.append(x_masked[n_labeled_class:])
        y_labeled.append(y_masked[:n_labeled_class])
        y_unlabeled.append(y_masked[n_labeled_class:])

    
    train_unlabeled.data = torch.cat(x_unlabeled).squeeze()
    train_unlabeled.labels = torch.cat(y_unlabeled)
    train_labeled.data = torch.cat(x_labeled).squeeze()
    train_labeled.labels = torch.cat(y_labeled)

    dl_train_labeled = torch.utils.data.DataLoader(train_labeled, batch_size=batch_size, shuffle=True)
    dl_train_unlabeled = torch.utils.data.DataLoader(train_unlabeled, batch_size=batch_size, shuffle=True)

    test = torchvision.datasets.MNIST(PATH, download=True, train=False, transform=transforms)
    dl_test = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False)

    return dl_train_labeled, dl_train_unlabeled, dl_test

In [0]:
dl_labeled, dl_unlabeled, dl_test = load_data(3000)



In [0]:
N_EPOCHS = 1000

In [0]:
model = M2(DEVICE).to(DEVICE)
train_loss_log, val_acc_log = train_M2(model, dl_labeled, dl_unlabeled, dl_test, N_EPOCHS, DEVICE)

In [0]:
train_loss_log[-1], val_acc_log[-1]

(105.78288269042969, 0.098)