In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import transforms
import pandas as pd
import numpy as np
import os
import pickle
from matplotlib import pyplot as plt
%matplotlib inline

In [2]:
NUM_EPOCH = 100
# NUM_CLASSES = 10
BATCH_SIZE = 100
LEARNING_RATE = 0.01

In [3]:
data_file = 'data/ecog_mel60_bi_dataset.pkl'
with open(data_file, 'rb') as f:
    dataset = pickle.load(f)

In [4]:
train_loader = torch.utils.data.DataLoader(dataset=dataset[0],
                                           batch_size=BATCH_SIZE, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=dataset[1],
                                          batch_size=BATCH_SIZE, 
                                          shuffle=True)

In [5]:
# class BinarizedF(Function):
#     def forward(self, input):
#         self.save_for_backward(input)
#         a = torch.ones_like(input)
#         b = -torch.ones_like(input)
#         output = torch.where(input>=0,a,b)
#         return output
#     def backward(self, output_grad):
#         input, = self.saved_tensors
#         input_abs = torch.abs(input)
#         ones = torch.ones_like(input)
#         zeros = torch.zeros_like(input)
#         input_grad = torch.where(input_abs<=1,ones, zeros)
#         return input_grad

In [8]:
class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 5, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(5, 10, kernel_size=3, padding=1)
#         self.mp = nn.MaxPool2d(3)
        self.gap = nn.AvgPool2d(1)
        self.fc = nn.Linear(120000, 384)
#         self.fc2 = nn.Linear(120000, 6)
        
    def forward(self, x):
        x = torch.unsqueeze(x, 1)
        in_size = x.size(0)
        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(x1))
        x3 = self.gap(x2)
        x4 = x3.view(in_size, -1)
        x5 = self.fc(x4)
#         x5 = torch.squeeze(x5)
#         return (x5>0.5).float().requires_grad_()
        return x5

In [9]:
model = Net()
criterion = nn.L1Loss()
optimizer = torch.optim.Adadelta(model.parameters(), lr=LEARNING_RATE)

In [8]:
model.eval()

Net(
  (conv1): Conv2d(1, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(5, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (gap): AvgPool2d(kernel_size=1, stride=1, padding=0)
  (fc): Linear(in_features=120000, out_features=384, bias=True)
)

In [7]:
# Train the model
total_step = len(train_loader)
train_loss_list = []
test_loss_list = []
test_accuracy_list = []

for epoch in range(NUM_EPOCH):
    model.train()
    for i, (ecog, mel) in enumerate(train_loader):
#         images = images.to(device)
#         labels = labels.to(device)
        
        # Forward pass
#         outputs = model(images).double()
#         labels = labels.long()
        outputs = model(ecog)
        labels = mel
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i) % 100 == 0:
            train_loss_list.append(loss.item())
            print ('Epoch [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, NUM_EPOCH, loss.item()))
#             model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
#             with torch.no_grad():
#                 correct = 0
#                 total = 0
#                 for images, labels in test_loader:
#                     images = images.to(device)
#                     labels = labels.to(device)
#                     outputs = model(images)
#                     _, predicted = torch.max(outputs.data, 1)
#                     total += labels.size(0)
#                     correct += (predicted == labels).sum().item()

#                 print('Test Accuracy of the model on the 250 test images: {} %'.format(100 * correct / total))
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
#         data, target = Variable(data, volatile=True), Variable(target)
#         data = data.to(device)
#         target = target.to(device)
        output = model(data)
        # sum up batch loss
        target = target
        test_loss += F.l1_loss(output, target, size_average=False).data
        # get the index of the max log-probability
#         pred = output.data.max(1, keepdim=True)[1]
#         correct += pred.eq(target.data.view_as(pred)).sum()

    test_loss /= len(test_loader.dataset)
    test_loss_list.append(test_loss.item())
#     test_accuracy_list.append((100. * correct / len(test_loader.dataset)).item())
    print('\nTest set: Average loss: {:.4f}\n'.format(
        test_loss))
#             print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
#                    .format(epoch+1, NUM_EPOCH, i+1, total_step, loss.item()))


Epoch [1/100], Loss: 0.0693





Test set: Average loss: 11.9267

Epoch [2/100], Loss: 0.0311

Test set: Average loss: 12.0483

Epoch [3/100], Loss: 0.0315

Test set: Average loss: 12.0037

Epoch [4/100], Loss: 0.0310

Test set: Average loss: 11.8124

Epoch [5/100], Loss: 0.0309

Test set: Average loss: 11.7603

Epoch [6/100], Loss: 0.0306

Test set: Average loss: 11.4218

Epoch [7/100], Loss: 0.0298

Test set: Average loss: 11.3773

Epoch [8/100], Loss: 0.0298

Test set: Average loss: 11.1477

Epoch [9/100], Loss: 0.0291

Test set: Average loss: 11.0589

Epoch [10/100], Loss: 0.0287

Test set: Average loss: 10.8936

Epoch [11/100], Loss: 0.0280

Test set: Average loss: 10.8059

Epoch [12/100], Loss: 0.0276

Test set: Average loss: 10.4985

Epoch [13/100], Loss: 0.0277

Test set: Average loss: 10.4486

Epoch [14/100], Loss: 0.0270

Test set: Average loss: 10.3010

Epoch [15/100], Loss: 0.0273

Test set: Average loss: 10.1252

Epoch [16/100], Loss: 0.0262

Test set: Average loss: 10.0383

Epoch [17/100], Loss: 0.0265


In [None]:
img = dataset[1][1]
img_arr = np.asarray([t.numpy() for t in img])
img = Variable(torch.from_numpy(img_arr[0])).view(1,1,96,125)
img = Variable(img).view(1,96,125)

In [None]:
mel_out = model(img)

In [None]:
mel_out_np = mel_out.numpy()
mel_out_np = mel_out_np.reshape(128,3)
# np.save('data/mel_out_np.npy', mel_out_np)

In [None]:
from matplotlib import pyplot as plt
%matplotlib inline 

In [None]:
plt.imshow(mel_out_np, interpolation='nearest')
plt.show()