In [None]:
import sys, os, math
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
import importlib

import torch.nn as nn
from torch.functional import F
import torch.distributions as dist
from torch.utils.data import DataLoader, TensorDataset, random_split

sys.path.append('../')
from utils_modules.models import SummaryNet, Expander, Net, vector_to_Cov
from utils_modules.vicreg import vicreg_loss, vicreg_loss_pairs
import utils_modules.data as utils_data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device: %s'%(device))

## Load maps

In [None]:
# load maps and parameters
n_params = 2
home_dir = ... # maps and parameters directory
maps   = np.load(home_dir + 'Maps_Mtot_SIMBA_LH_z=0.00.npy')
dset_size = 1000 # data set size
splits    = 15   # number of realizations per parameter set
maps_size = maps.shape[-1]
maps   = maps.reshape(dset_size, splits, 1, maps_size, maps_size) # prepare maps for VICReg

params = np.loadtxt(home_dir + 'params_SIMBA.txt')[:, None, :n_params]
params  = np.repeat(params, splits, axis = 1) # reshape the parameters to match the shape of the maps
minimum = np.array([0.1, 0.6])
maximum = np.array([0.5, 1.0])
params  = (params - minimum)/(maximum - minimum) # rescale parameters

# pre-process the maps data set
rescale     = True
standardize = True
verbose     = True

if rescale:
    maps = np.log10(maps)
if standardize:
    maps_mean, maps_std = np.mean(maps, dtype=np.float64), np.std(maps, dtype=np.float64)
    maps = (maps - maps_mean)/maps_std
    
if verbose:
    print('Shape of parameters and maps:', params.shape, maps.shape)
    print('Parameter 1 range of values: [{:.3f}, {:.3f}]'.format(params[:, :, 0].min(), params[:, :, 0].max()))
    print('Parameter 2 range of values: [{:.3f}, {:.3f}]'.format(params[:, :, 1].min(), params[:, :, 1].max()))
    
    if rescale: print('Rescale: ', rescale)
    if standardize: print('Standardize: ', standardize)

maps   = torch.tensor(maps).float().to(device) 
params = torch.tensor(params).float().to(device)

## Train the model

In [None]:
save_dir    = ...

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [5]:
def run_training(fmodel, floss, 
                 net, mlp_net, 
                 optimizer, scheduler, 
                 train_loader, valid_loader,
                 batch_size, epochs, splits, n_pairs = 1,
                 inv_weight = 1, var_weight = 0, cov_weight = 0):
    
    # compute minimum validation loss
    net.eval() 
    mlp_net.eval()
    total_loss, points = 0., 0
    
    with torch.no_grad():
        for x, y in valid_loader:
            bsz = x.shape[0] # get the batch size
            
            embeds1 = []
            embeds2 = []
            for pair in range(n_pairs):
                id1 = 2*pair
                id2 = 2*pair + 1
            
                emb_q = mlp_net(net(x[:, id1].contiguous()))
                emb_k = mlp_net(net(x[:, id2].contiguous()))
                
                embeds1.append(emb_q)
                embeds2.append(emb_k)
            
            loss, inv, var, cov = vicreg_loss_pairs(embeds1, embeds2, n_pairs, 
                                                    inv_weight, var_weight, cov_weight)
            total_loss += loss.detach()*bsz
            points += bsz

    min_loss_valid = total_loss/points
    if verbose: print('Min validation loss: ', min_loss_valid)
    
    # loop over the epochs
    for epoch in range(epochs): 
        
        total_loss, points = 0., 0
        inv_loss, var_loss, cov_loss = 0., 0., 0.
        
        net.train()
        mlp_net.train()
        for x, y in train_loader:
            bsz = x.shape[0] # get the batch size
            
            embeds1 = []
            embeds2 = []
            for pair in range(n_pairs):
                id1 = 2*pair
                id2 = 2*pair + 1
            
                emb_q = mlp_net(net(x[:, id1].contiguous()))
                emb_k = mlp_net(net(x[:, id2].contiguous()))
                
                embeds1.append(emb_q)
                embeds2.append(emb_k)
            
            # compute VICReg loss
            loss, inv, var, cov = vicreg_loss_pairs(embeds1, embeds2, n_pairs, 
                                                    inv_weight, var_weight, cov_weight)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.detach()*bsz
            points += bsz
            
            inv_loss += inv.detach()*bsz
            var_loss += var.detach()*bsz
            cov_loss += cov.detach()*bsz
                
        # get the training loss and its components    
        loss_train = total_loss/points
        inv_loss   = inv_loss/points
        var_loss   = var_loss/points
        cov_loss   = cov_loss/points
               

        # validation
        net.eval() 
        mlp_net.eval()
        total_loss, points = 0., 0
        inv_loss, var_loss, cov_loss = 0., 0., 0.
        with torch.no_grad():
            for x, y in valid_loader:
                embeds1 = []
                embeds2 = []
                for pair in range(n_pairs):
                    id1 = 2*pair
                    id2 = 2*pair + 1

                    emb_q = mlp_net(net(x[:, id1].contiguous()))
                    emb_k = mlp_net(net(x[:, id2].contiguous()))

                    embeds1.append(emb_q)
                    embeds2.append(emb_k)

                # compute VICReg loss
                loss, inv, var, cov = vicreg_loss_pairs(embeds1, embeds2, n_pairs, 
                                                        inv_weight, var_weight, cov_weight)

                total_loss += loss.detach()*bsz
                points += bsz
                
                inv_loss += inv.detach()*bsz
                var_loss += var.detach()*bsz
                cov_loss += cov.detach()*bsz
                
        # get the validation loss and its components      
        loss_valid = total_loss/points
        inv_loss   = inv_loss/points
        var_loss   = var_loss/points
        cov_loss   = cov_loss/points

        # save model if it is better
        if loss_valid < min_loss_valid:
            if verbose:
                print('saving model;  epoch %d; %.4e %.4e'\
                      %(epoch, loss_train, loss_valid))
            torch.save(net.state_dict(), fmodel)
            min_loss_valid = loss_valid
        else:
            if verbose:
                print('epoch %d; %.4e %.4e'\
                      %(epoch,loss_train,loss_valid))

        if epoch == 0:
            f = open(fout, 'w')
        else:
            f = open(fout, 'a')
        f.write('%d %.4e %.4e %.4e %.4e %.4e\n'%(epoch, loss_train, loss_valid, 
                                                 inv_loss, var_loss, cov_loss))
        f.close()
        scheduler.step(loss_valid)
        
    return net, mlp_net

In [None]:
inv_arr = [25] 
var_arr = [25] 
cov_arr = [1] 

# divide the data into train, validation, and test sets
seed = 1
batch_size = 50
train_frac, valid_frac, test_frac = 0.7, 0.2, 0.1
n_views = 10
n_pairs = n_views // 2
train_dset, valid_dset, test_dset = utils_data.create_datasets(maps, params, 
                                                    train_frac, valid_frac, test_frac, 
                                                    seed = seed, VICReg=True, n_views=n_views)


train_loader = DataLoader(train_dset, batch_size, shuffle = True)
valid_loader = DataLoader(valid_dset, batch_size, shuffle = True)
test_loader  = DataLoader(test_dset, batch_size, shuffle = False)

if verbose: print('\n Split the data into train, validation, and test sets.')
######################################
for num_config in range(len(inv_arr)):
    lr         = 1e-3
    epochs     = 150
    
    # define the model
    last_layer = 128
    
    model = torchvision.models.resnet18(num_classes=last_layer)
    model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    model.to(device);

    # define the expander model
    mlp_exp_units = [4*last_layer, 4*last_layer]
    expander_net = Expander(mlp_exp_units, last_layer, bn = True).to(device)

    # define the optimizer, scheduler
    optimizer = torch.optim.AdamW([*model.parameters(), *expander_net.parameters()], 
                                 lr=lr, betas=(0.9, 0.999), eps=1e-8, amsgrad=False)  
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 
                                                           factor=0.3, verbose=True)


    inv, var, cov = inv_arr[num_config], var_arr[num_config], cov_arr[num_config]
    fmodel = save_dir + 'model_{:d}_{:d}_{:d}_n_pairs_{:}_lr_{:.2e}.pt'.format(inv, var, cov, n_pairs, lr)
    fout   = save_dir + 'losses_{:d}_{:d}_{:d}_n_pairs_{:}_lr_{:.2e}.txt'.format(inv, var, cov, n_pairs, lr)
    
    net, mlp = run_training(fmodel, fout, 
                            model, expander_net, 
                            optimizer, scheduler, 
                            train_loader, valid_loader, n_pairs=n_pairs,
                            batch_size=batch_size, epochs=epochs, splits=splits,
                            inv_weight = inv, var_weight = var, cov_weight = cov)
    print('Done with config = {:d}'.format(num_config+1))

In [None]:
# plot VICReg loss functions
losses = np.loadtxt(fout)
start_epoch = 0
end_epoch = 200

plt.figure(figsize = (10, 6))
plt.plot(losses[start_epoch:end_epoch, 0], losses[start_epoch:end_epoch, 1], label = 'Training loss')
plt.plot(losses[start_epoch:end_epoch, 0], losses[start_epoch:end_epoch, 2], label = 'Validation loss')
plt.legend(loc = 'best')