<a href="https://colab.research.google.com/github/TomFrederik/learning_training_dynamics/blob/main/colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install torchdiffeq # only run on first execution
!git clone https://github.com/TomFrederik/learning_training_dynamics.git
%cd learning_training_dynamics
#!git pull

Collecting torchdiffeq
  Downloading https://files.pythonhosted.org/packages/63/c2/daf5cc6c548f789d0f5222a6daecb8a76d72ad2fa96d958d46cb85f7ae3a/torchdiffeq-0.2.1-py3-none-any.whl
Installing collected packages: torchdiffeq
Successfully installed torchdiffeq-0.2.1
Cloning into 'learning_training_dynamics'...
remote: Enumerating objects: 24, done.[K
remote: Counting objects: 100% (24/24), done.[K
remote: Compressing objects: 100% (18/18), done.[K
remote: Total 24 (delta 5), reused 16 (delta 3), pack-reused 0[K
Unpacking objects: 100% (24/24), done.
/content/learning_training_dynamics


In [16]:
import torchdiffeq as teq
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import os

import datasets

# set device
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [4]:
class MLP(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()

        self.net = nn.Sequential(nn.Linear(input_dim, 200), nn.Tanh(), nn.Linear(200, 200), nn.Tanh(), nn.Linear(200, 200), nn.Tanh(), nn.Linear(200, output_dim))
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    
    def forward(self, t, x):
        x = torch.cat([torch.tensor([t], device=self.device), x], dim=-1)
        return self.net(x)

In [12]:
# some hyperparams
hidden_dim = 100
train_steps = 1000
lr = 1e-3



# set up data
data_dir = './data/mini_mnist'
train_dataset = datasets.MiniMNISTParams(data_dir)
train_subset = torch.utils.data.Subset(train_dataset, indices=[0]) # for debugging
#test_dataset = datasets.MiniMNISTParams(data_dir, train=False)

train_loader = DataLoader(train_subset, batch_size=1)
#test_loader = DataLoader(test_dataset, batch_size=1)

# 
input_dim = train_dataset[0][1].shape[1]
time_stamps = torch.arange(0, train_dataset[0][1].shape[0], 1, dtype=float, device=device)


# set up model
model_kwargs = {'input_dim': input_dim+1, # +1 for time  
                'hidden_dim': hidden_dim, # is cur 100, should change to [20,20]
                'output_dim':input_dim
                }
model = MLP(**model_kwargs).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fct = nn.MSELoss()

In [13]:
print(len(train_subset))

1


In [14]:
# training
step = 0
epoch = 0
while step < train_steps:

    epoch += 1
    print(f'\nStarting epoch {epoch}')
    model.train()
    for y_0, y in iter(train_loader):
        step += 1
        
        y_0 = y_0.squeeze().to(device)
        y = y.squeeze().to(device)
        
        # train step
        optimizer.zero_grad()
        pred = teq.odeint_adjoint(model, y_0, time_stamps, adjoint_options=dict(norm='seminorm'))
        loss = loss_fct(pred, y.squeeze())
        print(f'Loss at step {step} is {loss.item():1.3f}')
        loss.backward()
        optimizer.step()

    # eval step
    #model.eval()

KeyboardInterrupt: ignored

In [15]:
# save model
torch.save(model.state_dict(), './trained_model.pt')

In [5]:
# base model class
class BaseMLP(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()

        layers = [nn.Linear(input_dim, hidden_dim[0]), nn.ReLU()]
        for i in range(1,len(hidden_dim)):
            layers.append(nn.Linear(hidden_dim[i-1], hidden_dim[i]))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_dim[-1], output_dim))
        
        self.net = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.net(x)

In [6]:
def get_params(model):
    '''
    Returns the parameters of a model as a single vector
    Args:
        model - instance of torch.nn.Module
    Returns:
        params - list of shape [n_params]
    '''
    params = []
    for param in model.parameters():
        params.extend(param.flatten().detach().tolist())
    return params

In [7]:
def params2weights(param_vector, param_ranges):
    '''
    Maps a vector of parameters to a list of tensors capturing the model weights and biases
    Args:
        param_vector - torch tensor of shape [n_params]
        param_ranges - list containing number of parameters for each weight matrix or bias vector
    Returns:
        weights - list of torch tensors containing weights and biases: [w1, b1, w2, b2, ...]
    '''
    return list(torch.split(param_vector, param_ranges))


In [32]:
# initialize random model 
random_model = BaseMLP(input_dim=28*28, hidden_dim=[20,20], output_dim=1).to(device)

# load data
data_dir = './data/mini_mnist'

train_base_dataset = datasets.MiniMNIST(data_dir, flatten=True)
train_base_loader = DataLoader(train_base_dataset, shuffle=False, batch_size=len(train_base_dataset))

test_base_dataset = datasets.MiniMNIST(data_dir, flatten=True, train=False)
test_base_loader = DataLoader(test_base_dataset, shuffle=False, batch_size=len(test_base_dataset))

# eval func
loss_fct = nn.BCEWithLogitsLoss()

# evluate model on train data
random_model.eval()
train_data, train_labels = next(iter(train_base_loader))
train_data = train_data.to(device)
train_labels = train_labels.to(device)
train_logits = random_model(train_data).squeeze()
train_loss = loss_fct(train_logits, train_labels).item()
train_acc = (torch.round(torch.sigmoid(train_logits)) == train_labels).sum() / len(train_labels)
print(f'Train loss at init is {train_loss:1.5f}')
print(f'Train accuracy at init is {train_acc:1.5f}')


# evluate model on test data
random_model.eval()
test_data, test_labels = next(iter(test_base_loader))
test_data = test_data.to(device)
test_labels = test_labels.to(device)
test_logits = random_model(test_data).squeeze()
test_loss = loss_fct(test_logits, test_labels).item()
test_acc = (torch.round(torch.sigmoid(test_logits)) == test_labels).sum() / len(test_labels)
print(f'\nTest loss at init is {test_loss:1.5f}')
print(f'Test accuracy at init is {test_acc:1.5f}')

Train loss at init is 3.46422
Train accuracy at init is 0.50000

Test loss at init is 3.08259
Test accuracy at init is 0.50000


In [60]:
def eval_model_params(model_class, model_kwargs, params, data_loader, id=0, device='cpu'):
    model = model_class(**model_kwargs).to(device)

    # convert param vector to fit weight dimensions of base model
    param_ranges = []
    param_shapes = []
    for param in model.parameters():
        param_ranges.append(torch.prod(torch.tensor(param.shape)).item())
        param_shapes.append(param.shape)

    sliced_params = params2weights(params, param_ranges)
    # reshape them into correct shape
    for i, shape in enumerate(param_shapes):
        sliced_params[i] = sliced_params[i].reshape(shape)

    # set base model params to evolved params
    for old_param, new_param in zip(model.parameters(), sliced_params):
      old_param.data = new_param

    # eval evolved model on the sample it was trained on
    model.eval()
    data, labels = next(iter(data_loader))
    data = data.to(device)
    labels = labels.to(device)
    logits = model(data).squeeze()
    loss = loss_fct(logits, labels).item()
    acc = (torch.round(torch.sigmoid(logits)) == labels).sum() / len(labels)
    print(f'Loss at {id} is {loss:1.5f}')
    print(f'Acc at {id} is {acc:1.5f}')

def eval_model_trajectory(model_class, model_kwargs, all_params, data_loader, device='cpu'):
    for i in range(all_params.shape[0]):
        eval_model_params(model_class, model_kwargs, all_params[i], data_loader, f'time {i}', device)


In [None]:
####
# Evolve using learned dynamics
####


# get parameters of random model
#params = torch.tensor(get_params(random_model), device=device)
params, _ = torch.utils.data.Subset(train_dataset, indices=[1])[0]
params = params.to(device)

# use learned dynamics model to evolve the parameters
evolved_params = teq.odeint_adjoint(model, params, time_stamps, adjoint_options={'norm': 'seminorm'})
params_at_25 = evolved_params[-1]
#params_at_8 = evolved_params[25]
params_traj = evolved_params

model_kwargs = {'input_dim':28*28, 'hidden_dim': [20,20], 'output_dim': 1}

eval_model_trajectory(BaseMLP, model_kwargs, params_traj, train_base_loader, device)

# init a new model
evolved_model = BaseMLP(input_dim=28*28, hidden_dim=[20,20], output_dim=1).to(device)

# convert param vector to fit weight dimensions of base model
param_ranges = []
param_shapes = []
for param in evolved_model.parameters():
    param_ranges.append(torch.prod(torch.tensor(param.shape)).item())
    param_shapes.append(param.shape)

sliced_params = params2weights(params_at_25, param_ranges)
# reshape them into correct shape
for i, shape in enumerate(param_shapes):
    sliced_params[i] = sliced_params[i].reshape(shape)

# set base model params to evolved params
for old_param, new_param in zip(evolved_model.parameters(), sliced_params):
   old_param.data = new_param

# eval evolved model on the sample it was trained on
evolved_model.eval()
data, labels = next(iter(train_base_loader))
data = data.to(device)
labels = labels.to(device)
logits = evolved_model(data).squeeze()
loss = loss_fct(logits, labels).item()
acc = (torch.round(torch.sigmoid(logits)) == labels).sum() / len(labels)
print(f'Loss on train sample after evol is {loss:1.5f}')
print(f'Acc on train sample after evol is {acc:1.5f}')

# compare evolved with trained
trained_params = train_subset[0][1].to(device)
mse_loss = nn.MSELoss(reduction='none')
evol_loss = mse_loss(trained_params, evolved_params).mean(dim=1)
print(evol_loss)
print(evol_loss.mean())


# eval evolved model
evolved_model.eval()
test_data, test_labels = next(iter(test_loader))
test_data = test_data.to(device)
test_labels = test_labels.to(device)
test_logits = evolved_model(test_data).squeeze()
test_loss = loss_fct(test_logits, test_labels).item()
test_acc = (torch.round(torch.sigmoid(test_logits)) == test_labels).sum() / len(test_labels)
print(f'\nTest loss after evolving is {test_loss:1.5f}')
print(f'Test accuracy after evolving is {test_acc:1.5f}')
print(f'Test logits after evolving {test_logits}')
print(f'Test probs after evolving {torch.sigmoid(test_logits)}')

In [40]:
####
# train in normal way
####

def train_model(model, opt_class, opt_kwargs, loss_fct, train_loader, test_loader, train_steps):

    # init param trajectory
    params = []
    params.append(get_params(model))
    
    # set up optimizer
    optimizer = opt_class(model.parameters(), **opt_kwargs)

    # initial random logits
    #train_logits = []
    #data, _ = next(iter(train_loader))
    #train_logits.append(model(data).squeeze().tolist())

    for step in range(train_steps):
        
        # one training step
        model.train()
        data, labels = next(iter(train_loader))
        data = data.to(device)
        labels = labels.to(device)
        logits = model(data).squeeze()
        loss = loss_fct(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # store new params
        params.append(get_params(model))

        # store output of network at this step
        #train_logits.append(logits.tolist())

        # eval on test set
        model.eval()
        test_data, test_labels = next(iter(test_loader))
        test_data = test_data.to(device)
        test_labels = test_labels.to(device)
        test_logits = model(test_data).squeeze()
        test_loss = loss_fct(test_logits, test_labels)
        test_acc = torch.sum(torch.round(torch.sigmoid(test_logits)) == test_labels) / len(test_labels)

        print(f'Step {step+1}: train loss = {loss.item():1.5f}, test loss = {test_loss.item():1.5f}, test acc = {test_acc.item():1.4f}')
                

    return params


In [None]:
# loss function
loss_fct = torch.nn.BCEWithLogitsLoss()

train_steps = 25
lr = 1e-4

# optimizer
opt_class = torch.optim.Adam
opt_kwargs = {'lr':lr}

train_kwargs = {'opt_class':opt_class,
                'opt_kwargs':opt_kwargs, 
                'loss_fct':loss_fct, 
                'train_loader':train_loader, 
                'test_loader':test_loader, 
                'train_steps':train_steps
                }

# train the model
trained_params = torch.tensor(train_model(random_model, **train_kwargs), device=device).view(evolved_params.shape)

# compare evolved with trained
mse_loss = nn.MSELoss(reduction='none')
evol_loss = mse_loss(trained_params, evolved_params).mean(dim=1)
print(evol_loss)

In [71]:
class Simple2dData(torch.utils.data.Dataset):

    def __init__(self, num_samples):
        super().__init__()

        self.num_samples = num_samples
        assert num_samples % 2 == 0, 'Number of samples should be even!'

        # sample around (1,1)
        pos_dist = torch.distributions.multivariate_normal.MultivariateNormal(loc=torch.ones(2), covariance_matrix=0.5*torch.eye(2))
        pos_samples = pos_dist.sample((num_samples//2,))

        # sample around (-1,-1)
        neg_dist = torch.distributions.multivariate_normal.MultivariateNormal(loc=-1 * torch.ones(2), covariance_matrix=0.5*torch.eye(2))
        neg_samples = neg_dist.sample((num_samples//2,))

        # concat
        self.data = torch.cat([pos_samples, neg_samples], dim=0)
        self.labels = torch.cat([torch.ones(len(pos_samples)), torch.zeros(len(neg_samples))], dim=0)
      
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]
      
    def __len__(self):
        return len(self.labels)


def collect_traj(num_trajs, model_class, model_kwargs, loss_fct, opt_class, opt_kwargs, train_loader, test_loader, train_steps = 10):
    train_kwargs = {'opt_class':opt_class,
                    'opt_kwargs':opt_kwargs, 
                    'loss_fct':loss_fct, 
                    'train_loader':train_loader, 
                    'test_loader':test_loader, 
                    'train_steps':train_steps
                    }

    param_list = []
    for i in range(num_trajs):
        # init new model
        model = BaseMLP(**model_kwargs).to(device)
        # train it
        trained_params = train_model(model, **train_kwargs)
        param_list.append(trained_params)
    
    return param_list



num_samples = 100

simple_train_data = Simple2dData(num_samples)
simple_train_loader = torch.utils.data.DataLoader(simple_train_data, batch_size=len(simple_train_data))

simple_test_data = Simple2dData(num_samples)
simple_test_loader = torch.utils.data.DataLoader(simple_test_data, batch_size=len(simple_train_data))



lr = 1e-2
num_trajs = 1000
train_steps = 20
base_model_kwargs = {'input_dim':2, 'hidden_dim':[20,20], 'output_dim':1}
collect_train_traj_kwargs = {'num_trajs': num_trajs,
                       'model_class': BaseMLP,
                       'model_kwargs':base_model_kwargs,
                       'loss_fct': nn.BCEWithLogitsLoss(),
                       'opt_class': torch.optim.Adam,
                       'opt_kwargs': {'lr':lr},
                       'train_loader': simple_train_loader,
                       'test_loader': simple_test_loader,
                       'train_steps': train_steps
}

simple_trajs = torch.tensor(collect_traj(**collect_train_traj_kwargs))
simple_train_trajs, simple_test_trajs = torch.split(simple_trajs, [800,200])

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Step 1: train loss = 0.72980, test loss = 0.68848, test acc = 0.5100
Step 2: train loss = 0.69240, test loss = 0.65667, test acc = 0.8400
Step 3: train loss = 0.65687, test loss = 0.62426, test acc = 0.9400
Step 4: train loss = 0.62161, test loss = 0.59076, test acc = 0.9600
Step 5: train loss = 0.58492, test loss = 0.55577, test acc = 0.9600
Step 6: train loss = 0.54656, test loss = 0.51934, test acc = 0.9600
Step 7: train loss = 0.50695, test loss = 0.48200, test acc = 0.9600
Step 8: train loss = 0.46659, test loss = 0.44417, test acc = 0.9600
Step 9: train loss = 0.42572, test loss = 0.40632, test acc = 0.9600
Step 10: train loss = 0.38497, test loss = 0.36924, test acc = 0.9600
Step 11: train loss = 0.34481, test loss = 0.33349, test acc = 0.9600
Step 12: train loss = 0.30585, test loss = 0.29984, test acc = 0.9600
Step 13: train loss = 0.26886, test loss = 0.26883, test acc = 0.9600
Step 14: train loss = 0.23434, tes

In [None]:
class Simple2dParam(torch.utils.data.Dataset):

    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        return self.data[idx][0], self.data[idx]
    
    def __len__(self):
        return len(self.data)

simple_param_train_data = Simple2dParam(simple_train_trajs)
simple_param_train_loader = torch.utils.data.DataLoader(simple_param_train_data)

simple_param_test_data = Simple2dParam(simple_test_trajs)
simple_param_test_loader = torch.utils.data.DataLoader(simple_param_test_data)

# 
input_dim = simple_param_train_data[0][1].shape[1]
time_stamps = torch.arange(0, simple_param_train_data[0][1].shape[0], 1, dtype=float, device=device)


# set up model
model_of_simple_kwargs = {'input_dim': input_dim+1, # +1 for time  
                'hidden_dim': 0, #is currently not used in model
                'output_dim':input_dim
                }
model_of_simple = MLP(**model_of_simple_kwargs).to(device)

train_steps = 1000
lr = 3e-4

optimizer = torch.optim.Adam(model_of_simple.parameters(), lr=lr)
loss_fct = nn.MSELoss()

# training
step = 0
epoch = 0
eval_freq = 300
while step < train_steps:

    epoch += 1
    print(f'\nStarting epoch {epoch}')
    model_of_simple.train()
    for y_0, y in iter(simple_param_train_loader):
        step += 1
        
        y_0 = y_0.squeeze().to(device)
        y = y.squeeze().to(device)
        
        # train step
        optimizer.zero_grad()
        pred = teq.odeint_adjoint(model_of_simple, y_0, time_stamps, adjoint_options=dict(norm='seminorm'))
        loss = loss_fct(pred, y.squeeze())
        print(f'Step {step}: train loss = {loss.item():1.5f}')
        loss.backward()
        optimizer.step()

        if step % eval_freq == 0:
            # eval on test
            model_of_simple.eval()
            test_loss = 0
            num_test_batches = 0
            for y_0, y in iter(simple_param_test_loader):
                num_test_batches += 1
                y_0 = y_0.squeeze().to(device)
                y = y.squeeze().to(device)
                
                # train step
                pred = teq.odeint_adjoint(model_of_simple, y_0, time_stamps, adjoint_options=dict(norm='seminorm'))
                loss = loss_fct(pred, y.squeeze()).item()
                test_loss += loss
            test_loss /= num_test_batches
            print(f'\nStep {step}: test loss = {test_loss:1.5f}\n')

        if step == train_steps:
            break

    




Starting epoch 1
Step 1: train loss = 0.64222
Step 2: train loss = 0.54524
Step 3: train loss = 0.42102
Step 4: train loss = 0.38758
Step 5: train loss = 0.32948
Step 6: train loss = 0.34763
Step 7: train loss = 0.30884
Step 8: train loss = 0.34624
Step 9: train loss = 0.28655
Step 10: train loss = 0.30418
Step 11: train loss = 0.21219
Step 12: train loss = 0.22283
Step 13: train loss = 0.22033
Step 14: train loss = 0.20041
Step 15: train loss = 0.21799
Step 16: train loss = 0.20298
Step 17: train loss = 0.18072
Step 18: train loss = 0.20506
Step 19: train loss = 0.17890
Step 20: train loss = 0.17655
Step 21: train loss = 0.16036
Step 22: train loss = 0.15133
Step 23: train loss = 0.15932
Step 24: train loss = 0.14549
Step 25: train loss = 0.14215
Step 26: train loss = 0.11927
Step 27: train loss = 0.11606
Step 28: train loss = 0.12839
Step 29: train loss = 0.10864
Step 30: train loss = 0.10353
Step 31: train loss = 0.11699
Step 32: train loss = 0.10577
Step 33: train loss = 0.10227
S

In [78]:
simple_param_test_data = Simple2dParam(simple_test_trajs)
simple_param_test_loader = torch.utils.data.DataLoader(simple_param_test_data)

for i in range(len(simple_param_test_data)):
    
    print(f'\nEvaluating model {i+1}')
    params, _ = simple_param_test_data[i]
    params = params.to(device)

    # use learned dynamics model to evolve the parameters
    evolved_params = teq.odeint_adjoint(model_of_simple, params, time_stamps, adjoint_options={'norm': 'seminorm'})

    base_model_kwargs = {'input_dim':2, 'hidden_dim':[20,20], 'output_dim':1}
    eval_kwargs = {'model_class':BaseMLP,
                  'model_kwargs': base_model_kwargs,
                  'all_params':evolved_params,
                  'data_loader':simple_train_loader,
                  'device': device}

    eval_model_trajectory(**eval_kwargs)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Loss at time 7 is 0.40446
Acc at time 7 is 0.47000
Loss at time 8 is 0.39234
Acc at time 8 is 0.48000
Loss at time 9 is 0.38085
Acc at time 9 is 0.47000
Loss at time 10 is 0.37010
Acc at time 10 is 0.48000
Loss at time 11 is 0.35951
Acc at time 11 is 0.49000
Loss at time 12 is 0.34912
Acc at time 12 is 0.48000
Loss at time 13 is 0.33895
Acc at time 13 is 0.48000
Loss at time 14 is 0.32874
Acc at time 14 is 0.46000
Loss at time 15 is 0.31860
Acc at time 15 is 0.47000
Loss at time 16 is 0.30872
Acc at time 16 is 0.47000
Loss at time 17 is 0.29914
Acc at time 17 is 0.48000
Loss at time 18 is 0.29029
Acc at time 18 is 0.48000
Loss at time 19 is 0.28200
Acc at time 19 is 0.50000
Loss at time 20 is 0.27469
Acc at time 20 is 0.50000

Evaluating model 88
Loss at time 0 is 0.44955
Acc at time 0 is 0.50000
Loss at time 1 is 0.43255
Acc at time 1 is 0.52000
Loss at time 2 is 0.42146
Acc at time 2 is 0.50000
Loss at time 3 is 0.41304

In [None]:
while True:
  pass