<a href="https://colab.research.google.com/github/TomFrederik/learning_training_dynamics/blob/main/colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#!pip install torchdiffeq # only run on first execution
#!git clone https://github.com/TomFrederik/learning_training_dynamics.git
%cd learning_training_dynamics
#!git pull

/content/learning_training_dynamics


In [2]:
import torchdiffeq as teq
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import os

import datasets

In [3]:
class MLP(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()

        self.net = nn.Sequential(nn.Linear(input_dim, 200), nn.Tanh(), nn.Linear(200, 200), nn.Tanh(), nn.Linear(200, 200), nn.Tanh(), nn.Linear(200, output_dim))
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    
    def forward(self, t, x):
        x = torch.cat([torch.tensor([t], device=self.device), x], dim=-1)
        return self.net(x)

In [4]:
def rms_norm(tensor):
    print(tensor)
    return tensor.pow(2).mean().sqrt()

def make_norm(state):
    state_size = state.numel()
    def norm(aug_state):
        print(f'aug state = {aug_state}')
        y = aug_state[1:1 + state_size]
        adj_y = aug_state[1 + state_size:1 + 2 * state_size] 
        print(f'y = {y}')
        print('rms_norm(y) = ')
        print(rms_norm(y))
        print('rms_norm(adj_y) = ')
        print(rms_norm(adj_y))
        return max(rms_norm(y), rms_norm(adj_y))
    return norm


In [39]:
# some hyperparams
hidden_dim = 100
train_steps = 1000
lr = 1e-3

# set device
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# set up data
data_dir = './data/mini_mnist'
train_dataset = datasets.MiniMNISTParams(data_dir)
#test_dataset = datasets.MiniMNISTParams(data_dir, train=False)

train_loader = DataLoader(train_dataset, batch_size=1)
#test_loader = DataLoader(test_dataset, batch_size=1)

# 
input_dim = train_dataset[0][1].shape[1]
time_stamps = torch.arange(0, train_dataset[0][1].shape[0], 1, dtype=float, device=device)


# set up model
model_kwargs = {'input_dim': input_dim+1, # +1 for time  
                'hidden_dim': hidden_dim, # is cur 100, should change to [20,20]
                'output_dim':input_dim
                }
model = MLP(**model_kwargs).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fct = nn.MSELoss()

In [40]:
# training
step = 0
epoch = 0
while step < train_steps:

    epoch += 1
    print(f'\nStarting epoch {epoch}')
    model.train()
    for y_0, y in iter(train_loader):
        step += 1
        
        y_0 = y_0.squeeze().to(device)
        y = y.squeeze().to(device)
        
        # train step
        optimizer.zero_grad()
        pred = teq.odeint_adjoint(model, y_0, time_stamps, adjoint_options=dict(norm='seminorm'))
        loss = loss_fct(pred, y.squeeze())
        print(f'Loss at step {step} is {loss.item():1.3f}')
        loss.backward()
        optimizer.step()

    # eval step
    #model.eval()


Starting epoch 1
Loss at step 1 is 0.883
Loss at step 2 is 4.949
Loss at step 3 is 3.521
Loss at step 4 is 2.485
Loss at step 5 is 1.925
Loss at step 6 is 1.539
Loss at step 7 is 1.142
Loss at step 8 is 0.869
Loss at step 9 is 0.876
Loss at step 10 is 0.808
Loss at step 11 is 0.755
Loss at step 12 is 0.679
Loss at step 13 is 0.648
Loss at step 14 is 0.640
Loss at step 15 is 0.607

Starting epoch 2
Loss at step 16 is 0.558
Loss at step 17 is 0.590
Loss at step 18 is 0.536
Loss at step 19 is 0.492
Loss at step 20 is 0.466
Loss at step 21 is 0.427
Loss at step 22 is 0.410
Loss at step 23 is 0.399
Loss at step 24 is 0.379
Loss at step 25 is 0.347
Loss at step 26 is 0.331
Loss at step 27 is 0.325
Loss at step 28 is 0.314
Loss at step 29 is 0.302
Loss at step 30 is 0.287

Starting epoch 3
Loss at step 31 is 0.273
Loss at step 32 is 0.271
Loss at step 33 is 0.263
Loss at step 34 is 0.271
Loss at step 35 is 0.260
Loss at step 36 is 0.247
Loss at step 37 is 0.234
Loss at step 38 is 0.224
Loss 

KeyboardInterrupt: ignored

In [41]:
# save model
torch.save(model.state_dict(), './trained_model.pt')

In [48]:
# base model class
class BaseMLP(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()

        layers = [nn.Linear(input_dim, hidden_dim[0]), nn.ReLU()]
        for i in range(1,len(hidden_dim)):
            layers.append(nn.Linear(hidden_dim[i-1], hidden_dim[i]))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_dim[-1], output_dim))
        
        self.net = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.net(x)

In [49]:
def get_params(model):
    '''
    Returns the parameters of a model as a single vector
    Args:
        model - instance of torch.nn.Module
    Returns:
        params - list of shape [n_params]
    '''
    params = []
    for param in model.parameters():
        params.extend(param.flatten().detach().tolist())
    return params

In [50]:
def params2weights(param_vector, param_ranges):
    '''
    Maps a vector of parameters to a list of tensors capturing the model weights and biases
    Args:
        param_vector - torch tensor of shape [n_params]
        param_ranges - list containing number of parameters for each weight matrix or bias vector
    Returns:
        weights - list of torch tensors containing weights and biases: [w1, b1, w2, b2, ...]
    '''
    return list(torch.split(param_vector, param_ranges))


In [66]:
# initialize random model 
random_model = BaseMLP(input_dim=28*28, hidden_dim=[20,20], output_dim=1).to(device)

# load data
data_dir = './data/mini_mnist'

train_dataset = datasets.MiniMNIST(data_dir, flatten=True)
train_loader = DataLoader(train_dataset, shuffle=False, batch_size=len(train_dataset))

test_dataset = datasets.MiniMNIST(data_dir, flatten=True, train=False)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=len(test_dataset))

# eval func
loss_fct = nn.BCEWithLogitsLoss()

# evluate model on train data
random_model.eval()
train_data, train_labels = next(iter(train_loader))
train_data = train_data.to(device)
train_labels = train_labels.to(device)
train_logits = random_model(train_data).squeeze()
train_loss = loss_fct(train_logits, train_labels).item()
train_acc = (torch.round(torch.sigmoid(train_logits)) == train_labels).sum() / len(train_labels)
print(f'Train loss at init is {train_loss:1.5f}')
print(f'Train accuracy at init is {train_acc:1.5f}')


# evluate model on test data
random_model.eval()
test_data, test_labels = next(iter(test_loader))
test_data = test_data.to(device)
test_labels = test_labels.to(device)
test_logits = random_model(test_data).squeeze()
test_loss = loss_fct(test_logits, test_labels).item()
test_acc = (torch.round(torch.sigmoid(test_logits)) == test_labels).sum() / len(test_labels)
print(f'\nTest loss at init is {test_loss:1.5f}')
print(f'Test accuracy at init is {test_acc:1.5f}')

Train loss at init is 2.30721
Train accuracy at init is 0.50000

Test loss at init is 2.04358
Test accuracy at init is 0.52500


In [67]:


####
# Evolve using learned dynamics
####


# get parameters of random model
params = torch.tensor(get_params(random_model), device=device)

# use learned dynamics model to evolve the parameters
evolved_params = teq.odeint_adjoint(model, params, time_stamps, adjoint_options={'norm': 'seminorm'})
params_at_25 = evolved_params[-1]
params_at_8 = evolved_params[25]

# convert param vector to fit weight dimensions of base model
param_ranges = []
param_shapes = []
for param in random_model.parameters():
    param_ranges.append(torch.prod(torch.tensor(param.shape)).item())
    param_shapes.append(param.shape)

sliced_params = params2weights(params_at_8, param_ranges)
# reshape them into correct shape
for i, shape in enumerate(param_shapes):
    sliced_params[i] = sliced_params[i].reshape(shape)

# init a new model
evolved_model = BaseMLP(input_dim=28*28, hidden_dim=[20,20], output_dim=1).to(device)

# set base model params to evolved params
for old_param, new_param in zip(evolved_model.parameters(), sliced_params):
   old_param.data = new_param

# eval evolved model
evolved_model.eval()
test_data, test_labels = next(iter(test_loader))
test_data = test_data.to(device)
test_labels = test_labels.to(device)
test_logits = evolved_model(test_data).squeeze()
test_loss = loss_fct(test_logits, test_labels).item()
test_acc = (torch.round(torch.sigmoid(test_logits)) == test_labels).sum() / len(test_labels)
print(f'\nTest loss after evolving is {test_loss:1.5f}')
print(f'Test accuracy after evolving is {test_acc:1.5f}')
print(f'Test logits after evolving {test_logits}')
print(f'Test probs after evolving {torch.sigmoid(test_logits)}')

####
# train in normal way
####

def train_model(model, opt_class, opt_kwargs, loss_fct, train_loader, test_loader, train_steps):

    # init param trajectory
    params = []
    params.append(get_params(model))
    
    # set up optimizer
    optimizer = opt_class(model.parameters(), **opt_kwargs)

    # initial random logits
    #train_logits = []
    #data, _ = next(iter(train_loader))
    #train_logits.append(model(data).squeeze().tolist())

    for step in range(train_steps):
        
        # one training step
        model.train()
        data, labels = next(iter(train_loader))
        data = data.to(device)
        labels = labels.to(device)
        logits = model(data).squeeze()
        loss = loss_fct(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # store new params
        params.append(get_params(model))

        # store output of network at this step
        #train_logits.append(logits.tolist())

        # eval on test set
        model.eval()
        test_data, test_labels = next(iter(test_loader))
        test_data = test_data.to(device)
        test_labels = test_labels.to(device)
        test_logits = model(test_data).squeeze()
        test_loss = loss_fct(test_logits, test_labels)
        test_acc = torch.sum(torch.round(torch.sigmoid(test_logits)) == test_labels) / len(test_labels)

        print(f'Step {step+1}: train loss = {loss.item():1.5f}, test loss = {test_loss.item():1.5f}, test acc = {test_acc.item():1.4f}')
                

    return params

# optimizer
opt_class = torch.optim.Adam
opt_kwargs = {'lr':lr}

# loss function
loss_fct = torch.nn.BCEWithLogitsLoss()

train_steps = 25
lr = 1e-4

train_kwargs = {'opt_class':opt_class,
                'opt_kwargs':opt_kwargs, 
                'loss_fct':loss_fct, 
                'train_loader':train_loader, 
                'test_loader':test_loader, 
                'train_steps':train_steps
                }

# train the model
trained_params = torch.tensor(train_model(random_model, **train_kwargs), device=device).view(evolved_params.shape)

# compare evolved with trained
mse_loss = nn.MSELoss(reduction='none')
evol_loss = mse_loss(trained_params, evolved_params).mean(dim=1)
print(evol_loss)


Test loss after evolving is 23.34353
Test accuracy after evolving is 0.50000
Test logits after evolving tensor([ -91.3628,  -63.6737,  -84.7311,  -81.5237,  -55.1757,  -59.5826,
         -59.3035, -101.9971,  -51.6361, -108.9850,  -18.5841, -115.8868,
         -55.6021,  -53.9199,  -80.9800,   -0.4303,  -14.1597,  -15.8407,
         -88.3558,  -38.3715,  -42.9586,  -45.3851,  -67.4790,   -7.7007,
         -36.7669,  -39.1791,  -19.7898,  -28.4659,  -18.9605,  -58.8635,
         -47.1996,  -38.9131,  -62.8426,  -98.1344,  -48.8066,  -52.9295,
         -43.5422,  -39.6262,  -56.8791,  -78.8172], device='cuda:0',
       grad_fn=<SqueezeBackward0>)
Test probs after evolving tensor([0.0000e+00, 2.2225e-28, 1.5913e-37, 3.9328e-36, 1.0901e-24, 1.3292e-26,
        1.7572e-26, 0.0000e+00, 3.7561e-23, 0.0000e+00, 8.4922e-09, 0.0000e+00,
        7.1176e-25, 3.8270e-24, 6.7737e-36, 3.9406e-01, 7.0879e-07, 1.3197e-07,
        4.2418e-39, 2.1650e-17, 2.2046e-19, 1.9475e-20, 4.9458e-30, 4.5230e-04,
