In [1]:
import os
import pandas as pd
import glob
import numpy as np
from sklearn.preprocessing import StandardScaler
import torch
import gc
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

from misl_dataset import train_valid_test_split

Getting the Stage Information

In [5]:
patient_list =  os.listdir('C:/Users/RaghulKrish/Desktop/UB/Spring_23/CVIP/Project/extracted_features/')
fixed_patient_list = [x[:-8] for x in patient_list]
fixed_patient_list = [p.lower() for p in fixed_patient_list]
total_stages = pd.read_csv('C:/Users/RaghulKrish/Desktop/UB/Spring_23/CVIP/Project/BRCA_stages.csv', delimiter=",")
mask = total_stages['pid'].isin(fixed_patient_list)
total_stages = total_stages.drop(total_stages[~mask].index)

# trimmable_patients = total_stages.loc[total_stages['pid'] not in fixed_patient_list].index
# total_stages.drop(index=trimmable_patients, inplace=True)
# trimmable_patients = total_stages.loc[total_stages['pid'] not in patient_list].index
# total_stages.drop(index=trimmable_patients.values, inplace=True)

total_stages['stage'].describe()

count           119
unique           10
top       stage iia
freq             38
Name: stage, dtype: object

In [6]:
for i in fixed_patient_list:
    if i not in total_stages['pid'].values:
        print(i)

tcga-bh-a0b2
tcga-d8-a13z


Reading the Clustered Data file

In [178]:
df = pd.read_csv('C:/Users/RaghulKrish/Desktop/UB/Spring_23/CVIP/Project/clustered_data.csv')

In [179]:
# df = df.drop(df.columns[0], axis=1)
df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,992,993,994,995,996,997,998,999,cluster_num,pid
0,-1.461104,-0.285967,-1.397519,0.229647,0.936981,3.368677,3.793468,-0.550289,1.1074,-2.729292,...,0.022788,0.624529,0.861521,-0.310676,1.878729,-0.246086,-0.164,2.957375,7,tcga-a1-a0sb
1,-0.799152,0.231676,-0.894569,0.052585,0.519808,1.227673,2.117962,-0.776678,-0.071368,-1.076421,...,-0.977934,-1.083503,-0.85344,-0.938821,1.287605,-0.965848,-0.061129,1.716388,3,tcga-a1-a0sb
2,-1.118939,0.161997,-1.180829,-0.175141,0.825578,2.004492,2.922759,-0.73427,0.82844,-1.833263,...,-0.74132,-1.726206,-0.900631,-1.783865,-0.284666,-0.169607,0.134723,3.002424,1,tcga-a1-a0sb
3,-0.516077,0.16615,-1.896911,-0.896396,-0.133664,1.366326,1.961223,-0.625198,0.875285,-1.964036,...,-0.618743,-2.314713,-0.857852,-1.888762,-0.840151,-0.670046,0.701192,2.627288,1,tcga-a1-a0sb
4,-2.132791,-0.133278,-0.793486,0.366926,1.200107,1.700159,2.119426,-1.155944,-0.424008,-1.007438,...,-1.015913,0.863694,-0.787496,0.022276,2.507776,-1.241899,-0.154245,1.791794,3,tcga-a1-a0sb


Preprocessing

In [180]:
scaler = StandardScaler()
pid = df['pid']
cluster_num = df['cluster_num']
df = df.drop(['pid', 'cluster_num'], axis=1)
scaler.fit(df)
n_df = scaler.transform(df)
df = pd.DataFrame(n_df, columns=df.columns)

In [181]:
df['pid'] = pid
df['cluster_num'] = cluster_num
df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,992,993,994,995,996,997,998,999,pid,cluster_num
0,-0.71771,-1.581317,-1.057331,-0.638967,-0.404404,0.659463,0.191061,0.352148,0.904366,-1.410457,...,0.061712,1.267137,0.157217,0.095518,0.922479,-0.158488,-0.994116,-0.224805,tcga-a1-a0sb,7
1,-0.018819,-0.940304,-0.713356,-0.770011,-0.699232,-1.1087,-0.866002,0.053147,-0.388176,0.808984,...,-1.088802,-0.059723,-1.049583,-0.500579,0.513526,-0.85395,-0.892466,-1.296506,tcga-a1-a0sb,3
2,-0.356451,-1.02659,-0.909133,-0.938551,-0.483136,-0.467159,-0.358262,0.109157,0.598482,-0.207288,...,-0.816771,-0.558997,-1.082791,-1.302507,-0.574204,-0.084591,-0.69894,-0.185901,tcga-a1-a0sb,1
3,0.280051,-1.021447,-1.398873,-1.472354,-1.16106,-0.994193,-0.964888,0.253212,0.649848,-0.382888,...,-0.675846,-1.01617,-1.052688,-1.402053,-0.9585,-0.568134,-0.139195,-0.509864,tcga-a1-a0sb,1
4,-1.426878,-1.392238,-0.644224,-0.537366,-0.218445,-0.718494,-0.865079,-0.447765,-0.774852,0.901613,...,-1.132467,1.452929,-1.003179,0.411483,1.357666,-1.120681,-0.984477,-1.231387,tcga-a1-a0sb,3


In [182]:
device = torch.device('cuda')
gc.collect()
torch.cuda.empty_cache()

In [190]:
num_clusters = 10
# for patient in fixed_patient_list:
mask = (df['pid'] == fixed_patient_list[1]) | (df['pid'] == fixed_patient_list[8])
patient_data = df.loc[mask]

patient_cluster = []
for i in range(num_clusters):
    m = patient_data['cluster_num'] == i
    data = (patient_data.loc[m]).drop(['pid', 'cluster_num'], axis=1)
    # data = torch.tensor(data.values, dtype=torch.float32)
    patient_cluster.append(data)

In [191]:
import torch.nn as nn
import torch

class DeepAttnMIL_Surv(nn.Module):
    """
    Deep AttnMISL Model definition
    """
    def __init__(self, cluster_num):
        super(DeepAttnMIL_Surv, self).__init__()
        self.embedding_net = nn.Sequential(nn.Conv1d(1000, 64, 1),
                                     nn.ReLU(),
                                     nn.AdaptiveAvgPool1d(1)
                                     )

        self.attention = nn.Sequential(
            nn.Linear(64, 32), # V
            nn.Tanh(),
            nn.Linear(32, 1)  # W
        )

        self.fc6 = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(32,10)
        )
        self.cluster_num = cluster_num

    def masked_softmax(self, x, mask=None):
        """
        Performs masked softmax, as simply masking post-softmax can be
        inaccurate
        :param x: [batch_size, num_items]
        :param mask: [batch_size, num_items]
        :return:
        """
        if mask is not None:
            mask = mask.float()
        if mask is not None:
            x_masked = x * mask + (1 - 1 / (mask+1e-5))
        else:
            x_masked = x
        x_max = x_masked.max(1)[0]
        x_exp = (x - x_max.unsqueeze(-1)).exp()
        if mask is not None:
            x_exp = x_exp * mask.float()
        return x_exp / x_exp.sum(1).unsqueeze(-1)

    def forward(self, x, mask):

        " x is a tensor list"
        res = []
        for i in range(self.cluster_num):
                
            hh = x[i]
            
            hh = hh.unsqueeze(0).reshape(1, 1000, len(hh))
            
            output = self.embedding_net(hh)
            
            output = output.view(output.size()[0], -1)
            
            res.append(output)

        h = torch.cat(res)

        b = h.size(0)
        c = h.size(1)

        h = h.view(b, c)
        
        A = self.attention(h)
        A = torch.transpose(A, 1, 0)  # KxN
        
        A = self.masked_softmax(A, mask)

        M = torch.mm(A, h)  # KxL

        Y_pred = self.fc6(M)

        return Y_pred

In [192]:
model = DeepAttnMIL_Surv(cluster_num = num_clusters).cuda()
tensor_list = [torch.tensor(df.values, dtype=torch.float32) for df in patient_cluster]
tensor_list = [tensor.to(device) for tensor in tensor_list]
               
mask = np.ones(num_clusters, dtype=np.float32)
for i in cluster_num:
    if len(patient_cluster[i]) == 0:
        mask[i] = 0
        tensor_list[i] = torch.tensor(np.zeros((1, 1000), dtype = np.float32)).to(device)
mask = torch.tensor(mask, dtype=torch.float32)
masked_cls = mask.cuda()
# print(mask)
# tensor_list[0].shape
lbl_pred = model(tensor_list, masked_cls)  # prediction
print(lbl_pred)

tensor([[ 0.0873,  0.1594, -0.2427, -0.0486, -0.1489, -0.1344,  0.1191, -0.0211,
         -0.1297,  0.0856]], device='cuda:0', grad_fn=<AddmmBackward0>)


In [2]:
epochs = 30
batch_size = 32
lr = 0.001

# Kasra's use
data_path = './clustered_data.csv'
label_path = './BRCA_stages.csv'

# Raghul's use
# data_path = 'C:/Users/RaghulKrish/Desktop/UB/Spring_23/CVIP/Project/clustered_data.csv'
# label_path = 'C:/Users/RaghulKrish/Desktop/UB/Spring_23/CVIP/Project/BRCA_stages.csv'

train_data, valid_data, test_data = train_valid_test_split(data_path, label_path, 0.2, 0.2)

train_loader = DataLoader(train_data, batch_size, shuffle = True, num_workers = 4, pin_memory = True)
test_loader = DataLoader(test_data, batch_size, shuffle = True, num_workers = 4, pin_memory = True)
valid_loader = DataLoader(valid_data, batch_size, shuffle = True, num_workers = 4, pin_memory = True)

59506
87


TypeError: object of type 'CustomDataset' has no len()

In [None]:
model = DeepAttnMIL_Surv(cluster_num = num_clusters).to(device)
model.train()

train_loss = []
loss_valid = []
acc_val = []

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = lr, weight_decay = 0.005, momentum = 0.9)

def predict(batch):
    patient, label = batch
    patient = patient.to(device)
    label = label.to(device)

    output = model(patient, mask)
    loss = criterion(output, label)
    output = output.to("cpu")
    y_pred = np.argmax(np.array(output), axis = 1)
    label = label.to("cpu")
    label = np.argmax(label, axis = 1)
    acc = accuracy(y_pred, label) 
    return {'val_loss': loss, 'val_acc': acc}

def accuracy(y_pred, labels):
    return np.sum(y_pred == labels.numpy()) / labels.shape[0] * 100

for i in range(epochs):
    epoch_loss = []
    for batch in train_loader:
        patient, label, mask = batch
        patient = patient.to(device)
        label = label.to(device)

        output = model(patient, mask)
        loss = criterion(output, label)
        epoch_loss.append(loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    epoch_loss = torch.stack(epoch_loss)
    epoch_loss = epoch_loss.to("cpu")
    epoch_loss = train_loss.detach().numpy().mean()
    train_loss.append(epoch_loss)

    if i % 10 == 0:
        print(f'Epoch: {i} ---> Epoch Loss: {epoch_loss}')

    with torch.no_grad():
        valid_op = [predict(batch) for batch in valid_loader]
        batch_losses = [x['val_loss'] for x in valid_op]
        valid_loss = torch.stack(batch_losses).mean()
        batch_accs = [x['val_acc'] for x in valid_op]
        valid_acc = (np.array(batch_accs)).mean()
        acc_val.append(valid_acc)
        loss_valid.append(valid_loss)
        print(f"epoch: {i}, Validation loss: {valid_loss}, Validation Accuracy: {valid_acc}%")