In [None]:
"""
Vanilla VAE implementation for EvilMouse Data.
This uses a 3D encoder/decoder structure, since data is time-series of 2D images
and we wish to convolve over pt closer in time.
This is for mouse video data from Musall et al. (2019) which has 160x120 frames.
"""

import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import h5py
import os
import datetime
import itertools
from umap import UMAP
import torch
from torch.distributions import LowRankMultivariateNormal
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
SUMMARY_WRITER_PATH = "/home/achint/Practice_code/logs"

In [None]:
X_SHAPE = (160, 120, 31)
X_DIM = np.prod(X_SHAPE)

class VAE(nn.Module):
    def __init__(self, save_dir='', lr=1e-3, z_dim=32, model_precision=10, device_name="auto"):
        super(VAE, self).__init__()
        self.save_dir = save_dir
        self.lr = lr
        self.z_dim = z_dim
        self.model_precision = model_precision
        assert device_name != "cuda" or torch.cuda.is_available()
        if device_name == "auto":
            device_name = "cuda" if torch.cuda.is_available() else "cpu"
            self.device = torch.device(device_name)
        if self.save_dir != '' and not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        self._build_network()
        self.optimizer = Adam(self.parameters(), lr=self.lr)
        self.epoch = 0
        self.loss = {'train':{}, 'test':{}}
        #init summary writter instance for TB logging
#         ts = datetime.datetime.now().date()
#         self.writer = SummaryWriter(log_dir = os.path.join(self.save_dir, 'run', ts.strftime('%m_%d_%Y')))
        self.writer = SummaryWriter(log_dir = SUMMARY_WRITER_PATH)

        self.to(self.device)

    def _build_network(self):
        # Encoder
        self.conv1 = nn.Conv3d(1,8,3,1,padding=1)
        self.conv2 = nn.Conv3d(8,8,3,2,padding=1)
        self.conv3 = nn.Conv3d(8,16,3,1,padding=1)
        self.conv4 = nn.Conv3d(16,16,3,2,padding=1)
        self.conv5 = nn.Conv3d(16,24,3,1,padding=1)
        self.conv6 = nn.Conv3d(24,24,3,2,padding=1)
        self.conv7 = nn.Conv3d(24,32,3,1,padding=1)
        self.conv8 = nn.Conv3d(32,32,3,2,padding=1)
        self.bn1 = nn.BatchNorm3d(1)
        self.bn2 = nn.BatchNorm3d(8)
        self.bn3 = nn.BatchNorm3d(8)
        self.bn4 = nn.BatchNorm3d(16)
        self.bn5 = nn.BatchNorm3d(16)
        self.bn6 = nn.BatchNorm3d(24)
        self.bn7 = nn.BatchNorm3d(24)
        self.bn8 = nn.BatchNorm3d(32)
        self.fc1 = nn.Linear(5120, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc31 = nn.Linear(256, 64)
        self.fc32 = nn.Linear(256, 64)
        self.fc33 = nn.Linear(256, 64)
        self.fc41 = nn.Linear(64, self.z_dim)
        self.fc42 = nn.Linear(64, self.z_dim)
        self.fc43 = nn.Linear(64, self.z_dim)
        #Decoder
        self.fc5 = nn.Linear(self.z_dim,64)
        self.fc6 = nn.Linear(64,256)
        self.fc7 = nn.Linear(256,1024)
        self.fc8 = nn.Linear(1024,5120)
        self.convt1 = nn.ConvTranspose3d(32,24,3,1,padding=1)
        self.convt2 = nn.ConvTranspose3d(24,24,3,2,padding=1,output_padding=(1, 0, 1))
        self.convt3 = nn.ConvTranspose3d(24,16,3,1,padding=1)
        self.convt4 = nn.ConvTranspose3d(16,16,3,2,padding=1,output_padding=1)
        self.convt5 = nn.ConvTranspose3d(16,8,3,1,padding=1)
        self.convt6 = nn.ConvTranspose3d(8,8,3,2,padding=1,output_padding=1)
        self.convt7 = nn.ConvTranspose3d(8,4,3,1,padding=1)
        self.convt8 = nn.ConvTranspose3d(4,4,3,2,padding=1, output_padding=(1, 1, 0))
        self.convt9 = nn.ConvTranspose3d(4,1,3,1,padding=1)
        self.bn9 = nn.BatchNorm3d(32)
        self.bn10 = nn.BatchNorm3d(24)
        self.bn11 = nn.BatchNorm3d(24)
        self.bn12 = nn.BatchNorm3d(16)
        self.bn13 = nn.BatchNorm3d(16)
        self.bn14 = nn.BatchNorm3d(8)
        self.bn15 = nn.BatchNorm3d(8)
        self.bn16 = nn.BatchNorm3d(4)
        self.bn17 = nn.BatchNorm3d(4)


    def _get_layers(self):
        """Return a dictionary mapping names to network layers."""
        return {'fc1':self.fc1, 'fc2':self.fc2, 'fc31':self.fc31,
                'fc32':self.fc32, 'fc33':self.fc33, 'fc41':self.fc41,
                'fc42':self.fc42, 'fc43':self.fc43, 'fc5':self.fc5,
                'fc6':self.fc6, 'fc7':self.fc7, 'fc8':self.fc8, 'bn1':self.bn1,
                'bn2':self.bn2, 'bn3':self.bn3, 'bn4':self.bn4, 'bn5':self.bn5,
                'bn6':self.bn6, 'bn7':self.bn7, 'bn8':self.bn8, 'bn9':self.bn9,
                'bn10':self.bn10, 'bn11':self.bn11, 'bn12':self.bn12,
                'bn13':self.bn13, 'bn14':self.bn14, 'bn15':self.bn15,
                'bn16':self.bn16, 'bn17':self.bn17,'conv1':self.conv1,
                'conv2':self.conv2, 'conv3':self.conv3, 'conv4':self.conv4,
                'conv5':self.conv5, 'conv6':self.conv6, 'conv7':self.conv7,
                'conv8':self.conv8,'convt1':self.convt1, 'convt2':self.convt2,
                'convt3':self.convt3, 'convt4':self.convt4,'convt5':self.convt5,
                'convt6':self.convt6,'convt7':self.convt7, 'convt8':self.convt8,
                'convt9':self.convt9}

    def encode(self, x):
        x = x.unsqueeze(1)
        x = F.relu(self.conv1(self.bn1(x)))
        x = F.relu(self.conv2(self.bn2(x)))
        x = F.relu(self.conv3(self.bn3(x)))
        x = F.relu(self.conv4(self.bn4(x)))
        x = F.relu(self.conv5(self.bn5(x)))
        x = F.relu(self.conv6(self.bn6(x)))
        x = F.relu(self.conv7(self.bn7(x)))
        x = F.relu(self.conv8(self.bn8(x)))
        x = x.view(-1, 5120)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mu = F.relu(self.fc31(x))
        mu = self.fc41(mu)
        u = F.relu(self.fc32(x))
        u = self.fc42(u).unsqueeze(-1) # Last dimension is rank \Sigma = 1.
        d = F.relu(self.fc33(x))
        d = torch.exp(self.fc43(d)) # d must be positive.
        return mu, u, d

    def decode(self, z):
        z = F.relu(self.fc5(z))
        z = F.relu(self.fc6(z))
        z = F.relu(self.fc7(z))
        z = F.relu(self.fc8(z))
        z = z.view(-1,32,10,8,2)
        z = F.relu(self.convt1(self.bn9(z)))
        z = F.relu(self.convt2(self.bn10(z)))
        z = F.relu(self.convt3(self.bn11(z)))
        z = F.relu(self.convt4(self.bn12(z)))
        z = F.relu(self.convt5(self.bn13(z)))
        z = F.relu(self.convt6(self.bn14(z)))
        z = self.convt7(self.bn15(z))
        z = self.convt8(self.bn16(z))
        z = self.convt9(self.bn17(z))
        z = z.view(-1, X_DIM)
        return z

    def forward(self, x, return_latent_rec=False, train_mode=False):
        mu, u, d = self.encode(x)
        latent_dist = LowRankMultivariateNormal(mu, u, d)
        z = latent_dist.rsample()
        x_rec = self.decode(z)
        if train_mode: #log reconstructions
            batch_size = x_rec.detach().cpu().numpy().shape[0]
            self.log_reconstruction(x_rec.detach().cpu().numpy(), batch_size, log_type='train')
        # E_{q(z|x)} p(z)
        elbo = -0.5 * (torch.sum(torch.pow(z,2)) + self.z_dim * np.log(2*np.pi))
        # E_{q(z|x)} p(x|z)
        pxz_term = -0.5 * X_DIM * (np.log(2*np.pi/self.model_precision))
        l2s = torch.sum(torch.pow(x.view(x.shape[0],-1) - x_rec, 2), dim=1)
        pxz_term = pxz_term - 0.5 * self.model_precision * torch.sum(l2s)
        elbo = elbo + pxz_term
        # H[q(z|x)]
        elbo = elbo + torch.sum(latent_dist.entropy())
        if return_latent_rec:
            return -elbo, z.detach().cpu().numpy(), \
            x_rec.view(-1, X_SHAPE[0], X_SHAPE[1], X_SHAPE[2]).detach().cpu().numpy()
        return -elbo

    def train_epoch(self, train_loader):
        self.train()
        train_loss = 0.0
        for batch_idx, data in enumerate(train_loader):
            self.optimizer.zero_grad()
            frame = data['frame']
            frame = frame.to(self.device)
            loss = self.forward(frame, train_mode=True)
            train_loss += loss.item()
            loss.backward()
            self.optimizer.step()
        train_loss /= len(train_loader.dataset)
        print('Epoch: {} Average loss: {:.4f}'.format(self.epoch, \
        train_loss))
        self.epoch += 1
        return train_loss

    def test_epoch(self, test_loader):
        self.eval()
        test_loss = 0.0
        with torch.no_grad():
            for i, data in enumerate(test_loader):
                frame = data['frame']
                frame = frame.to(self.device)
                loss = self.forward(frame)
                test_loss += loss.item()
        test_loss /= len(test_loader.dataset)
        print('Test loss: {:.4f}'.format(test_loss))
        return test_loss

    def train_loop(self, loaders, epochs=100, test_freq=2, save_freq=10):
        print("="*40)
        print("Training: epochs", self.epoch, "to", self.epoch+epochs-1)
        print("Training set:", len(loaders['train'].dataset))
        print("Test set:", len(loaders['test'].dataset))
        print("="*40)
        # For some number of epochs...
        for epoch in range(self.epoch, self.epoch+epochs):
            # Run through the training data and record a loss.
            loss = self.train_epoch(loaders['train'])
            self.loss['train'][epoch] = loss
            #log loss to TB
            self.writer.add_scalar("Loss/Train", loss, self.epoch)
            self.writer.flush()
            # Run through the test data and record a loss.
            if (test_freq is not None) and (epoch % test_freq == 0):
                loss = self.test_epoch(loaders['test'])
                self.loss['test'][epoch] = loss
            # Save the model.
            if (save_freq is not None) and (epoch % save_freq == 0) and (epoch > 0):
                filename = "checkpoint_"+str(epoch).zfill(3)+'.tar'
                self.save_state(filename)

    def save_state(self, filename):
        """Save all the model parameters to the given file."""
        layers = self._get_layers()
        state = {}
        for layer_name in layers:
            state[layer_name] = layers[layer_name].state_dict()
        state['optimizer_state'] = self.optimizer.state_dict()
        state['loss'] = self.loss
        state['z_dim'] = self.z_dim
        state['epoch'] = self.epoch
        state['lr'] = self.lr
        state['save_dir'] = self.save_dir
        filename = os.path.join(self.save_dir, filename)
        torch.save(state, filename)

    def load_state(self, filename):
        checkpoint = torch.load(filename, map_location=self.device)
        assert checkpoint['z_dim'] == self.z_dim
        layers = self._get_layers()
        for layer_name in layers:
            layer = layers[layer_name]
            layer.load_state_dict(checkpoint[layer_name])
        self.optimizer.load_state_dict(checkpoint['optimizer_state'])
        self.loss = checkpoint['loss']
        self.epoch = checkpoint['epoch']

    def get_latent_umap(self, loaders, save_dir, title=None):
        filename = str(self.epoch).zfill(3) + '_latents.pdf'
        file_path = os.path.join(save_dir, filename)
        latent = np.zeros((len(loaders['test'].dataset), self.z_dim)) #am using test loader b/c it is unshuffled.
        with torch.no_grad():
            j = 0
            for i, sample in enumerate(loaders['test']):
                x = sample['frame']
                x = x.to(self.device)
                mu, _, _ = self.encode(x)
                latent[j:j+len(mu)] = mu.detach().cpu().numpy()
                j += len(mu)
        # UMAP these
        transform = UMAP(n_components=2, n_neighbors=20, min_dist=0.1, \
        metric='euclidean', random_state=42)
        projection = transform.fit_transform(latent)
        # save these to do PCA and the rest
        latent_info = {'latents':latent, 'UMAP':projection}
        fname = os.path.join(self.save_dir, 'latent_info.tar')
        torch.save(latent_info, fname)
        #and return projections for plotting 
        return projection

    def log_reconstruction(self, frames, batch_size, log_type):
        frames = frames.reshape((batch_size, X_SHAPE[0], X_SHAPE[1], X_SHAPE[2]))
        for i in range(batch_size):
            frame = frames[i, :, :, 15]
            fig_name = 'reconstruction_{}/{}'.format(log_type, i)
            self.writer.add_image(fig_name, frame, dataformats='HW')

    def get_recons(self, dataset, vals_list):
        """
        Returns a np array with recons for all frames in train set (in order).
        Useful when wanting to build tooltip plot.
        Args:
        ---------
        Dataset: dataclass instance.
        Vals_list: (list). First element is starting frame (inclusive) and
        last is ending frame (exclusive).
        """
        start = vals_list[0]
        end = vals_list[1]
        all_recons = []
        for i in range(start, end):
            frame = dataset.__getitem__(i)
            frame = frame['frame'].unsqueeze(0).to(self.device)
            _, _, recon = self.forward(frame, return_latent_rec=True)
            all_recons.append(np.squeeze(recon))
        return np.array(all_recons)

if __name__ == '__main__':
    pass