In [1]:
import torch
import torch.nn as nn

import pandas as pd
import numpy as np
import random
import os
import re

from torch.utils.data import DataLoader
from pytorch_lightning import Trainer, seed_everything

from model import MRI_LSTM
from data_loader import DataGenerator

import sklearn.metrics as sm
import scipy.stats as ss


from plotly.offline import init_notebook_mode, iplot
import plotly.graph_objs as go
import matplotlib.pyplot as plt
init_notebook_mode(connected=True) # initiate notebook for offline plot

In [2]:
limit = 80 # limit epochs
save_loss  = 9000.01
save_state = ""

for sfile in os.listdir("./ckpt_final/"):
    re_str = r"(weights\.epoch=(?P<epoch>.*)-val_loss=(?P<loss>.*)\.h5\.ckpt)"
    matched = re.match(re_str, sfile)
    epoch = int(matched["epoch"])
    valid_loss = float(matched["loss"])
    
    if epoch < limit and valid_loss < save_loss:
        save_loss = valid_loss
        save_state = sfile

In [3]:
print("Best save_state: {}".format(save_state))
print("Lowest valid_loss: {}".format(save_loss))

Best save_state: weights.epoch=33-val_loss=13.07878104.h5.ckpt
Lowest valid_loss: 13.07878104


In [2]:
def Encoder_Block(in_channels, out_channels):
    return nn.Sequential( 
        nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=1),
        nn.InstanceNorm2d(out_channels),
        nn.MaxPool2d(2, stride=2),
        nn.ReLU()
    )

In [3]:
class MRI_LSTM(nn.Module):
    
    
    def __init__(self):
        super(MRI_LSTM, self).__init__()
        
        self.feat_embed_dim = 2
        self.latent_dim = 128
        
        # Build Encoder
        encoder_blocks = []
        encoder_blocks.append(Encoder_Block(1, 32))
        encoder_blocks.append(Encoder_Block(32, 64))
        encoder_blocks.append(Encoder_Block(64, 128))
        encoder_blocks.append(Encoder_Block(128, 256))
        encoder_blocks.append(Encoder_Block(256, 256))
        self.encoder = nn.Sequential(*encoder_blocks)
        
        # Post processing
        self.post_proc = nn.Sequential(
            nn.Conv2d(256, 64, 1, stride=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(),
            nn.AvgPool2d([3,2]),
            nn.Dropout(p=0.5),
            nn.Conv2d(64, self.feat_embed_dim, 1)
        )
        
        # Connect w/ LSTM
        self.n_layers = 1
        self.lstm = nn.LSTM(
            self.feat_embed_dim, self.latent_dim, self.n_layers, batch_first=True)
        
        # Build regressor
        self.lstm_post = nn.Linear(self.latent_dim, 64)
        self.regressor = nn.Sequential(
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    
    
    def init_hidden(self, batch):
        weight = next(self.parameters()).data
        h_0 = weight.new(self.n_layers, batch.size(0), self.latent_dim).zero_()
        c_0 = weight.new(self.n_layers, batch.size(0), self.latent_dim).zero_()
        h_0.requires_grad=True
        c_0.requires_grad=True
        return h_0, c_0
    
    
    def encode(self, x, h_t, c_t):
        B, C, H, W, D = x.size()
        for i in range(H):
            out = self.encoder(x[:, :, i, :, :])
            out = self.post_proc(out)
            out = out.view(B, 1, self.feat_embed_dim)
            h_t = h_t.view(1, B, self.latent_dim)
            c_t = c_t.view(1, B, self.latent_dim)
            h_t, (_, c_t) = self.lstm(out, (h_t, c_t))
        encoding = h_t.view(B, self.latent_dim)
        return encoding
    
    
    def forward(self, batch, h_0, c_0):
        x, y_true = batch
        x, y_true = x.cuda(), y_true.cuda()  # ???
        embedding = self.encode(x, h_0, c_0)
        post = self.lstm_post(embedding)
        y_pred = self.regressor(post)
        return y_pred, embedding

In [4]:
DEVICE = 5
# Restrict GPUs to those not being used by someone else
os.environ["CUDA_VISIBLE_DEVICES"] = str(DEVICE)

NUM_WORKERS = 5
BATCH_SIZE = 16

LOAD_PATH = "./ckpt_final/"
TEST_PATH = "../sipam_test.csv"

In [5]:
test_scans = DataGenerator(TEST_PATH, data_col="9dof_2mm_vol")
test_loader = DataLoader(
    test_scans, batch_size=BATCH_SIZE, shuffle=False, 
    num_workers=NUM_WORKERS)

In [6]:
model_path = "./ckpt_final/"
model_path += save_state
state_dict = torch.load(model_path, map_location="cuda:0")

In [7]:
state_dict = state_dict["state_dict"]

In [8]:
model = MRI_LSTM().to(0)
model.load_state_dict(state_dict)
model.eval()

MRI_LSTM(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): ReLU()
    )
    (3): Sequential(
      (0): Conv2d(128, 256, ker

In [9]:
results = {
    "true_ages":[], 
    "pred_ages":[],
    "embeds": []
}

print("Testing started!")
for batch_idx, batch in enumerate(test_loader):
    X_in, y = batch
    X_in  = X_in.to(device=0, dtype=torch.float)
    
    h_0, c_0 = model.init_hidden(X_in)
    y_pred, embedding = model(batch, h_0, c_0)
    
    results["true_ages"].append(y.cpu().detach().numpy())
    results["pred_ages"].append(y_pred.cpu().detach().numpy())
    
    print("Done processing {}/{} batches".format(batch_idx+1, len(test_loader)))

Testing started!
Done processing 1/138 batches
Done processing 2/138 batches
Done processing 3/138 batches
Done processing 4/138 batches
Done processing 5/138 batches
Done processing 6/138 batches
Done processing 7/138 batches
Done processing 8/138 batches
Done processing 9/138 batches
Done processing 10/138 batches
Done processing 11/138 batches
Done processing 12/138 batches
Done processing 13/138 batches
Done processing 14/138 batches
Done processing 15/138 batches
Done processing 16/138 batches
Done processing 17/138 batches
Done processing 18/138 batches
Done processing 19/138 batches
Done processing 20/138 batches
Done processing 21/138 batches
Done processing 22/138 batches
Done processing 23/138 batches
Done processing 24/138 batches
Done processing 25/138 batches
Done processing 26/138 batches
Done processing 27/138 batches
Done processing 28/138 batches
Done processing 29/138 batches
Done processing 30/138 batches
Done processing 31/138 batches
Done processing 32/138 batches


In [10]:
results["true_ages"] = np.concatenate(results["true_ages"]).ravel()
results["pred_ages"] = np.concatenate(results["pred_ages"]).ravel()

In [11]:
# Save stats
stats = {}
stats["rmse"]  = [sm.mean_squared_error(results["true_ages"], results["pred_ages"])**0.5]
stats["mae"]   = [sm.mean_absolute_error(results["true_ages"], results["pred_ages"])]
stats["corr"]  = [ss.pearsonr(results["true_ages"], results["pred_ages"])[0]]
stats["rcorr"] = [ss.pearsonr(results["true_ages"], results["pred_ages"] - results["true_ages"])[0]]

# Print stats
for k, v in stats.items():
    print("{}: {}".format(k, v[0]))

rmse: 3.619030476770937
mae: 2.860799789428711
corr: 0.8764812511107893
rcorr: -0.33386457905011546


In [13]:
trace1 = go.Scatter(
    x=results["true_ages"],
    y=results["pred_ages"],
    mode="markers",
    name="Embeds"
)

d = [trace1]

layout = dict(title="Sagittal-Slice LSTM Results", xaxis=dict(title="True Ages"), yaxis=dict(title="Predicted Ages"))

fig = dict(data=d, layout=layout)

iplot(fig, filename="Distributions")