In [None]:
#imports 

#PyTorch
import torch
from torch.nn import functional as F
from torch import nn

#general
from glob import glob
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

#imaging
from monai import transforms
import nibabel

#custom
from model128 import *
from dataset_visualize import *

In [None]:
transforms_monai = transforms.Compose(
    [transforms.AddChannel(), transforms.ToTensor(),]
)

In [None]:
class aedataset_T1(torch.utils.data.Dataset):
    def __init__(self, datafile, transforms):
        """
        Provide list of files for unbiased brain
        """
        self.unbiased_brain = datafile
        self.transforms = transforms


    def __len__(self):
        return len(self.unbiased_brain)

    def __getitem__(self, idxx=int):
        img = nib.load(self.unbiased_brain[idxx])
        img = img.get_fdata()
        img = (img - img[img != 0].mean()) / img[img != 0].std()
        img = self.transforms(img)

        img = img.type(torch.float)
        MRI_name = self.unbiased_brain[idxx]

        return img, MRI_name


In [None]:
T1 = glob("/T1/*")

In [None]:
#T1 ckpt

T1_ckpt = "T1.ckpt"

In [None]:
device = torch.device("cuda:7")

In [None]:
model_T1 = engine_AE(0.001)

In [None]:
model_T1 = model_T1.to(device)

In [None]:
checkpoint = torch.load(T1_ckpt, map_location=device)
model_T1.load_state_dict(checkpoint["state_dict"])

In [None]:
T1_ds = aedataset_T1(
    datafile=T1,
    transforms=transforms_monai,
)

In [None]:
T1_dataloader = torch.utils.data.DataLoader(
    T1_ds, batch_size=16, pin_memory=True, num_workers=32, shuffle=False
)

In [None]:
bottle_neck = []
img_names = []
mode_T1 = model_T1.eval()
with torch.no_grad():
    for data in tqdm(T1_dataloader , total=len(T1_dataloader)):
        img, name = data
        img = img.to(device)
        recon, lin1 = model_T1(img)
        bottle_neck.extend(lin1)
        img_names.extend(name)

In [None]:
len(bottle_neck)

In [None]:
T1_128 = [x.cpu().numpy() for x in bottle_neck]

In [None]:
import pickle

In [None]:
with open('T1_128.pkl', 'wb') as f:
    pickle.dump(T1_128 , f)

In [None]:
with open("T1_img_names", "wb") as f:
    pickle.dump(img_names, f)