In [63]:
import torch
import matplotlib.pyplot as plt
from torchvision import io, transforms
from torchvision.utils import Image, ImageDraw
from torchvision.transforms.functional import to_pil_image
import numpy as np
%matplotlib inline

In [64]:
img = Image.open('/scratch/st-singha53-1/datasets/geomx/dkd/geomx_pngs/disease1B_scan/disease1B_scan - 001.png').convert('RGB')
transform = transforms.Compose([
    transforms.ToTensor()])
img = transform(img)
img.shape

torch.Size([3, 2508, 1906])

#### Channels:
- FITC/525 nm : SYTO 13 : DNA (Grey)
- Cy3/568 nm : Alexa 532 : PanCK (Yellow)
- Texas Red/615 nm : Alexa 594 : CD45 (Cyan)
- Cy5/666 nm : Cy5 : Custom (Magenta)

**SYTO** Deep Red Nucleic Acid Stain is cell-permeant dye that specifically stains the nuclei of live, dead, or fixed cells.

**pan-CK** (AE1/AE3) and EMA are epithelium-specific antibodies. As the basic component of cellular structure of normal epithelial cells and epithelial cancer cells, they are often used to differentiate tumors according to whether they originate from the epithelium or not.

**CD45** is a signalling molecule that is an essential regulator of T and B cell antigen receptor signalling.
**CD10+CD31** – Proximal nephrons and endothelial cells (Custom)

In [65]:
def create_datasets(dict_images, n_patches = 9, ref_group = 'normal'):
    PATCH_SIZE = 256
    f = int(np.sqrt(n_patches))
    IMG_SIZE = PATCH_SIZE * f
    resize = transforms.Resize((IMG_SIZE, IMG_SIZE))
    transform = transforms.Compose([
      transforms.ToTensor()
    ])
    
    images = {}
    for group in dict_images.keys():
        dataset = []
        for img in dict_images[group]:
            ## import image
            img = Image.open('geomx_data/'+img+'.png').convert('RGB')
            ## convert img to tensor
            img = transform(img)
            ## resize image
            resized_img = resize(img)
            ## create patches
            patches = resized_img.unfold(1, PATCH_SIZE, PATCH_SIZE).unfold(2, PATCH_SIZE, PATCH_SIZE)
            ## reshape data
            for i in range(f):
                for j in range(f):
                    sub_img = patches[:, i, j]
                    if group == ref_group:
                        data_target = (sub_img, 0)
                    else:
                        data_target = (sub_img, 1)
                    dataset.append(data_target)
            images[group] = dataset
    return images

class Custom_Dataset(torch.utils.data.dataset.Dataset):
    def __init__(self, _dataset):
        self.dataset = _dataset

    def __getitem__(self, index):
        example, target = self.dataset[index]
        return example, target

    def __len__(self):
        return len(self.dataset)

In [66]:
from os import listdir
from os.path import isfile, join

path_to_images = "/scratch/st-singha53-1/datasets/geomx/dkd/geomx_pngs/disease1B_scan/"
onlyfiles = [f for f in listdir(path_to_images) if isfile(join(path_to_images, f))]

onlyfiles

['disease1B_scan - 002.png',
 'disease1B_scan - 023.png',
 'disease1B_scan - 024.png',
 'disease1B_scan - 021.png',
 'disease1B_scan - 006.png',
 'disease1B_scan - 005.png',
 'disease1B_scan - 004.png',
 'disease1B_scan - 019.png',
 'disease1B_scan - 016.png',
 'disease1B_scan - 012.png',
 'disease1B_scan - 022.png',
 'disease1B_scan - 020.png',
 'disease1B_scan - 001.png',
 'disease1B_scan - 003.png',
 'disease1B_scan - 007.png',
 'disease1B_scan - 010.png',
 'disease1B_scan - 015.png',
 'disease1B_scan - 014.png',
 'disease1B_scan - 017.png',
 'disease1B_scan - 009.png',
 'disease1B_scan - 013.png',
 'disease1B_scan - 008.png',
 'disease1B_scan - 018.png',
 'disease1B_scan - 011.png']

In [67]:
def process_image(path_to_images, x, group, ref_group, n_patches):
    PATCH_SIZE = 256
    f = int(np.sqrt(n_patches))
    IMG_SIZE = PATCH_SIZE * f
    resize = transforms.Resize((IMG_SIZE, IMG_SIZE))
    transform = transforms.Compose([
     transforms.ToTensor()
    ])
    ## import image
    img = Image.open(path_to_images+x).convert('RGB')
    ## convert img to tensor
    img = transform(img)
    ## resize image
    resized_img = resize(img)
    ## create patches
    patches = resized_img.unfold(1, PATCH_SIZE, PATCH_SIZE).unfold(2, PATCH_SIZE, PATCH_SIZE)
    dataset=[]
    ## reshape data
    for i in range(f):
        for j in range(f):
            sub_img = patches[:, i, j]
            if group == ref_group:
                data_target = (sub_img, 0)
            else:
                data_target = (sub_img, 1)
            dataset.append(data_target)
    return dataset

In [68]:
def create_datasets(dict_images={}, n_patches = 9, ref_group = 'normal', path_to_images=[]):
    images = {}
    for group in dict_images.keys():
        for sample in dict_images[group]:
            path = path_to_images+sample+"/"
            onlyfiles = [f for f in listdir(path) if isfile(join(path, f))]
            l = [process_image(path, x, group, ref_group, n_patches) for x in onlyfiles]
            flat_list = [item for sublist in l for item in sublist]
            images[group] = flat_list
    return images

class Custom_Dataset(torch.utils.data.dataset.Dataset):
    def __init__(self, _dataset):
        self.dataset = _dataset

    def __getitem__(self, index):
        example, target = self.dataset[index]
        return example, target

    def __len__(self):
        return len(self.dataset)

In [69]:
train = {'dkd': ['disease1B_scan', 'disease2B_scan'],
         'normal': ['normal2B_scan']}
valid = {'dkd': ['disease3_scan'],
         'normal': ['normal3_scan']}
test = {'dkd': ['disease4_scan'],
        'normal': ['normal4_scan']}

In [None]:
path_to_images = "/scratch/st-singha53-1/datasets/geomx/dkd/geomx_pngs/"

print("train loader")
train_datasets = create_datasets(dict_images = train, n_patches = 9, ref_group = 'normal', path_to_images=path_to_images)
train_loader = torch.utils.data.DataLoader(dataset=Custom_Dataset(train_datasets['dkd'] + train_datasets['normal']),
                                           batch_size=1,
                                           shuffle=False)
print("validation loader")
valid_datasets = create_datasets(dict_images = valid, n_patches = 9, ref_group = 'normal', path_to_images=path_to_images)
valid_loader = torch.utils.data.DataLoader(dataset=Custom_Dataset(valid_datasets['dkd'] + valid_datasets['normal']),
                                           batch_size=1,
                                           shuffle=False)
print("test loader")
test_datasets = create_datasets(dict_images = test, n_patches = 9, ref_group = 'normal', path_to_images=path_to_images)
test_loader = torch.utils.data.DataLoader(dataset=Custom_Dataset(test_datasets['dkd'] + test_datasets['normal']),
                                           batch_size=1,
                                           shuffle=False)

train loader
validation loader


In [None]:
from torch import nn, optim
from torchvision import models

# Hyperparameters
in_channel = 3
num_classes = 2
learning_rate = 1e-3
num_epochs = 10

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model
torch.manual_seed(0)
np.random.seed(0)

net = models.alexnet()

net.load_state_dict(torch.load('../alexnet-owt-4df8aa71.pth'))
net

In [None]:
net.classifier[6] = nn.Linear(4096, 2)
net

In [None]:
for param in net.parameters():
    param.requires_grad = False
net.classifier[6].weight.requires_grad = True
net.classifier[6].bias.requires_grad = True
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
net

In [None]:
for param in net.parameters():
    print(param.requires_grad)

In [None]:
def accuracy(out, labels):
    _,pred = torch.max(out, dim=1)
    return torch.sum(pred==labels).item()

net.to(device)
net.train()

n_epochs = 50
print_every = 10
valid_loss_min = np.Inf
val_loss = []
val_acc = []
train_loss = []
train_acc = []
total_step = len(train_loader)
for epoch in range(1, n_epochs+1):
    running_loss = 0.0
    correct = 0
    total=0
    print(f'Epoch {epoch}\n')
    for batch_idx, (data_, target_) in enumerate(train_loader):
        data_, target_ = data_.to(device), target_.to(device)
        optimizer.zero_grad()
        
        outputs = net(data_)
        loss = criterion(outputs, target_)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _,pred = torch.max(outputs, dim=1)
        correct += torch.sum(pred==target_).item()
        total += target_.size(0)
        if (batch_idx) % 20 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch, n_epochs, batch_idx, total_step, loss.item()))
    train_acc.append(100 * correct / total)
    train_loss.append(running_loss/total_step)
    print(f'\\ntrain-loss: {np.mean(train_loss):.4f}, train-acc: {(100 * correct/total):.4f}')
    batch_loss = 0
    total_t=0
    correct_t=0
    with torch.no_grad():
        net.eval()
        for data_t, target_t in (valid_loader):
            data_t, target_t = data_t.to(device), target_t.to(device)
            outputs_t = net(data_t)
            loss_t = criterion(outputs_t, target_t)
            batch_loss += loss_t.item()
            _,pred_t = torch.max(outputs_t, dim=1)
            correct_t += torch.sum(pred_t==target_t).item()
            total_t += target_t.size(0)
        val_acc.append(100 * correct_t/total_t)
        val_loss.append(batch_loss/len(valid_loader))
        network_learned = batch_loss < valid_loss_min
        print(f'validation loss: {np.mean(val_loss):.4f}, validation acc: {(100 * correct_t/total_t):.4f}\\n')

        
        if network_learned:
            valid_loss_min = batch_loss
            torch.save(net.state_dict(), 'model.pt')
            print('Improvement-Detected, save-model')
    net.train()

In [None]:
fig = plt.figure(figsize=(20,10))
plt.title("Train-Validation Accuracy")
plt.plot(train_acc, label='train')
plt.plot(val_acc, label='validation')
plt.xlabel('num_epochs', fontsize=12)
plt.ylabel('accuracy', fontsize=12)
plt.legend(loc='best')
plt.ylim(0,110)
plt.axhline(y=50, color='gray', linestyle='--')

In [None]:
total_t=0
correct_t=0
with torch.no_grad():
    net.eval()
    for data_t, target_t in (test_loader):
        data_t, target_t = data_t.to(device), target_t.to(device)
        outputs_t = net(data_t)
        loss_t = criterion(outputs_t, target_t)
        batch_loss += loss_t.item()
        _,pred_t = torch.max(outputs_t, dim=1)
        correct_t += torch.sum(pred_t==target_t).item()
        total_t += target_t.size(0)
        print(pred_t.item(), target_t.item())
        
100 * correct_t/total_t