In [None]:
import sys, os, math
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
import importlib

import torch.nn as nn
from torch.functional import F
import torch.distributions as dist
from torch.utils.data import DataLoader, TensorDataset, random_split

sys.path.append('../')
from utils_modules.models import SummaryNet, Expander, Net, vector_to_Cov
from utils_modules.vicreg import vicreg_loss, vicreg_loss_pairs
import utils_modules.data as utils_data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device: %s'%(device))

## Load maps

In [None]:
# load maps and parameters
n_params = 2
home_dir = ... # maps and parameters directory
maps   = np.load(home_dir + 'Maps_Mtot_SIMBA_LH_z=0.00.npy')
dset_size = 1000 # data set size
splits    = 15   # number of realizations per parameter set
maps_size = maps.shape[-1]
maps   = maps.reshape(dset_size, splits, 1, maps_size, maps_size) # prepare maps for VICReg

params = np.loadtxt(home_dir + 'params_SIMBA.txt')[:, None, :n_params]
params  = np.repeat(params, splits, axis = 1) # reshape the parameters to match the shape of the maps
minimum = np.array([0.1, 0.6])
maximum = np.array([0.5, 1.0])
params  = (params - minimum)/(maximum - minimum) # rescale parameters

# pre-process the maps data set
rescale     = True
standardize = True
verbose     = True

if rescale:
    maps = np.log10(maps)
if standardize:
    maps_mean, maps_std = np.mean(maps, dtype=np.float64), np.std(maps, dtype=np.float64)
    maps = (maps - maps_mean)/maps_std
    
if verbose:
    print('Shape of parameters and maps:', params.shape, maps.shape)
    print('Parameter 1 range of values: [{:.3f}, {:.3f}]'.format(params[:, :, 0].min(), params[:, :, 0].max()))
    print('Parameter 2 range of values: [{:.3f}, {:.3f}]'.format(params[:, :, 1].min(), params[:, :, 1].max()))
    
    if rescale: print('Rescale: ', rescale)
    if standardize: print('Standardize: ', standardize)

maps   = torch.tensor(maps).float().to(device) 
params = torch.tensor(params).float().to(device)

## Train the model

In [None]:
save_dir    = ...

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [None]:
# the loss is taken from
# https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial17/SimCLR.html
def nce_loss(features, temperature=0.1):
    # Calculate cosine similarity
    cos_sim = F.cosine_similarity(features[:,None,:], features[None,:,:], dim=-1)
    # Mask out cosine similarity to itself
    self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
    cos_sim.masked_fill_(self_mask, -9e15)
    # Find positive example -> batch_size//2 away from the original example
    pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
    # InfoNCE loss
    cos_sim = cos_sim / temperature
    nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
    nll = nll.mean()
    return nll

In [5]:
def run_training(fmodel, floss, 
                 net, mlp_net, 
                 optimizer, scheduler, 
                 train_loader, valid_loader,
                 epochs, verbose=True,
                 temperature = 0.1):
    
    # compute minimum validation loss
    net.eval() 
    mlp_net.eval()
    total_loss, points = 0., 0
    
    with torch.no_grad():
        for x, y in valid_loader:
            #print(x.shape)
            bsz, n_views, n_channels, img_shape1, img_shape2 = x.shape
            
            # combine the two views into one dataset
            imgs = torch.cat((x[:, 0], x[:, 1]), dim=0)
            print(imgs.shape)
            # Encode all images
            feats = mlp_net(net(imgs))
            # Compute loss function
            loss=nce_loss(feats, temperature)
            
            total_loss += loss.detach()*bsz
            points += bsz

    min_loss_valid = total_loss/points
    if verbose: print('Min validation loss: ', min_loss_valid)
    
    # do a loop over the different epochs
    for epoch in range(epochs): 
        
        total_loss, points = 0., 0
        
        net.train()
        mlp_net.train()
        for x, y in train_loader:
            x = x.float().to(device)
            bsz, n_views, n_channels, img_shape1, img_shape2 = x.shape
            
            # combine the two views into one dataset
            imgs = torch.cat((x[:, 0], x[:, 1]), dim=0)
            # Encode all images
            feats = mlp_net(net(imgs))
            # Compute loss function
            loss=nce_loss(feats, temperature)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.detach()*bsz
            points += bsz
                
        # get the training loss and its components    
        loss_train = total_loss/points       

        # validation
        net.eval() 
        mlp_net.eval()
        total_loss, points = 0., 0
        inv_loss, var_loss, cov_loss = 0., 0., 0.
        with torch.no_grad():
            for x, y in valid_loader:
                x = x.float().to(device)
                bsz, n_views, n_channels, img_shape1, img_shape2 = x.shape
                # combine the two views into one dataset
                imgs = torch.cat((x[:, 0], x[:, 1]), dim=0)
                # Encode all images
                feats = mlp_net(net(imgs))
                # Compute loss function
                loss=nce_loss(feats, temperature)
                
                total_loss += loss.detach()*bsz
                points += bsz
                
                
        # get the validation loss and its components      
        loss_valid = total_loss/points
    
        # save model if it is better
        if loss_valid < min_loss_valid:
            if verbose:
                print('saving model;  epoch %d; %.4e %.4e'\
                      %(epoch, loss_train, loss_valid))
            torch.save(net.state_dict(), fmodel)
            min_loss_valid = loss_valid
        else:
            if verbose:
                print('epoch %d; %.4e %.4e'\
                      %(epoch,loss_train,loss_valid))

        if epoch == 0:
            f = open(fout, 'w')
        else:
            f = open(fout, 'a')
        f.write('%d %.4e %.4e \n'%(epoch, loss_train, loss_valid))
        f.close()
        scheduler.step(loss_valid)
        
    return net, mlp_net


In [None]:
# divide the data into train, validation, and test sets
batch_size = 256
temp = 0.1
train_frac, valid_frac, test_frac = 0.7, 0.2, 0.1


train_dset, valid_dset, test_dset = utils_data.create_datasets(maps, params, 
                                                    train_frac, valid_frac, test_frac, 
                                                    seed = seed, VICReg=True)


train_loader = DataLoader(train_dset, batch_size, shuffle = True)
valid_loader = DataLoader(valid_dset, batch_size, shuffle = True)
test_loader  = DataLoader(test_dset, batch_size, shuffle = False)
###################
lr         = 1e-3
epochs     = 150
# define the model
last_layer = 128

model = torchvision.models.resnet18(num_classes=last_layer)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.to(device);

# define the expander model
mlp_exp_units = [4*last_layer, 4*last_layer]
expander_net = Expander(mlp_exp_units, last_layer, bn = True).to(device)

# define the optimizer, scheduler
optimizer = torch.optim.AdamW([*model.parameters(), *expander_net.parameters()], 
                                  lr=lr, betas=(0.9, 0.999), eps=1e-8, amsgrad=False)  
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 
                                                           factor=0.3,verbose=True)
temp = 0.1

fmodel = save_dir + 'SimCLR_T_{:.3f}_bs_{:d}_{:d}hl_{:d}.pt'.format(temp, batch_size,
                                                                    len(mlp_exp_units), 
                                                                     last_layer)
fout   = save_dir + 'SimCLR_T_{:.3f}_bs_{:d}_{:d}hl_{:d}.txt'.format(temp, batch_size,
                                                                     len(mlp_exp_units), 
                                                                     last_layer)
net, mlp = run_training(fmodel, fout, model, expander_net, optimizer, scheduler, 
                                    train_loader, valid_loader, epochs=epochs)

In [None]:
losses = np.loadtxt(fout)
start_epoch = 0
plt.plot(losses[start_epoch:, 0], losses[start_epoch:, 1], label = 'Training loss')
plt.plot(losses[start_epoch:, 0], losses[start_epoch:, 2], label = 'Validation loss')
plt.legend(loc = 'best')