## Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import math
import scipy.io as sio
from pathlib import Path

## Load data

### Fmri data

In [2]:
NUM_SUBJS = 8
subjects_fmri = [] #stores all 8 subject fmri np arrays

fMRI_folder = Path('./doi_10.5061_dryad.gt413__v1')
assert fMRI_folder.exists(), f"Foldder: {fMRI_folder} does not exist."

for subj_id in range(8):
    fmri_file_name = str(subj_id) + '_masked_2d.npy'
    fmri = np.load(fMRI_folder / fmri_file_name)
    assert isinstance(fmri, np.ndarray), f"Imported fmri_scan for subject {subj_id} is not of type numpy.ndarray"
    assert(fmri.ndim) == 2, f"Imported fmri_scan for subject {subj_id} is not 2 dimensional"
    subjects_fmri.append(fmri)

### Word features

In [3]:
feature_matrix = np.zeros((5176,195)) #stores the feature vectors as a row for each word
feature_names = [] #stores the names of all features in order
feature_types = {} #stores the types of features and all the names of the features for each type

features = sio.loadmat(fMRI_folder / 'story_features.mat')
feature_count = 0
for feature_type in features['features'][0]:
    feature_types[feature_type[0][0]] = []
    if isinstance(feature_type[1][0], str):
        feature_types[feature_type[0][0]].append(feature_type[1][0])
        feature_names.append(feature_type[1][0])
    else:
        for feature in feature_type[1][0]:
            feature_types[feature_type[0][0]].append(feature[0])
            feature_names.append(feature[0])
    feature_matrix[:, feature_count:feature_count+feature_type[2].shape[1]] = feature_type[2] #adds the (5176xN) feature values to the feature matrix for the current feature group
    feature_count += feature_type[2].shape[1]

### Word values and timings

In [4]:
words_info = [] #stores tuples of (word, time, features) sorted by time appeared

mat_file = fMRI_folder / 'subject_1.mat' #only looks at the first subject file, somewhere it said all the timings were the same so this should be safe
mat_contents = sio.loadmat(mat_file)
for count, row in enumerate(mat_contents['words'][0]):
    word_value = row[0][0][0][0]
    time = row[1][0][0]
    word_tuple = (word_value, time, feature_matrix[count,:])
    words_info.append(word_tuple)

## Align fmri and word features

Still working on this part - Harrison

In [5]:
#class for storing all the information for each sample

class sample: 
    def __init__(self, subj_id, time, input_voxels, output_voxels, input_words, input_word_features, output_words, output_word_features):
        self.subj_id = subj_id #id of subject
        
        self.time = time #time at which the scan occurred
        
        assert input_voxels.shape[0] == 4, "input_voxels does not contain 4 scans"
        self.input_voxels = input_voxels #2d array of 4 TRs of voxels at time of scan
        
        assert output_voxels.shape[0] == 4, "output_voxels does not contain 4 scans"
        self.output_voxels = output_voxels #2d array of 4 TRs of voxels associated with next 4 words
        
        assert len(input_words) == 4, "input_words does not contain 4 words"
        self.input_words = input_words #list of 4 words associated with scan
        
        assert len(output_words) == 4, "output_words does not contain 4 words"
        self.output_words = output_words #list of 4 words that come after scan
        
        assert input_word_features.shape[0] == 4, "input_word_features does not contain 4 words"
        self.input_word_features = input_word_features #np array of size (4,nFeatures) storing the features for the 4 words
        
        assert output_word_features.shape[0] == 4, "output_word_features does not contain 4 words"
        self.output_word_features = output_word_features #np array of size (4,nFeatures) storing the features for the 4 words afterwards
    
    def get_subj_id(self):
        return self.subj_id
    
    def get_time(self):
        return self.time
    
    def get_input_voxels(self):
        return self.input_voxels
    
    def get_output_voxels(self):
        return self.output_voxels
    
    def get_input_words(self):
        return self.input_words
    
    def get_input_word_features(self):
        return self.input_word_features
    
    def get_output_words(self):
        return self.output_words
    
    def get_output_word_features(self):
        return self.output_word_features

In [6]:
subjects_samples = [[] for i in range(NUM_SUBJS)] #stores lists of all the samples for each subject

#still working on this, need to deal with the issue where a rest happens 
word_count = 0
while word_count < len(words_info) - 8:
    #gets the 4 input words, and the 4 consecutive words while verifying they were read in sequence
    scan_words = []
    start_time = words_info[word_count][1]
    in_sequence = True #tracks if the words are in sequence or not
    for i in range(8):
        word_info = words_info[word_count + i]
        if word_info[1] != start_time + 0.5*i:
            #if some of the words are not in sequence, skip forward 1 word after innter loop
            in_sequence = False
        scan_words.append(word_info[0])
    if not in_sequence:
        word_count +=1
        continue
    #gets word features for input and output words
    input_word_features = feature_matrix[word_count:word_count+4, :]
    output_word_features = feature_matrix[word_count+4:word_count+8,:]
    #gets index of fmri that comes 2 seconds after reading first word
    fmri_time = start_time + 2 #effect of reading words is assumed to start 2 seconds after and end 8 seconds after
    fmri_index = fmri_time//2 #since a scan happens every two seconds, the index is the time divided by 2
    if not isinstance(fmri_index, np.int32):
        #if the first word is not aligned with the fmri scan (i.e. its not the first word in a TR)
        word_count += 1
        continue
    for count, subject in enumerate(subjects_fmri):
        new_sample = sample(count, 
                            start_time, 
                            subject[fmri_index:fmri_index+4,:], #gets the scans 2,4,6,8 seconds after reading
                            subject[fmri_index+2:fmri_index+6,:], #gets the scans 4,6,8,10 seconds after reading
                            scan_words[0:4], 
                            input_word_features,
                            scan_words[4:8],
                            output_word_features)
        subjects_samples[count].append(new_sample)
    print("Created sample:")
    print("\tScan time:", str(start_time))
    print("\tInput words:", str(scan_words[0:4]))
    print("\tOutput_words", str(scan_words[4:8]))
    #if successful, skip forward to the next set of 4 words
    word_count += 4

print("Total number of samples:", str(len(subjects_samples[0])))

Created sample:
	Scan time: 20
	Input words: ['Harry', 'had', 'never', 'believed']
	Output_words ['he', 'would', 'meet', 'a']
Created sample:
	Scan time: 22
	Input words: ['he', 'would', 'meet', 'a']
	Output_words ['boy', 'he', 'hated', 'more']
Created sample:
	Scan time: 24
	Input words: ['boy', 'he', 'hated', 'more']
	Output_words ['than', 'Dudley,', 'but', 'that']
Created sample:
	Scan time: 26
	Input words: ['than', 'Dudley,', 'but', 'that']
	Output_words ['was', 'before', 'he', 'met']
Created sample:
	Scan time: 28
	Input words: ['was', 'before', 'he', 'met']
	Output_words ['Draco', 'Malfoy.', 'Still,', 'first-year']
Created sample:
	Scan time: 30
	Input words: ['Draco', 'Malfoy.', 'Still,', 'first-year']
	Output_words ['Gryffindors', 'only', 'had', 'Potions']
Created sample:
	Scan time: 32
	Input words: ['Gryffindors', 'only', 'had', 'Potions']
	Output_words ['with', 'the', 'Slytherins,', 'so']
Created sample:
	Scan time: 34
	Input words: ['with', 'the', 'Slytherins,', 'so']
	Out

	Input words: ['hand', '--', 'a', 'foot']
	Output_words ['from', 'the', 'ground', 'he']
Created sample:
	Scan time: 896
	Input words: ['from', 'the', 'ground', 'he']
	Output_words ['caught', 'it,', 'just', 'in']
Created sample:
	Scan time: 898
	Input words: ['caught', 'it,', 'just', 'in']
	Output_words ['time', 'to', 'pull', 'his']
Created sample:
	Scan time: 900
	Input words: ['time', 'to', 'pull', 'his']
	Output_words ['broom', 'straight,', 'and', 'he']
Created sample:
	Scan time: 902
	Input words: ['broom', 'straight,', 'and', 'he']
	Output_words ['toppled', 'gently', 'onto', 'the']
Created sample:
	Scan time: 904
	Input words: ['toppled', 'gently', 'onto', 'the']
	Output_words ['grass', 'with', 'the', 'Remembrall']
Created sample:
	Scan time: 906
	Input words: ['grass', 'with', 'the', 'Remembrall']
	Output_words ['clutched', 'safely', 'in', 'his']
Created sample:
	Scan time: 908
	Input words: ['clutched', 'safely', 'in', 'his']
	Output_words ['fist.', '+', '"HARRY', 'POTTER!"']
Cre

Created sample:
	Scan time: 2102
	Input words: ['"They\'re', 'in', 'here', 'somewhere,"']
	Output_words ['they', 'heard', 'him', 'mutter,']
Created sample:
	Scan time: 2104
	Input words: ['they', 'heard', 'him', 'mutter,']
	Output_words ['"probably', 'hiding."', '+', '"This']
Created sample:
	Scan time: 2106
	Input words: ['"probably', 'hiding."', '+', '"This']
	Output_words ['way!"', 'Harry', 'mouthed', 'to']
Created sample:
	Scan time: 2108
	Input words: ['way!"', 'Harry', 'mouthed', 'to']
	Output_words ['the', 'others', 'and,', 'petrified,']
Created sample:
	Scan time: 2110
	Input words: ['the', 'others', 'and,', 'petrified,']
	Output_words ['they', 'began', 'to', 'creep']
Created sample:
	Scan time: 2112
	Input words: ['they', 'began', 'to', 'creep']
	Output_words ['down', 'a', 'long', 'gallery']
Created sample:
	Scan time: 2114
	Input words: ['down', 'a', 'long', 'gallery']
	Output_words ['full', 'of', 'suits', 'of']
Created sample:
	Scan time: 2116
	Input words: ['full', 'of', 's

## Initialize model

In [114]:
class LinearModel(nn.Module):
    def __init__(self, num_features, num_outputs):
        super(LinearModel, self).__init__()
        self.linear = nn.Linear(num_features, num_outputs)

    def forward(self, x):
        out = self.linear(x)
        return out

## Build samples and separate into folds

In [115]:
subject_0_samples = subjects_samples[0] #only using the first subjects samples currently
fmri_size = subject_0_samples[0].get_input_voxels().shape[1]
word_feature_size = subject_0_samples[0].get_input_word_features().shape[1]
samples = np.zeros((len(subject_0_samples), (fmri_size) + (word_feature_size * 8)))
#samples = np.zeros((len(subject_0_samples), word_feature_size * 8))
for i in range(len(subject_0_samples)):
    samples[i, :fmri_size] = subject_0_samples[i].get_input_voxels()[2] #get feature 6 seconds after reading 4 words
#     samples[i, fmri_size: fmri_size*2] = subject_0_samples[i].get_output_voxels()[2] #get feature 6 seconds after reading next 4 words
    samples[i, fmri_size:fmri_size + word_feature_size*4] = subject_0_samples[i].get_input_word_features().flatten()
#     samples[i, word_feature_size*4:] = np.sum(subject_0_samples[i].get_output_word_features(), axis=0)
    samples[i, fmri_size + word_feature_size*4:] = subject_0_samples[i].get_output_word_features().flatten()

In [116]:
NUM_FOLDS = 3
np.random.shuffle(samples) #shuffles rows
print(samples.shape)
folds = np.array(np.split(samples, NUM_FOLDS))
print(folds.shape)

(1287, 43333)
(3, 429, 43333)


## Train model

In [131]:
#performs stochastic gradient descent for num_epoch epochs
def train_model(model, samples, num_features, num_outputs, alpha=1e-7, momentum=0.9, num_epochs=50):
    print("Training for ", num_epochs, " epochs:")
    optimizer = torch.optim.SGD(model.parameters(), lr=alpha, momentum=momentum)
    loss_fn = nn.MSELoss()
    for epoch in range(num_epochs):
        epoch_loss = 0
        #randomly shuffle samples 
        np.random.shuffle(samples)
        for sample in samples:
            optimizer.zero_grad()
            x = torch.tensor(sample[:num_features]).float().to(device)
            y = torch.tensor(sample[num_features:]).float().to(device)
            pred = model(x)
            loss = loss_fn(pred, y)
            loss.backward()
            epoch_loss += loss.item()
            optimizer.step()
        if epoch % 1 == 0:
            epoch_accuracy = get_accuracy(model, samples, num_features, num_outputs)
            print("\tEpoch: ", epoch, ", Train Loss: ", epoch_loss, ", Train Accuracy: ", epoch_accuracy)
    epoch_accuracy = get_accuracy(model, samples, num_features, num_outputs)
    print("\tEpoch: ", epoch, ", Train Loss: ", epoch_loss, ", Train Accuracy: ", epoch_accuracy)

In [132]:
#gets the model accuracy by getting the prediction for a sample and seeing 
#if it closer to the labels of that sample or another randomly selected sample
def get_accuracy(model, samples, num_features, num_outputs):
    correct = 0
    for sample in samples:
        #gets the input and labels from current sample
        correct_x = torch.tensor(sample[:num_features]).float().to(device)
        correct_labels = torch.tensor(sample[num_features:]).float().to(device)
        #gets labels from a random sample
        rand = random.randint(0, samples.shape[0]-1)
        random_labels = torch.tensor(samples[rand,num_features:]).float().to(device)
        #gets prediction on current sample and computes euclidean distances from both correct and random labels
        pred = model(correct_x)
        correct_distance = torch.linalg.norm(correct_labels - pred)
        #print("correct distance: ", str(correct_distance))
        random_distance = torch.linalg.norm(random_labels - pred)
        #print("random distance: ", str(random_distance))
        #if distance to correct labels is less than distance to random labels then the prediction is considered correct
        if correct_distance < random_distance:
            correct += 1
    return correct / samples.shape[0]

In [133]:
#for each fold split, creates a model and trains it on n-1 folds and then tests it on the last fold
def cross_validate(folds, num_folds, num_features, num_outputs, num_epochs=50):
    for i in range(num_folds):
        print("Fold: ", i)
        model = LinearModel(num_features, num_outputs).to(device)
        loss_fn = nn.MSELoss()
        #gets all folds except one and reshapes them into a 2d array
        train_samples = np.reshape(np.delete(folds, i, axis=0), (folds.shape[1]*(num_folds-1), folds.shape[2])) 
        test_samples = folds[i]
        train_model(model, train_samples, num_features, num_outputs, num_epochs=num_epochs)
        test_loss = 0
        for sample in test_samples:
            x = torch.tensor(sample[:num_features]).float().to(device)
            y = torch.tensor(sample[num_features:]).float().to(device)
            pred = model(x)
            loss = loss_fn(pred, y)
            test_loss += loss.item()
        test_accuracy = get_accuracy(model, test_samples, num_features, num_outputs)
        print("\tFold: ", i, ", Test Loss: ", test_loss, ", Test Accuracy: ", test_accuracy)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
cross_validate(folds, NUM_FOLDS, fmri_size + word_feature_size*4, word_feature_size*4, num_epochs=500)

cuda
Fold:  0
Training for  500  epochs:
	Epoch:  0 , Train Loss:  1025102.5688171387 , Train Accuracy:  0.49417249417249415
	Epoch:  1 , Train Loss:  110160.85584640503 , Train Accuracy:  0.5104895104895105
	Epoch:  2 , Train Loss:  104102.3522529602 , Train Accuracy:  0.4965034965034965
	Epoch:  3 , Train Loss:  91302.08600616455 , Train Accuracy:  0.48834498834498835
	Epoch:  4 , Train Loss:  88250.85731124878 , Train Accuracy:  0.5046620046620046
	Epoch:  5 , Train Loss:  79035.21173095703 , Train Accuracy:  0.47086247086247085
	Epoch:  6 , Train Loss:  76621.67371749878 , Train Accuracy:  0.5233100233100233
	Epoch:  7 , Train Loss:  72748.39073944092 , Train Accuracy:  0.5151515151515151
	Epoch:  8 , Train Loss:  70307.11055755615 , Train Accuracy:  0.5046620046620046
	Epoch:  9 , Train Loss:  69865.36672973633 , Train Accuracy:  0.479020979020979
	Epoch:  10 , Train Loss:  66186.27717971802 , Train Accuracy:  0.4976689976689977
	Epoch:  11 , Train Loss:  64394.73640060425 , Train

	Epoch:  99 , Train Loss:  35294.34007644653 , Train Accuracy:  0.506993006993007
	Epoch:  100 , Train Loss:  35551.26582527161 , Train Accuracy:  0.5303030303030303
	Epoch:  101 , Train Loss:  38598.607067108154 , Train Accuracy:  0.4662004662004662
	Epoch:  102 , Train Loss:  36179.10408401489 , Train Accuracy:  0.4766899766899767
	Epoch:  103 , Train Loss:  35965.091526031494 , Train Accuracy:  0.506993006993007
	Epoch:  104 , Train Loss:  35585.167823791504 , Train Accuracy:  0.5093240093240093
	Epoch:  105 , Train Loss:  35121.51633834839 , Train Accuracy:  0.5174825174825175
	Epoch:  106 , Train Loss:  35733.762533187866 , Train Accuracy:  0.5571095571095571
	Epoch:  107 , Train Loss:  38011.03334236145 , Train Accuracy:  0.5174825174825175
	Epoch:  108 , Train Loss:  37298.80922317505 , Train Accuracy:  0.513986013986014
	Epoch:  109 , Train Loss:  36264.79595184326 , Train Accuracy:  0.5163170163170163
	Epoch:  110 , Train Loss:  35491.42925834656 , Train Accuracy:  0.497668997

	Epoch:  197 , Train Loss:  30525.813556671143 , Train Accuracy:  0.5186480186480187
	Epoch:  198 , Train Loss:  31502.944679260254 , Train Accuracy:  0.5
	Epoch:  199 , Train Loss:  30472.469877243042 , Train Accuracy:  0.5104895104895105
	Epoch:  200 , Train Loss:  28824.93549156189 , Train Accuracy:  0.5
	Epoch:  201 , Train Loss:  29601.339752197266 , Train Accuracy:  0.4976689976689977
	Epoch:  202 , Train Loss:  29870.273433685303 , Train Accuracy:  0.5058275058275058
	Epoch:  203 , Train Loss:  30617.156578063965 , Train Accuracy:  0.4696969696969697
	Epoch:  204 , Train Loss:  30702.411211013794 , Train Accuracy:  0.5034965034965035
	Epoch:  205 , Train Loss:  29071.629495620728 , Train Accuracy:  0.5128205128205128
	Epoch:  206 , Train Loss:  28796.667585372925 , Train Accuracy:  0.5058275058275058
	Epoch:  207 , Train Loss:  28400.508710861206 , Train Accuracy:  0.513986013986014
	Epoch:  208 , Train Loss:  30216.748067855835 , Train Accuracy:  0.4988344988344988
	Epoch:  209

	Epoch:  295 , Train Loss:  26265.739570617676 , Train Accuracy:  0.5011655011655012
	Epoch:  296 , Train Loss:  25549.017435073853 , Train Accuracy:  0.5128205128205128
	Epoch:  297 , Train Loss:  26874.791015625 , Train Accuracy:  0.5
	Epoch:  298 , Train Loss:  27016.092670440674 , Train Accuracy:  0.5384615384615384
	Epoch:  299 , Train Loss:  25825.954540252686 , Train Accuracy:  0.5128205128205128
	Epoch:  300 , Train Loss:  25154.544946670532 , Train Accuracy:  0.4825174825174825
	Epoch:  301 , Train Loss:  26799.300714492798 , Train Accuracy:  0.4965034965034965
	Epoch:  302 , Train Loss:  26461.381748199463 , Train Accuracy:  0.5337995337995338
	Epoch:  303 , Train Loss:  26842.902069091797 , Train Accuracy:  0.5128205128205128
	Epoch:  304 , Train Loss:  25400.984632492065 , Train Accuracy:  0.506993006993007
	Epoch:  305 , Train Loss:  26266.038816452026 , Train Accuracy:  0.5291375291375291
	Epoch:  306 , Train Loss:  25872.078393936157 , Train Accuracy:  0.5174825174825175

	Epoch:  393 , Train Loss:  22911.19664955139 , Train Accuracy:  0.5023310023310024
	Epoch:  394 , Train Loss:  23091.27799320221 , Train Accuracy:  0.5116550116550117
	Epoch:  395 , Train Loss:  22595.063990592957 , Train Accuracy:  0.5186480186480187
	Epoch:  396 , Train Loss:  23897.56623363495 , Train Accuracy:  0.5874125874125874
	Epoch:  397 , Train Loss:  23219.675636291504 , Train Accuracy:  0.5093240093240093
	Epoch:  398 , Train Loss:  22856.05347442627 , Train Accuracy:  0.5093240093240093
	Epoch:  399 , Train Loss:  24330.48450279236 , Train Accuracy:  0.5792540792540792
	Epoch:  400 , Train Loss:  22282.56940460205 , Train Accuracy:  0.5174825174825175
	Epoch:  401 , Train Loss:  22599.767177581787 , Train Accuracy:  0.5186480186480187
	Epoch:  402 , Train Loss:  23259.876014709473 , Train Accuracy:  0.5198135198135199
	Epoch:  403 , Train Loss:  22634.854122161865 , Train Accuracy:  0.578088578088578
	Epoch:  404 , Train Loss:  25132.448266983032 , Train Accuracy:  0.6212

	Epoch:  491 , Train Loss:  23467.879214286804 , Train Accuracy:  0.5407925407925408
	Epoch:  492 , Train Loss:  21668.15242958069 , Train Accuracy:  0.5186480186480187
	Epoch:  493 , Train Loss:  23194.121576309204 , Train Accuracy:  0.5512820512820513
	Epoch:  494 , Train Loss:  20504.008853912354 , Train Accuracy:  0.5512820512820513
	Epoch:  495 , Train Loss:  20754.273182868958 , Train Accuracy:  0.5396270396270396
	Epoch:  496 , Train Loss:  21139.377179145813 , Train Accuracy:  0.527972027972028
	Epoch:  497 , Train Loss:  20479.864924430847 , Train Accuracy:  0.6072261072261073
	Epoch:  498 , Train Loss:  20527.138355255127 , Train Accuracy:  0.548951048951049
	Epoch:  499 , Train Loss:  20416.6328458786 , Train Accuracy:  0.5233100233100233
	Epoch:  499 , Train Loss:  20416.6328458786 , Train Accuracy:  0.5291375291375291
	Fold:  0 , Test Loss:  9402.655109405518 , Test Accuracy:  0.4988344988344988
Fold:  1
Training for  500  epochs:
	Epoch:  0 , Train Loss:  900113.748603820

	Epoch:  88 , Train Loss:  36764.95198249817 , Train Accuracy:  0.5337995337995338
	Epoch:  89 , Train Loss:  38662.53456115723 , Train Accuracy:  0.4988344988344988
	Epoch:  90 , Train Loss:  38035.02023124695 , Train Accuracy:  0.5198135198135199
	Epoch:  91 , Train Loss:  37965.80377006531 , Train Accuracy:  0.4976689976689977
	Epoch:  92 , Train Loss:  37426.00679397583 , Train Accuracy:  0.4813519813519814
	Epoch:  93 , Train Loss:  39642.46802330017 , Train Accuracy:  0.534965034965035
	Epoch:  94 , Train Loss:  37939.55895614624 , Train Accuracy:  0.48834498834498835
	Epoch:  95 , Train Loss:  37412.21424674988 , Train Accuracy:  0.5337995337995338
	Epoch:  96 , Train Loss:  35089.5016784668 , Train Accuracy:  0.486013986013986
	Epoch:  97 , Train Loss:  35876.693771362305 , Train Accuracy:  0.5233100233100233
	Epoch:  98 , Train Loss:  36380.175884246826 , Train Accuracy:  0.5128205128205128
	Epoch:  99 , Train Loss:  37185.73833847046 , Train Accuracy:  0.513986013986014
	Epoc

	Epoch:  186 , Train Loss:  31085.70902824402 , Train Accuracy:  0.5431235431235432
	Epoch:  187 , Train Loss:  29701.733924865723 , Train Accuracy:  0.5011655011655012
	Epoch:  188 , Train Loss:  30986.323846817017 , Train Accuracy:  0.48717948717948717
	Epoch:  189 , Train Loss:  31570.818475723267 , Train Accuracy:  0.5256410256410257
	Epoch:  190 , Train Loss:  32018.07599067688 , Train Accuracy:  0.527972027972028
	Epoch:  191 , Train Loss:  28848.772075653076 , Train Accuracy:  0.5186480186480187
	Epoch:  192 , Train Loss:  29580.43487739563 , Train Accuracy:  0.5116550116550117
	Epoch:  193 , Train Loss:  29618.711038589478 , Train Accuracy:  0.4965034965034965
	Epoch:  194 , Train Loss:  29076.269857406616 , Train Accuracy:  0.5081585081585082
	Epoch:  195 , Train Loss:  31690.259393692017 , Train Accuracy:  0.5815850815850816
	Epoch:  196 , Train Loss:  32762.641164779663 , Train Accuracy:  0.5058275058275058
	Epoch:  197 , Train Loss:  29467.75159072876 , Train Accuracy:  0.5

	Epoch:  284 , Train Loss:  26986.70409965515 , Train Accuracy:  0.5221445221445221
	Epoch:  285 , Train Loss:  26608.117395401 , Train Accuracy:  0.5442890442890443
	Epoch:  286 , Train Loss:  26869.527296066284 , Train Accuracy:  0.5034965034965035
	Epoch:  287 , Train Loss:  25811.060800552368 , Train Accuracy:  0.5023310023310024
	Epoch:  288 , Train Loss:  25898.282348632812 , Train Accuracy:  0.48484848484848486
	Epoch:  289 , Train Loss:  25576.698776245117 , Train Accuracy:  0.5431235431235432
	Epoch:  290 , Train Loss:  26987.650791168213 , Train Accuracy:  0.4755244755244755
	Epoch:  291 , Train Loss:  26308.47180557251 , Train Accuracy:  0.486013986013986
	Epoch:  292 , Train Loss:  26300.489686965942 , Train Accuracy:  0.5372960372960373
	Epoch:  293 , Train Loss:  25731.1981048584 , Train Accuracy:  0.5163170163170163
	Epoch:  294 , Train Loss:  26701.338554382324 , Train Accuracy:  0.49067599067599066
	Epoch:  295 , Train Loss:  25629.198944091797 , Train Accuracy:  0.543

	Epoch:  382 , Train Loss:  22849.3450756073 , Train Accuracy:  0.5361305361305362
	Epoch:  383 , Train Loss:  23321.9152507782 , Train Accuracy:  0.5186480186480187
	Epoch:  384 , Train Loss:  22775.698093414307 , Train Accuracy:  0.5268065268065268
	Epoch:  385 , Train Loss:  22468.49440097809 , Train Accuracy:  0.6002331002331003
	Epoch:  386 , Train Loss:  23423.62002182007 , Train Accuracy:  0.4836829836829837
	Epoch:  387 , Train Loss:  25201.62550354004 , Train Accuracy:  0.5955710955710956
	Epoch:  388 , Train Loss:  23459.754244804382 , Train Accuracy:  0.5046620046620046
	Epoch:  389 , Train Loss:  23992.79571533203 , Train Accuracy:  0.5477855477855478
	Epoch:  390 , Train Loss:  24427.120782852173 , Train Accuracy:  0.5151515151515151
	Epoch:  391 , Train Loss:  22425.739183425903 , Train Accuracy:  0.5128205128205128
	Epoch:  392 , Train Loss:  22816.391801834106 , Train Accuracy:  0.5093240093240093
	Epoch:  393 , Train Loss:  23290.35464477539 , Train Accuracy:  0.493006

	Epoch:  479 , Train Loss:  21551.766553878784 , Train Accuracy:  0.5932400932400932
	Epoch:  480 , Train Loss:  22753.199312210083 , Train Accuracy:  0.49533799533799533
	Epoch:  481 , Train Loss:  21319.183799743652 , Train Accuracy:  0.5990675990675991
	Epoch:  482 , Train Loss:  20553.888019561768 , Train Accuracy:  0.5244755244755245
	Epoch:  483 , Train Loss:  21207.68498802185 , Train Accuracy:  0.5128205128205128
	Epoch:  484 , Train Loss:  21748.253987312317 , Train Accuracy:  0.5221445221445221
	Epoch:  485 , Train Loss:  22307.186884880066 , Train Accuracy:  0.5058275058275058
	Epoch:  486 , Train Loss:  21377.405329704285 , Train Accuracy:  0.6177156177156177
	Epoch:  487 , Train Loss:  21296.119289398193 , Train Accuracy:  0.506993006993007
	Epoch:  488 , Train Loss:  20632.754786491394 , Train Accuracy:  0.6655011655011654
	Epoch:  489 , Train Loss:  21225.582414627075 , Train Accuracy:  0.5046620046620046
	Epoch:  490 , Train Loss:  21048.61483001709 , Train Accuracy:  0