In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as td
from typing import Optional
import matplotlib.pyplot as plt
import math
import seaborn as sns
sns.set()
%load_ext autoreload
%autoreload 2
%matplotlib notebook

# Model environment

### Questions
1. **Data normalization**: Compute across batch ? compute on huge tensor when solving least squares and make it an attribute ? **BATCH**
3. **Smoothness**: In this first implementation, the changes in X0 and X1 are really rough. Should we make them smoother ? For example, make the means follow 2D random walks and the variance vary according to a smooth function ? 
4. **Experiment interpretability**: Is there really a relation between `env.w` and `adam_w` considering that one is for least squares `w.T@x` and the other for logistic regression `sigmoid(w.T@x)`. 
    1. We probably should not use logistic regression but linear regression for comparison sake. However, this is known to be a bad model for classification ...
    2. We may want to do a third experiment in which we use a neural net in higher dimension. In here we could compare sgd loss and different flavours of adam loss. We can also analyze the evolution of the momentum and therefore analyze how adam adapts to non-stationarity.
->> probably not consider least sqaures solution and just monitor the loss of the different optimzers. Gives more freedom regarding the environement.

### Todo
1. **seed**: add a seed to control pytorch random states and be able to replicate experiments !
2. **distribution plot**: add `env.w`, `adam_model.w`, `sgd_model.w` (and then others)
3. **weight error plot**: do it

In [2]:
class Env2(object):
    """Experiment2 environment
    
    Arguments
    ---------
    d: int
        dimensionality of the environment 
        
    lambda_: float
        parameter of the lifetime exponential distribution
        
    batch_size: int
        batch size
        
    N: int
        maximum number of samples to use to solve the least squares problem
        
    p: torch.Tensor
        parameter of the bernoulli label distribution. Must be of size (1)
    
    Attributes
    ----------
    Y: torch.disibutions.distribution.Distribution
        labels distribution
        
    X0: torch.disibutions.distribution.Distribution
        negative samples distribution
        
    X1: torch.disibutions.distribution.Distribution
        positive samples distribution
        
    T: torch.disibutions.distribution.Distribution
        lifetime of X0 and X1 distribution
        
    counter: int
        counts calls to get_batch
        
    tau: int
        lifetime of X0 and X1
    
    mu0, sigma0: torch.Tensor, torch.Tensor
        parameters of Gaussian distribution X0
        
    mu1, sigma1: torch.Tensor, torch.Tensor
        parameters of Gaussian distribution X1
        
    w: torch.Tensor
        current least squares solution
    """
    def __init__(self, d:int, lambda_:float, batch_size: int, N: int, 
                 p:Optional[torch.Tensor]=torch.Tensor([0.5])):
        # Attributes
        self.d = d
        self.lambda_ = lambda_
        self.batch_size = batch_size
        self.N = N
        # label random variable
        self.Y = td.bernoulli.Bernoulli(p)
        # data random variables
        self.build_data()
        self.compute_min()
        # lifetime random variable
        self.T = td.exponential.Exponential(lambda_)
        self.reset_lifetime()
        # counter (for lifetime)
        self.counter = 0
    
    def reset(self):
        """reset environment internals"""
        self.counter = 0
        self.reset_lifetime()
        self.build_data()
        self.compute_min()
        
    def reset_lifetime(self):
        """sample lifetime"""
        self.tau = math.ceil((self.T.sample()))
        
    def build_data(self):
        """Build two normal distributions"""
        mus = torch.randn(self.d, 2)
        sigmas = torch.randn(self.d, 2).pow_(2)
        # magic number
        sigmas[sigmas<1]=1
        self.mu0, self.mu1 = mus[:,0], mus[:,1]
        self.sigma0, self.sigma1 = sigmas[:,0], sigmas[:,1]
        self.X0 = td.normal.Normal(loc=self.mu0, scale=self.sigma0)
        self.X1 = td.normal.Normal(loc=self.mu1, scale=self.sigma1)
        
    def compute_min(self):
        """solve least squares"""
        # Sample data
        x0 = self.X0.sample((self.N//2,))
        x1 = self.X1.sample((self.N//2,))
        X = torch.cat([x0, x1], dim=0)
        Y = torch.cat([
            torch.zeros((self.N//2,)), torch.ones((self.N//2,))
        ], dim=0)
        # Add bias dimension
        X = torch.cat([
            torch.ones((X.shape[0],1)), X
        ], dim=1)
        # Shuffle
        idx = torch.randperm(self.N)
        X = X[idx]
        Y = Y[idx]
        # solve least squares
        w, _= torch.lstsq(Y, X)
        self.w = w[:self.d+1]
    
    def get_batch(self, bias:bool=False):
        """produce a batch of data and labels"""
        # Sample from Y, get the number of 0s and 1s 
        n0 = (self.Y.sample((self.batch_size,))==0).size(0)
        n1 = self.batch_size-n0
        # Sample data points
        x0 = self.X0.sample((n0,))
        x1 = self.X1.sample((n1,))
        X = torch.cat([x0, x1], dim=0)
        # Add bias dimension
        if bias:
            X = torch.cat([
                torch.ones((X.shape[0],1)), X
            ], dim=1)
        # Build labels
        Y = torch.cat([
            torch.zeros((n0,1)), torch.ones((n1,1))
        ], dim=0)
        # Shuffle
        idx = torch.randperm(self.batch_size)
        X = X[idx]
        Y = Y[idx]
        # increment lifetime counter
        self.counter += 1
        if self.counter==self.tau:
            self.reset()
            
        return X, Y

# Instantiate environement

In [3]:
class Options:
    """Class containing all parameters"""
    def __init__(self):
        # Env params
        self.d = 100
        self.lambda_ = 0.01
        self.batch_size= 128
        self.N = int(1e4)
        self.p = torch.Tensor([0.5])
        self.bias_in_batch = False
        # Other params
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.lr = 1e-3
opt = Options()

In [4]:
env = Env2(opt.d, opt.lambda_, opt.batch_size, opt.N, opt.p)

# Model: logistic regression

In [5]:
class LinearModel(nn.Module):
    def __init__(self, d, bias):
        super().__init__()
        self.f = nn.Linear(in_features=d, out_features=1, bias=bias)
    def forward(self, x):
        return self.f(x)

In [6]:
d_ = opt.d+1 if opt.bias_in_batch else opt.d
bias_ = not opt.bias_in_batch
adam_model = LinearModel(d_, bias_).to(opt.device)
sgd_model = LinearModel(d_, bias_).to(opt.device)

In [7]:
# Binary crossentropy on logits
criterion = nn.BCEWithLogitsLoss().to(opt.device)
# criterion = nn.MSELoss().to(opt.device)

# Optimizer: Adam

In [8]:
adam = optim.Adam(adam_model.parameters(), lr=opt.lr)
sgd = optim.SGD(sgd_model.parameters(), lr=opt.lr)

# Run experiment

In [17]:
def experiment_dim(d):
#for d in range(2,11):
    # Create models
    d_ = d
    bias_ = not opt.bias_in_batch
    adam_model = LinearModel(d_, bias_).to(opt.device)
    sgd_model = LinearModel(d_, bias_).to(opt.device)
    # Create loss and optimizers
    criterion = nn.BCEWithLogitsLoss().to(opt.device)
    adam = optim.Adam(adam_model.parameters(), lr=opt.lr)
    sgd = optim.SGD(sgd_model.parameters(), lr=opt.lr)
    gen = 0
    # Reset env
    env = Env2(d, opt.lambda_, opt.batch_size, opt.N, opt.p)
    # Init plots
    plt.ion()
    plt.rcParams['figure.figsize'] = (10,15)
    fig = plt.figure()
    ax_l = fig.add_subplot(311)
    ax_avg = fig.add_subplot(312)
    ax_avg_sq = fig.add_subplot(313)
    fig.canvas.draw()
    adam_losses = []
    sgd_losses = []
    # Init params
    avg = None
    avg_sq = None
    n = 0
    while gen < 4000:
        gen += 1
        # Get batch
        X,Y = env.get_batch(opt.bias_in_batch)
        X = X.to(opt.device)
        Y = Y.to(opt.device)
        # Adam model
        Yp = adam_model(X)
        loss = criterion(Yp, Y)
        adam.zero_grad()
        loss.backward()
        adam.step()
        adam_losses.append(loss.cpu().detach().item())
        # Sgd model
        Yp = sgd_model(X)
        loss = criterion(Yp, Y)
        sgd.zero_grad()
        loss.backward()
        sgd.step()
        sgd_losses.append(loss.cpu().detach().item())
        # Get states
        ks = []
        vs = []
        for i, (k,v) in enumerate(adam.state.items()):
            ks.append(k)
            vs.append(v)
        avg_ = vs[0]['exp_avg'].cpu().numpy().reshape(-1,1)
        avg_sq_ = vs[0]['exp_avg_sq'].cpu().numpy().reshape(-1,1)
        if n==0:
            avg = avg_
            avg_sq = avg_sq_
            n += 1
        else:
            avg = np.concatenate([avg, avg_], axis=1)
            avg_sq = np.concatenate([avg_sq, avg_sq_], axis=1)
        # Plots when there are changes
        if env.counter==0:
            
            # plot losses
            ax_l.clear()
            ax_l.plot(np.arange(gen), adam_losses, label='adam')
            ax_l.plot(np.arange(gen), sgd_losses, label='sgd')
            ax_l.set_title(f'loss in dimension {d}')
            ax_l.set_xlabel('Number of batches')
            ax_l.set_ylabel('BCE')
            ax_l.legend()
            # Plot Adama avg_exp
            for k in range(d):
                ax_avg.plot(np.arange(gen), avg[k], label=f'dim({k})')
            ax_avg.set_title(f'Evolution of avg in dimension {d}')
            ax_avg.set_xlabel(f'Number of batches')
            ax_avg.set_ylabel('exponential average')
            # ax_avg.legend()
            # Plot Adam avg_exp_sq
            for k in range(d):
                ax_avg_sq.plot(np.arange(gen), avg_sq[k], label=f'dim({k})')
            ax_avg_sq.set_title(f'Evolution of avg_sq in dimension {d}')
            ax_avg_sq.set_xlabel(f'Number of batches')
            ax_avg_sq.set_ylabel('exponential average square')
            # ax_avg_sq.legend()
            # Draw
            fig.canvas.draw()    

In [18]:
experiment_dim(2)

<IPython.core.display.Javascript object>

In [19]:
experiment_dim(4)

<IPython.core.display.Javascript object>

In [20]:
experiment_dim(6)

<IPython.core.display.Javascript object>

In [21]:
experiment_dim(10)

<IPython.core.display.Javascript object>

In [22]:
experiment_dim(20)

<IPython.core.display.Javascript object>