In [3]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
#import emd
import os

from torch import nn
import torch
from IPython.display import clear_output

from makegif import make_gif
import copy
import time

torch.set_default_dtype(torch.float64)

In [4]:
plt.style.use('dark_background')

In [5]:
dataset = torch.tensor(np.random.randn(300,3))

### Some functions for visualization:

In [7]:
# Here are some functions for visualizing the output

def get_angle(v, w):
    '''get the angle between two vectors'''
    return v @ w / (torch.norm(v) * torch.norm(w))


def get_axis(M):
    '''Finds the eigenvector corresponding to the eigenvalue with the least 
    imaginary part, if the matrix is a rotation matrix or a generator of 
    rotations then this vector is the axis of rotation'''    

    eig_vals, eig_vecs = torch.linalg.eig(M)
    axis = eig_vecs.T[torch.argmin(torch.abs(eig_vals.imag))]
    return torch.sign(torch.sum(axis).real) * axis


def draw_vec(ax, v, color='C0', lw=4):
    '''Draw a vector to ax, this adds lines for the projection'''
    ax.plot([0,v[0]], [0,v[1]], [0,v[2]], color=color, lw=lw)

    ax.plot([v[0],v[0]], [v[1],v[1]], [0,v[2]], color='w', alpha=.25, ls='--')
    ax.plot([v[0],v[0]], [0   ,v[1]], [0,0   ], color='w', alpha=.25, ls='--')
    ax.plot([0   ,v[0]], [v[1],v[1]], [0,0   ], color='w', alpha=.25, ls='--')


def visualize_generators(params, eps=1e-3):
    '''Imshow each of the generators'''
    plt.figure(figsize=[12,4])
    
    for i, X in enumerate(params):
        plt.subplot(1,len(params),i+1)
        plt.imshow(X.detach().numpy(), cmap='RdBu')
        plt.title(f'det = {np.linalg.det(np.eye(X.shape[0]) + eps * X.detach().numpy())}')
        plt.colorbar()


def draw_4tych(params):
    
    fig, [ax1, ax2, ax3, ax4] = plt.subplots(nrows=1, ncols=4, figsize=[18,4])
    
    # draw each generator
    for i, [ax, X] in enumerate(zip([ax1, ax2, ax3], params)):
        
        ax.imshow(X.detach().numpy(), cmap='RdBu')
        ax.set_title(f'det = {np.linalg.det(np.eye(3) + X.detach().numpy())}')
        ax.axis('off')

    ax4 = fig.add_subplot(144, projection='3d')
    
    # draw coordinate axis
    ax_lim = 1
    ax4.plot([-ax_lim,ax_lim],[0,0],[0,0], color='w', alpha=.3)
    ax4.plot([0,0],[-ax_lim,ax_lim],[0,0], color='w', alpha=.3)
    ax4.plot([0,0],[0,0],[-ax_lim,ax_lim], color='w', alpha=.3)
    
    # draw each rotation axis
    for X, col in zip(params, ['C0', 'C1', 'C2']):
        draw_vec(ax4, get_axis(X.detach().numpy()+np.eye(3)), col)

    ax4.set_xlim(-ax_lim,ax_lim)
    ax4.set_ylim(-ax_lim,ax_lim)
    ax4.set_zlim(-ax_lim,ax_lim)
    ax4.grid(False)
    ax4.axis('off')
    
    ax4.set_title(f'Epoch {i}')

In [None]:
class MyModel(nn.Module):
    
    def __init__(self):
        super(MyModel, self).__init__()
        
        self.X1 = nn.Linear(3,3, bias=False)
        self.X2 = nn.Linear(3,3, bias=False)
        self.X3 = nn.Linear(3,3, bias=False)
        
    def forward(self, inputs, epsilon):
        
        out1 = inputs[0] + epsilon * self.X1(inputs[0])
        out2 = inputs[1] + epsilon * self.X2(inputs[1])
        out3 = inputs[2] + epsilon * self.X3(inputs[2])
        
        return [out1, out2, out3]
    
model = MyModel()

In [12]:
class MyModel(nn.Module):
    
    def __init__(self, n_generators=3, n_dim=3):
        super(MyModel, self).__init__()
        
        self.n_generators = n_generators
        self.X_list = []
        
        for _ in range(n_generators):
            self.X_list.append(nn.Linear(n_dim, n_dim, bias=False))
        
    def forward(self, inputs, epsilon):
        
        outputs = []
        for i in range(self.n_generators):
            outputs.append(inputs[0] + epsilon * self.X_list[i](inp))
        
        return outputs
    
model = MyModel()

In [13]:
def bracket(A, B):
    
    return A @ B - B @ A


def single_loss(inp, pred, X, w_infinitesimal=1, eps=1e-3):
    
    dl = torch.mean((torch.norm(inp, dim=1) - torch.norm(pred, dim=1))**2)
    
    return dl + w_infinitesimal * (torch.sum(X**2) - eps**2)**2


def ensemble_loss(inp, preds, params, eps=1e-3):
    
    loss = 0
    for i, X in enumerate(params):
        
        for j, Y in enumerate(params):
            loss += get_angle(get_axis(X).real, 
                              get_axis(Y).real)**2
            
        loss += single_loss(inputs[0], preds[0], X1, eps)
     
    return loss

In [14]:
model = MyModel()

In [16]:
list(model.parameters())

[]