## import dependencies

In [None]:
import os, sys, time
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
seed = 42
torch.manual_seed(seed)

from torchvision import transforms
import torch.utils.data as data_utils
import PIL

from torchvision.models import resnet50
# from sklearn.model_selection import KFold

## define parameters

In [None]:
## define parameters
BATCH_N = 1
EPOCH_N = 500
num_of_tiles = 2000

MODEL_FILEPATH = "/gstore/home/lix233/miltest_bs=" + str(BATCH_N) + "_numtiles=" + str(num_of_tiles) + "_equalization/"
IMAGE_DATAPATH = "/gstore/scratch/u/jea/HKoeppen_POPLAR_MIL_tiles/"

META_DATA_CSV = "/gstore/home/lix233/meta_POP.csv"

In [None]:
if not os.path.exists(MODEL_FILEPATH):
    os.mkdir(MODEL_FILEPATH)
    
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)    

## read meta file

In [None]:
meta_df = pd.read_csv(META_DATA_CSV)

### gather names of all the tiles under the each folder

In [None]:
all_tiles = [each_tile for each_img in os.listdir(IMAGE_DATAPATH)  for each_tile in os.listdir(IMAGE_DATAPATH+each_img)]
all_tiles_imgid = [all_tiles[x].split('_')[0] + '_' +(all_tiles[x].split('_')[1]) for x in range(len(all_tiles))]

meta_img = pd.DataFrame({"img_id": all_tiles_imgid, 
                         "tile_ID": all_tiles})

### join the two data frame together

In [None]:
main_df = meta_df.merge(meta_img, how='left', left_on='img_id', right_on='img_id')
print(main_df.shape)

In [None]:
main_df = main_df[~main_df['tile_ID'].isna()]
print(main_df.shape)

In [None]:
main_df = [main_df.reset_index(drop=True), pd.get_dummies(main_df['immunophenotype']).reset_index(drop=True)]
main_df = pd.concat(main_df, axis=1)

In [None]:
main_df = main_df.assign(ISTRAIN = main_df.DATATYPE == "train")
train_df = pd.DataFrame()
test_df = pd.DataFrame()

In [None]:
# loop through each patient and select tiles to use
uniq_ids = np.unique(main_df.img_id)
for x in uniq_ids:
    tmp_df = main_df[main_df.img_id == x].reset_index(drop=True)

    if tmp_df.shape[0] < num_of_tiles:
        
        # copy the tiles until you hit fifty
        sel_indx = []
        for x in range(np.ceil(num_of_tiles / tmp_df.shape[0]).astype('int')):
            sel_indx.extend(list(range(tmp_df.shape[0])))
        sel_indx = sel_indx[0:num_of_tiles]

        if tmp_df.ISTRAIN[0]:
            train_df = train_df.append(tmp_df.loc[sel_indx])
        else:
            test_df = test_df.append(tmp_df.loc[sel_indx])
    
    # for patients with a lot of tiles use them multiple times
    elif tmp_df.shape[0] > 10000:
        
        rng = np.random.default_rng(seed)
        rnd_indx = rng.permutation(tmp_df.index)

        if tmp_df.ISTRAIN[0]:
            for x in range(4):
                app_df = tmp_df.loc[rnd_indx[x*num_of_tiles:(x+1)*num_of_tiles]]
                app_df.tile_ID = app_df.tile_ID + "_" + str(x)
                train_df = train_df.append(app_df)

        else:
            for x in range(4):
                app_df = tmp_df.loc[rnd_indx[x*num_of_tiles:(x+1)*num_of_tiles]]
                app_df.tile_ID = app_df.tile_ID + "_" + str(x)
                test_df = test_df.append(app_df)
            
    else:
        
        rng = np.random.default_rng(seed)
        rnd_indx = rng.permutation(tmp_df.index)[0:num_of_tiles]

        if tmp_df.ISTRAIN[0]:
            train_df = train_df.append(tmp_df.loc[rnd_indx])
        else:
            test_df = test_df.append(tmp_df.loc[rnd_indx])


In [None]:
# reset indices for each df
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

meta_dict = {
    'train': train_df,
    'test': test_df,
}

In [None]:
print("In the training set there are " + str(train_df.shape[0] / num_of_tiles) + " observations")
print("From " + str(len(np.unique(train_df.img_id))) + " unique patients")
print("In the test set there are " + str(test_df.shape[0] / num_of_tiles) + " observations")
print("From " + str(len(np.unique(test_df.img_id))) + " unique patients")

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(224),
        transforms.RandomHorizontalFlip(0.5),
        transforms.RandomVerticalFlip(0.5),
        # transforms.RandomEqualize(p=1),
        # transforms.ColorJitter(brightness=(0.5,1.5),contrast=(1),saturation=(0.5,1.5),hue=(-0.1,0.1)),
        transforms.ToTensor()
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(224),
        #transforms.RandomEqualize(p=1),
        transforms.ToTensor()
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
}

In [None]:
class TileBags(data_utils.Dataset):

    def __init__(self, img_path, meta_df, num_tiles, transforms=None):
        self.img_path = img_path
        self.meta_df = meta_df
        self.num_tiles = num_tiles
        self.id_list = list(np.unique(self.meta_df['img_id']))
        self.transforms = transforms

    def __getitem__(self, idx):
        
        image_subset = self.meta_df['tile_ID'].loc[self.meta_df['img_id'].isin([self.id_list[idx]])].to_list()
        image_tiles = []
        
        for tile in image_subset:

            ## read in the tiff format
            tile_folder = tile.split('_')[0] + '_' + tile.split('_')[1]
            image = PIL.Image.open(self.img_path + tile_folder + '/' + tile)
            
            if self.transforms is not None:
                image = self.transforms(image)   ### (H,W,C) -> (C,H,W)

            image_tiles.append(image)
        
        # (num_tiles, height, width, channel) -> (num_tiles, channel, height, width)
        image_tiles = torch.stack(image_tiles, dim=0)
        
        tab_label = self.meta_df[['Desert', 'Excluded', 'Inflamed']].loc[self.meta_df['img_id'].isin([self.id_list[idx]])].iloc[0]
        
        return image_tiles, torch.tensor(tab_label)

    def __len__(self):
        return len(self.id_list)

In [None]:
### Create training and test datasets
image_datasets = {x: TileBags(IMAGE_DATAPATH, meta_dict[x], num_of_tiles, transforms=data_transforms[x]) for x in ['train', 'test']}

### Create training and test dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_N, shuffle=True, num_workers=0) for x in ['train', 'test']}

In [None]:
one_batch = next(iter(dataloaders_dict['train']))
import matplotlib.pyplot as plt
plt.imshow(np.transpose(one_batch[0][0,333,:], (1,2,0)))

In [None]:
one_batch = next(iter(dataloaders_dict['train']))
import matplotlib.pyplot as plt
plt.imshow(np.transpose(one_batch[0][0,333,:], (1,2,0)))

In [None]:
# Define MIL model
class MultiResNet(nn.Module):
    def __init__(self, num_tiles = num_of_tiles, tab_dropout = 0.0):
        super(self.__class__, self).__init__()
        
        self.L = 2048
        self.D = 128
        # self.K = 1
        self.num_tiles = num_tiles
        
        # Model for raw images 
        self.img_model = resnet50(weights = 'ResNet50_Weights.DEFAULT')
        
        for param in self.img_model.parameters():
            param.requires_grad = False

        # Replace the last fully-connected layer with Identity
        # Parameters of newly constructed modules have requires_grad=True by default
        self.img_model.fc = nn.Identity() #2048

        ## attention score
        self.img_attention = nn.Sequential(
            nn.Linear(2048, self.D),
            nn.Tanh(),
            nn.Linear(self.D, 1)
        )
        
        # Final Layers Per Modality
        self.img_cls = nn.Sequential(
            nn.Linear(2048, 3),   # with attention should be self.L*self.K
            nn.Softmax()
        )

        

    
    def forward(self, x):
        ## (batch_size, nums_times, C, H, W)
        sh = x.shape
        x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4])
        
        self.img_model.eval() ## to reassure the pre-train weights are used
        with torch.no_grad():
            img_feats_fc = self.img_model(x)

        ## (batch_size, nums_tiles, 2048(nfc))
        img_feats_fc = img_feats_fc.reshape(sh[0], sh[1], -1)
        # print(img_feats_fc.shape)
        
        ## (batch_size, nums_tiles, 1)
        Atten = self.img_attention(img_feats_fc)
        # print(Atten.shape)

        ## (batch_size, nums_tiles, 1)
        Atten = F.softmax(Atten, dim=1)
        # print(Atten.shape)

        ## (batch_size, 2048(nfc))
        MM = torch.sum(img_feats_fc * Atten, dim=1)
        # print(MM.shape)

        ## (batch_size, num_cls)
        img_prob = self.img_cls(MM)

        return img_prob


### early stopping

In [None]:
class EarlyStopping:
    """Early stops the training if test loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path=None, trace_func=print,
                best_score=None, test_loss_min = np.Inf, curr_epoch=0):
        """
        Args:
            patience (int): How long to wait after last time test loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each test loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: Current working directory
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.epoch = curr_epoch
        self.best_score = best_score
        self.early_stop = False
        self.test_loss_min = test_loss_min
        self.delta = delta
        self.path = path
        self.fname = ''
        self.trace_func = trace_func
    def __call__(self, test_loss, model, optimizer, scheduler):

        score = -test_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(test_loss, model, optimizer, scheduler)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(test_loss, model, optimizer, scheduler)
            self.counter = 0
        self.epoch += 1

    def save_checkpoint(self, test_loss, model, optimizer, scheduler):
        '''Saves model when test loss decreases.'''

        self.trace_func(f'Test loss decreased ({self.test_loss_min:.6f} --> {test_loss:.6f}).  Saving model ...')

        if self.path is None:
            self.path = os.getcwd()

        out_fname = "model_" + '%.6f' % test_loss + "_" + str(self.epoch) + ".ckpt"
        self.fname = os.path.join(self.path, out_fname)

        # update test loss value
        self.test_loss_min = test_loss

        # save the info we need to resume training if needed
        checkpoint_dict = {
            'epoch' : self.epoch,
            'test_loss_min' : self.test_loss_min,
            'state_dict' : model.state_dict(),
            'optimizer' : optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
        }

        torch.save(checkpoint_dict, self.fname)

        # save a file pointing to best model
        with open(self.path + '/best_model.txt', 'w') as f:
            f.write("%s\n" % self.fname)
            

## Define model

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs):

    since = time.time()
    early_stopping = EarlyStopping(patience=15, delta=1e-6, path=MODEL_FILEPATH)

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and test phase
        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for i, one_batch in enumerate(dataloaders_dict[phase]):

                imgs = one_batch[0].to(device)
                outcomes = one_batch[1].float().to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(imgs)
                    _, preds = torch.max(outputs, 1)
                    
                    loss = criterion(outputs, outcomes)
                  
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * outcomes.size(0)
                running_corrects += (torch.sum(preds == torch.max(outcomes, 1).indices)).item()

                
            epoch_loss = running_loss / len(dataloaders_dict[phase])
            epoch_acc = running_corrects / len(image_datasets[phase])
            
            if phase == 'train':
                #avg_train_losses.append(epoch_loss)
                with open(MODEL_FILEPATH + '/train_loss.txt', 'a') as f:
                    f.write("%s\n" % epoch_loss + "%s\n" % epoch_acc)
            else:
                #avg_test_losses.append(epoch_loss)
                with open(MODEL_FILEPATH + '/test_loss.txt', 'a') as f:
                    f.write("%s\n" % epoch_loss + "%s\n" % epoch_acc)

                # Step LR in test phase
                scheduler.step(epoch_loss)

                # Track Early Stopping in Test Phase
                early_stopping(epoch_loss, model, optimizer, scheduler)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

        if early_stopping.early_stop:
            print("Early stopping")
            break

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best test MSE: {:6f}'.format(early_stopping.test_loss_min))

    # load best model weights
    best_model_dict = torch.load(early_stopping.fname)
    model.load_state_dict(best_model_dict['state_dict'])

    return model


## Initialize model

In [None]:
# our model uses frozen layers from resnet for feature extraction
feature_extract = True

model_ft = MultiResNet().to(device)

params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)


In [None]:
# Observe that all parameters are being optimized
optimizer_ft = torch.optim.SGD(params_to_update, lr=0.001, momentum=0.9)

criterion_ft = nn.CrossEntropyLoss()

# Decay LR by a factor of 0.1 every 10 epochs
# exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)
# LR on Plateau - reduces LR by 0.1 after 10 epochs
plat_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer_ft, mode='min', verbose=True)


# Start training

In [None]:
model_ft = train_model(model_ft, criterion_ft, optimizer_ft, plat_lr_scheduler, num_epochs=EPOCH_N)

# Predication

In [None]:
import os, sys, time
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
seed = 42
torch.manual_seed(seed)

from torchvision import transforms
import torch.utils.data as data_utils
import PIL

from torchvision.models import resnet50
from sklearn.metrics import confusion_matrix 
# from sklearn.model_selection import KFold

In [None]:
os.listdir('/gstore/home/lix233/miltest_bs=1_numtiles=2000_wo_coloraug_wo_normalization')

In [None]:
## define parameters
BATCH_N = 1
EPOCH_N = 500
num_of_tiles = 2000

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)    


class TileBags(data_utils.Dataset):

    def __init__(self, img_path, meta_df, num_tiles, transforms=None):
        self.img_path = img_path
        self.meta_df = meta_df
        self.num_tiles = num_tiles
        self.id_list = list(np.unique(self.meta_df['img_id']))
        self.transforms = transforms

    def __getitem__(self, idx):

        image_subset = self.meta_df['tile_ID'].loc[self.meta_df['img_id'].isin([self.id_list[idx]])].to_list()
        image_tiles = []

        for tile in image_subset:

            ## read in the tiff format
            tile_folder = tile.split('_')[0] + '_' + tile.split('_')[1]
            image = PIL.Image.open(self.img_path + tile_folder + '/' + tile)

            if self.transforms is not None:
                image = self.transforms(image)   ### (H,W,C) -> (C,H,W)

            image_tiles.append(image)

        # (num_tiles, height, width, channel) -> (num_tiles, channel, height, width)
        image_tiles = torch.stack(image_tiles, dim=0)

        tab_label = self.meta_df[['Desert', 'Excluded', 'Inflamed']].loc[self.meta_df['img_id'].isin([self.id_list[idx]])].iloc[0]

        return image_tiles, torch.tensor(tab_label)

    def __len__(self):
        return len(self.id_list)
    
# Define MIL model
class MultiResNet(nn.Module):
    def __init__(self, num_tiles = num_of_tiles, tab_dropout = 0.0):
        super(self.__class__, self).__init__()
        
        self.L = 2048
        self.D = 128
        # self.K = 1
        self.num_tiles = num_tiles
        
        # Model for raw images 
        self.img_model = resnet50(weights = 'ResNet50_Weights.DEFAULT')
        
        for param in self.img_model.parameters():
            param.requires_grad = False

        # Replace the last fully-connected layer with Identity
        # Parameters of newly constructed modules have requires_grad=True by default
        self.img_model.fc = nn.Identity() #2048

        ## attention score
        self.img_attention = nn.Sequential(
            nn.Linear(2048, self.D),
            nn.Tanh(),
            nn.Linear(self.D, 1)
        )
        
        # Final Layers Per Modality
        self.img_cls = nn.Sequential(
            nn.Linear(2048, 3),   # with attention should be self.L*self.K
            nn.Softmax()
        )

        

    
    def forward(self, x):
        ## (batch_size, nums_times, C, H, W)
        sh = x.shape
        x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4])
        
        self.img_model.eval() ## to reassure the pre-train weights are used
        with torch.no_grad():
            img_feats_fc = self.img_model(x)

        ## (batch_size, nums_tiles, 2048(nfc))
        img_feats_fc = img_feats_fc.reshape(sh[0], sh[1], -1)
        # print(img_feats_fc.shape)
        
        ## (batch_size, nums_tiles, 1)
        Atten = self.img_attention(img_feats_fc)
        # print(Atten.shape)

        ## (batch_size, nums_tiles, 1)
        Atten = F.softmax(Atten, dim=1)
        # print(Atten.shape)

        ## (batch_size, 2048(nfc))
        MM = torch.sum(img_feats_fc * Atten, dim=1)
        # print(MM.shape)

        ## (batch_size, num_cls)
        img_prob = self.img_cls(MM)

        return img_prob


In [None]:
# our model uses frozen layers from resnet for feature extraction
feature_extract = True

model_ft = MultiResNet().to(device)

params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

# Observe that all parameters are being optimized
optimizer_ft = torch.optim.SGD(params_to_update, lr=0.001, momentum=0.9)

criterion_ft = nn.CrossEntropyLoss()

# Decay LR by a factor of 0.1 every 10 epochs
# exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)
# LR on Plateau - reduces LR by 0.1 after 10 epochs
plat_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer_ft, mode='min', verbose=True)


In [None]:
# our model uses frozen layers from resnet for feature extraction
feature_extract = True

model_ft = MultiResNet().to(device)

params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)


In [None]:
checkpoint = torch.load('/gstore/home/lix233/miltest_bs=1_numtiles=2000_wo_coloraug_wo_normalization/model_0.897802_31.ckpt')

In [None]:
model_ft.load_state_dict(checkpoint['state_dict'])

In [None]:
epoch = checkpoint['epoch']
loss = checkpoint['test_loss_min']

In [None]:
print(epoch)
print(loss)

## test on OAK

In [None]:
IMAGE_DATAPATH_test1 = '/gstore/scratch/u/jea/HKoeppen_OAK_MIL_tiles/'

In [None]:
all_tiles_test1 = [each_tile for each_img in os.listdir(IMAGE_DATAPATH_test1)  for each_tile in os.listdir(IMAGE_DATAPATH_test1+each_img)]
all_tiles_imgid_test1 = [all_tiles_test1[x].split('_')[0] + '_' +(all_tiles_test1[x].split('_')[1]) for x in range(len(all_tiles_test1))]

meta_img_test1 = pd.DataFrame({"img_id": all_tiles_imgid_test1, 
                         "tile_ID": all_tiles_test1})

In [None]:
### read meta file
meta_df_test1 = pd.read_csv('/gstore/home/lix233/meta_OAK.csv')

In [None]:
### join two data frame together
main_df_test1 = meta_df_test1.merge(meta_img_test1, how='left', left_on='img_id', right_on='img_id')
print(main_df_test1.shape)

main_df_test1 = main_df_test1[~main_df_test1['tile_ID'].isna()]
print(main_df_test1.shape)

main_df_test1 = [main_df_test1.reset_index(drop=True), pd.get_dummies(main_df_test1['immunophenotype']).reset_index(drop=True)]
main_df_test1 = pd.concat(main_df_test1, axis=1)


In [None]:
test1_df = pd.DataFrame()

In [None]:
## loop through each patient and select tiles to use
uniq_ids = np.unique(main_df_test1.img_id)
for x in uniq_ids:
    tmp_df = main_df_test1[main_df_test1.img_id == x].reset_index(drop=True)

    if tmp_df.shape[0] < num_of_tiles:

        # copy the tiles until you hit fifty
        sel_indx = []
        for x in range(np.ceil(num_of_tiles / tmp_df.shape[0]).astype('int')):
            sel_indx.extend(list(range(tmp_df.shape[0])))
        sel_indx = sel_indx[0:num_of_tiles]

        test1_df = test1_df.append(tmp_df.loc[sel_indx])
       

    # for patients with a lot of tiles use them multiple times
    elif tmp_df.shape[0] > 10000:

        rng = np.random.default_rng(seed)
        rnd_indx = rng.permutation(tmp_df.index)

        for x in range(4):
            app_df = tmp_df.loc[rnd_indx[x*num_of_tiles:(x+1)*num_of_tiles]]
            app_df.tile_ID = app_df.tile_ID + "_" + str(x)
            test1_df = test1_df.append(app_df)

     

    else:

        rng = np.random.default_rng(seed)
        rnd_indx = rng.permutation(tmp_df.index)[0:num_of_tiles]

        test1_df = test1_df.append(tmp_df.loc[rnd_indx])



In [None]:
## reset indices for test1_df
test1_df = test1_df.reset_index(drop=True)

In [None]:
meta_dict = {
    'test': test1_df,
}

In [None]:
print("In the test set1 there are " + str(test1_df.shape[0] / num_of_tiles) + " observations")
print("From " + str(len(np.unique(test1_df.img_id))) + " unique patients")

In [None]:
data_transforms = {
    'test': transforms.Compose([
        transforms.Resize(224),
        # transforms.RandomEqualize(p=1),
        transforms.ToTensor()
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
}


In [None]:
### Create training and test datasets
image_datasets = {x: TileBags(IMAGE_DATAPATH_test1, meta_dict[x], num_of_tiles, transforms=data_transforms[x]) for x in ['test']}

### Create training and test dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_N, shuffle=False, num_workers=0) for x in ['test']}


In [None]:
model_ft.eval()

In [None]:
# Iterate over data.
pred_test1= pd.DataFrame(columns = ['prob_Desert', 'prob_Excluded', 'prob_Inflamed', 'Desert', 'Excluded', 'Inflamed'])

with torch.no_grad():
    for i, one_batch in enumerate(dataloaders_dict['test']):
        imgs = one_batch[0].to(device)
        outcomes = one_batch[1].float().to(device)

        y_pred = model_ft(imgs)

        y_pred_df = pd.DataFrame(y_pred.cpu().data.numpy()); y_pred_df.columns = ['prob_Desert', 'prob_Excluded', 'prob_Inflamed']
        outcomes_df = pd.DataFrame(outcomes.cpu().data.numpy()); outcomes_df.columns = ['Desert', 'Excluded', 'Inflamed']
        this_batch_pred = pd.concat([y_pred_df, outcomes_df], axis=1)

        pred_test1 = pd.concat([pred_test1.reset_index(drop=True), this_batch_pred.reset_index(drop = True)], axis=0)


In [None]:
pred_test1.to_csv('MIL_OAK.csv', index = False)

In [None]:
# pd.DataFrame({"ID":uniq_ids})

In [None]:
pd.concat([pred_test1.reset_index(drop = True), pd.DataFrame({"ID":uniq_ids})], axis=1).to_csv('MIL_OAK_v2.csv', index = False)

## test on IMP130

In [None]:
import os, sys, time
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
seed = 42
torch.manual_seed(seed)

from torchvision import transforms
import torch.utils.data as data_utils
import PIL

from torchvision.models import resnet50
from sklearn.metrics import confusion_matrix 


In [None]:
## define parameters
BATCH_N = 1
EPOCH_N = 500
num_of_tiles = 2000

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)    


class TileBags(data_utils.Dataset):

    def __init__(self, img_path, meta_df, num_tiles, transforms=None):
        self.img_path = img_path
        self.meta_df = meta_df
        self.num_tiles = num_tiles
        self.id_list = list(np.unique(self.meta_df['img_id']))
        self.transforms = transforms

    def __getitem__(self, idx):

        image_subset = self.meta_df['tile_ID'].loc[self.meta_df['img_id'].isin([self.id_list[idx]])].to_list()
        image_tiles = []

        for tile in image_subset:

            ## read in the tiff format
            tile_folder = tile.split('_')[0]
            image = PIL.Image.open(self.img_path + tile_folder + '/' + tile)

            if self.transforms is not None:
                image = self.transforms(image)   ### (H,W,C) -> (C,H,W)

            image_tiles.append(image)

        # (num_tiles, height, width, channel) -> (num_tiles, channel, height, width)
        image_tiles = torch.stack(image_tiles, dim=0)

        tab_label = self.meta_df[['Desert', 'Excluded', 'Inflamed']].loc[self.meta_df['img_id'].isin([self.id_list[idx]])].iloc[0]

        return image_tiles, torch.tensor(tab_label)

    def __len__(self):
        return len(self.id_list)
    
# Define MIL model
class MultiResNet(nn.Module):
    def __init__(self, num_tiles = num_of_tiles, tab_dropout = 0.0):
        super(self.__class__, self).__init__()
        
        self.L = 2048
        self.D = 128
        # self.K = 1
        self.num_tiles = num_tiles
        
        # Model for raw images 
        self.img_model = resnet50(weights = 'ResNet50_Weights.DEFAULT')
        
        for param in self.img_model.parameters():
            param.requires_grad = False

        # Replace the last fully-connected layer with Identity
        # Parameters of newly constructed modules have requires_grad=True by default
        self.img_model.fc = nn.Identity() #2048

        ## attention score
        self.img_attention = nn.Sequential(
            nn.Linear(2048, self.D),
            nn.Tanh(),
            nn.Linear(self.D, 1)
        )
        
        # Final Layers Per Modality
        self.img_cls = nn.Sequential(
            nn.Linear(2048, 3),   # with attention should be self.L*self.K
            nn.Softmax()
        )

        

    
    def forward(self, x):
        ## (batch_size, nums_times, C, H, W)
        sh = x.shape
        x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4])
        
        self.img_model.eval() ## to reassure the pre-train weights are used
        with torch.no_grad():
            img_feats_fc = self.img_model(x)

        ## (batch_size, nums_tiles, 2048(nfc))
        img_feats_fc = img_feats_fc.reshape(sh[0], sh[1], -1)
        # print(img_feats_fc.shape)
        
        ## (batch_size, nums_tiles, 1)
        Atten = self.img_attention(img_feats_fc)
        # print(Atten.shape)

        ## (batch_size, nums_tiles, 1)
        Atten = F.softmax(Atten, dim=1)
        # print(Atten.shape)

        ## (batch_size, 2048(nfc))
        MM = torch.sum(img_feats_fc * Atten, dim=1)
        # print(MM.shape)

        ## (batch_size, num_cls)
        img_prob = self.img_cls(MM)

        return img_prob



In [None]:
# our model uses frozen layers from resnet for feature extraction
feature_extract = True

model_ft = MultiResNet().to(device)

params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

# Observe that all parameters are being optimized
optimizer_ft = torch.optim.SGD(params_to_update, lr=0.001, momentum=0.9)

criterion_ft = nn.CrossEntropyLoss()

# Decay LR by a factor of 0.1 every 10 epochs
# exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)
# LR on Plateau - reduces LR by 0.1 after 10 epochs
plat_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer_ft, mode='min', verbose=True)


In [None]:
# our model uses frozen layers from resnet for feature extraction
feature_extract = True

model_ft = MultiResNet().to(device)

params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)


In [None]:
checkpoint = torch.load('/gstore/home/lix233/miltest_bs=1_numtiles=2000_wo_coloraug_wo_normalization/model_0.897802_31.ckpt')

In [None]:
model_ft.load_state_dict(checkpoint['state_dict'])

In [None]:
epoch = checkpoint['epoch']
loss = checkpoint['test_loss_min']

In [None]:
print(epoch)
print(loss)

In [None]:
IMAGE_DATAPATH_test2 = '/gstore/scratch/u/jea/HKoeppen_IMP130_MIL_tiles/'

In [None]:
all_tiles_test2 = [each_tile for each_img in os.listdir(IMAGE_DATAPATH_test2)  for each_tile in os.listdir(IMAGE_DATAPATH_test2+each_img)]
all_tiles_imgid_test2 = [all_tiles_test2[x].split('_')[0] for x in range(len(all_tiles_test2))]

meta_img_test2 = pd.DataFrame({"img_id": all_tiles_imgid_test2, 
                               "tile_ID": all_tiles_test2})

In [None]:
meta_df_test2 = pd.read_csv('/gstore/home/lix233/meta_IMP130.csv')

In [None]:
### join two data frame together
main_df_test2 = meta_df_test2.merge(meta_img_test2, how='left', left_on='img_id', right_on='img_id')
print(main_df_test2.shape)

main_df_test2 = main_df_test2[~main_df_test2['tile_ID'].isna()]
print(main_df_test2.shape)

main_df_test2 = [main_df_test2.reset_index(drop=True), pd.get_dummies(main_df_test2['immunophenotype']).reset_index(drop=True)]
main_df_test2 = pd.concat(main_df_test2, axis=1)


In [None]:
test2_df = pd.DataFrame()

In [None]:
## loop through each patient and select tiles to use
uniq_ids = np.unique(main_df_test2.img_id)
for x in uniq_ids:
    tmp_df = main_df_test2[main_df_test2.img_id == x].reset_index(drop=True)

    if tmp_df.shape[0] < num_of_tiles:

        # copy the tiles until you hit fifty
        sel_indx = []
        for x in range(np.ceil(num_of_tiles / tmp_df.shape[0]).astype('int')):
            sel_indx.extend(list(range(tmp_df.shape[0])))
        sel_indx = sel_indx[0:num_of_tiles]

        test2_df = test2_df.append(tmp_df.loc[sel_indx])
       

    # for patients with a lot of tiles use them multiple times
    elif tmp_df.shape[0] > 10000:

        rng = np.random.default_rng(seed)
        rnd_indx = rng.permutation(tmp_df.index)

        for x in range(4):
            app_df = tmp_df.loc[rnd_indx[x*num_of_tiles:(x+1)*num_of_tiles]]
            app_df.tile_ID = app_df.tile_ID + "_" + str(x)
            test2_df = test2_df.append(app_df)

     

    else:

        rng = np.random.default_rng(seed)
        rnd_indx = rng.permutation(tmp_df.index)[0:num_of_tiles]

        test2_df = test2_df.append(tmp_df.loc[rnd_indx])


In [None]:
## reset indices for test2_df
test2_df = test2_df.reset_index(drop=True)

meta_dict = {
    'test': test2_df,
}

In [None]:
print("In the test set2 there are " + str(test2_df.shape[0] / num_of_tiles) + " observations")
print("From " + str(len(np.unique(test2_df.img_id))) + " unique patients")



In [None]:
data_transforms = {
    'test': transforms.Compose([
        transforms.Resize(224),
        # transforms.RandomEqualize(p=1),
        transforms.ToTensor()
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
}


In [None]:
### Create training and test datasets
image_datasets = {x: TileBags(IMAGE_DATAPATH_test2, meta_dict[x], num_of_tiles, transforms=data_transforms[x]) for x in ['test']}

### Create training and test dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_N, shuffle=False, num_workers=0) for x in ['test']}


In [None]:
model_ft.eval()

In [None]:
# Iterate over data.
pred_test2= pd.DataFrame(columns = ['prob_Desert', 'prob_Excluded', 'prob_Inflamed', 'Desert', 'Excluded', 'Inflamed'])

with torch.no_grad():
    for i, one_batch in enumerate(dataloaders_dict['test']):
        imgs = one_batch[0].to(device)
        outcomes = one_batch[1].float().to(device)

        y_pred = model_ft(imgs)

        y_pred_df = pd.DataFrame(y_pred.cpu().data.numpy()); y_pred_df.columns = ['prob_Desert', 'prob_Excluded', 'prob_Inflamed']
        outcomes_df = pd.DataFrame(outcomes.cpu().data.numpy()); outcomes_df.columns = ['Desert', 'Excluded', 'Inflamed']
        this_batch_pred = pd.concat([y_pred_df, outcomes_df], axis=1)

        pred_test2 = pd.concat([pred_test2.reset_index(drop=True), this_batch_pred.reset_index(drop = True)], axis=0)


In [None]:
pred_test2.to_csv('MIL_IMP130.csv', index = False)

In [None]:
pd.concat([pred_test2.reset_index(drop = True), pd.DataFrame({"ID":uniq_ids})], axis=1).to_csv('MIL_IMP130_v2.csv', index = False)