In [1]:
"""
Vanilla VAE implementation for Wfield Data.
This uses a 3D encoder/decoder structure.
"""

import numpy as np
import random
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import os
import datetime
import itertools
from umap import UMAP
import torch
from torch.distributions import LowRankMultivariateNormal
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
SUMMARY_WRITER_PATH = "/home/achint/Practice_code/logs"

In [2]:


X_SHAPE = (7, 2, 540, 640)
X_DIM = np.prod(X_SHAPE)

class VAE(nn.Module):
    def __init__(self, save_dir='', lr=1e-3, z_dim=32, model_precision=10, device_name="auto"):
        super(VAE, self).__init__()
        print('vae_here')
        self.save_dir = save_dir
        self.lr = lr
        self.z_dim = z_dim
        self.model_precision = model_precision
        assert device_name != "cuda" or torch.cuda.is_available()
        if device_name == "auto":
            device_name = "cuda" if torch.cuda.is_available() else "cpu"
            self.device = torch.device(device_name)
        if self.save_dir != '' and not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        self._build_network()
        self.optimizer = Adam(self.parameters(), lr=self.lr)
        self.epoch = 0
        self.loss = {'train':{}, 'test':{}}
        #init summary writter instance for TB logging
        #ts = datetime.datetime.now().date()
        #self.writer = SummaryWriter(log_dir = os.path.join(self.save_dir, 'run', ts.strftime('%m_%d_%Y')))
        self.writer = SummaryWriter(log_dir = SUMMARY_WRITER_PATH)
        print('vae_here1')
        self.to(self.device)

    def _build_network(self):
        # Encoder
        self.conv1 = nn.Conv3d(7,16,3,5,padding=1)
        self.conv2 = nn.Conv3d(16,16,3,3,padding=1)
        self.conv3 = nn.Conv3d(16,24,3,1,padding=1)
        self.conv4 = nn.Conv3d(24,24,3,3,padding=1)
        self.conv5 = nn.Conv3d(24,32,3,3,padding=1)
        self.bn1 = nn.BatchNorm3d(7)
        self.bn2 = nn.BatchNorm3d(16)
        self.bn3 = nn.BatchNorm3d(16)
        self.bn4 = nn.BatchNorm3d(24)
        self.bn5 = nn.BatchNorm3d(24)
        self.fc1 = nn.Linear(640, 256)
        self.fc21 = nn.Linear(256, 64)
        self.fc22 = nn.Linear(256, 64)
        self.fc23 = nn.Linear(256, 64)
        self.fc31 = nn.Linear(64, self.z_dim)
        self.fc32 = nn.Linear(64, self.z_dim)
        self.fc33 = nn.Linear(64, self.z_dim)
        #Decoder
        self.fc4 = nn.Linear(self.z_dim,64)
        self.fc5 = nn.Linear(64,256)
        self.fc6 = nn.Linear(256,640)
        self.convt1 = nn.ConvTranspose3d(32,24,3,3,padding=1, output_padding=(0, 2, 2))
        self.convt2 = nn.ConvTranspose3d(24,24,3,3,padding=1, output_padding=(0, 2, 0))
        self.convt3 = nn.ConvTranspose3d(24,16,3,1,padding=1)
        self.convt4 = nn.ConvTranspose3d(16,16,3,3,padding=1, output_padding=(0, 2, 1))
        self.convt5 = nn.ConvTranspose3d(16,7,3,5,padding=1, output_padding=(1, 4, 4))
        self.bn6 = nn.BatchNorm3d(32)
        self.bn7 = nn.BatchNorm3d(24)
        self.bn8 = nn.BatchNorm3d(24)
        self.bn9 = nn.BatchNorm3d(16)
        self.bn10 = nn.BatchNorm3d(16)

    def _get_layers(self):
        """Return a dictionary mapping names to network layers."""
        return {'fc1':self.fc1, 'fc21':self.fc21, 'fc22':self.fc22,
                'fc23':self.fc23,'fc31':self.fc31,'fc32':self.fc32,
                'fc33':self.fc33, 'fc4':self.fc4, 'fc5':self.fc5,
                'fc6':self.fc6,'bn1':self.bn1, 'bn2':self.bn2,
                'bn3':self.bn3, 'bn4':self.bn4, 'bn5':self.bn5,
                'bn6':self.bn6, 'bn7':self.bn7, 'bn8':self.bn8,
                'bn9':self.bn9, 'bn10':self.bn10,'conv1':self.conv1,
                'conv2':self.conv2, 'conv3':self.conv3,'conv4':self.conv4,
                'conv5':self.conv5, 'convt1':self.convt1,'convt2':self.convt2,
                'convt3':self.convt3, 'convt4':self.convt4,'convt5':self.convt5}

    def encode(self, x):
        x = F.relu(self.conv1(self.bn1(x)))
        x = F.relu(self.conv2(self.bn2(x)))
        x = F.relu(self.conv3(self.bn3(x)))
        x = F.relu(self.conv4(self.bn4(x)))
        x = F.relu(self.conv5(self.bn5(x)))
        x = x.view(-1, 640)
        x = F.relu(self.fc1(x))
        mu = F.relu(self.fc21(x))
        mu = self.fc31(mu)
        u = F.relu(self.fc22(x))
        u = self.fc32(u).unsqueeze(-1) # Last dimension is rank \Sigma = 1.
        d = F.relu(self.fc23(x))
        d = torch.exp(self.fc33(d)) # d must be positive.
        return mu, u, d

    def decode(self, z):
        z = F.relu(self.fc4(z))
        z = F.relu(self.fc5(z))
        z = F.relu(self.fc6(z))
        z = z.view(-1,32,1,4,5)
        z = F.relu(self.convt1(self.bn6(z)))
        z = F.relu(self.convt2(self.bn7(z)))
        z = F.relu(self.convt3(self.bn8(z)))
        z = F.relu(self.convt4(self.bn9(z)))
        z = F.relu(self.convt5(self.bn10(z)))
        z = z.view(-1, X_DIM)
        return z

    def forward(self, x, return_latent_rec=False, train_mode=False):
        mu, u, d = self.encode(x)
        latent_dist = LowRankMultivariateNormal(mu, u, d)
        z = latent_dist.rsample()
        x_rec = self.decode(z)
        if train_mode:
            batch_size = x_rec.detach().cpu().numpy().shape[0]
            self.log_reconstruction(x_rec.detach().cpu().numpy(), batch_size, log_type='train')
        # E_{q(z|x)} p(z)
        elbo = -0.5 * (torch.sum(torch.pow(z,2)) + self.z_dim * np.log(2*np.pi))
        # E_{q(z|x)} p(x|z)
        pxz_term = -0.5 * X_DIM * (np.log(2*np.pi/self.model_precision))
        l2s = torch.sum(torch.pow(x.view(x.shape[0],-1) - x_rec, 2), dim=1)
        pxz_term = pxz_term - 0.5 * self.model_precision * torch.sum(l2s)
        elbo = elbo + pxz_term
        # H[q(z|x)]
        elbo = elbo + torch.sum(latent_dist.entropy())
        if return_latent_rec:
            return -elbo, z.detach().cpu().numpy(), \
            x_rec.view(-1, X_SHAPE[0], X_SHAPE[1], X_SHAPE[2], X_SHAPE[3]).detach().cpu().numpy()
        return -elbo

    def train_epoch(self, train_loader):
        self.train()
        train_loss = 0.0
        print('reaching here')
        for batch_idx, data in enumerate(train_loader):
            print('batch_idx',batch_idx)
            self.optimizer.zero_grad()
            frame = data['frame']
            frame = frame.to(self.device)
            loss = self.forward(frame, train_mode=True)
            train_loss += loss.item()
            loss.backward()
            self.optimizer.step()
        train_loss /= len(train_loader.dataset)
        print('Epoch: {} Average loss: {:.4f}'.format(self.epoch, \
        train_loss))
        self.epoch += 1
        return train_loss

    def test_epoch(self, test_loader):
        self.eval()
        test_loss = 0.0
        with torch.no_grad():
            for i, data in enumerate(test_loader):
                frame = data['frame']
                frame = frame.to(self.device)
                loss = self.forward(frame)
                test_loss += loss.item()
        test_loss /= len(test_loader.dataset)
        print('Test loss: {:.4f}'.format(test_loss))
        return test_loss

    def train_loop(self, loaders, epochs=100, test_freq=2, save_freq=10):
        print("="*40)
        print("Training: epochs", self.epoch, "to", self.epoch+epochs-1)
        print("Training set:", len(loaders['train'].dataset))
        print("Test set:", len(loaders['test'].dataset))
        print("="*40)
        # For some number of epochs...
        for epoch in range(self.epoch, self.epoch+epochs):
            # Run through the training data and record a loss.
            loss = self.train_epoch(loaders['train'])
            self.loss['train'][epoch] = loss
            #log loss to TB
            self.writer.add_scalar("Loss/Train", loss, self.epoch)
            self.writer.flush()
            # Run through the test data and record a loss.
            if (test_freq is not None) and (epoch % test_freq == 0):
                loss = self.test_epoch(loaders['test'])
                self.loss['test'][epoch] = loss
            # Save the model.
            if (save_freq is not None) and (epoch % save_freq == 0) and (epoch > 0):
                filename = "checkpoint_"+str(epoch).zfill(3)+'.tar'
                self.save_state(filename)

    def save_state(self, filename):
        """Save all the model parameters to the given file."""
        layers = self._get_layers()
        state = {}
        for layer_name in layers:
            state[layer_name] = layers[layer_name].state_dict()
        state['optimizer_state'] = self.optimizer.state_dict()
        state['loss'] = self.loss
        state['z_dim'] = self.z_dim
        state['epoch'] = self.epoch
        state['lr'] = self.lr
        state['save_dir'] = self.save_dir
        filename = os.path.join(self.save_dir, filename)
        torch.save(state, filename)

    def load_state(self, filename):
        checkpoint = torch.load(filename, map_location=self.device)
        assert checkpoint['z_dim'] == self.z_dim
        layers = self._get_layers()
        for layer_name in layers:
            layer = layers[layer_name]
            layer.load_state_dict(checkpoint[layer_name])
        self.optimizer.load_state_dict(checkpoint['optimizer_state'])
        self.loss = checkpoint['loss']
        self.epoch = checkpoint['epoch']

    def get_latent_umap(self, loaders, save_dir, title=None):
        filename = str(self.epoch).zfill(3) + '_latents.pdf'
        file_path = os.path.join(save_dir, filename)
        latent = np.zeros((len(loaders['test'].dataset), self.z_dim)) #take data unshuffled
        with torch.no_grad():
            j = 0
            for i, sample in enumerate(loaders['test']):
                x = sample['frame']
                x = x.to(self.device)
                mu, _, _ = self.encode(x)
                latent[j:j+len(mu)] = mu.detach().cpu().numpy()
                j += len(mu)
        # UMAP these
        transform = UMAP(n_components=2, n_neighbors=20, min_dist=0.1, \
        metric='euclidean', random_state=42)
        projection = transform.fit_transform(latent)
        # save these to do PCA and the rest
        latent_info = {'latents':latent, 'UMAP':projection}
        fname = os.path.join(self.save_dir, 'latent_info.tar')
        torch.save(latent_info, fname)
        #return projection

    def log_reconstruction(self, frames, batch_size, log_type):
        frames = frames.reshape((batch_size, X_SHAPE[0], X_SHAPE[1], X_SHAPE[2], X_SHAPE[3]))
        #log only some elements in batch...
        all_batch_items = list(range(batch_size))
        random_subset = random.sample(all_batch_items, 5)
        for i in random_subset:
            ch1_frame = frames[i, 4, 0, :, :]
            ch2_frame = frames[i, 4, 1, :, :]
            ch1_fig_name = 'reconstruction_{}/{}_ch1'.format(log_type, i)
            ch2_fig_name = 'reconstruction_{}/{}_ch2'.format(log_type, i)
            self.writer.add_image(ch1_fig_name, ch1_frame, dataformats='HW')
            self.writer.add_image(ch2_fig_name, ch2_frame, dataformats='HW')

    def get_recons(self, dataset, vals_list):
        """
        Returns a np array with recons for all frames in train set (in order).
        Useful when wanting to build tooltip plot.
        Args:
        ---------
        Dataset: dataclass instance.
        Vals_list: (list). First element is starting frame (inclusive) and
        last is ending frame (exclusive).
        """
        start = vals_list[0]
        end = vals_list[1]
        all_recons = []
        for i in range(start, end):
            frame = dataset.__getitem__(i)
            frame = frame['frame'].unsqueeze(0).to(self.device)
            _, _, recon = self.forward(frame, return_latent_rec=True)
            all_recons.append(np.squeeze(recon))
        return np.array(all_recons)

if __name__ == '__main__':
    pass