In [None]:
!pip install -r mta_requirements.txt &> /dev/null
!pip install --upgrade scikit-learn &> /dev/null

In [None]:
from modules.utils import *
refset_iri_dir = 'data/SnomedCT/nrc_refset_iri/'
refset_embed_dir = 'data/SnomedCT/refset_embed/'
embed_path = 'data/SnomedCT/concept_embedding.csv'
iri_path = 'data/SnomedCT/concept_iri.csv'
seed_path = 'data/SnomedCT/nrc_seed_random_iri/'

## Define Layer

In [None]:
import torch
from torch import nn

from qpth.qp import QPFunction


class SVDDLayer(nn.Module):
    def __init__(self, shot, dim=1, eps=1e-6):
        super(SVDDLayer, self).__init__()
        self._dim = dim
        self._eps = eps

    def forward(self, inputs):
        shot = inputs.shape[1]

        kernel_matrices = torch.bmm(inputs, inputs.transpose(1, 2))
        kernel_matrices += self._eps * torch.eye(shot)
        kernel_diags = torch.diagonal(kernel_matrices, dim1=-2, dim2=-1)
        Q = 2 * kernel_matrices
        p = -kernel_diags
        A = torch.ones(1, shot)
        b = torch.ones(1)
        G = -torch.eye(shot)
        h = torch.zeros(shot)
        alphas = QPFunction(verbose=-1)(
            Q,
            p,
            G.detach(),
            h.detach(),
            A.detach(),
            b.detach(),
        )

        alphas = alphas.unsqueeze(-1)
        centers = torch.sum(alphas * inputs, dim=self._dim)

        return centers


class CentersDistance(nn.Module):
    def __init__(self, dim=-1):
        super(CentersDistance, self).__init__()
        self._dim = dim

    def forward(self, inputs, centers):
        logits = -torch.sum((centers.unsqueeze(1) - inputs)**2, dim=self._dim)
        return logits


## Define Modle

In [None]:
import torch
from torch import nn


class EmbeddingNet(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 hidden_size: int = 64) -> None:
        super(EmbeddingNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.hidden_size = hidden_size

        self.net = nn.Sequential(
            nn.Linear(200, 200),
        )

    def forward(self, inputs):
        reshaped = inputs.reshape(-1, *inputs.shape[2:])
        embeddings = self.net(reshaped)
        outputs = embeddings.view(*inputs.shape[:2], -1)
        return outputs


class MetaOCCModel(nn.Module):
    def __init__(self, embedding_net, occ_layer):
        super(MetaOCCModel, self).__init__()
        self._embedding_net = embedding_net
        self._net = nn.Sequential(embedding_net, occ_layer)
        self._to_logits = CentersDistance()

    def forward(self, support_inputs, query_inputs):
        centers = self._net(support_inputs)
        query_embeddings = self._embedding_net(query_inputs)
        logits = self._to_logits(query_embeddings, centers)
        return logits

    def infer(self, support_inputs, query_inputs):
        logits = self(support_inputs, query_inputs)
        probs = 1.0 + torch.tanh(logits)
        return probs


## Utils

In [None]:
import numpy as np
import torch
from torch.utils.data.dataloader import default_collate
from sklearn.metrics import roc_auc_score

def evaluate(model, loader, total_episodes, shot, device=None):
    accs = []
    while len(accs) < total_episodes:
        (support_inputs, query_inputs,
          query_labels) = loader.get_train_data(True)

        if device:
            support_inputs = support_inputs.to(device=device)
            query_inputs = query_inputs.to(device=device)
            query_labels = query_labels.to(device=device)

        probs = model.infer(support_inputs, query_inputs)
        preds = (probs >= 0.5).long()
        correct = preds.eq(query_labels)
        batch_accs = torch.mean(correct.float(),
                                dim=1).detach().cpu().numpy()

        episodes_so_far = len(accs)
        if episodes_so_far + len(batch_accs) < total_episodes:
            accs.extend(batch_accs)
        else:
            rem = total_episodes - episodes_so_far
            accs.extend(batch_accs[:rem])
            break

    mean = np.mean(accs)
    std = np.std(accs)
    ci95 = 1.96 * std / np.sqrt(len(accs))

    return mean, ci95

## Other utils

In [None]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import torch
from  sklearn.metrics import ndcg_score, roc_auc_score

def get_refset_name(raw_name):
    name = raw_name.split("Active")[0].split('_')[-1]
    return name

class my_Data_Loader(object):
  def __init__(self, refset_iri_dir, refset_embed_dir, embed_path, iri_path):
    self.rnd = np.random.RandomState(42)
    self.fname_list = os.listdir(refset_iri_dir)
    size = len(self.fname_list)
    shuffled_idx = self.rnd.permutation(range(size))
    self.train_idx = shuffled_idx[:int(size * 0.8)]
    self.test_idx = shuffled_idx[int(size * 0.8):]

    self.refset_embeds = [np.loadtxt(refset_embed_dir + fname) for fname in self.fname_list]
    self.refset_iris = [np.loadtxt(refset_iri_dir + fname, dtype='str') for fname in self.fname_list]
    self.embeds = np.loadtxt(embed_path)
    # self.embeds = self.refset_embeds[0]
    self.iris = np.loadtxt(iri_path, dtype=str)
    self.iri2idx = dict()
    for index, iri in enumerate(self.iris):
      self.iri2idx[iri] = index
    
    print("Loading done!")
  def reset(self, seed=42):
    self.rnd = np.random.RandomState(seed)
  
  def set_test_idx(self, test_idx):
    size = len(self.fname_list)
    assert test_idx < size

    shot = 5

    self.test_idx = test_idx
    self.train_idx = [i for i in range(size)]
    self.train_idx.pop(test_idx)

    fname = self.fname_list[self.test_idx]
    seed_iris = np.loadtxt(seed_path + fname, dtype=str)
    seed_embeds = []
    for item in seed_iris:
      seed_embeds.append(self.embeds[self.iri2idx[item]])
    
    self.support_inputs = torch.tensor([[[item] for item in seed_embeds[:shot]]], dtype=torch.float32) 
    self.query_inputs = torch.tensor([[[item] for item in self.embeds]], dtype=torch.float32)
    
    tmp_list = np.zeros(shape=len(self.embeds))
    for item in self.refset_iris[test_idx]:
      tmp_list[self.iri2idx[item]] = 1
    self.query_outputs = np.array([tmp_list])


  def get_train_data(self, is_val=False):
    batch_size = 8

    support_inputs = []
    query_inputs = []
    query_outputs = []

    batch_idx = self.rnd.permutation(self.train_idx)[:batch_size]
    refset_sizes = [int(len(self.refset_embeds[idx])/2) for idx in batch_idx]
    shot = min(min(refset_sizes), 256)

    for idx in range(batch_size):

      sample_idx = batch_idx[idx]
    
      this_refset_embeds = self.rnd.permutation(self.refset_embeds[sample_idx])

      support_inputs.append(
          [[item] for item in this_refset_embeds[:shot]]
      )

      query_inputs.append(
          [[item] for item in this_refset_embeds[shot:2 * shot]]
      )
      
      tmp_list = []
      for _idx in self.rnd.permutation(len(self.embeds)):
        if self.iris[_idx] in self.refset_iris[sample_idx]:
          continue
        else:
          tmp_list.append([self.embeds[_idx]])
          if len(tmp_list) >= shot:
            break
      query_inputs[-1].extend(tmp_list)
      
      query_outputs.append([1 if _ < shot else 0 for _ in range(2 * shot)])
    
    return torch.tensor(support_inputs,dtype=torch.float32), torch.tensor(query_inputs,dtype=torch.float32), torch.tensor(query_outputs,dtype=torch.long)

  def get_test_data(self):
    return self.support_inputs, self.query_inputs, self.query_outputs

## Train

In [None]:
mdl = my_Data_Loader(refset_iri_dir, refset_embed_dir, embed_path, iri_path)

In [None]:
f = open('./select_nrc.txt', 'a+')
def run(test_idx):
  mdl.set_test_idx(test_idx)
  device = torch.device('cpu')
  mdl.reset(42)

  layer = SVDDLayer(5)


  model = MetaOCCModel(EmbeddingNet(200, 200), layer)
  model.to(device)
  loss = torch.nn.BCEWithLogitsLoss()

  optimizer = torch.optim.Adam(model.parameters(), lr=4e-5, weight_decay=0.01)

  step = 0
  best_lb = 0.
  best_mean = 0.
  faults = 0
  while step <= 300:
    (support_inputs, query_inputs,
      query_labels) = mdl.get_train_data()

    support_inputs = support_inputs.to(device)
    query_inputs = query_inputs.to(device)
    query_labels = query_labels.to(device)

    # print(support_inputs.shape)

    try:
      loss_val = loss(model(support_inputs, query_inputs), query_labels.float())
    except Exception as e:
      continue

    optimizer.zero_grad()
    loss_val.backward()
    optimizer.step()

    if step % 10 == 0:
        pass
        # print(f'Step {step}, loss = {loss_val.item()}')

    if step % 10 == 0:
        model.train(False)
        (support_inputs, query_inputs, query_labels) = mdl.get_test_data()
        

        support_inputs = support_inputs.to(device)
        query_inputs = query_inputs.to(device)

        probs = model.infer(support_inputs, query_inputs).detach().numpy()
        model.train(True)

        y_pred = np.zeros(shape=354256)
        counter = 0.0
        for index in np.argsort(probs[0]):
          y_pred[index] = 1 - counter / 354256
          counter += 1
        probs = [y_pred]

        refset_name = mdl.fname_list[test_idx]
        ndcg = ndcg_score([query_labels[0]], [probs[0]])
        auc = roc_auc_score(query_labels[0], probs[0])

        y_pred = np.zeros(shape=354256)
        counter = 0.0
        for index in np.argsort(probs[0]):
          y_pred[index] = 1 - counter / 354256
          counter += 1
        probs = [y_pred]

        new_ndcg = ndcg_score([query_labels[0]], [probs[0]])
        new_auc = roc_auc_score(query_labels[0], probs[0])

        st = "refset={:35s} training_iter={} meta_occ score: NDCG={:.4f} AUC={:.4f}".format(refset_name, step, new_ndcg, new_auc)
        print(st)
        f.write(st + '\n')
        f.flush()

    step += 1

In [None]:
for test_idx in range(len(os.listdir(refset_iri_dir))):
  run(test_idx)