In [571]:
import torch

In [572]:
genes = [0, 1, 2, 3]
ab = [0, 1, 2, 3]

data = [{'genes': [1, 0, 0, 0], 'res': [0, 0, 0, 0]}, 
        {'genes': [0, 1, 0, 0], 'res': [0, 1, 0, 0]},
        {'genes': [0, 0, 1, 0], 'res': [0, 0, 0, 0]},
        {'genes': [0, 0, 0, 1], 'res': [0, 0, 0, 0]},
        {'genes': [1, 1, 0, 0], 'res': [0, 1, 0, 0]},
        {'genes': [1, 0, 1, 0], 'res': [0, 0, 0, 0]},
        {'genes': [1, 0, 0, 1], 'res': [0, 0, 0, 0]},
        {'genes': [0, 1, 1, 0], 'res': [0, 1, 1, 0]},
        {'genes': [0, 1, 0, 1], 'res': [0, 1, 0, 1]},
        {'genes': [0, 0, 1, 1], 'res': [0, 0, 0, 0]},
        {'genes': [1, 0, 0, 0], 'res': [0, 0, 0, 0]}, 
        {'genes': [0, 0, 1, 0], 'res': [0, 0, 0, 0]},
        {'genes': [0, 0, 0, 1], 'res': [0, 0, 0, 0]},
        ]

In [573]:
class GabDataset(torch.utils.data.Dataset):
        def __init__(self, data):
                self.data = data


        def __len__(self):
                return len(self.data)

        def __getitem__(self, idx):
                d = self.data[idx]
                return torch.IntTensor(d['genes']), torch.LongTensor(d['res'])

In [574]:
class GabModel(torch.nn.Module):
        def __init__(self):
                super().__init__()
                self.genes = torch.IntTensor([0, 1, 2, 3]).to('cuda')
                self.n_known = 2
                self.g_emb = torch.nn.Embedding(5, 32, padding_idx=4)
                self.g2_emb = torch.nn.Embedding(5, 32, padding_idx=4)
                self.exist_emb = torch.nn.Embedding(3, 32, padding_idx=2)
                self.scorer = torch.nn.Linear(32, 1)
                self.layers = torch.nn.Sequential(torch.nn.Linear(32, 64),
                                                  torch.nn.ReLU(),
                                                  torch.nn.Linear(64, 128),
                                                  torch.nn.ReLU(),
                                                  torch.nn.Linear(128, 256),
                                                  torch.nn.ReLU())
                self.head1 = torch.nn.Linear(256, 2)
                self.head2 = torch.nn.Linear(256, 2)
                self.head3 = torch.nn.Linear(256, 2)
                self.head4 = torch.nn.Linear(256, 2)
        
        def forward(self, x):
                b = x.shape[0]
                g = self.genes.repeat(b, 1)
                gene_emb = self.g_emb(self.genes)
                gene_scores = self.scorer(gene_emb)
                gene_scores = torch.nn.functional.softmax(gene_scores, dim=0)
                #print("Gene scores are: {}".format(gene_scores.view(-1)))
                scores, i = torch.topk(gene_scores.view(-1), self.n_known, dim=0)
                out_i = i
                #print("Selected genes are: {}".format(i))
                if self.n_known < 4:
                        random_gene = torch.multinomial(torch.Tensor([0 if g in i else 1 for g in genes]).to('cuda'), 1)
                        #print("Random gene is: {}".format(random_gene))
                        i = torch.cat([i, random_gene])
                        g_comp = torch.Tensor([1 if g in i else 0 for g in genes]).to('cuda')
                        scores = torch.cat([scores, gene_scores.view(-1)[random_gene]])
                        scores_all = gene_scores.view(-1) #* g_comp
                sel_genes = g[:,i]
                sel_known = x[:,i]
                sel_g_emb = self.g2_emb(sel_genes)
                sel_exist_emb = self.exist_emb(sel_known)
                
                sel_emb = sel_g_emb + sel_exist_emb
                
                out_emb = self.layers(sel_emb)
                
                out_mean = torch.mean(out_emb, dim=1)
                
                
                pred_1 = self.head1(out_mean)
                pred_2 = self.head2(out_mean)
                pred_3 = self.head3(out_mean)
                pred_4 = self.head4(out_mean)
                
                return pred_1, pred_2, pred_3, pred_4, scores, out_i, scores_all

In [575]:
def compute_loss(pred_1, pred_2, pred_3, pred_4, labels_1, labels_2, labels_3, labels_4, scores):
        loss_fn = torch.nn.CrossEntropyLoss()
        loss_1 = loss_fn(pred_1, labels_1)
        loss_2 = loss_fn(pred_2, labels_2)
        loss_3 = loss_fn(pred_3, labels_3)
        loss_4 = loss_fn(pred_4, labels_4)
        
        total_pred_loss = (loss_1 + loss_2 + loss_3 + loss_4).mean()
        print("Mean CE loss: {}".format(total_pred_loss))
        total_loss = total_pred_loss ** scores.sum()
        #total_loss = total_pred_loss
        return total_loss

In [576]:
dataset = GabDataset(data)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)

model = GabModel()
model.to('cuda')

optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=1e-3,
                                  weight_decay=1e-4)

epochs = 10
for e in range(epochs):
        n_correct_ab_1 = 0
        n_ab_1 = 0
        n_correct_ab_2 = 0
        n_ab_2 = 0
        n_correct_ab_3 = 0
        n_ab_3 = 0
        n_correct_ab_4 = 0
        n_ab_4 = 0
        for d, l in data_loader:
                d = d.to('cuda')
                l = d.to('cuda').long()
                labels_1 = l[:,0]
                labels_2 = l[:,1]
                labels_3 = l[:,2]
                labels_4 = l[:,3]
                pred_1, pred_2, pred_3, pred_4, scores, out_i, scores_all = model(d)
                loss = compute_loss(pred_1, pred_2, pred_3, pred_4, labels_1, labels_2, labels_3, labels_4, scores)   
                print("Loss is: {}".format(loss))
                print("Selected genes are: {}".format(out_i))
                print("Selected scores are: {}".format(scores))
                print("All scores are: {}".format(scores_all))
                hard_pred_1 = torch.argmax(pred_1, dim=1)
                hard_pred_2 = torch.argmax(pred_2, dim=1)
                hard_pred_3 = torch.argmax(pred_3, dim=1)
                hard_pred_4 = torch.argmax(pred_4, dim=1)
                
                n_correct_ab_1 += (hard_pred_1 == labels_1).sum()
                n_ab_1 += 1
                n_correct_ab_2 += (hard_pred_2 == labels_2).sum()
                n_ab_2 += 1
                n_correct_ab_3 += (hard_pred_3 == labels_3).sum()
                n_ab_3 += 1
                n_correct_ab_4 += (hard_pred_4 == labels_4).sum()
                n_ab_4 += 1
                
                
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
        acc_ab_1 = n_correct_ab_1 / n_ab_1
        acc_ab_2 = n_correct_ab_2 / n_ab_2
        acc_ab_3 = n_correct_ab_3 / n_ab_3
        acc_ab_4 = n_correct_ab_4 / n_ab_4
        print("--- Epoch {} --- Accuracy AB 1: {} ---".format(e, acc_ab_1))
        print("--- Epoch {} --- Accuracy AB 2: {} ---".format(e, acc_ab_2))
        print("--- Epoch {} --- Accuracy AB 3: {} ---".format(e, acc_ab_3))
        print("--- Epoch {} --- Accuracy AB 4: {} ---".format(e, acc_ab_4))

Mean CE loss: 2.704684019088745
Loss is: 2.26690673828125
Selected genes are: tensor([2, 0], device='cuda:0')
Selected scores are: tensor([0.4596, 0.1994, 0.1636], device='cuda:0', grad_fn=<CatBackward0>)
All scores are: tensor([0.1994, 0.1636, 0.4596, 0.1775], device='cuda:0',
       grad_fn=<ViewBackward0>)
Mean CE loss: 2.8765618801116943
Loss is: 2.3731086254119873
Selected genes are: tensor([2, 0], device='cuda:0')
Selected scores are: tensor([0.4534, 0.1999, 0.1646], device='cuda:0', grad_fn=<CatBackward0>)
All scores are: tensor([0.1999, 0.1646, 0.4534, 0.1821], device='cuda:0',
       grad_fn=<ViewBackward0>)
Mean CE loss: 2.8958730697631836
Loss is: 2.4284942150115967
Selected genes are: tensor([2, 0], device='cuda:0')
Selected scores are: tensor([0.4473, 0.2004, 0.1868], device='cuda:0', grad_fn=<CatBackward0>)
All scores are: tensor([0.2004, 0.1655, 0.4473, 0.1868], device='cuda:0',
       grad_fn=<ViewBackward0>)
Mean CE loss: 2.7263598442077637
Loss is: 2.301819086074829
S