In [None]:
#Import required packages
import cv2 as cv
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
from torchvision import transforms
from torch.utils.data import DataLoader

from Helper_functions import load_dict

from AI_functions import resnet18,CellDataset_supervised,data_generator
import torch.nn as nn
import torch.optim as optim

from Helper_functions import log_pol_scale, log_pol
import seaborn as sns


In [None]:
DATASETS=["Retina_1_2","Retina_0_0","Colon","Choroid","BoneMarrow_sample1"]

LABELS_DICT={}


train_data_list=[]
test_data_list=[]

cutoff=30000


for N,DATASET in enumerate(DATASETS):
    basepath = r"C:\Users\Thibaut Goldsborough\Documents\Seth_BoneMarrow\Data\\" +DATASET
    outpath = basepath + "\\Outputs"
    image_dim=64 #Dim of the final images
    nuclear_channel="Ch7"
    cellmask_channel="Ch1_mask"
    df=pd.read_csv(outpath+"\\cell_info1.csv")
    cell_names=df["Cell_ID"].to_numpy()
    Prediction_Channels=['Ch07']

    image_dict=load_dict(outpath,cell_names,image_dim)

    Channels=['Ch1']  #Channel to be fed to the NN
    images_with_index = []

    for image_i in image_dict:
        #print(image_dict[image_i].keys())
        if len(image_dict[image_i].keys())>=len(Channels):
            image=cv.merge([image_dict[image_i][i] for i in Channels])
            images_with_index.append((int(image_i),image))
        else:
            print(image_i)
        
    images=np.array([image[1] for image in images_with_index])
    names=np.array([image[0] for image in images_with_index])
    assert sum(names!=df['Cell_ID'].to_numpy()) ==0  #Check that the order has been preserved
    DNA_pos=df["DNA_pos"].to_numpy()
    Touches_Boundary=df["Touches_boundary"].to_numpy()
    tissue=np.array([N for i in names]) #Labels are the tissues
    # labels=df[["Intensity_MC_"+channel for channel in Prediction_Channels]].to_numpy()
    labels=df[["Scaled_"+channel for channel in Prediction_Channels]].to_numpy()
    # #labels=log_pol_scale(labels,slope=1,c=1000)
    # logs=log_pol(labels,slope=1,c=1000)
    # labels=(logs-6.21703)/1.7187  #These were determined using a custom script

    Thresh=50
    plt.hist(df["Gradient RMS_M01_Ch01"],bins=200);
    plt.axvline(x=Thresh,color="red")
    plt.show()
    idx_to_keep=np.array(Touches_Boundary==0,dtype=int)+np.array(df["Gradient RMS_M01_Ch01"]>Thresh,dtype=int)==2 #np.array(DNA_pos==1,dtype=int)+
    #Filter
    print(len(images))
    images=images[idx_to_keep]
    names=names[idx_to_keep]
    labels=labels[idx_to_keep]

    images=images[:cutoff]
    names=names[:cutoff]
    labels=labels[:cutoff]

    LABELS_DICT[DATASET]=labels

    plt.hist(labels[:,0],bins=100);
    plt.show()


    mini=int(round(abs(np.array(images).min()),0))
    images=images+abs(np.array(images).min())
    mean=np.array(images).mean()
    maxi=np.array(images).max()
    std=np.array(images).std()

    print(len(images))

    


    returned=data_generator(images,labels,names,mini,train_test_split = 0.8,batch_size = 100,sample=False)

    
    train,test,batch_size,mean_loader,std_loader=returned
    print(mean_loader,std_loader,mini)

   # mean_loader=0.2
   # std_loader=0.02
    [train_data,train_data1,train_labels,train_ID]=train
    [test_data,test_data1,test_labels,test_ID]=test
    transform_train = transforms.Compose(
        [transforms.ToPILImage('L'),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(degrees=18,fill=mini),
      #  transforms.Lambda(lambda x: np.array(x) -mini ),  #This was added on 12/07
        transforms.ToTensor(),
        transforms.Normalize(mean_loader,std_loader)  ,
        transforms.Lambda(lambda x: x-x[0,0] ), 
        ])

    transform_validation = transforms.Compose(
        [transforms.ToPILImage('L'),
 
        transforms.ToTensor(),
        transforms.Normalize(mean_loader,std_loader),
        transforms.Lambda(lambda x: x-x[0,0] ), 
        ])


    train_data = CellDataset_supervised(train_data1,train_labels,train_ID, transform_train)
    test_data = CellDataset_supervised(test_data1,test_labels, test_ID,transform_validation)

    train_data_list.append(train_data)
    test_data_list.append(test_data)



from torch.utils.data import ConcatDataset
train_cat = ConcatDataset([train_data for train_data in train_data_list])
train_cat_loader = DataLoader(train_cat, batch_size=524,shuffle=True,drop_last=True)


test_cat = ConcatDataset([test_data for test_data in test_data_list])
test_cat_loader = DataLoader(test_cat, batch_size=524,shuffle=True)


    

In [None]:
# from Helper_functions import log_pol
# for N,DATASET in enumerate(DATASETS):
#     basepath = r"C:\Users\Thibaut Goldsborough\Documents\Seth_BoneMarrow\Data\\" +DATASET
#     outpath = basepath + "\\Outputs"
#     df=pd.read_csv(outpath+"\\cell_info.csv")
#     labels=df[["Intensity_MC_"+channel for channel in Prediction_Channels]].to_numpy()
#     logs=log_pol(labels,slope=1,c=1000)
#     if N==0:
#         LABELS=logs.copy()
#     else:
#         print(np.shape(LABELS),np.shape(logs))
#         LABELS=np.vstack((LABELS,logs))
#     plt.hist(logs,bins=100);
#     plt.show()

# plt.hist(LABELS,bins=1000);
# print(np.mean(LABELS))
# print(np.std(LABELS))

In [None]:
def train_epoch(NN, device, dataloader, loss_fn, optimizer,noise_factor=0):
    NN.train()
    train_loss = []
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for image_batch,labels_batch,_ in dataloader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
        image_noisy = image_batch
        image_batch = image_noisy.to(device)
        labels_batch=labels_batch.to(device)
        output = NN(image_batch)
        # Evaluate loss
        loss = loss_fn(labels_batch,output)
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss.append(loss.detach().cpu().numpy())
    return np.mean(train_loss)

def spearman(x,y):
    vx = x - torch.mean(x)
    vy = y - torch.mean(y)
    cost = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)))
    return cost

### validationing function
def validation_epoch(NN, device, dataloader, loss_fn):
    NN.eval()
    val_loss=[]
    with torch.no_grad(): # No need to track the gradients
        for image_batch,labels_batch,ID_batch in dataloader:
            # Move tensor to the proper device
            image_batch = image_batch.to(device)
            labels_batch = labels_batch.to(device)
            output = NN(image_batch)
         #   loss = loss_fn(labels_batch,output,)
            loss=spearman(labels_batch,output)
            val_loss.append(loss.detach().cpu().numpy())

    return np.mean(val_loss)


In [None]:
train_results_matrix=np.zeros((len(train_data_list),len(test_data_list)))
test_results_matrix=np.zeros((len(train_data_list),len(test_data_list)))
num_epochs=10
lr=1e-3
num_classes=len(Prediction_Channels)

models=[]

for i,_ in enumerate(train_data_list):
    print(DATASETS[i])
    ConvNet_simple=resnet18(channel_num=len(Channels),num_classes=num_classes)
    #ConvNet_simple=Trained_model

    device = "cuda" if torch.cuda.is_available() else "cpu"
    ConvNet_simple.to(device)
    loss_dict={"L1":nn.L1Loss(),"MSE":nn.MSELoss()}

    loss_fn = loss_dict["L1"]
    #optimizer = optim.Adam(ConvNet_simple.parameters(), lr = lr) 
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, ConvNet_simple.parameters()),lr=lr)
    diz_loss = {'train_loss':[],'val_loss':[]}

    train_loader = DataLoader(train_data_list[i], batch_size=128,shuffle=True,drop_last=True)
    print("background is ",next(iter(train_loader))[0][0][0][0][0])

    for epoch in range(num_epochs):

        train_loss = train_epoch(ConvNet_simple,device,train_loader,loss_fn,optimizer)

        if epoch > num_epochs-3:
            for j,_ in enumerate(test_data_list):
                validation_loader = DataLoader(test_data_list[j], batch_size=128,shuffle=True,drop_last=True)
                val_loss = validation_epoch(ConvNet_simple,device,validation_loader,loss_fn)

                train_results_matrix[i,j]+=train_loss
                test_results_matrix[i,j]+=val_loss
                print(DATASETS[j],val_loss)
            fig, ax = plt.subplots(figsize=(7,7))
        print('\n EPOCH',epoch+1,' \t train loss',train_loss)
    ax.set_aspect('equal')
    LABELS=["Retina","Colon","Choroid","BM"]
    s=sns.heatmap(test_results_matrix/2,xticklabels=LABELS,yticklabels=LABELS,cmap='Spectral_r',cbar_kws={'label': 'L1 Loss','shrink':0.7})
    s.set(xlabel='Test Datasets', ylabel='Train Datasets')
    plt.show()

    models.append(ConvNet_simple)


In [None]:
DATASETS=["Retina_1_2","Colon","Choroid","BoneMarrow_sample1","Retina_0_0"]
IMS=[]
OUTS=[]
INPUTS=[]
for j,test in enumerate(test_data_list):
    validation_loader = DataLoader(test_data_list[j], batch_size=128,shuffle=True,drop_last=True)
    for i,model in enumerate(models):
        inputs=[]
        outputs=[]
        images=[]
        with torch.no_grad(): # No need to track the gradients
                for image_batch,labels_batch,ID_batch in validation_loader:
                    # Move tensor to the proper device
                    image_batch = image_batch.to(device)
                    labels_batch = labels_batch.to(device)
                    output = model(image_batch)
                    inputs+=list(labels_batch.cpu().flatten().numpy())
                    outputs+=list(output.cpu().flatten().numpy())
                    images+=list(image_batch.cpu().numpy())
        
        # plt.figure(figsize=(10,10))
        # plt.scatter(torch.Tensor(inputs),torch.Tensor(outputs),s=0.5)
        # plt.title(str("Trained on:"+DATASETS[i]+"Tested on:"+DATASETS[j]))
        # plt.show()
        # print(spearman(torch.Tensor(inputs),torch.Tensor(outputs)))

        if i==4 and j==3:
            IMS.append(images)
            OUTS.append(outputs)
            INPUTS.append(inputs)
    print(len(np.unique(image_batch.cpu().flatten().numpy())))
    plt.hist(inputs,bins=100)


In [None]:

#plt.hist(INPUTS[0],bins=100);
plt.hist(OUTS[0],bins=100);
plt.show()

plt.scatter(INPUTS[0],OUTS[0],s=0.1)


In [None]:
a=np.array(IMS[0])[np.array(OUTS[0])>0.25]

In [None]:
for i in range(23):
    fig,(ax1,ax2)=plt.subplots(1,2)
    ax1.imshow(a[i][0],vmin=-10,vmax=10)
    print(a[i][0][0,0])
    ax2.imshow(IMS[0][i][0],vmin=-10,vmax=10)
    plt.show()


In [None]:
a=np.zeros(np.shape(inputs))
a+=np.mean(inputs)
spearman(torch.Tensor(a),torch.Tensor(outputs))

In [None]:
plt.scatter(inputs,outputs,s=1)

In [None]:
fig, ax = plt.subplots(figsize=(7,7))
ax.set_aspect('equal')
LABELS=["Retina_1","Retina_0","Colon","Choroid","BM"]
s=sns.heatmap(test_results_matrix/2,xticklabels=LABELS,yticklabels=LABELS,cmap='Spectral_r',cbar_kws={'label': 'Spearman correlation','shrink':0.7},annot=True,linewidths=.5)
s.set(xlabel='Test Datasets', ylabel='Train Datasets')
#plt.savefig("conf_mat_tissues_corr.jpg",dpi=1000,bbox_inches='tight')
