In [None]:
from typing import List

import matplotlib.pyplot as plt
from torchvision import io, transforms
from torchvision.utils import Image, ImageDraw
from torchvision.transforms.functional import to_pil_image
import numpy as np

%matplotlib inline

## import sample image

In [None]:
img = Image.open('/scratch/st-singha53-1/datasets/geomx/dkd/geomx_pngs/disease1B_scan/disease1B_scan - 001.png').convert('RGB')
transform = transforms.Compose([
 transforms.ToTensor()
])
img = transform(img)
img.shape

#### Channels:	
 - FITC/525 nm : SYTO 13 : DNA (Grey)
 - Cy3/568 nm : Alexa 532 : PanCK (Yellow)
 - Texas Red/615 nm : Alexa 594 : CD45 (Cyan)
 - Cy5/666 nm : Cy5 : Custom (Magenta)

**SYTO** Deep Red Nucleic Acid Stain is cell-permeant dye that specifically stains the nuclei of live, dead, or fixed cells.

**pan-CK** (AE1/AE3) and EMA are epithelium-specific antibodies. As the basic component of cellular structure of normal epithelial cells and epithelial cancer cells, they are often used to differentiate tumors according to whether they originate from the epithelium or not.

**CD45** is a signalling molecule that is an essential regulator of T and B cell antigen receptor signalling.

**CD10+CD31** – Proximal nephrons and endothelial cells (Custom)

In [None]:
to_pil_image(img)

In [None]:
img = Image.open('/scratch/st-singha53-1/datasets/geomx/dkd/geomx_pngs/disease1B_scan/disease1B_scan - 001.png').convert('RGB')
transform = transforms.Compose([
 transforms.ToTensor()
])
img = transform(img)
img.shape

In [None]:
to_pil_image(img)

## image resize

In [None]:
f = 3
IMG_SIZE = 256*f
PATCH_SIZE = 256

resize = transforms.Resize((IMG_SIZE, IMG_SIZE))
resized_img = resize(img)
resized_img.shape

In [None]:
to_pil_image(resized_img)

## Patches

In [None]:
patches = resized_img.unfold(1, PATCH_SIZE, PATCH_SIZE).unfold(2, PATCH_SIZE, PATCH_SIZE)

dataset = []

fig, ax = plt.subplots(f, f, figsize=(16, 16))
for i in range(f):
    for j in range(f):
        sub_img = patches[:, i, j]
        dataset.append(sub_img.unsqueeze(0))
        ax[i][j].imshow(to_pil_image(sub_img))
        ax[i][j].axis('off')

In [None]:
import os

os.getcwd()
!ls -lahtr

In [None]:
import os

os.getcwd()

In [None]:
import torch
import torchvision
from torchvision import models
from torchvision import transforms

torch.manual_seed(0)
np.random.seed(0)

model = models.alexnet()
model.load_state_dict(torch.load('alexnet-owt-4df8aa71.pth'))
model.eval()

In [None]:
import torch
import torch.nn.functional as F

from PIL import Image

import os
import json
import numpy as npd
from matplotlib.colors import LinearSegmentedColormap

import torchvision
from torchvision import models
from torchvision import transforms

from captum.attr import IntegratedGradients
from captum.attr import GradientShap
from captum.attr import Saliency
from captum.attr import NoiseTunnel
from captum.attr import visualization as viz


In [None]:
# !wget -P / https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json
labels_path = 'imagenet_class_index.json'
with open(labels_path) as json_data:
    idx_to_labels = json.load(json_data)

In [None]:
input = dataset[0]

output = model(input)
output = F.softmax(output, dim=1)
prediction_score, pred_label_idx = torch.topk(output, 1)

pred_label_idx.squeeze_()
predicted_label = idx_to_labels[str(pred_label_idx.item())][1]
print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')')

In [None]:
# Create IntegratedGradients object and get attributes
integrated_gradients = IntegratedGradients(model)
attributions_ig = integrated_gradients.attribute(input, target=pred_label_idx, n_steps=200)

# create custom colormap for visualizing the result
default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                 [(0, '#ffffff'),
                                                  (0.25, '#000000'),
                                                  (1, '#000000')], N=256)


# visualize the results using the visualize_image_attr helper method
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)),
                             np.transpose(input.squeeze().cpu().detach().numpy(), (1,2,0)),
                             methods=["original_image", "heat_map"],
                             signs=['all', 'positive'],
                             cmap=default_cmap,
                             show_colorbar=True)

## train ResNet with trian and validation set of images

In [None]:
def create_datasets(dict_images, n_patches = 9, ref_group = 'normal'):
    PATCH_SIZE = 256
    f = int(np.sqrt(n_patches))
    IMG_SIZE = PATCH_SIZE * f
    resize = transforms.Resize((IMG_SIZE, IMG_SIZE))
    transform = transforms.Compose([
     transforms.ToTensor()
    ])
    
    images = {}
    for group in dict_images.keys():
        dataset = []
        for img in dict_images[group]:
            ## import image
            img = Image.open('geomx_data/'+img+'.png').convert('RGB')
            ## convert img to tensor
            img = transform(img)
            ## resize image
            resized_img = resize(img)
            ## create patches
            patches = resized_img.unfold(1, PATCH_SIZE, PATCH_SIZE).unfold(2, PATCH_SIZE, PATCH_SIZE)
            ## reshape data
            for i in range(f):
                for j in range(f):
                    sub_img = patches[:, i, j]
                    if group == ref_group:
                        data_target = (sub_img, 0)
                    else:
                        data_target = (sub_img, 1)
                    dataset.append(data_target)
            images[group] = dataset
    return images

class Custom_Dataset(torch.utils.data.dataset.Dataset):
    def __init__(self, _dataset):
        self.dataset = _dataset

    def __getitem__(self, index):
        example, target = self.dataset[index]
        return example, target

    def __len__(self):
        return len(self.dataset)

In [None]:
from os import listdir
from os.path import isfile, join

path_to_images = "/scratch/st-singha53-1/datasets/geomx/dkd/geomx_pngs/disease1B_scan/"
onlyfiles = [f for f in listdir(path_to_images) if isfile(join(path_to_images, f))]

onlyfiles



In [None]:
import csv

In [None]:
def sread(filename: str):
   
    losd = [] 

    with open(filename) as csvfile:
        
        reader = csv.reader(csvfile, delimiter='\t')
        
        for row in reader:
            losd.append(row)

    
    return losd

In [None]:
def change_image_name(f):
    
    x = f.split()
    image_type = x[0]
    image_number = x[2][0:3]
    
    return image_type + " " + "|" + " " + image_number + " " + "|" + " " + "Geometric Segment"

In [None]:
thetest1 = sread('Kidney_Q3Norm_TargetCountMatrix.txt')
thetest1

In [None]:
import pandas as pd

eset = pd.read_csv('Kidney_Q3Norm_TargetCountMatrix.txt', delimiter="\t", index_col="TargetName")
eset.head()

In [None]:
eset = pd.read_csv('Kidney_Q3Norm_TargetCountMatrix.txt', delimiter="\t", index_col="TargetName")

In [None]:

[i for i in eset.index if i.startswith('FRM')]

In [None]:
[predict_gene_exp(gene_name) for gene in list(eset.index)]

In [None]:
def predict_gene_exp(gene_name):
    

In [None]:
def corresponding_gene_expression_number(image_name):
    eset = pd.read_csv('Kidney_Q3Norm_TargetCountMatrix.txt', delimiter="\t", index_col="TargetName")
    x = change_image_name(image_name)
    return eset[change_image_name(x)]['FRMD3']

In [None]:
def is_image_name_in_list(f):
    thetest1 = sread('Kidney_Q3Norm_TargetCountMatrix.txt')
    image_name_list = thetest1[0]
    x = change_image_name(f)
    for image_name in image_name_list:
        if (image_name == x):
            return True
    return False
    

In [None]:
def image_list(loi):
    newfiles = []
    for i in loi:
        if (is_image_name_in_list(i)):
            newfiles.append(i)
    return newfiles
            

In [None]:
def process_image(path_to_images, image_name, group, ref_group):
    ## import image
    img = Image.open(path_to_images+image_name).convert('RGB')
    ## convert img to tensor
    img = transform(img)
    ## resize image
    resized_img = resize(img)
    ## create patches
    patches = resized_img.unfold(1, PATCH_SIZE, PATCH_SIZE).unfold(2, PATCH_SIZE, PATCH_SIZE)
    dataset=[]
    gene_num = corresponding_gene_expression_number(image_name)
    ## reshape data
    for i in range(f):
        for j in range(f):
            sub_img = patches[:, i, j]
            data_target = (sub_img, gene_num)
            dataset.append(data_target)
    return dataset

In [None]:
def create_datasets(dict_images={}, n_patches = 9, ref_group = 'normal', path_to_images=[]):
    PATCH_SIZE = 256
    f = int(np.sqrt(n_patches))
    IMG_SIZE = PATCH_SIZE * f
    resize = transforms.Resize((IMG_SIZE, IMG_SIZE))
    transform = transforms.Compose([
     transforms.ToTensor()
    ])
    images = {}
    for group in dict_images.keys():
        for sample in dict_images[group]:
            path = path_to_images+sample+"/"
            onlyfiles = [f for f in listdir(path) if isfile(join(path, f))]
            selectedfiles = image_list(onlyfiles)
            l = [process_image(path, image_name, group, ref_group) for image_name in selectedfiles]
            flat_list = [item for sublist in l for item in sublist]
            images[group] = flat_list
    return images

class Custom_Dataset(torch.utils.data.dataset.Dataset):
    def __init__(self, _dataset):
        self.dataset = _dataset

    def __getitem__(self, index):
        example, target = self.dataset[index]
        return example, target

    def __len__(self):
        return len(self.dataset)

In [None]:
train = {'dkd': ['disease1B_scan', 'disease2B_scan'],
         'normal': ['normal2B_scan']}
valid = {'dkd': ['disease3_scan'],
         'normal': ['normal3_scan']}
test = {'dkd': ['disease4_scan'],
        'normal': ['normal4_scan']}

In [None]:
path_to_images = "/scratch/st-singha53-1/datasets/geomx/dkd/geomx_pngs/"

print("train loader")
train_datasets = create_datasets(dict_images = train, n_patches = 4, ref_group = 'normal', path_to_images=path_to_images)
train_loader = torch.utils.data.DataLoader(dataset=Custom_Dataset(train_datasets['dkd'] + train_datasets['normal']),
                                   batch_size=1,
                                   shuffle=False)

print("validation loader")
valid_datasets = create_datasets(dict_images = valid, n_patches = 4, ref_group = 'normal', path_to_images=path_to_images)
valid_loader = torch.utils.data.DataLoader(dataset=Custom_Dataset(valid_datasets['dkd'] + valid_datasets['normal']),
                                   batch_size=1,
                                   shuffle=False)

print("test loader")
test_datasets = create_datasets(dict_images = test, n_patches = 4, ref_group = 'normal', path_to_images=path_to_images)
test_loader = torch.utils.data.DataLoader(dataset=Custom_Dataset(test_datasets['dkd'] + test_datasets['normal']),
                                   batch_size=1,
                                   shuffle=False)

In [None]:
len(train_loader)

In [None]:
import os

os.getcwd()

In [None]:
from torch import nn, optim

# Hyperparameters
in_channel = 3
num_classes = 2
learning_rate = 1e-3
num_epochs = 10

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model
torch.manual_seed(0)
np.random.seed(0)

net = models.alexnet()
net.load_state_dict(torch.load('alexnet-owt-4df8aa71.pth'))
net

In [None]:
net.classifier[6] = nn.Linear(4096, 1)
net

In [None]:
for param in net.parameters():
    param.requires_grad = False
net.classifier[6].weight.requires_grad = True
net.classifier[6].bias.requires_grad = True
net.to(device)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
net

In [None]:
for param in net.parameters():
    print(param.requires_grad)

In [None]:
data_, target_ = train_loader.dataset[0][0], train_loader.dataset[0][1]

In [None]:
net

In [None]:
data_, target_ = train_loader.dataset[0][0], train_loader.dataset[0][1]
data_ = torch.FloatTensor(data_).unsqueeze(0)
target_ = torch.FloatTensor(np.array(target_))
data_, target_ = data_.to(device), target_.to(device)
net.to(device)

type(data_), type(target_), print(device)
# outputs = net(data_)

In [None]:
target_

In [None]:
outputs = net(data_)
outputs

In [None]:
n_epochs = 30
print_every = 10
valid_loss_min = np.Inf
val_loss = []
val_acc = []
train_loss = []
train_acc = []
total_step = len(train_loader)
net.train()
for epoch in range(1, n_epochs+1):
    running_loss = 0.0
    correct = 0
    total=0
    print(f'Epoch {epoch}\n')
    for batch_idx, (data_, target_) in enumerate(train_loader):
        data_ = torch.FloatTensor(data_)
        target_ = torch.FloatTensor(np.array(target_))
        data_, target_ = data_.to(device), target_.to(device)
        optimizer.zero_grad()
        
        outputs = net(data_)
        target_ = target_.to(torch.float)
        loss = criterion(outputs, target_)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    train_loss.append(running_loss/len(train_loader))
    print(f'\ntrain-loss: {np.mean(train_loss):.4f}')
    batch_loss = 0
    total_t=0
    correct_t=0
    with torch.no_grad():
        net.eval()
        for data_t, target_t in (valid_loader):
            data_t = torch.FloatTensor(data_t)
            target_t = torch.FloatTensor(np.array(target_t))
            data_t, target_t = data_t.to(device), target_t.to(device)
            outputs_t = net(data_t)
            loss_t = criterion(outputs_t, target_t)
            batch_loss += loss_t.item()
            total_t += target_t.size(0)
        val_loss.append(batch_loss/len(valid_loader))
        network_learned = batch_loss < valid_loss_min
        print(f'validation loss: {np.mean(val_loss):.4f}')

        
        if network_learned:
            valid_loss_min = batch_loss
            torch.save(net.state_dict(), 'model.pt')
            print('Improvement-Detected, save-model')
    net.train()

In [None]:
fig = plt.figure(figsize=(20,10))
plt.title("Train-Validation Accuracy")
plt.plot(train_loss, label='train')
plt.plot(val_loss, label='validation')
plt.xlabel('num_epochs', fontsize=12)
plt.ylabel('MSE loss', fontsize=12)
plt.legend(loc='best')

In [None]:
pred = []
actual = []
with torch.no_grad():
    net.eval()
    for data_t, target_t in (test_loader):
        data_t = torch.FloatTensor(data_t)
        target_t = torch.FloatTensor(np.array(target_t))
        data_t, target_t = data_t.to(device), target_t.to(device)
        outputs_t = net(data_t)
        loss_t = criterion(outputs_t, target_t)
        batch_loss += loss_t.item()
        total_t += target_t.size(0)
        pred.append(outputs_t.item())
        actual.append(target_t.item())

In [None]:
import numpy as np
import matplotlib.pyplot as plt

#define data
x = np.array(actual)
y = np.array(pred)

#find line of best fit
a, b = np.polyfit(x, y, 1)

#add points to plot
plt.scatter(x, y, color='purple')

#add line of best fit to plot
plt.plot(x, a*x+b, color='steelblue', linestyle='--', linewidth=2)

#add fitted regression equation to plot
plt.text(40, 35, 'y = ' + '{:.2f}'.format(b) + ' + {:.2f}'.format(a) + 'x', size=14)

In [None]:
np.corrcoef(actual, pred)