In [1]:
#imports 

#PyTorch
import torch
from torch.nn import functional as F
from torch import nn

#general
from glob import glob
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

#imaging
from monai import transforms
import nibabel

#custom
from model128 import *
from dataset_visualize import *

In [2]:
transforms_monai = transforms.Compose(
    [transforms.AddChannel(), transforms.ToTensor(),]
)

In [3]:
class aedataset_T2(torch.utils.data.Dataset):
    def __init__(self, datafile, transforms):
        """
        Provide list of files for unbiased brain
        """
        self.unbiased_brain = datafile
        self.transforms = transforms


    def __len__(self):
        return len(self.unbiased_brain)

    def __getitem__(self, idxx=int):
        img = nib.load(self.unbiased_brain[idxx])
        img = img.get_fdata()
        img = (img - img[img != 0].mean()) / img[img != 0].std()
        img = self.transforms(img)

        img = img.type(torch.float)
        MRI_name = self.unbiased_brain[idxx]

        return img, MRI_name



In [4]:
T2 = glob("/T2/*")

In [6]:
#T2 ckpt
T2_ckpt = "T2.ckpt"

In [7]:
device = torch.device("cuda:6")

In [9]:
T2_ds = aedataset_T2(
    datafile=T2,
    transforms=transforms_monai,
)

In [10]:
T2_dataloader = torch.utils.data.DataLoader(
    T2_ds, batch_size=16, pin_memory=True, num_workers=16, shuffle=False
)

In [11]:
device = torch.device("cuda:1")
model_T2 = engine_AE(0.001)

model_T2 = model_T2.to(device)

checkpoint = torch.load(T2_ckpt, map_location=device)
model_T2.load_state_dict(checkpoint["state_dict"])

<All keys matched successfully>

In [12]:
bottle_neck_T2 = []
img_names_T2 = []
model_T2 = model_T2.eval()
with torch.no_grad():
    for data in tqdm(T2_dataloader , total=len(T2_dataloader)):
        img, name = data
        img = img.to(device)
        recon, lin1 = model_T2(img)
        bottle_neck_T2.extend(lin1.cpu().numpy())
        img_names_T2.extend(name)

100%|██████████| 2703/2703 [1:03:11<00:00,  1.40s/it]


In [15]:
import pickle

In [16]:
with open('T2_128.pkl', 'wb') as f:
    pickle.dump(bottle_neck_T2, f)

In [17]:
with open("T2_img_names", "wb") as f:
    pickle.dump(img_names_T2, f)