In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch

In [2]:
plt.rcParams['figure.dpi'] = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
def try_gpu():
    """
    If GPU is available, return torch.device as cuda:0; else return torch.device
    as cpu.
    """
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    return device

plot latent space

In [1]:
def plot_latent(model, data_loader, num_batches=100, device=device):
    '''
    Plots position of data in latent space (which is either 2D or 3D)
    Args:
      autoencoder: pytorch network that contains an encoder subnetwork
      data_loader: the data we want to plot in latent space
      num_batches: number of batches to use in for the plot
    '''
    # Iterate over all data
    plt.rcParams['figure.figsize'] = (5, 3)
    plt.rcParams['figure.dpi'] = 144
    for idx, data in enumerate(data_loader):
        x, y = data
        z = model.encoder(x.to(device))
        # Encode image data
        z = z.to('cpu').detach().numpy()  # Get numpy version of data in latent space

        # 2D latent space (single image)
        if np.size(z, axis=1) == 2:
            plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')  # Add data to plot

        # ------------------------------------------------------------------------
        # 3D latent space (4 images: 3D image, and 3 x 2D projections onto xy, xz
        # and yz)
        if np.size(z, axis=1) == 3:
            if idx == 0:  # initialize at first iteration
                plt.rcParams['figure.figsize'] = (5, 5)
                fig1 = plt.figure()
                plt.rcParams['figure.figsize'] = (15, 5)
                fig2 = plt.figure()
                ax1 = fig1.add_subplot(1, 1, 1, projection='3d')
                ax2 = fig2.add_subplot(1, 3, 1)
                ax3 = fig2.add_subplot(1, 3, 2)
                ax4 = fig2.add_subplot(1, 3, 3)
                ax1.grid(False)
                # Hide axes ticks
                ax1.set_xticks([])
                ax1.set_yticks([])
                ax1.set_zticks([])
                # set labels
                ax1.set_xlabel('dimension 1')
                ax1.set_ylabel('dimension 2')
                ax1.set_zlabel('dimension 3')
                ax2.set_xlabel('dimension 1')
                ax2.set_ylabel('dimension 2')
                ax3.set_xlabel('dimension 1')
                ax3.set_ylabel('dimension 3')
                ax4.set_xlabel('dimension 2')
                ax4.set_ylabel('dimension 3')
            ax1.scatter3D(z[:, 0], z[:, 1], z[:, 2], c=y, cmap='tab10');
            ax2.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10');
            ax3.scatter(z[:, 0], z[:, 2], c=y, cmap='tab10');
            ax4.scatter(z[:, 1], z[:, 2], c=y, cmap='tab10');
            if idx > num_batches:
                fig1.tight_layout()
                fig2.tight_layout()

        # Stop if we've reach the maximum number of batches
        if idx > num_batches:
            if np.size(z, axis=1) == 2:
                plt.colorbar()
            break

NameError: name 'device' is not defined

In [None]:
class EarlyStop:
    """Used to early stop the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=20, verbose=False, delta=0,
                 save_name="checkpoint.pt"):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 20
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            save_name (string): The filename with which the model and the optimizer is saved when improved.
                            Default: "checkpoint.pt"
        """
        self.patience = patience
        self.verbose = verbose
        self.save_name = save_name
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model, optimizer):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, optimizer)
        elif score < self.best_score - self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, optimizer)
            self.counter = 0

        return self.early_stop

    def save_checkpoint(self, val_loss, model, optimizer):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        state = {"net":model.state_dict(), "optimizer":optimizer.state_dict()}
        torch.save(state, self.save_name)
        self.val_loss_min = val_loss
        print("model saved")