In [None]:
import os, sys
project_dir = os.path.join(os.getcwd(),'../..')
if project_dir not in sys.path:
    sys.path.append(project_dir)

hyspeclab_dir = os.path.join(project_dir, 'HySpecLab')
if hyspeclab_dir not in sys.path:
    sys.path.append(hyspeclab_dir)

ipdl_dir = os.path.join(project_dir, 'modules/IPDL')
if ipdl_dir not in sys.path:
    sys.path.append(ipdl_dir)

sparse_dir = os.path.join(project_dir, 'modules/Sparse')
if sparse_dir not in sys.path:
    sys.path.append(sparse_dir)

import config
import numpy as np
from torch import nn
from matplotlib import pyplot as plt 

In [None]:
from HySpecLab.dataset import DermaDataset

train_dir = ['train', 'validation']
dataset_dir = list(map(lambda x: os.path.join(config.DERMA_DATASET_DIR, x), train_dir))

from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
import torch

class DermaDatasetTorch(Dataset):
    def __init__(self, dataset_dir):
        super(DermaDatasetTorch, self).__init__()
        dataset = DermaDataset(dataset_dir)
        self.x, self.y = dataset.get(dataframe=False)
        
    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return torch.tensor(self.x[idx], dtype=torch.float32), torch.tensor(self.y[idx])


dataset = DermaDatasetTorch(dataset_dir)

from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=128, shuffle=True)

In [None]:
from Sparse import KWinners

model = nn.Sequential(*[
    nn.Conv1d(116, 116, 1, bias=False),
    nn.Flatten(start_dim=1),
    KWinners(116, 25), # K Most relevant bands
    nn.BatchNorm1d(116),
    nn.Linear(116, 64),
    nn.ReLU(inplace=True),
    nn.BatchNorm1d(64),
    nn.Linear(64, 24),
    nn.ReLU(inplace=True),
    nn.BatchNorm1d(24),
    nn.Linear(24, 2),
    nn.Softmax(dim=1)
])

for l in model.modules():
    if isinstance(l, nn.Conv1d):
        nn.init.kaiming_normal_(l.weight, mode='fan_out', nonlinearity='relu')

In [None]:
from tqdm import tqdm
from torch.nn.functional import one_hot

n_epoch = 20
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
criterion = nn.BCELoss()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = model.to(device)

epoch_iterator = tqdm(
        range(n_epoch),
        leave=True,
        unit="epoch",
        postfix={"tls": "%.4f" % 1},
    )
model.train()
loss_value = []
for epoch in epoch_iterator:
    for input, target in loader:
        input = input.to(device)
        target = one_hot(target).float().to(device)

        out = model(input[:,:,None])
        loss = criterion(out, target)
        loss_value.append(loss.item())

        # epoch_iterator.set_postfix(tls="%.4f" % np.mean(loss.detach().item()))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    epoch_iterator.set_postfix(tls="%.4f" % np.mean(loss_value))

In [None]:
model.eval()
kwinner_out = model[0:3](input[:,:,None])[0].detach().cpu()

selected = input.clone()[0].detach().cpu()
selected[kwinner_out == 0] = 0

fig, ax = plt.subplots(1,2, figsize=(16,6))
fig.suptitle('Selected features')
ax[0].plot(selected)
ax[0].set_title('Original values')
ax[1].plot(kwinner_out.detach().cpu())
ax[1].set_title('Layer output')
plt.show()
