In [1]:
import numpy as np
import pandas as pd
import os
import tifffile
import cv2
from os.path import join, isfile, exists
import torch
from torchview import draw_graph
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn as nn
from torchcam.utils import overlay_mask
from torchcam.methods import SmoothGradCAMpp
from torchcam.methods import CAM
import os
from PIL import Image
import re
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import matplotlib.pyplot as plt
import time
import copy
import warnings
import contextlib
import sys

In [None]:
imagePath = 'C:/Users/aruys/Dropbox (GaTech)/CZIConverted'

def strTmpt(t):
    if t == 1:
        return 't001'
    elif t < 100:
        return 't0'+str(t)
    else:
        return 't'+str(t)

#returns the treatment condition for a particular sample
#-samp is the sample name
#-condFile is a .csv file with the sample names and conditions
# each stored in separate columns
#-condName is the name of the column of conditions in the file
#-sampName is the name of the column of sample names in the file
def get_condition(samp,condFile,condName = 'conds',sampName = 'samps'):
    df = pd.read_csv(condFile)
    return df[condName][list(df[sampName]).index(samp)]

#creates a dataset of labeled sample images for a particular neural network, all
#from a single timepoint.
#The initialization assumes the images are stored in folders according to 
#sample, and then labeled by timepoint
#-root_dir, the directory of the image folder
#-condFile, the .csv file listing the treatment conditions for each sample
#-timestamp: the timepoint for which we are taking the sample images
class ImageDataset(Dataset):
    def __init__(self, root_dir, condFile, transform=None, timestamp=1):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        label_map 
        for folder in os.listdir(root_dir):
            if len(folder) == 5 and (folder[:2] == '22' or folder[:2] == '26'):
                folder_path = os.path.join(root_dir, folder)
                files = [f for f in os.listdir(folder_path) if isfile(join(folder_path, f))]
                label = get_condition(folder,condFile)
                for file in files:
                    file_path = os.path.join(folder_path, file)
                    #in our case we had to distinguish between channel 1 'c001' and channel 2 'c002' of the image
                    match = re.search(r'_t(\d+)_c002', file)
                    if match:
                        number = int(match.group(1))
                        if number == timestamp:
                            if file_path.endswith(".png"):
                                img = Image.open(file_path).convert("RGB")
                                name = file.split("_")[0]
                                label = label_map.get(label, -1)
                                if label != -1:
                                    self.images.append(img)
                                    self.labels.append(label)
                            elif file_path.endswith(".tif"):                     
                                try:
                                    image_array = tifffile.imread(file_path)
                                except TypeError:
                                    pass
                                img_rescaled = 255 * (image_array - image_array.min()) / (image_array.max() - image_array.min())
                                img_col = cv2.applyColorMap(img_rescaled.astype(np.uint8), cv2.COLORMAP_DEEPGREEN)
                                img = Image.fromarray(img_col)
                                img = img.convert("RGB")
                                name = file.split("_")[0]
                                label = label_map.get(label, -1)
                                if label != -1:
                                    self.images.append(img)
                                    self.labels.append(label)
                                break
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        # apply transformation         
        if self.transform is not None:
            img = self.transform(img)
        return img, label

def train_model(model, criterion, optimizer, dataloaders, num_epochs=100, debug=False):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    loss_values = []
    for epoch in range(num_epochs+1):
        if epoch % 5 == 0 and debug:
            print(f'Epoch {epoch}/{num_epochs}')
            
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            running_corrects = 0
            total = 0
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs.data, 1)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        
                running_loss += loss.item()
                running_corrects += (preds == labels).sum().item()
                total += labels.size(0)

            epoch_loss = running_loss / len(dataloaders[phase])
            loss_values.append(epoch_loss)
            epoch_acc =  running_corrects / total
            if epoch % 5 == 0 and debug:
                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')   
                
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        if epoch % 5 == 0 and debug:  
            print('-' * 10)

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    model.load_state_dict(best_model_wts)
    return model, loss_values

def get_metrics(model, test_dataloader):
    true_labels = []
    predicted_labels = []
    accuracy_values = []
    correct = 0
    total = 0
    model.eval()

    with torch.no_grad():
        for images, labels in test_dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)

            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(predicted.cpu().numpy())

    precision = precision_score(true_labels, predicted_labels, average='macro')
    recall = recall_score(true_labels, predicted_labels, average='macro')
    f1 = f1_score(true_labels, predicted_labels, average='macro')
    accuracy = accuracy_score(true_labels, predicted_labels)
    return accuracy, precision, recall, f1

#-imRroot is the directory of the image folders
#-time is the timepoint at which we are training/testing
#-condFile is the .csv file which has the treatment condition for each sample
#-train is whether we are training a model, or just testing an already existing one

def get_dataloaders(imRoot, time, condFile, train = True):
    dataset = ImageDataset(imRoot, condFile, transform=transform, timestamp=time)
    
    if train:
        train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)
        train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
        test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)
        dataloaders = {'train': train_dataloader, 'val': test_dataloader}
        return dataloaders
    else:
        test_dataloader = DataLoader(dataset, batch_size=8,shuffle = False)
        return test_dataloader

label_map = {"BMP4" :0, "CHIR": 1, "DS": 2, "DS+CHIR": 3,  "WT": 4}
input_shape = (3, 224, 224)
num_classes = 5

# Define the data transformations
transform = transforms.Compose([
    transforms.Resize(input_shape[1:]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
#-imRroot is the directory of the image folders
#-time is the timepoint at which we are training/testing
#-condFileTr is the .csv file which has the treatment condition for each sample (for training)
#-condFileTe is the .csv file which has the treatment condition for each sample (for testing)
#    the different condition file names are for cases like say, you want to train the model
#    on randomized labels and test on the actual labels, as a control
#-train is whether we are training a model, or just testing an already existing one
#-saveModel is whether to save the model
#-modelName is the filename of the model if saved
def train_at_Tmpt(t,condFileTr,condFileTe,saveModel = 1, modelName = 'MyResNetModel.pt'):
    tOI = strTmpt(t)
    model = models.resnet18(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    model.fc = nn.Identity()
    model.fc.requires_grad = True
    model.fc = nn.Linear(512, num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
    model = model.to(device)
    original_stderr = sys.stderr
    sys.stderr = open(os.devnull, 'w')

    dataloaders = get_dataloaders(imRoot,t,condFileTr)
    model_trained, loss_values = train_model(model, criterion, optimizer, dataloaders, num_epochs=50)
    dataloaders_actuallabel = get_dataloaders(imRoot,t,condFileTe)
    accuracy, precision, recall, fOne = get_metrics(model_trained, dataloaders_actuallabel['val'])
    if saveModel==1:
        torch.save(model_trained, modelName)
    sys.stderr = original_stderr
    timestamp:  1
    return accuracy, precision, recall, fOne, model_trained