In [1]:
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
from PIL import Image
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import math
from functools import partial
from torch.utils.data import Dataset, DataLoader
from torchvision import utils, transforms
from glob import glob
import random
import os
from siamese_script import resnet101, normalize_data
import pickle
# import horovod.torch as hvd

In [2]:
use_gpu = True

In [3]:
model = resnet101()
if use_gpu:
    model = model.cuda()
    checkpoint = torch.load(f'/nethome/asilva9/brains/model_checkpoints/siamese_final.pth.tar')
else:
    checkpoint = torch.load(f'/nethome/asilva9/brains/model_checkpoints/siamese_final.pth.tar', map_location='cpu')

model.load_state_dict(checkpoint['model_state_dict'])

sample duration 218
last duration 12
Last size 6


<All keys matched successfully>

In [4]:
class EmbedNIFTIDataset(Dataset):
    """
    create a dataset class in PyTorch for reading NIfTI files
    Args:
        source_dir (str): path to images
        transform (Callable): transform to apply to images (Probably None or ToTensor)
        preload (bool): load all data when initializing the dataset
    """

    def __init__(self, data_path, transform=None):
        self.data_fns = np.loadtxt(data_path, delimiter=',')
        self.transform = transform
        self.data_len = len(self.data_fns)

    def __len__(self):
        return self.data_len

    def __getitem__(self, idx):
        vol_in = self.data_fns[idx]
        main_vol = nib.load(vol_in).get_fdata(dtype=np.float32)
        main_vol = np.moveaxis(main_vol, 1, 0)
        sample = main_vol
        if self.transform is not None:
            sample = self.transform(sample).unsqueeze(0)
        return sample

pos_dataset = EmbedNIFTIDataset(data_dir='/media/data/Track_2/good',
                             label=1,
                             transform=transforms.ToTensor())

neg_dataset = EmbedNIFTIDataset(data_dir="/media/data/Track_2/bad",
                                label=0,
                                transform=transforms.ToTensor())


In [5]:
pos_loader = torch.utils.data.DataLoader(pos_dataset, batch_size=1)
neg_loader = torch.utils.data.DataLoader(neg_dataset, batch_size=1)

In [6]:
embeds = []
labels = []
for i, data in enumerate(pos_loader):
    img, label = data
    if use_gpu:
        img = img.cuda()
    img = normalize_data(img)
    output = model.get_embed(img)
    embeds.append(output.detach().cpu().numpy().reshape(-1))
    labels.append(label.item())
for i, data in enumerate(neg_loader):
    img, label = data
    if use_gpu:
        img = img.cuda()
    img = normalize_data(img)
    output = model.get_embed(img)
    embeds.append(output.detach().cpu().numpy().reshape(-1))
    labels.append(label.item())

np.savetxt('embeds.csv', np.array(embeds), delimiter=',', newline='\n')
np.savetxt('labels.csv', np.array(labels), delimiter=',', newline='\n')

In [16]:
def generate_embeddings(data_path='', model_path='./siamese_final.pth.tar' data_name='all', use_gpu=True):
    dataset = EmbedNIFTIDataset(data_path=data_path,
                                label=0,
                                transform=transforms.ToTensor())
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
    model = resnet101()
    if use_gpu:
        model = model.cuda()
        checkpoint = torch.load(model_path)
    else:
        checkpoint = torch.load(model_path, map_location='cpu')
        
    model.load_state_dict(checkpoint['model_state_dict'])
    embeds = []
    for i, data in enumerate(data_loader):
        img = data
        if use_gpu:
            img = img.cuda()
        img = normalize_data(img)
        output = model.get_embed(img)
        embeds.append(output.detach().cpu().numpy().reshape(-1))
    np.savetxt(data_name+'embeds.csv', np.array(embeds), delimiter=',', newline='\n')
    return np.array(embeds)

def generate_labels(data_dir='/media/data/Track_2', data_name='all'):
    dataset = EmbedNIFTIDataset(data_dir=data_dir, label=0, transform=transforms.ToTensor())
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)

    labels = []
    for i, data in enumerate(data_loader):
        img, label = data
        if use_gpu:
            img = img.cuda()
        img = normalize_data(img)
        output = model([img])[0]
        output = F.softmax(output, dim=1)
        labels.append(1-output[0][0].item())
    np.savetxt(data_name+'_pred_labels.csv', np.array(labels), delimiter=',', newline='\n')

In [17]:
generate_labels(data_dir='/media/data/Track_2/good', data_name='good')

In [18]:
generate_labels(data_dir='/media/data/Track_2/bad', data_name='bad')