In [1]:
%matplotlib inline
import os
from matplotlib import pyplot as plt
from datetime import datetime
import numpy as np

from datetime import datetime;now = datetime.now;t00=now();

import torch
import torch.nn as nn
import torch.optim as optim

import seaborn as sns

In [2]:
from google.colab import drive
drive.mount("/content/drive")
print('done, mounted')
cd '/content/drive/My Drive/BC-MRI-AE'
print(now())

Mounted at /content/drive
done, mounted
2020-11-03 18:41:47.991837


/content/drive/My Drive/BC-MRI-AE


In [None]:
# ABIDE-64iso-normed
# SFARI-64iso-normed
print('loading data');t0 = now()
def load_data(data_dir):
  #data_dir = './Data/ABIDE-64iso-normed'
  files = [file for file in os.listdir(data_dir) if '.npy' in file]; files.sort()
  data = np.array([np.load(os.path.join(data_dir,file)) for file in files])
  return data

#data_validation = load_data('./Data/SFARI-64iso-normed');print(data_validation.shape)
#data_validation = data_validation[:,np.newaxis,:,:,:]

data = load_data('./Data/ABIDE-64iso-normed');print(data.shape)
data = data[:,np.newaxis,:,:,:]

print(f"loaded in {now()-t0}")

loading data


In [None]:
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

In [None]:
# Autoencoder class
class CAE(nn.Module):
    def __init__(self,input_shape,k=1):
        super().__init__()
        
        # Stuff
        self.input_shape = input_shape
        self.batch_size = input_shape[0]

        self.lrelu = torch.nn.LeakyReLU(negative_slope=.02)
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()
        self.tanh = torch.nn.Tanh()

        # Shapes and sizes
        #R = np.ones(shape=self.input_shape)
        #R = torch.tensor(R).float()

        bias_bool = False # Do you want padding or not 

        # Encoder
        self.enc_C1 = nn.Conv3d(in_channels=1,out_channels=int(4*k),kernel_size=5, stride=2, padding=2, bias=bias_bool)
        self.enc_C2 = nn.Conv3d(in_channels=int(4*k),out_channels=int(8*k),kernel_size=5, stride=2, padding=2, bias=bias_bool)
        self.enc_C3 = nn.Conv3d(in_channels=int(8*k),out_channels=int(16*k),kernel_size=5, stride=2, padding=2, bias=bias_bool)
        self.enc_C4 = nn.Conv3d(in_channels=int(16*k),out_channels=int(32*k),kernel_size=5, stride=2, padding=2, bias=bias_bool)
        self.enc_C5 = nn.Conv3d(in_channels=int(32*k),out_channels=int(64*k),kernel_size=5, stride=2, padding=2, bias=bias_bool)

        self.batchNormE1 = nn.BatchNorm3d(int(4*k), affine=False)
        self.batchNormE2 = nn.BatchNorm3d(int(8*k), affine=False)
        self.batchNormE3 = nn.BatchNorm3d(int(16*k), affine=False)
        self.batchNormE4 = nn.BatchNorm3d(int(32*k), affine=False)
        self.batchNormE5 = nn.BatchNorm3d(int(64*k), affine=False)

        self.batchNormD1 = nn.BatchNorm3d(int(32*k), affine=False)
        self.batchNormD2 = nn.BatchNorm3d(int(16*k), affine=False)
        self.batchNormD3 = nn.BatchNorm3d(int(8*k), affine=False)
        self.batchNormD4 = nn.BatchNorm3d(int(4*k), affine=False)
        #self.batchNormD5 = nn.BatchNorm3d(int(4*k), affine=False)

        self.dec_C1 = nn.ConvTranspose3d(in_channels=int(64*k),out_channels=int(32*k),kernel_size=4, stride=2, padding=1,bias=bias_bool)
        self.dec_C2 = nn.ConvTranspose3d(in_channels=int(32*k),out_channels=int(16*k),kernel_size=4, stride=2, padding=1,bias=bias_bool) 
        self.dec_C3 = nn.ConvTranspose3d(in_channels=int(16*k),out_channels=int(8*k),kernel_size=4, stride=2, padding=1,bias=bias_bool)
        self.dec_C4 = nn.ConvTranspose3d(in_channels=int(8*k),out_channels=int(4*k),kernel_size=4, stride=2, padding=1,bias=bias_bool)
        self.dec_C5 = nn.ConvTranspose3d(in_channels=int(4*k),out_channels=1,kernel_size=4, stride=2, padding=1,bias=bias_bool)

    def forward(self,hello):
      
        activation = self.relu( self.enc_C1(hello) )
        activation = self.batchNormE1(activation)

        activation = self.relu( self.enc_C2(activation) )
        activation = self.batchNormE2(activation)

        activation = self.relu( self.enc_C3(activation) )
        activation = self.batchNormE3(activation)

        activation = self.relu( self.enc_C4(activation) )
        activation = self.batchNormE4(activation)

        activation = self.relu( self.enc_C5(activation) )
        activation = self.batchNormE5(activation)

        activation = self.relu( self.dec_C1(activation))
        activation = self.batchNormD1(activation)

        activation = self.relu( self.dec_C2(activation) )
        activation = self.batchNormD2(activation)

        activation = self.relu( self.dec_C3(activation) )
        activation = self.batchNormD3(activation)

        activation = self.relu( self.dec_C4(activation) )
        activation = self.batchNormD4(activation)

        activation = self.sigmoid( self.dec_C5(activation) )
        #activation = self.batchNormD5(activation)

        return activation


In [None]:
## Convert data 
D = torch.tensor(data[:,:,:,:,:]).float()
data_batch = D[0:5,:,:,:,:]
#D = D[0:10,:,:,:,:]

## Model Definition
model = CAE(data_batch.shape,k=25) # Initiate mode

latentSpaceSize = model.enc_C5(model.enc_C4(model.enc_C3(model.enc_C2(model.enc_C1(data_batch))))).shape
print(f"latentSpaceSize: {latentSpaceSize}")
latentSpaceDim = np.prod(latentSpaceSize)
print(f"latentSpaceDim: {latentSpaceDim/latentSpaceSize[0]}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #  use gpu if available
device = 'cpu'
model.to(device)

# takes in a module and applies the specified weight initialization
def weights_init_uniform_rule(m):
    classname = m.__class__.__name__
    # for every Linear layer in a model..
    if classname.find('Linear') != -1:
        # get the number of the inputs
        n = m.in_features
        y = 1.0/np.sqrt(n)
        m.weight.data.uniform_(-y, y)
        m.bias.data.fill_(0)

# create a new model with these weights
model.apply(weights_init_uniform_rule)

#model.to(device)

#optimizer = optim.Adam(model.parameters(), lr=1e-3,)
optimizer = optim.Adam(model.parameters(), lr=.01,weight_decay=.01)
#optimizer = optim.SGD(model.parameters(),lr=.001,weight_decay=.0,momentum=.0)
#optimizer = optim.RMSprop(model.parameters(),lr=.1)
#criterion = nn.MSELoss()
print(device)

## Training Parameters
num_epochs = 501
batch_size = 5
ndata = data.shape[0]
n_batches = np.floor(D.shape[0]/5)
batches = np.array([np.arange((i)*5,(i+1)*5) for i in range(int(n_batches))])
#D = D.cuda()
track = list();

ofdir = os.path.join(os.path.curdir,'drive','My Drive','BC-MRI-AE','models')
print(ofdir)

session_name = 'test'

print(n_batches)

In [None]:
def myLoss(outputs,data_batch):
  return torch.sum(torch.square(outputs.view(-1)-data_batch.view(-1)))

In [None]:
int(n_batches)

In [None]:
t0 = datetime.now();
for epoch in range(int(num_epochs)):
    permutation = np.random.permutation(ndata)
    D = D[permutation,:,:,:,:]
    for batch_idx in range(int(n_batches)):
        optimizer.zero_grad()
        data_batch = D[batches[batch_idx,:],:,:,:,:]
        outputs = model.forward(data_batch)
        #train_loss = criterion(outputs,data_batch)
        train_loss = myLoss(outputs,data_batch)
        train_loss.backward()
        optimizer.step()
        track.append(train_loss.item())
        if batch_idx%50==0:
          print(f"epoch {epoch}/{num_epochs} | batch {batch_idx}/{n_batches} | time {str(datetime.now()-t0)} | loss {round(train_loss.item(),5)}")


    track = track[-min(len(track),10000)::]

    if epoch%100==0:
      ofn = os.path.join(ofdir,f'{session_name}' + f'e{epoch}_'+str(now()))
      print('saved')
      print(ofn)
      #torch.save(model.state_dict(),ofn)
     
    if epoch%5==0:
      print(outputs.view(-1))
      plt.figure()
      plt.subplot(1,2,1)
      plt.imshow(data_batch.cpu().detach().numpy()[0,0,32,:,:])
      plt.subplot(1,2,2)
      plt.imshow(outputs.cpu().detach().numpy()[0,0,32,:,:])
      plt.show()

      plt.figure(figsize=(10,3))
      plt.subplot(1,2,1)
      t = track
      plt.plot(t)
      plt.title('training acc')

      xs = np.arange(len(t))+1
      z = np.polyfit(xs, t, 1)
      p = np.poly1d(z)
      plt.subplot(1,2,2)
      plt.plot(xs,p(xs),"r--")
      plt.title('trend')


      b = outputs.cpu().detach()
      a = data_batch.cpu().detach()
      plt.figure(figsize=(10,3))
      plt.subplot(1,2,1)
      sns.distplot(np.array(b[a>0.001]).flatten())
      plt.title('Predicted')
      plt.subplot(1,2,2)
      sns.distplot(np.array(a[a>0.001]).flatten())
      plt.title('True')
      plt.show()

In [None]:
 %print('all done')