In [1]:
import numpy as np
import pandas as pd
import os
import tifffile
import cv2
from os.path import join, isfile, exists
import torch
from torchview import draw_graph
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn as nn
from torchcam.utils import overlay_mask
from torchcam.methods import SmoothGradCAMpp
import os
from PIL import Image
import re
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import matplotlib.pyplot as plt
import time
import copy
import warnings
import contextlib
import sys

In [2]:
m=40
n=40

def strTmpt(t):
    if t == 1:
        return 't001'
    elif t < 100:
        return 't0'+str(t)
    else:
        return 't'+str(t)

def prepl(t,a,b):
    if t <= a or b<= t:
        return 0
    elif a < t and t <= (a+b)/2:
        return t-a
    else:
        return b-t

def prepl_list(t, intervals):
    return [prepl(t,i[0],i[1]) for i in intervals]

def prepare_Column(m, t, intervals):
    pl_Vals = prepl_list(t, intervals)
    pl_Vals = sorted(pl_Vals, reverse = True)
    return pl_Vals[:m]

def prepare_Input(filePath):
    df1 = pd.read_csv(filePath)
    startList1 = list(df1['starts'])
    endList1 = list(df1['ends'])
    l1 = [[startList1[i],endList1[i]] for i in range(len(startList1))]
    return l1

#returns the treatment condition for a particular sample
#-samp is the sample name
#-condFile is a .csv file with the sample names and conditions
#each stored in separate columns
#-condName is the name of the column of conditions in the file
#-sampName is the name of the column of sample names in the file
def get_condition(samp,condFile,condName = 'conds',sampName = 'samps'):
    df = pd.read_csv(condFile)
    return df[condName][list(df[sampName]).index(samp)]
      
#creates a dataset of labeled homology data (possibly of a particular dimension)
#for a particular neural network, all from a single timepoint.
#The initialization assumes the images are stored in folders according to 
#timepoint, and then labeled by sample
#-homRoots is a list of this format:
#     [path to 0-dim hom folders, path to 1-dim hom folders]
#-time is the timepoint, given as 'txxx' (e.g. 't070' or 't210')
#-condFile is the .csv file which has the treatment condition for each sample
#-dims is the dimension of the homology we are using, either 0, 1, or both
#-m is the number of persistence landscape functions
#-n is the number of points to sample for persistence landscapes
#   -width is the distance between sampled points
#   -start is the starting sample point
class HomDataset(Dataset):
    def __init__(self,homRoots, time,condFile,dims, transform=None, m = 40, n = 40, width = 1, start=1):
        root_dir = homRoots
        self.transform = transform
        self.pls = []
        self.labels = []
        label_map

        if dims == 'Dim1' or dims == 'Dim0':
            if dims == 'Dim0':
                root_dir = homRoots[0]+strTmpt(t)
            else:
                root_dir = homRoots[1]+strTmpt(t)
            files = [f for f in os.listdir(root_dir) if isfile(join(root_dir, f))]
            for file in files:
                file_path = os.path.join(root_dir, file)
                #these if-else statements are to address parsing the filename formats
                if len(file[:file.find('.')]) == 6:
                    label = get_condition(file[1:6],condFile)
                else:
                    label = get_condition(file[1:4]+'0'+file[4], condFile)
                intvls = prepare_Input(file_path)
                scalepoints = [start+i*width for i in range(n)]
                plMat = np.asarray([prepare_Column(m, i, intvls) for i in scalepoints], dtype=np.float32)
                label = label_map.get(label, -1)
                if label != -1:
                    self.pls.append(plMat)
                    self.labels.append(label)

        else:
            dir_one = homRoots[1]+strTmpt(t)
            dir_zero = homRoots[0]+strTmpt(t)
            files = [f for f in os.listdir(dir_one) if (isfile(join(dir_one, f)) and isfile(join(dir_zero,f)))]
            for file in files:
                path_one = os.path.join(dir_one, file)
                path_zero = os.path.join(dir_zero, file)
                #these if-else statements are to address parsing the filename formats
                if len(file[:file.find('.')]) == 6:
                    label = get_condition(file[1:6],condFile)
                else:
                    label = get_condition(file[1:4]+'0'+file[4], condFile)

                intvls_one = prepare_Input(path_one)
                intvls_zero = prepare_Input(path_zero)
                scalepoints = [start+i*width for i in range(n)]
                plMat_one = np.asarray([prepare_Column(int(m/2), i, intvls_one) for i in scalepoints], dtype=np.float32)
                plMat_zero = np.asarray([prepare_Column(int(m/2), i, intvls_zero) for i in scalepoints], dtype=np.float32)
                label = label_map.get(label, -1)
                plMat = np.hstack((plMat_zero,plMat_one))
                if label != -1:
                    self.pls.append(plMat)
                    self.labels.append(label)
        
        
    def __len__(self):
        return len(self.pls)
    
    def __getitem__(self, idx):
        pl = self.pls[idx]
        label = self.labels[idx]
        # apply transformation         
        if self.transform is not None:
            pl = self.transform(pl)
        return pl, label

def train_model(model, criterion, optimizer, dataloaders, num_epochs=100, debug=False):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    loss_values = []
    for epoch in range(num_epochs+1):
        if epoch % 5 == 0 and debug:
            print(f'Epoch {epoch}/{num_epochs}')
            
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0
            total = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs.data, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item()
                running_corrects += (preds == labels).sum().item()
                total += labels.size(0)

            epoch_loss = running_loss / len(dataloaders[phase])
            loss_values.append(epoch_loss)
            epoch_acc =  running_corrects / total
            if epoch % 5 == 0 and debug:
                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')   
                
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        if epoch % 5 == 0 and debug:  
            print('-' * 10)

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    model.load_state_dict(best_model_wts)
    return model, loss_values

def get_metrics(model, test_dataloader):
    true_labels = []
    predicted_labels = []
    accuracy_values = []
    correct = 0
    total = 0
    model.eval()

    with torch.no_grad():
        for images, labels in test_dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)

            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(predicted.cpu().numpy())

    precision = precision_score(true_labels, predicted_labels, average='macro')
    recall = recall_score(true_labels, predicted_labels, average='macro')
    f1 = f1_score(true_labels, predicted_labels, average='macro')
    accuracy = accuracy_score(true_labels, predicted_labels)
    return accuracy, precision, recall, f1

label_map = {"BMP4" :0, "CHIR": 1, "DS": 2, "DS+CHIR": 3,  "WT": 4}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class TDANet(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.runthru = nn.Sequential(
            nn.Linear(m*n, 20),
            nn.ReLU(),
            nn.Linear(20, 20),
            nn.ReLU(),
            nn.Linear(20, 20),
            nn.ReLU(),
            nn.Linear(20,5),
            nn.Softmax()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.runthru(x)
        return logits    
    
transform = transforms.Compose([
    transforms.ToTensor()
])

#-homRoots is a list of this format:
#     [path to 0-dim hom folders, path to 1-dim hom folders]
#-time is the timepoint at which we are training/testing
#-condFile is the .csv file which has the treatment condition for each sample
#-dims is the dimension of the homology we are using, either 0, 1, or both
#-landFunct is the number of persistence landscape functions
#-train is whether we are training a model, or justing testing an already existing one
def get_dataloaders(homRoots,time, condFile, dims, landFunct = 40, train = True):
    dataset = HomDataset(homRoots,time,condFile,dims, m = landFunct, transform=transform)
    if train:
        train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)
        train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
        test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)
        dataloaders = {'train': train_dataloader, 'val': test_dataloader}
        return dataloaders
    else:
        test_dataloader = DataLoader(dataset, batch_size=8,shuffle = False)
        return test_dataloader

In [3]:
#-homRoots is a list of this format:
#     [path to 0-dim hom folders, path to 1-dim hom folders]
#-time is the timepoint at which we are training/testing
#-condFileTr is the .csv file which has the treatment condition for each sample (for training)
#-condFileTe is the .csv file which has the treatment condition for each sample (for testing)
#    the different condition file names are for cases like say, you want to train the model
#    on randomized labels and test on the actual labels, as a control
#-dims is the dimension of the homology we are using, either 0, 1, or both
#-ldfs is the number of persistence landscape functions
#-train is whether we are training a model, or justing testing an already existing one
#-saveModel is whether to save the model
#-modelName is the filename of the model if saved
def train_at_Tmpt(homRoots,t,condFileTr, condFileTe, dims,ldfs,saveModel = 1, modelName = 'MyTDANetModel.pt'):
    tOI = strTmpt(t)
    model = TDANet()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    model = model.to(device)
    original_stderr = sys.stderr
    sys.stderr = open(os.devnull, 'w')

    dataloaders = get_dataloaders(homRoots,t,condFileTr, dims, landFunct = ldfs)
    model_trained, loss_values = train_model(model, criterion, optimizer, dataloaders, num_epochs=50)
    dataloaders_actual = get_dataloaders(homRoots,t,condFileTe, dims, landFunct = ldfs)
    accuracy, precision, recall, fOne = get_metrics(model_trained, dataloaders_actual['val'])
    if saveModel==1:
        torch.save(model_trained, modelName)
    sys.stderr = original_stderr
    timestamp:  1
    return accuracy, precision, recall, fOne, model_trained