### Import Libraries

In [1]:
import numpy as np
import argparse
import os
from sklearn import preprocessing
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import torchvision
import matplotlib.pyplot as plt

### Load the dataset

In [42]:
# load the dataset
fmri_dataset = torch.load('data/fMRI_data/demo1/digits-fmri')
# load the labels
labels = torch.load('data/images/demo1/raw_imgs/digits-labels') - 1
# print the shape
print(fmri_dataset.shape)
print(labels.shape)


(100, 3092)
(100, 1)


In [43]:
total_blocks = fmri_dataset.shape[0]
fmri_size = fmri_dataset.shape[1]
print('total blocks : '+str(total_blocks))
print('input fmri size : '+str(fmri_size))

total blocks : 100
input fmri size : 3092


In [44]:
print(labels)

[[0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]]


In [45]:
n_epochs = 100
batch_size = 10
test_size = 10

In [46]:
# define the model


class Net(nn.Module):
  """
  Initialize MLP Network
  """

  def __init__(self, actv, input_feature_num, hidden_unit_nums, output_feature_num):
    """
    Initialize MLP Network parameters

    Args:
      actv: string
        Activation function
      input_feature_num: int
        Number of input features
      hidden_unit_nums: list
        Number of units in the hidden layer
      output_feature_num: int
        Number of output features

    Returns:
      Nothing
    """
    super(Net, self).__init__()
    self.input_feature_num = input_feature_num # Save the input size for reshaping later
    self.model = nn.Sequential() # Initialize layers of MLP

    in_num = input_feature_num # Initialize the temporary input feature to each layer
    for i in range(len(hidden_unit_nums)): # Loop over layers and create each one

      out_num = hidden_unit_nums[i] # Assign the current layer hidden unit from list
      layer = nn.Linear(in_num, out_num) # Use nn.Linear to define the layer

      in_num = out_num # Assign next layer input using current layer output
      self.model.add_module('Linear_%d'%i, layer) # Append layer to the model with a name

      actv_layer = eval('nn.%s'%actv) # Assign activation function (eval allows us to instantiate object from string)
      self.model.add_module('Activation_%d'%i, actv_layer) # Append activation to the model with a name

    out_layer = nn.Linear(in_num, output_feature_num) # Create final layer
    self.model.add_module('Output_Linear', out_layer) # Append the final layer

  def forward(self, x):
    """
    Simulate forward pass of MLP Network

    Args:
      x: torch.tensor
        Input data

    Returns:
      logits: Instance of MLP
        Forward pass of MLP
    """
    # Reshape inputs to (batch_size, input_feature_num)
    # Just in case the input vector is not 2D, like an image!
    x = x.view(x.shape[0], -1)
    logits = self.model(x) # Forward pass of MLP
    return logits


In [47]:
hidden_layers_shape = [1024, 64]
activation = 'Tanh()'
net = Net(actv=activation, input_feature_num=fmri_size, hidden_unit_nums=hidden_layers_shape, output_feature_num=2)
# y = net()
# print(f'The output shape is {y.shape} for an input of shape {fmri_size.shape}')

In [53]:
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
fmri = torch.from_numpy(fmri_dataset)
fmri = fmri.type(Tensor)
print(type(fmri_dataset))
net.forward(fmri)

<class 'numpy.ndarray'>


tensor([[-0.0009, -0.1368],
        [-0.0017, -0.1358],
        [-0.0039, -0.1356],
        [-0.0193, -0.1315],
        [-0.0122, -0.1353],
        [-0.0062, -0.1308],
        [-0.0084, -0.1273],
        [-0.0122, -0.1331],
        [-0.0091, -0.1342],
        [-0.0075, -0.1321],
        [-0.0110, -0.1288],
        [-0.0114, -0.1314],
        [-0.0102, -0.1326],
        [-0.0132, -0.1323],
        [-0.0139, -0.1305],
        [-0.0068, -0.1322],
        [-0.0108, -0.1373],
        [-0.0033, -0.1332],
        [-0.0056, -0.1255],
        [-0.0009, -0.1313],
        [-0.0084, -0.1315],
        [-0.0042, -0.1268],
        [-0.0097, -0.1282],
        [-0.0134, -0.1256],
        [-0.0166, -0.1348],
        [-0.0085, -0.1321],
        [-0.0129, -0.1314],
        [-0.0121, -0.1336],
        [-0.0096, -0.1286],
        [-0.0138, -0.1320],
        [-0.0124, -0.1308],
        [-0.0116, -0.1308],
        [-0.0149, -0.1305],
        [-0.0074, -0.1287],
        [-0.0096, -0.1276],
        [-0.0076, -0