In [None]:
from google.colab import drive
drive.mount('/content/hgg')

Mounted at /content/hgg


In [None]:
!mkdir data
!cp hgg/MyDrive/WQE/data/*.npy ./data/

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import train_test_split
import random
from sklearn.metrics import roc_auc_score

### Training data

In [None]:
train_cancers = ["BLCA", "BRCA", "COAD", "ESCA", "HNSC", "KIRC", "KIRP", "LIHC", "LUAD", "LUSC", "PRAD", "THCA", "UCEC"]
isFirst = True
for i in train_cancers:
    data = np.load("data/"+i+"_m.npy")
    if isFirst:
        all_data = data
        isFirst = False
    else:
        all_data = np.concatenate((all_data, data), axis=1)
all_x = np.transpose(all_data) #(6203, 4861)

In [None]:
all_y = np.load("data/onehot_labels.npy")
y_tissue = np.array([np.where(i==1)[0][0] for i in all_y[:,:-2]])
y_type = np.array([np.where(i==1)[0][0] for i in all_y[:,-2:]])
all_y = np.array([(y_tissue[i], y_type[i]) for i in range(len(all_x))])

In [None]:
train_x, validate_x, train_y, validate_y = train_test_split(all_x, all_y, test_size=0.25, random_state=8) #(4652, 4861) (1551, 4861)

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
        
        assert(len(X) == len(Y))
        self.length = len(X)
        
    def __len__(self):
        return self.length
    
    def __getitem__(self, i):
        return self.X[i], self.Y[i]
    
    def collate_fn(batch):
        batch_x = [x for x,y in batch]
        batch_y = [y for x,y in batch]
        
        batch_x = torch.as_tensor(batch_x)
        batch_y = torch.as_tensor(batch_y)
        
        return batch_x, batch_y

In [None]:
train_data = Dataset(train_x, train_y)
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=8, shuffle=True)

for i, (data, target) in enumerate(train_dataloader):
    print("Batch", i, ":\n", data.shape, "\n", target.shape)
    print(target)
    break

In [None]:
validate_data = Dataset(validate_x, validate_y)
validate_dataloader = torch.utils.data.DataLoader(validate_data, batch_size=8, shuffle=False)

for i, (data, target) in enumerate(validate_dataloader):
    print("Batch", i, ":\n", data.shape, "\n", target.shape)
    print(target)
    print(data)
    break

### DNN pretraining

In [None]:
class MLP4(nn.Module):
    def __init__(self, feat_in = 4861, tissue_num = 13, type_num = 2, hidden = [256, 256, 128]):
        super().__init__()

        self.dropout = nn.Dropout(0.2)

        self.fc1 = nn.Linear(feat_in, hidden[0])
        self.fc2 = nn.Linear(hidden[0], hidden[1])
        self.fc3 = nn.Linear(hidden[1], hidden[2])
        self.fc4_tissue = nn.Linear(hidden[2], tissue_num)
        self.fc4_type = nn.Linear(hidden[2], type_num)

        self.actv1 = nn.Sequential(nn.BatchNorm1d(hidden[0]), nn.ReLU(), nn.Dropout(0.5))
        self.actv2 = nn.Sequential(nn.BatchNorm1d(hidden[1]), nn.ReLU(), nn.Dropout(0.5))
        self.actv3 = nn.Sequential(nn.BatchNorm1d(hidden[2]), nn.ReLU(), nn.Dropout(0.5))
        self.actv4 = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.dropout(x)
        x = self.actv1(self.fc1(x))
        x = self.actv2(self.fc2(x))
        x = self.actv3(self.fc3(x))
        x_tissue = self.actv4(self.fc4_tissue(x))
        x_type = self.actv4(self.fc4_type(x))
        
        return x_tissue, x_type

In [None]:
def test_model_final(model, test_loader, criterion):
    with torch.no_grad():
        model.eval()

        running_loss_tissue = 0.0
        correct, total = 0, 0
        isFirst = True

        for i, data in enumerate(test_loader):
            inputs, labels = data
            tissue_labels = torch.tensor([label[0] for label in labels])
            type_labels = torch.tensor([label[1] for label in labels])

            inputs = inputs.float().to("cuda:0")
            tissue_labels = tissue_labels.to("cuda:0")
            type_labels = type_labels.to("cuda:0")

            outputs = model(inputs)
            loss1 = criterion(outputs[0], tissue_labels)
            loss2 = criterion(outputs[1], type_labels)
            loss = loss1 + loss2
            running_loss_tissue += loss

            total += labels.size(0)
            _, tissue_pred = torch.max(outputs[0].data, 1)
            _, type_pred = torch.max(outputs[1].data, 1)
            tissue_acc = tissue_pred.eq(tissue_labels.data).cpu()
            type_acc = type_pred.eq(type_labels.data).cpu()
            all_acc = torch.tensor([tissue_acc[i]*type_acc[i] for i in range(len(tissue_acc))]).sum().item()
            correct += all_acc

            pred = torch.tensor([tissue_pred[i]*(type_pred[i]+1) for i in range(len(tissue_acc))])
            true = torch.tensor([tissue_labels[i]*(type_labels[i]+1) for i in range(len(tissue_acc))])

            if isFirst:
                pred_all = pred
                true_all = true
                tissue_labels_all = tissue_labels
                type_labels_all = type_labels
                tissue_pred_all = tissue_pred
                type_pred_all = type_pred

                isFirst = False
            else:
                pred_all = torch.cat((pred_all, pred), 0)
                true_all = torch.cat((true_all, true), 0)
                tissue_labels_all = torch.cat((tissue_labels_all, tissue_labels), 0)
                type_labels_all = torch.cat((type_labels_all, type_labels), 0)
                tissue_pred_all = torch.cat((tissue_pred_all, tissue_pred), 0)
                type_pred_all = torch.cat((type_pred_all, type_pred), 0)


        p_r_f = precision_recall_fscore_support(true_all, pred_all, average=None, labels=np.unique(pred_all))[:3]
        print(tissue_labels_all, type_labels_all)
        print(tissue_pred_all, type_pred_all)
        print(np.unique(pred_all))
        print(p_r_f)

        running_loss_tissue /= len(test_loader)
        print('Validating Loss (tissue):', running_loss_tissue, 'Acc:', correct/total)
        return running_loss_tissue

In [None]:
device = torch.device('cuda')
net = MLP4()
net = net.to(device).float()
criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(net.parameters(), lr=0.001)
optimizer = optim.Adam(net.parameters(), lr=0.001)

for epoch in range(5000):  # loop over the dataset multiple times
    net.train()
    running_loss, correct, total = 0, 0, 0
    for i, data in enumerate(train_dataloader):
        inputs, labels = data
        tissue_labels = torch.tensor([label[0] for label in labels])
        type_labels = torch.tensor([label[1] for label in labels])

        inputs = inputs.float().to("cuda:0")
        tissue_labels = tissue_labels.to("cuda:0")
        type_labels = type_labels.to("cuda:0")

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        
        loss1 = criterion(outputs[0], tissue_labels)
        loss2 = criterion(outputs[1], type_labels)
        loss = loss1 + loss2

        total += labels.size(0)
        _, tissue_pred = torch.max(outputs[0].data, 1)
        _, type_pred = torch.max(outputs[1].data, 1)
        tissue_acc = tissue_pred.eq(tissue_labels.data).cpu()
        type_acc = type_pred.eq(type_labels.data).cpu()
        all_acc = torch.tensor([tissue_acc[i]*type_acc[i] for i in range(len(tissue_acc))]).sum().item()
        correct += all_acc

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % len(train_dataloader) == len(train_dataloader)-1:
            print('[%d, %5d] tissue loss: %.3f' %
                (epoch + 1, i + 1, running_loss / len(train_dataloader)), 'Acc:', correct/total)
            running_loss = 0.0

    if epoch % 20 == 0:
        test_model_final(net, validate_dataloader, nn.CrossEntropyLoss())
    if epoch % 100 == 0:
        print("Save model...")
        torch.save(net, "./hgg/MyDrive/model_3.pt")

print('Finished Training')

### Data

In [None]:
#test_y = [i[1] for i in test_y]
support_x, query_x, support_y, query_y = train_test_split(test_x, test_y, test_size=0.5, random_state=6)

In [None]:
query_data = Dataset(query_x, query_y)
query_dataloader = torch.utils.data.DataLoader(query_data, batch_size=100, shuffle=False)

for i, (data, target) in enumerate(query_dataloader):
    print("Batch", i, ":\n", data.shape, "\n", target.shape)
    print(target)
    print(data)
    break

In [None]:
support_data = Dataset(support_x, support_y)
support_dataloader = torch.utils.data.DataLoader(support_data, batch_size=100, shuffle=False)

for i, (data, target) in enumerate(support_dataloader):
    print("Batch", i, ":\n", data.shape, "\n", target.shape)
    print(target)
    print(data)
    break

### Few shot learning

In [None]:
# Load pre-trained model
net = torch.load("./hgg/MyDrive/model3.pt")
state = net.state_dict()

In [None]:
class MLP5(nn.Module):
    def __init__(self, feat_in = 4861, tissue_num = 13, type_num = 2, hidden = [256, 256, 128]):
        super().__init__()

        self.dropout = nn.Dropout(0.2)

        self.fc1 = nn.Linear(feat_in, hidden[0])
        self.fc2 = nn.Linear(hidden[0], hidden[1])
        self.fc3 = nn.Linear(hidden[1], hidden[2])
        self.fc4_tissue = nn.Linear(hidden[2], tissue_num)
        self.fc4_type = nn.Linear(hidden[2], type_num)

        self.actv1 = nn.Sequential(nn.BatchNorm1d(hidden[0]), nn.ReLU(), nn.Dropout(0.5))
        self.actv2 = nn.Sequential(nn.BatchNorm1d(hidden[1]), nn.ReLU(), nn.Dropout(0.5))
        self.actv3 = nn.Sequential(nn.BatchNorm1d(hidden[2]), nn.ReLU(), nn.Dropout(0.5))
        self.actv4 = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.dropout(x)
        x = self.actv1(self.fc1(x))
        x = self.actv2(self.fc2(x))
        x = self.actv3(self.fc3(x))
        x_tissue = self.actv4(self.fc4_tissue(x))
        x_type = self.actv4(self.fc4_type(x))

        return x_tissue, x_type

    def features(self, x):
        x = self.dropout(x)
        x = self.actv1(self.fc1(x))
        x = self.actv2(self.fc2(x))
        x = self.actv3(self.fc3(x))
        
        return x

In [None]:
def test_model(model, test_loader, criterion, z_proto_tissue, z_proto_type, best_acc):
    with torch.no_grad():
        model.eval()

        running_loss_tissue = 0.0
        correct, total = 0, 0
        isFirst = True

        for i, data in enumerate(test_loader):
            inputs, labels = data
            tissue_labels = torch.tensor([label[0] for label in labels]).to(device)
            type_labels = torch.tensor([label[1] for label in labels]).to(device)
            inputs = inputs.float().to(device)

            z_query = model.features(inputs)

            tissue_dists = torch.cdist(z_query, z_proto_tissue)
            type_dists = torch.cdist(z_query, z_proto_type)
            tissue_scores = -tissue_dists
            type_scores = -type_dists

            loss1 = criterion(tissue_scores, tissue_labels)
            loss2 = criterion(type_scores, type_labels)
            loss = loss1 + loss2
            running_loss_tissue += loss.item()

            total += labels.size(0)
            _, tissue_pred = torch.max(tissue_scores.data, 1)
            _, type_pred = torch.max(type_scores.data, 1)
            tissue_acc = tissue_pred.eq(tissue_labels.data).cpu()
            type_acc = type_pred.eq(type_labels.data).cpu()
            all_acc = torch.tensor([tissue_acc[i]*type_acc[i] for i in range(len(tissue_acc))]).sum().item()
            correct += all_acc

            pred = torch.tensor([tissue_pred[i]*10+type_pred[i] for i in range(len(tissue_acc))])
            true = torch.tensor([tissue_labels[i]*10+type_labels[i] for i in range(len(tissue_acc))])

            if isFirst:
                pred_all = pred
                true_all = true
                tissue_labels_all = tissue_labels
                type_labels_all = type_labels
                tissue_pred_all = tissue_pred
                type_pred_all = type_pred

                isFirst = False
            else:
                pred_all = torch.cat((pred_all, pred), 0)
                true_all = torch.cat((true_all, true), 0)
                tissue_labels_all = torch.cat((tissue_labels_all, tissue_labels), 0)
                type_labels_all = torch.cat((type_labels_all, type_labels), 0)
                tissue_pred_all = torch.cat((tissue_pred_all, tissue_pred), 0)
                type_pred_all = torch.cat((type_pred_all, type_pred), 0)


        p_r_f = precision_recall_fscore_support(true_all, pred_all, average=None, labels=np.unique(pred_all))[:3]
        # print(tissue_labels_all, type_labels_all)
        # print(tissue_pred_all, type_pred_all)
        print(np.unique(pred_all))
        print(np.array(p_r_f))

        running_loss_tissue /= len(test_loader)
        print('-------------------------Validating Loss:', running_loss_tissue, 'Acc:', correct/total)

        if correct/total > best_acc:
            best_acc = correct/total
            print("---------------------------------------------------------------Save model...")
            torch.save(few_shot_net, "./hgg/MyDrive/model_fewshot.pt")
            print(tissue_labels_all, type_labels_all)
            print(tissue_pred_all, type_pred_all)
        
        return best_acc

In [None]:
device = torch.device('cuda')
few_shot_net = MLP5()
few_shot_net.load_state_dict(state)
few_shot_net = few_shot_net.to(device).float()
criterion = nn.CrossEntropyLoss() 
#optimizer = optim.SGD(few_shot_net.parameters(), lr=0.001)
optimizer = optim.Adam(few_shot_net.parameters(), lr=0.00001)

best_acc = 0
support_dataloader, query_dataloader = train_dataloader, validate_dataloader

for epoch in range(200):  # loop over the dataset multiple times

    few_shot_net.train()
    running_loss, correct, total = 0, 0, 0
    for i, data in enumerate(support_dataloader):
        inputs, labels = data
        tissue_labels = torch.tensor([label[0] for label in labels]).to(device)
        type_labels = torch.tensor([label[1] for label in labels]).to(device)
        inputs = inputs.float().to(device)
        # support_labels = support_labels.to("cuda:0")

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        z_support = few_shot_net.features(inputs)

        n_way_tissue = len(torch.unique(tissue_labels))
        n_way_type = len(torch.unique(type_labels))

        z_proto_tissue = torch.cat([z_support[torch.nonzero(tissue_labels == label)].mean(0) for label in torch.unique(tissue_labels)])
        z_proto_type = torch.cat([z_support[torch.nonzero(type_labels == label)].mean(0) for label in torch.unique(type_labels)])

        tissue_dists = torch.cdist(z_support, z_proto_tissue)
        type_dists = torch.cdist(z_support, z_proto_type)
        # tissue_scores = -tissue_dists
        # type_scores = -type_dists
        tissue_scores = (-tissue_dists).softmax(dim=1)
        type_scores = (-type_dists).softmax(dim=1)

        loss1 = criterion(tissue_scores, tissue_labels)
        loss2 = criterion(type_scores, type_labels)
        loss = loss1 + loss2
        
        total += tissue_labels.size(0)
        _, tissue_pred = torch.max(tissue_scores.data, 1)
        _, type_pred = torch.max(type_scores.data, 1)
        tissue_acc = tissue_pred.eq(tissue_labels.data).cpu()
        type_acc = type_pred.eq(type_labels.data).cpu()
        all_acc = torch.tensor([tissue_acc[i]*type_acc[i] for i in range(len(tissue_acc))]).sum().item()
        correct += all_acc

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % len(support_dataloader) == len(support_dataloader)-1:
            print('[%d, %5d] training loss: %.3f'  %
                (epoch + 1, i + 1, running_loss / len(support_dataloader)), 'Acc:', correct/total)
            running_loss = 0.0


    if epoch % 3 == 0:
        best_acc = test_model(few_shot_net, query_dataloader, nn.CrossEntropyLoss(), z_proto_tissue, z_proto_type, best_acc)
    # if correct/total > best_acc:
    #     best_acc = correct/total
    #     print("Save model...")
    #     torch.save(few_shot_net, "./hgg/MyDrive/model_fewshot.pt")

print('Finished Training')