In [None]:
import sys, os, math
import numpy as np
import torch
import matplotlib.pyplot as plt
import importlib

import torch.nn as nn
from torch.functional import F
import torch.distributions as dist
from torch.utils.data import DataLoader, TensorDataset, random_split

sys.path.append('../')
from utils_modules.models import SummaryNet, Expander, Net, vector_to_Cov
from utils_modules.vicreg import vicreg_loss
import utils_modules.data as utils_data

In [None]:
# select device; use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device: %s'%(device))

## Load data

In [None]:
# load maps and parameters
maps      = np.load(...)[:, :, None, :, :]
dset_size = maps.shape[0] # data set size
splits    = maps.shape[1] # number of augmentations/views per parameter set

params  = np.load(...)[:, None, :]
params  = np.repeat(params, splits, axis = 1) # reshape the parameters to match the shape of the maps

# pre-process the maps data set
rescale     = True
standardize = True
verbose     = True

if rescale:
    maps = np.log(maps+1)
if standardize:
    maps_mean, maps_std = np.mean(maps, dtype=np.float64), np.std(maps, dtype=np.float64)
    maps = (maps - maps_mean)/maps_std
    
if verbose:
    print('Shape of parameters and maps:', params.shape, maps.shape)
    print('Parameter 1 range of values: [{:.3f}, {:.3f}]'.format(params[:, :, 0].min(), params[:, :, 0].max()))
    print('Parameter 2 range of values: [{:.3f}, {:.3f}]'.format(params[:, :, 1].min(), params[:, :, 1].max()))
    
    if rescale: print('Rescale: ', rescale)
    if standardize: print('Standardize: ', standardize)

maps   = torch.tensor(maps).float().to(device) 
params = torch.tensor(params).float().to(device)

In [None]:
# divide the data into train, validation, and test sets
batch_size = 200
train_frac, valid_frac, test_frac = 0.8, 0.1, 0.1


train_dset, valid_dset, test_dset = utils_data.create_datasets(maps, params, 
                                                    train_frac, valid_frac, test_frac, 
                                                    seed = seed, VICReg=True)


train_loader = DataLoader(train_dset, batch_size, shuffle = True)
valid_loader = DataLoader(valid_dset, batch_size, shuffle = True)
test_loader  = DataLoader(test_dset, batch_size, shuffle = False)

## Train the encoder model

In [7]:
def run_training(fmodel, floss, 
                 net, mlp_net, 
                 optimizer, scheduler, 
                 train_loader, valid_loader,
                 batch_size, epochs, splits, 
                 inv_weight = 1, var_weight = 0, cov_weight = 0):
    
    # compute minimum validation loss
    net.eval() 
    mlp_net.eval()
    total_loss, points = 0., 0
    with torch.no_grad():
        for x, y in valid_loader:
            bsz = x.shape[0] # batch size
            
            # get VICReg summaries of the two maps
            emb_q = mlp_net(net(x[:, 0].contiguous()))
            emb_k = mlp_net(net(x[:, 1].contiguous()))
            
            # compute VICReg loss
            loss, inv, var, cov = vicreg_loss(emb_q, emb_k, inv_weight, var_weight, cov_weight)
            
            total_loss += loss.detach()*bsz
            points += bsz

    min_loss_valid = total_loss/points
    if verbose: print('Min validation loss: ', min_loss_valid)
    
    # loop over the epochs
    for epoch in range(epochs): 
        
        total_loss, points = 0., 0
        inv_loss, var_loss, cov_loss = 0., 0., 0.
        
        net.train()
        mlp_net.train()
        for x, y in train_loader:
            bsz = x.shape[0] # batch size
            
            # get VICReg summaries of the two maps
            emb_q = mlp_net(net(x[:, 0].contiguous()))
            emb_k = mlp_net(net(x[:, 1].contiguous()))
            
            # compute VICReg loss
            loss, inv, var, cov = vicreg_loss(emb_q, emb_k, inv_weight, var_weight, cov_weight)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.detach()*bsz
            points += bsz
            
            # compute different components of the loss
            inv_loss += inv.detach()*bsz
            var_loss += var.detach()*bsz
            cov_loss += cov.detach()*bsz
                
        # get the training loss and its components    
        loss_train = total_loss/points
        inv_loss   = inv_loss/points
        var_loss   = var_loss/points
        cov_loss   = cov_loss/points
               
        # validation
        net.eval() 
        mlp_net.eval()
        total_loss, points = 0., 0
        inv_loss, var_loss, cov_loss = 0., 0., 0.
        with torch.no_grad():
            for x, y in valid_loader:
                bs = x.shape[0] # batch size
                
                # get VICReg summaries of the two maps
                emb_q = mlp_net(net(x[:, 0].contiguous()))
                emb_k = mlp_net(net(x[:, 1].contiguous()))
            
                # compute VICReg loss
                loss, inv, var, cov = vicreg_loss(emb_q, emb_k, inv_weight, var_weight, cov_weight)
                total_loss += loss.detach()*bsz
                points += bsz
                
                inv_loss += inv.detach()*bsz
                var_loss += var.detach()*bsz
                cov_loss += cov.detach()*bsz
                
        # get validation loss and its components      
        loss_valid = total_loss/points
        inv_loss   = inv_loss/points
        var_loss   = var_loss/points
        cov_loss   = cov_loss/points

        # save model which performs best on validation set
        if loss_valid < min_loss_valid:
            if verbose:
                print('saving model;  epoch %d; %.4e %.4e'%(epoch, loss_train, loss_valid))
            torch.save(net.state_dict(), fmodel)
            min_loss_valid = loss_valid
        else:
            if verbose:
                print('epoch %d; %.4e %.4e'%(epoch,loss_train,loss_valid))

        if epoch == 0:
            f = open(fout, 'w')
        else:
            f = open(fout, 'a')
        f.write('%d %.4e %.4e %.4e %.4e %.4e\n'%(epoch, loss_train, loss_valid, 
                                                 inv_loss, var_loss, cov_loss))
        f.close()
        scheduler.step(loss_valid)
        
    return net, mlp_net

In [None]:
# weights for the loss function (invariance, variance, covariance)
inv_arr = [5]
var_arr = [5]
cov_arr = [1]

# hyperparameters
lr         = 2e-4
epochs     = 200
hidden     = 8
last_layer = 2*hidden
for num_config in range(len(inv_arr)):
    
    
    # define the encoder model
    model = SummaryNet(hidden = hidden, last_layer = last_layer).to(device)
    
    # define the expander model
    mlp_exp_units = [4*last_layer, 4*last_layer]
    expander_net = Expander(mlp_exp_units, last_layer, bn = True).to(device)

    # define the optimizer, scheduler
    optimizer = torch.optim.AdamW([*model.parameters(), *expander_net.parameters()], 
                                 lr=lr, betas=(0.9, 0.999), eps=1e-8, amsgrad=False)  
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.3)

    save_dir = ...
    inv, var, cov = inv_arr[num_config], var_arr[num_config], cov_arr[num_config]
    fmodel = save_dir + 'model_{:d}_{:d}_{:d}.pt'.format(inv, var, cov, num_sims_k)
    fout   = save_dir + 'losses_{:d}_{:d}_{:d}.txt'.format(inv, var, cov, num_sims_k)

    net, mlp = run_training(fmodel, fout, 
                            model, expander_net, 
                            optimizer, scheduler, 
                            train_loader, valid_loader,
                            batch_size=batch_size, epochs=epochs, splits=splits,
                            inv_weight = inv, var_weight = var, cov_weight = cov)
    

In [None]:
losses = np.loadtxt(fout)
start_epoch = 0
plt.plot(losses[start_epoch:, 0], losses[start_epoch:, 1], label = 'Training loss')
plt.plot(losses[start_epoch:, 0], losses[start_epoch:, 2], label = 'Validation loss')
plt.legend(loc = 'best')