In [None]:
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
root = '/home/albertdm99/Uni/TFG/neural-admixture'
data_path = f'{root}/data/all_chm_combined_snps_world_2M_with_labels.npz'

In [None]:
class _L0Norm(nn.Module):
    def __init__(self, origin, loc_mean=0, loc_sdev=0.01, beta=2 / 3, gamma=-0.1,
                 zeta=1.1, fix_temp=True):
        """
        Base class of layers using L0 Norm
        :param origin: original layer such as nn.Linear(..), nn.Conv2d(..)
        :param loc_mean: mean of the normal distribution which generates initial location parameters
        :param loc_sdev: standard deviation of the normal distribution which generates initial location parameters
        :param beta: initial temperature parameter
        :param gamma: lower bound of "stretched" s
        :param zeta: upper bound of "stretched" s
        :param fix_temp: True if temperature is fixed
        """
        super(_L0Norm, self).__init__()
        self._origin = origin
        self._size = self._origin.weight.size()
        self.loc = nn.Parameter(torch.zeros(self._size).normal_(loc_mean, loc_sdev))
        self.temp = beta if fix_temp else nn.Parameter(torch.zeros(1).fill_(beta))
        self.register_buffer("uniform", torch.zeros(self._size))
        self.gamma = gamma
        self.zeta = zeta
        self.gamma_zeta_ratio = math.log(-gamma / zeta)
        self.sigmoid = nn.Sigmoid()
    
    def _hard_sigmoid(self, x):
        return torch.min(torch.max(x, torch.zeros_like(x)), torch.ones_like(x))

    def _get_mask(self):
        if self.training:
            self.uniform.uniform_()
            u = Variable(self.uniform)
            s = self.sigmoid((torch.log(u) - torch.log(1 - u) + self.loc) / self.temp)
            s = s * (self.zeta - self.gamma) + self.gamma
            penalty = self.sigmoid(self.loc - self.temp * self.gamma_zeta_ratio).sum()
        else:
            s = self.sigmoid(self.loc) * (self.zeta - self.gamma) + self.gamma
            penalty = 0
        return self._hard_sigmoid(s), penalty


class L0Linear(_L0Norm):
    def __init__(self, in_features, out_features, bias=True, **kwargs):
        super(L0Linear, self).__init__(nn.Linear(in_features, out_features, bias=bias), **kwargs)

    def forward(self, input):
        mask, penalty = self._get_mask()
        return F.linear(input, self._origin.weight * mask, self._origin.bias), penalty

In [None]:
class ConstrainedLinear(torch.nn.Module):
    def __init__ (self, input_size, output_size, bias=True): 
        super().__init__() 
        self.W = nn.Parameter(torch.zeros(input_size, output_size)) 
        self.W = nn.init.kaiming_normal_(self.W)
        self.bias = bias
        if self.bias:
            self.b = nn.Parameter(torch.ones(output_size)) 

    def forward(self, x):
        if self.bias:
            return torch.addmm(self.b, x, torch.sigmoid(self.W))
        return torch.mm(x, torch.sigmoid(self.W)) 

In [None]:
class AdmixtureAE(torch.nn.Module):
    def __init__(self, k, num_features, beta_l0=2/3, gamma_l0=-0.1, zeta_l0=1.1, lambda_l0=0.1):
        super().__init__()
        self.k = k
        self.num_features = num_features
        self.beta_l0, self.gamma_l0, self.zeta_l0 = beta_l0, gamma_l0, zeta_l0
        self.lambda_l0 = lambda_l0
        self.encoder = L0Linear(self.num_features, self.k, bias=False, beta=self.beta_l0, gamma=self.gamma_l0, zeta=self.zeta_l0)
        self.decoder = ConstrainedLinear(self.k, num_features, bias=False)
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, X):
        enc, l0_pen = self.encoder(X)
        hid_state = self.softmax(enc)
        reconstruction = self.decoder(hid_state)
        return reconstruction, hid_state, l0_pen/X.shape[0]
        
    def train(self, trX, optimizer, loss_f, num_epochs, device, batch_size=0, valX=None, display_logs=True):
        for ep in range(num_epochs):
            if display_logs:
                print(f'------------- EPOCH {ep+1} -------------')
            tr_loss, val_loss = self._run_epoch(trX, optimizer, loss_f, batch_size, valX, device)
            if display_logs:
                print(f'Mean training loss: {tr_loss}')
                if val_loss is not None:
                    print(f'Mean validation loss: {val_loss}')
        return tr_loss, val_loss

    def _batch_generator(self, X, batch_size=0):
        if batch_size < 1:
            yield torch.tensor(X, dtype=torch.float32)
        else:
            for i in range(0, X.shape[0], batch_size):
                yield torch.tensor(X[i:i+batch_size], dtype=torch.float32)

    def _validate(self, valX, loss_f, batch_size, device):
        acum_val_loss = 0
        for X in self._batch_generator(valX, batch_size):
            rec, _, _ = self.forward(X.to(device))
            acum_val_loss += loss_f(rec, X).cpu().item()
        return acum_val_loss

        
    def _run_step(self, X, optimizer, loss_f):
        self.zero_grad()
        rec, _, l0_pen = self.forward(X)
        loss = loss_f(rec, X)+self.lambda_l0*l0_pen
        loss.backward()
        optimizer.step()
        return loss

    def _run_epoch(self, trX, optimizer, loss_f, batch_size, valX, device):
        tr_loss, val_loss = 0, None
        for X in self._batch_generator(trX, batch_size):
            step_loss = self._run_step(X.to(device), optimizer, loss_f)
            tr_loss += step_loss.cpu().item()
        if valX is not None:
            val_loss = self._validate(valX, loss_f, device)
            return tr_loss / trX.shape[0], val_loss / valX.shape[0]
        return tr_loss / trX.shape[0], None


## Data

In [None]:
npzfile = np.load(data_path, allow_pickle=True)
snps = npzfile['snps']
del npzfile

## Training

In [None]:
model = AdmixtureAE(k=8, num_features=snps.shape[1], lambda_l0=0.01)

In [None]:
learning_rate = 0.1
optimizer = torch.optim.Adam(model.parameters(), learning_rate)
loss_f = nn.MSELoss(reduction='sum')
num_epochs = 10

In [None]:
model.train(snps, optimizer, loss_f, num_epochs, batch_size=32)