In [1]:
import sys
sys.path.append('../')
from torchvision.utils import make_grid
from torch.optim import Adam
from torch.autograd import Variable
import torch.cuda as cuda
import torch.nn as nn
import torch
from torch.distributions import Normal
from mlp.datasets import QSM_slices
import os
import numpy as np
import util
import sklearn.preprocessing as skp
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

  warn(f"Failed to load image Python extension: {e}")


In [2]:
# Get case IDs
case_list = open('/home/ali/RadDBS-QSM/data/docs/cases_90','r')
lines = case_list.read()
lists = np.loadtxt(case_list.name,comments="#", delimiter=",",unpack=False,dtype=str)
case_id = []
for lines in lists:     
    case_id.append(lines[-9:-7])

# Load scores
file_dir = '/home/ali/RadDBS-QSM/data/docs/QSM anonymus- 6.22.2023-1528_wldd.csv'
motor_df = util.filter_scores(file_dir,'pre-dbs updrs','stim','pre op levadopa equivalent dose (mg)','CORNELL ID')
# Find cases with all required scores
subs,pre_imp,post_imp,pre_updrs_off,ledd = util.get_full_cases(motor_df,
                                                          'CORNELL ID',
                                                          'OFF (pre-dbs updrs)',
                                                          'ON (pre-dbs updrs)',
                                                          'OFF meds ON stim 6mo',
                                                          'pre op levadopa equivalent dose (mg)')
print(ledd)
# Load extracted features
npy_dir = '/home/ali/RadDBS-QSM/data/npy/'
phi_dir = '/home/ali/RadDBS-QSM/data/phi/phi/'
roi_path = '/data/Ali/atlas/mcgill_pd_atlas/PD25-subcortical-labels.csv'
n_rois = 6
Phi_all, X_all, R_all, K_all, ID_all = util.load_featstruct(phi_dir,npy_dir+'X/',npy_dir+'R/',npy_dir+'K/',n_rois,1595,False)
ids = np.asarray(ID_all).astype(int)

# Find overlap between scored subjects and feature extraction cases
c_cases = np.intersect1d(np.asarray(case_id).astype(int),np.asarray(subs).astype(int))
# Complete case indices with respect to feature matrix
c_cases_idx = np.in1d(ids,c_cases)
X_all_c = X_all[c_cases_idx,0:4,:]
K_all_c = K_all[c_cases_idx,0:4,:]
R_all_c = R_all[c_cases_idx,0:4,:]
# Re-index the scored subjects with respect to complete cases
s_cases_idx = np.in1d(subs,ids[c_cases_idx])
subs_init = subs[s_cases_idx]
pre_imp_init = pre_imp[s_cases_idx]
post_imp_init = post_imp[s_cases_idx]
pre_updrs_off_init = pre_updrs_off[s_cases_idx]
ledd_init = ledd[s_cases_idx]
per_change_init = post_imp_init
subs = np.asarray(ID_all,dtype=float)[np.in1d(np.asarray(ID_all,dtype=float),subs_init)]
subs0 = subs_init
pre_imp = np.zeros((1,len(subs))).T
post_imp = np.zeros((1,len(subs))).T
pre_updrs_off = np.zeros((1,len(subs))).T
ledd = np.zeros((1,len(subs))).T
per_change = np.zeros((1,len(subs))).T
for j in np.arange(len(subs)):
    pre_imp[j] = pre_imp_init[subs_init == subs[j]]
    post_imp[j] = post_imp_init[subs_init == subs[j]]
    pre_updrs_off[j] = pre_updrs_off_init[subs_init == subs[j]]
    ledd[j] = ledd_init[subs_init == subs[j]]
    per_change[j] = per_change_init[subs_init == subs[j]]

subsc = subs
X_all_c = X_all_c.reshape(X_all_c.shape[0],-1)
X_all_c = np.append(X_all_c,pre_updrs_off,axis=1)
X_all_c = np.append(X_all_c,ledd,axis=1)
print(np.unique(R_all_c))

       Unnamed: 0                PRE-OP           Unnamed: 2  \
0      CORNELL ID  Apathy Off (pre-dbs)  Apathy ON (pre-dbs)   
1              67                    na                   13   
2   only Ct data                     na                   na   
3              74                    na                   na   
4              84                    na                   22   
..            ...                   ...                  ...   
87             52                   NaN                  NaN   
88             53                   NaN                  NaN   
89             54                   NaN                  NaN   
90             55                   NaN                  NaN   
91             56                   NaN                  NaN   

             Unnamed: 3          Unnamed: 4    Unnamed: 5  \
0   OFF (pre-dbs updrs)  ON (pre-dbs updrs)  mri (pre-op)   
1                    60                  41      3/9/2020   
2                    43                  12     

In [22]:
data_dir = os.listdir('../mlp/tensor_slices_0')
scaler = skp.StandardScaler()
X = scaler.fit_transform(X_all_c)
train_dataset = QSM_slices(data_dir=data_dir,aug_state=1, factor=0, X=X,subsc=subsc,targets=per_change,prefix='../mlp/')
data_loader = DataLoader(train_dataset, batch_size=X.shape[0], shuffle=True)
img_size = 64 # Image size
batch_size = 45  # Batch size

# Model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
z_size = 128
generator_layer_size = [256, 512, 1024]
discriminator_layer_size = [1024, 512, 256]
class_num = 0
N = len(np.unique(np.round(per_change,1)))
# Training
epochs = 30
learning_rate = 1e-3

In [23]:
class Generator(nn.Module):
    def __init__(self, generator_layer_size, z_size, img_size, class_num):
        super().__init__()
        
        self.z_size = z_size
        self.img_size = img_size
        
        self.label_emb = nn.Embedding(class_num, class_num)
        
        self.model = nn.Sequential(
            nn.Linear(self.z_size + class_num, generator_layer_size[0]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(generator_layer_size[0], generator_layer_size[1]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(generator_layer_size[1], generator_layer_size[2]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(generator_layer_size[2], self.img_size * self.img_size),
            nn.Tanh()
        )
    
    def forward(self, z, labels):
        
        # Reshape z
        z = z.view(-1, self.z_size)
        # One-hot vector to embedding vector
        c = labels.reshape(-1,1)
        # Concat image & label
        x = z
        #x = torch.cat([z, c], 1)
        # Generator out
        out = self.model(x)
        
        return out.view(-1, self.img_size, self.img_size)

In [24]:
class Discriminator(nn.Module):
    def __init__(self, discriminator_layer_size, img_size, class_num):
        super().__init__()
        
        self.label_emb = nn.Embedding(class_num, class_num)
        self.img_size = img_size
        
        self.model = nn.Sequential(
            nn.Linear(self.img_size * self.img_size + class_num, discriminator_layer_size[0]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(discriminator_layer_size[0], discriminator_layer_size[1]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(discriminator_layer_size[1], discriminator_layer_size[2]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(discriminator_layer_size[2], 1),
            nn.Sigmoid()
        )
    
    def forward(self, x, labels):
        
        # Reshape fake image
        x = x.view(-1, self.img_size * self.img_size)
        # One-hot vector to embedding vector
        c = labels.reshape(-1,1)

        # Concat image & label
        #x = torch.cat([x, c], 1)

        # Discriminator out
        out = self.model(x)
        
        return out.squeeze()

In [25]:
# Define generator
generator = Generator(generator_layer_size, z_size, img_size, class_num).to(device)
# Define discriminator
discriminator = Discriminator(discriminator_layer_size, img_size, class_num).to(device)
# Loss function
criterion = nn.BCELoss()

In [26]:
# Optimizer
g_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)

In [27]:
def generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion):
    
    # Init gradient
    g_optimizer.zero_grad()
    
    # Building z
    z = Variable(torch.randn(batch_size, z_size)).to(device)
    
    # Building fake labels
    fake_labels = Variable(torch.LongTensor(np.random.randint(0, class_num, batch_size))).to(device)
    
    # Generating fake images
    fake_images = generator(z, fake_labels)
    
    # Disciminating fake images
    validity = discriminator(fake_images, fake_labels)
    
    # Calculating discrimination loss (fake images)
    g_loss = criterion(validity, Variable(torch.ones(batch_size)).to(device))
    
    # Backword propagation
    g_loss.backward()
    
    #  Optimizing generator
    g_optimizer.step()
    
    return g_loss.data

In [28]:
def discriminator_train_step(batch_size, discriminator, generator, d_optimizer, criterion, real_images, labels):
    
    # Init gradient 
    d_optimizer.zero_grad()

    # Disciminating real images
    real_validity = discriminator(real_images, labels)
    
    # Calculating discrimination loss (real images)
    real_loss = criterion(real_validity, Variable(torch.ones(batch_size)).to(device))
    
    # Building z
    z = Variable(torch.randn(batch_size, z_size)).to(device)
    
    # Building fake labels
    fake_labels = Variable(torch.LongTensor(np.random.randint(0, class_num, batch_size))).to(device)
    
    # Generating fake images
    fake_images = generator(z, fake_labels)
    
    # Disciminating fake images
    fake_validity = discriminator(fake_images, fake_labels)
    
    # Calculating discrimination loss (fake images)
    fake_loss = criterion(fake_validity, Variable(torch.zeros(batch_size)).to(device))
    
    # Sum two losses
    d_loss = real_loss + fake_loss
    
    # Backword propagation
    d_loss.backward()
    
    # Optimizing discriminator
    d_optimizer.step()
    
    return d_loss.data

In [29]:

for epoch in range(epochs):
    
    print('Starting epoch {}...'.format(epoch+1))
    
    for i, (images, X, labels) in enumerate(data_loader):
        
        # Train data
        real_images = Variable(images).to(device)
        labels = torch.zeros_like(labels)
        labels = Variable(labels).to(device)
        
        # Set generator train
        generator.train()
        
        # Train discriminator
        d_loss = discriminator_train_step(len(real_images), discriminator,
                                          generator, d_optimizer, criterion,
                                          real_images, labels)
        
        # Train generator
        g_loss = generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion)
    
    # Set generator eval
    generator.eval()
    
    print('g_loss: {}, d_loss: {}'.format(g_loss, d_loss))
    
    # Building z 
    z = Variable(torch.randn(len(data_loader.dataset), z_size)).to(device)

    # Labels
    labels = Variable(torch.Tensor((np.round(per_change.reshape(-1,1),1)))).to(device)
    labels = torch.zeros_like(labels)
    # Generating images
    print('Passing sample images to generator:')
    sample_images = generator(z, labels).unsqueeze(1).data.cpu()


# Show images
fig, ax = plt.subplots(figsize=(20, 20))
grid = make_grid(torch.vstack((250*sample_images,real_images.cpu())), nrow=9, normalize=True).permute(1,2,0).numpy()
ax.imshow(np.rot90(grid))
fig.show()


Starting epoch 1...


ValueError: high <= 0