In [1]:
%load_ext autoreload

In [None]:
"""
Rainbow passage = ["When the sunlight strikes raindrops in the air", "they act as a prism and form a rainbow", "The rainbow is a division of white", "light into many beautiful colors", 
           "These take the shape of a long round arch", "with its path high above", "and its two ends apparently beyond the horizon", "There is according to legend", "a boiling pot of gold at one end",
           "People look but no one ever finds it", "When a man looks for something beyond his reach", " his friends say he is looking for the", "pot of gold at the end of the rainbow",
           "Throughout the centuries people have", "explained the rainbow in various ways", "Some have accepted it as", "a miracle without physical explanation", "To the Hebrews it was a token that there",
           "would be no more universal floods", "The Greeks used to imagine that it was", "a sign from the gods to foretell war or heavy rain", "The Norsemen considered the rainbow as a bridge", 
           "over which the gods passed from earth to their", "home in the sky Others have tried to explain", "the phenomenon physically Aristotle thought that the", "rainbow was caused by reflection of the suns",
           "rays by the rain Since then physicists have", "found that it is not reflection but refraction", "by the raindrops which causes the rainbows", "Many complicated ideas about the", "rainbow have been formed", 
           "The difference in the rainbow depends considerably", "upon the size of the drops and the width of the", "colored band increases as the size of the drops increases", 
           "The actual primary rainbow observed is", "said to be the effect of super imposition", "of a number of bows If the red of the second bow falls", 
           "upon the green of the first the result is", "to give a bow with an abnormally wide yellow band", "since red and green light when mixed form yellow", "This is a very common type of bow", 
           "one showing mainly red and yellow",  "with little or no green or blue"]

Each word was read character-by-character. Cue to read the characters were given 1.5 seconds apart. There are pauses between words (no pauses between characters). 
DATA is an array of (number characters, 22 - number channels, 7500 - number time steps).

Due to a glitch in the software a couple of characters are missing from some subjects. Labels are provided separately for each subject. 
"""

In [2]:
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from spdLearning import spdNN
from spdLearning import optimizers 
from spdLearning import trainTest
from spdLearning import spdNet

In [3]:
class BaseDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __getitem__(self, index):
        return self.data[index].astype('float32'), self.labels[index]

    def __len__(self):
        return len(self.data)

In [4]:
dev = "cpu" 
device = torch.device(dev)

In [5]:
subjectNumber = 4
subject = "Subject" + str(subjectNumber)

In [6]:
DATA = np.load("Experiment2/" + subject + "/rainbowPassage.npy")
   
mean = np.mean(DATA, axis = -1)
std = np.std(DATA, axis = -1)
DATA = (DATA - mean[..., np.newaxis])/(std[..., np.newaxis] + 1e-5)
Labels = np.load("Experiment2/" + subject + "/rainbowPassageLabels.npy")

In [7]:
numberChannels = 22
windowLength = 7500
numberAlphabets = 26

In [8]:
covarianceMatrices = np.zeros((len(Labels), numberChannels, numberChannels))

for j in range(len(Labels)):
    covarianceMatrices[j] = 1/windowLength * ((DATA[j] @ DATA[j].T))

In [9]:
testFeatures = covarianceMatrices
testLabels = Labels

In [10]:
testDataset = BaseDataset(testFeatures, testLabels)
testDataloader = DataLoader(testDataset, batch_size = 32, shuffle = False)

In [11]:
model = spdNet.learnSPDMatrices(numberAlphabets).to(device)
numParams = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(numParams)
lossFunction = nn.CrossEntropyLoss()
spdOptimizer = optimizers.MixOptimizer(model.parameters(), lr = 0.05)

7926


In [12]:
checkpoint = torch.load('Experiment2/' + subject + '/spdNet.pt')

In [13]:
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [14]:
def testOperationk(model, device, dataloader, Loss, k):
    with torch.no_grad():
        model.eval()
        testLoss, accuracy, correct, total = 0, 0, 0, 0
        PREDICT = []

        for data, target in dataloader:
            target = target.long()
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = Loss(output, target)
            
            testLoss += loss.data.item()
            _, topkPreds = torch.topk(output.data, k, dim = 1)
            total += target.size(0)
            correct += sum([target[i].item() in topkPreds[i].cpu().numpy() for i in range(target.size(0))])
            PREDICT.append(topkPreds.cpu().numpy())
    accuracy = 100. * correct/total
    return testLoss/total, accuracy, PREDICT

In [15]:
testLoss, testAccuracy1, prediction = testOperationk(model, device, testDataloader, lossFunction, 1)
testLoss, testAccuracy2, prediction = testOperationk(model, device, testDataloader, lossFunction, 2)
testLoss, testAccuracy3, prediction = testOperationk(model, device, testDataloader, lossFunction, 3)
testLoss, testAccuracy4, prediction = testOperationk(model, device, testDataloader, lossFunction, 4)
testLoss, testAccuracy5, prediction = testOperationk(model, device, testDataloader, lossFunction, 5)

In [16]:
print(testLoss)

0.10309173213742209


In [17]:
print("Top-1 accuracy: ", testAccuracy1)
print("Top-2 accuracy: ", testAccuracy2)
print("Top-3 accuracy: ", testAccuracy3)
print("Top-4 accuracy: ", testAccuracy4)
print("Top-5 accuracy: ", testAccuracy5)

Top-1 accuracy:  20.74281709880869
Top-2 accuracy:  35.529081990189205
Top-3 accuracy:  45.900490539593555
Top-4 accuracy:  52.83812193412754
Top-5 accuracy:  58.51436580238262


In [18]:
%autoreload