In [None]:
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse, Circle
import numpy as np
import torch
from model import BioMLP
import torch.nn as nn
import os

seed = 1
np.random.seed(seed)
torch.manual_seed(seed)

n_sample = 500
X, y = make_moons(n_samples=n_sample, noise=0.1)
color = ['red', 'blue']

for i in range(n_sample):
    plt.scatter(X[i,0],X[i,1],color=color[y[i]])

In [7]:
phase = "toy_swap"

plot_log = 50#50
save_log = 100

steps = 10001
# create dataset

d_in = 2
d_out = 2

n_sample = 200
X, y = make_moons(n_samples=n_sample, noise=0.1)
X = torch.tensor(X, dtype=torch.float, requires_grad=True)
y = torch.tensor(y, dtype=torch.long)

n_sample = 200
X_test, y_test = make_moons(n_samples=n_sample, noise=0.1)
X_test = torch.tensor(X_test, dtype=torch.float, requires_grad=True)
y_test = torch.tensor(y_test, dtype=torch.long)


width = 20
depth = 3
shp = [d_in, 20, 20, d_out]


seed = 2
np.random.seed(seed)
torch.manual_seed(seed)

#torch.set_default_tensor_type(torch.DoubleTensor)

<torch._C.Generator at 0x7f7860a1eb30>

In [None]:
indx = len(os.listdir('./saved_models'))
os.mkdir(f"./saved_models/{indx}")

In [None]:
model = BioMLP(shp=[2,20,20,2])

optimizer = torch.optim.AdamW(model.parameters(), lr=0.002, weight_decay=0.0)
#steps = 10000
log = 200
lamb = 0.001
swap_log = 200


for step in range(steps):
    
    # small lambda first, then large lambda, then small lambda
    if step == 5000:
        lamb = 0.01
        
    if step == 15000:
        lamb = 0.001
    
    CEL = nn.CrossEntropyLoss()
    
    optimizer.zero_grad()
    pred  = model(X)
    loss = CEL(pred, y)
    #print((torch.argmax(pred, dim=1) == y).float())
    acc = torch.mean((torch.argmax(pred, dim=1) == y).float())
    
    
    pred_test  = model(X_test)
    loss_test = CEL(pred_test, y_test)
    acc_test = torch.mean((torch.argmax(pred_test, dim=1) == y_test).float())
    
    reg = model.get_cc(weight_factor=1.)
    total_loss = loss + lamb*reg
    total_loss.backward()
    optimizer.step()
    
    if step % log == 0:
        print("step = %d | total loss: %.2e | train loss: %.2e | test loss %.2e | train acc : %.2f | test acc : %.2f | reg: %.2e "%(step, total_loss.detach().numpy(), loss.detach().numpy(), loss_test.detach().numpy(), acc.detach().numpy(), acc_test.detach().numpy(), reg.detach().numpy()))
    
    if (step+1) % swap_log == 0:
        #pass
        model.relocate()

    if step % save_log == 0:
        torch.save(model.state_dict(), f"saved_models/{indx}/model_{step}.pt")


    if step % plot_log == 0:
        plt.figure(figsize=(3, 7)) 

        plt.subplot(2,1,1)

        N = 2
        s = 1/(2*max(shp))
        for j in range(len(shp)):
            N = shp[j]
            for i in range(N):
                circle = Ellipse((1/(2*N)+i/N, 0.1*j), s, s/10*((len(shp)-1)+0.4), color='black')
                plt.gca().add_patch(circle)


        plt.ylim(-0.02,0.1*(len(shp)-1)+0.02)
        plt.xlim(-0.02,1.02)

        ii = 0
        for p in model.parameters():


            if len(p.shape) == 2:
                p_shp = p.shape
                p = p/torch.abs(p).max()
                for i in range(p_shp[0]):
                    for j in range(p_shp[1]):
                        if p[i,j] > 0:
                            plt.plot([1/(2*p_shp[0])+i/p_shp[0], 1/(2*p_shp[1])+j/p_shp[1]], [0.1*(ii+1),0.1*ii], lw=1*np.abs(p[i,j].detach().numpy()), color="blue")
                        else:
                            plt.plot([1/(2*p_shp[0])+i/p_shp[0], 1/(2*p_shp[1])+j/p_shp[1]], [0.1*(ii+1),0.1*ii], lw=1*np.abs(p[i,j].detach().numpy()), color="red")

                formulas = ["Class 1", "Class 2"]
                if ii == 0:
                    for j in range(p_shp[1]):
                        plt.text(1/(2*p_shp[1])+j/p_shp[1]-0.05, 0.1*ii-0.04, "$x_{}$".format(model.in_perm[j].long()+1), fontsize=15)
                ii += 1


        for j in range(p_shp[0]):
            plt.text(1/(2*p_shp[0])+j/p_shp[0]-0.15, 0.1*ii+0.02, formulas[model.out_perm[j].long()], fontsize=15)

        plt.gca().axis('off')
        plt.title("step={}".format(step), fontsize=15, y=1.1)


        plt.subplot(2,1,2)


        start_x = X[:,0].min()-0.1
        end_x = X[:,0].max()+0.1
        start_y = X[:,1].min()-0.1
        end_y = X[:,1].max()+0.1
        n_values = 30

        x_vals = np.linspace(start_x.detach().numpy(), end_x.detach().numpy(), n_values)
        y_vals = np.linspace(start_y.detach().numpy(), end_y.detach().numpy(), n_values)
        XX, YY = np.meshgrid(x_vals, y_vals)
        pred = model(torch.tensor([XX.reshape(-1,), YY.reshape(-1,)], dtype=torch.float).permute(1,0))
        pred = pred[:,1] - pred[:,0]

        #ZZ = np.sqrt(XX**2 + YY**2)

        cp = plt.contourf(XX, YY, pred.reshape(n_values,n_values).detach().numpy(), [-100,0.,100.], colors=["green","orange"], alpha=0.2)
        #plt.colorbar(cp)
        color = ['green', 'orange']

        for i in range(n_sample):
            plt.scatter(X[i,0].detach().numpy(),X[i,1].detach().numpy(),color=color[y[i]])


        plt.xticks([])
        plt.yticks([])
        plt.xlabel(r"$x_1$", fontsize=15)
        plt.ylabel(r"$x_2$", fontsize=15)
        #plt.show()
    
        plt.savefig("figures/{0:05d}.png".format(step))
        
        plt.show()
    