In [1]:
from morphocycle.models.coatnet import create_model

In [4]:


model = create_model(num_classes=4,
                    pretrained=True,
                    )

In [5]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [33]:
def calculate_loss(self, inputs, labels):
    logits = self.forward(inputs)
    loss = self.criterion(logits, labels.long())
    y_prob = torch.softmax(logits, dim=1)[:, 1]
    y_hat = torch.argmax(y_prob, dim=1)
    return loss, y_prob, y_hat, logits


In [34]:
import numpy as np
import torch
from torch.utils.data import Dataset
from pathlib import Path
import tifffile as tfl
import torchvision.transforms as T



class CellCycleData(Dataset):
    """
    Dataset class for the single cell dataset
    """

    def __init__(
        self,
        img_dir=None,
    ):
        # Set all input args as attributes
        self.__dict__.update(locals())
        self.img_dir = Path(img_dir)

        # Get all the image files
        img_files = list(self.img_dir.glob("**/*/*/*.tif"))
        self.img_files = [i for i in img_files if ("NotKnown" not in str(i)) and
                          ("G1-S" not in str(i)) and
                          ("S-G2" not in str(i)) and
                          ("G2-M" not in str(i))]

        # Get all the labels
        labels = [str(x.parent.name) for x in self.img_files]
        self.label_dict = {"G1": 0,
                           # "G1-S": 10,
                           "S": 1,
                           # "S-G2": 30,
                           "G2": 2,
                           # "G2-M": 50,
                           "MorG1": 3}

        self.labels = [self.label_dict[x] for x in labels]
        self.track_ids = [str(x.parent.parent.name) for x in self.img_files]
        self.slide_ids = [str(x.parent.parent.parent.name) for x in self.img_files]

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_path = self.img_files[idx]
        img = tfl.imread(img_path)
        # v_min, v_max = img.min(), img.max()
        # new_min, new_max = 0, 255
        img = img/(img.max() + 1e-5)
        # if (img > 0).sum() > (img > -1).sum():
        label = torch.tensor(self.labels[idx])
        # else:
        transform = T.Compose([T.ToTensor(),
                               T.RandomHorizontalFlip(p=0.5),
                               T.RandomVerticalFlip(p=0.5),
                               T.RandomRotation(degrees=90),
                               T.RandomPerspective(distortion_scale=0.5, p=0.5),
                               T.Resize((64, 64))
                               ])

        img = transform(img)
        img = img.expand(3, *img.shape[1:]).type(torch.FloatTensor)

        track_id = self.track_ids[idx]
        slide_id = self.slide_ids[idx]


        return img, label, track_id, slide_id

In [35]:
dset = CellCycleData("/media/mvries/Derek_Jeeters/PCNA_cell_cycle_marker/data_analysis/CycleData")

In [36]:
model.load_from_checkpoint("/home/mvries/Documents/GitHub/MorphoCycle/"\
                           "logs/Coatnet_MorphoCycle/qwkc6lr8/checkpoints/epoch=18-step=336452.ckpt")

COATNet(
  (criterion): CrossEntropyLoss()
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (dow

In [37]:
model.cuda()
model.eval()

COATNet(
  (criterion): CrossEntropyLoss()
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (dow

In [38]:
from torch.utils.data import DataLoader

dload = DataLoader(dset, batch_size=1, shuffle=False)

In [39]:
from tqdm import tqdm

all_results = []

for i, d in tqdm(enumerate(dload)):
#     print(d[0].isnan().any())

    logits = model(d[0].cuda())
    y_prob = torch.softmax(logits, dim=1)
    y_hat = torch.argmax(y_prob, dim=1)
    all_results.append({
            "logits": logits.detach().cpu().numpy(),
            "Y_prob": y_prob.detach().cpu().numpy(),
            "Y_hat": y_hat.detach().cpu().numpy(),
            "label": d[1],
        })

143038it [37:35, 63.41it/s]


In [18]:
y_hats = [int(i['Y_hat']) for i in all_results]
y_true = [int(i['label']) for i in all_results]

In [22]:
from sklearn.metrics import accuracy_score, confusion_matrix

accuracy_score(y_true, y_hats)

0.13087086290001496

In [30]:
np.unique(y_true)

array([0, 1, 2, 3])

In [24]:
confusion_matrix(y_true, y_hats, normalize='true')

array([[0.68573368, 0.07305936, 0.12717512, 0.08361101, 0.03042083],
       [0.65293261, 0.07351913, 0.17649544, 0.07954943, 0.0175034 ],
       [0.64711085, 0.07670055, 0.17396281, 0.08118542, 0.02104036],
       [0.67039246, 0.0666741 , 0.12384324, 0.09404616, 0.04504404],
       [0.        , 0.        , 0.        , 0.        , 0.        ]])

In [46]:
tfl.imsave("test_input.tif", torch.squeeze(d[0]).numpy())

In [48]:
y_hat

tensor([1], device='cuda:0')

In [25]:
img = tfl.imread("/media/mvries/Derek_Jeeters/PCNA_cell_cycle_marker/data_analysis/C_6_added_as_no_00045/trackid_4/MorG1/timepoint_239.tif")

In [30]:
img = np.clip(img, 0, None)
img = img/(img.max() + 1e-5)

(img > 0).sum() 

1873

In [31]:
(0.5 * img.shape[0] * img.shape[1])

2177.5

In [12]:
max_probs = torch.stack([x["Y_hat"] for x in all_results])

In [20]:
max_probs

tensor([[6, 4, 3,  ..., 1, 3, 6],
        [6, 6, 6,  ..., 6, 6, 6],
        [6, 6, 6,  ..., 6, 6, 6],
        ...,
        [6, 6, 6,  ..., 6, 6, 6],
        [6, 6, 6,  ..., 6, 6, 6],
        [6, 6, 6,  ..., 6, 6, 6]])

In [74]:
y_hat = torch.argmax(y_prob, dim=1)

In [19]:
y_hat.shape

torch.Size([256])

In [31]:
torch.cat((torch.ones((256)), torch.ones((256)), torch.zeros((120))), dim=0)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 

In [77]:
import torch.nn as nn

criterion = nn.CrossEntropyLoss()

In [78]:
criterion(logits, d[1].double().long())

tensor(2.1008, grad_fn=<NllLossBackward0>)

In [81]:
logits.mean()

tensor(-0.0108, grad_fn=<MeanBackward0>)