In [1]:
from nilearn import datasets, plotting, image
import nibabel as nib
from nilearn.maskers import NiftiLabelsMasker
from nilearn.connectome import ConnectivityMeasure
from nilearn.interfaces.fmriprep import load_confounds
import numpy as np
import os
import pandas as pd

import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, GRU, BatchNorm1d
from torch_geometric.nn import EdgeConv, GCNConv, GraphConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import InMemoryDataset, Data, DataLoader
from torch_geometric.utils import from_networkx

import networkx as nx
from networkx.convert_matrix import from_numpy_array

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
import torch_geometric as tg

from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split

import wandb
import random
import functions as f
import torch
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

In [5]:
#generate correlation functions
corr_matrices_root = 'ADNI_full/corr_matrices'
full_corr_path_list = sorted(os.listdir(corr_matrices_root))
corr_matrices_full = []
for i in range(len(full_corr_path_list)):
    path = os.path.join(corr_matrices_root, full_corr_path_list[i])
    corr_matrices_full.append(np.loadtxt(path, delimiter=','))

In [6]:
#generate the label files
labels_full = np.loadtxt('ADNI_full/label_full.csv', dtype=str, delimiter=',')

In [7]:
##remove "SMC" values
idx = f.filter_SMC_patient_info()
corr_matrices = [corr_matrices_full[i] for i in idx]
labels = [labels_full[i] for i in idx]
for i in range(len(labels)):
    if labels[i] == 'CN':
        labels[i] = 0

    elif labels[i] == 'EMCI' or labels[i] == 'MCI' or labels[i] == 'LMCI':
        labels[i] = 1

    elif labels[i] == 'AD':
        labels[i] = 2

    else:
        print('Error: incorrect label')

In [4]:
#binary classification
cn = filter_group('CN')
ad = filter_group('AD')
bin_idx = sorted(cn + ad)
corr_matrices = [corr_matrices_full[i] for i in bin_idx]
labels = [labels_full[i] for i in bin_idx]
for i in range(len(labels)):
    if labels[i] == 'CN':
        labels[i] = 0

    elif labels[i] == 'AD':
        labels[i] = 1

    else:
        print('Error: incorrect label')

NameError: name 'filter_group' is not defined

In [8]:
assert len(labels) == len(corr_matrices)

In [10]:
class ADNI_dataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, threshold=0.4):
        self.threshold = threshold
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        """ Converts raw data into GNN-readable format by constructing
        graphs out of connectivity matrices.

        """
        graphs=[]
        for i in range(len(corr_matrices)):
            corr_matrix = corr_matrices[i]
            n_rois = corr_matrix.shape[0]
            edge_matrix = np.zeros((n_rois,n_rois))
            for j in range(n_rois):
                for k in range(n_rois):
                    if np.abs(corr_matrix[ j , k ]) < self.threshold:
                        edge_matrix[ j , k ] = 0
                    else:
                        edge_matrix[ j , k ] = corr_matrix[ j , k]

            corr_matrix_nx = from_numpy_array(edge_matrix)

            deg_dict = dict(corr_matrix_nx.degree())
            bc_dict = nx.betweenness_centrality(corr_matrix_nx)
            cc_dict = nx.clustering(corr_matrix_nx)
            # Compute the global efficiency of the graph
            ge = nx.global_efficiency(corr_matrix_nx)

            le_dict = {}

            # loop over all nodes in the graph
            for node in corr_matrix_nx.nodes():
                # find the subgraph of neighbors of the current node
                subgraph = corr_matrix_nx.subgraph(corr_matrix_nx.neighbors(node))

                # calculate the efficiency of the subgraph
                if subgraph.number_of_nodes() > 1:
                    efficiency = nx.global_efficiency(subgraph)
                else:
                    efficiency = 0.0

                # store the efficiency in the dictionary
                le_dict[node] = efficiency


            # Compute the participation coefficient and ratio of local to global efficiency of each node
            ratio_le_ge = np.array(list(le_dict.values())) / ge

            # Convert the degree, participation coefficient, betweenness centrality, local efficiency, and ratio of local to global efficiency dictionaries to NumPy arrays
            deg_array = np.array(list(deg_dict.values()))
            bc_array = np.array(list(bc_dict.values()))
            le_array = np.array(list(le_dict.values()))



            cc_array = np.array(list(cc_dict.values()))
            ratio_le_ge_array = ratio_le_ge

            # Normalize the degree, participation coefficient, betweenness centrality, local efficiency, and ratio of local to global efficiency arrays to have zero mean and unit variance
            deg_array_norm = (deg_array - np.mean(deg_array)) / np.std(deg_array)
            bc_array_norm = (bc_array - np.mean(bc_array)) / np.std(bc_array)
            le_array_norm = (le_array - np.mean(le_array)) / np.std(le_array)
            ratio_le_ge_array_norm = (ratio_le_ge_array - np.mean(ratio_le_ge_array)) / np.std(ratio_le_ge_array)
            cc_array_norm = (cc_array - np.mean(cc_array)) / np.std(cc_array)

            # Concatenate the degree, participation coefficient, betweenness centrality, local efficiency, and ratio of local to global efficiency arrays to form a single feature vector
            x_conc = torch.tensor(np.concatenate((deg_array_norm, bc_array_norm, le_array_norm, cc_array_norm, ratio_le_ge_array_norm)), dtype=torch.float)
            x = torch.reshape(x_conc , (5 , n_rois)).T


            corr_matrix_data = from_networkx(corr_matrix_nx)
            corr_matrix_data.x = x
            corr_matrix_data.y = labels[i]
            #pcorr_matrix_data.pos = coordinates

            # Add to running list of all dataset items
            graphs.append(corr_matrix_data)

        data, slices = self.collate(graphs)
        torch.save((data, slices), self.processed_paths[0])

78


In [11]:
dataset = ADNI_dataset('ADNI_0.5')

print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

In [12]:
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels, dropout, input_dim ,output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.fc = torch.nn.Linear(hidden_channels, output_dim)
        self.dropout = dropout

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.conv3(x, edge_index))
        x = global_mean_pool(x, batch)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

Processing...



Dataset: ADNI_dataset(78):
Number of graphs: 78
Number of features: 5
Number of classes: 2

Data(edge_index=[2, 444], weight=[444], x=[116, 5], y=[1], num_nodes=116)
Number of nodes: 116
Number of edges: 444
Average node degree: 3.83
Has isolated nodes: True
Has self-loops: False
Is undirected: True


Done!


In [13]:
def train():
    model.train()
    for data in train_loader:  # Iterate in batches over the training dataset.
         #out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         out = model(data)
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.#
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.

In [114]:
tot_test_acc = []
kf = StratifiedKFold(n_splits=5, shuffle=False)
for train_val_idx, test_idx in kf.split(dataset, dataset.data.y):
    X_train_val = [dataset[i] for i in train_val_idx]
    X_test      = [dataset[i] for i in test_idx]
    Y_train_val = [dataset.data.y[i] for i in train_val_idx]
    Y_test      = [dataset.data.y[i] for i in test_idx]

    X_train, X_valid, Y_train, Y_valid = train_test_split(X_train_val, Y_train_val , test_size=0.125,
                                                    random_state=42, stratify=Y_train_val)

    print(f'Number of training graphs: {len(X_train)}')
    print(f'Number of validation graphs: {len(X_valid)}')
    print(f'Number of test graphs: {len(X_test)}')

    lab = [0 , 1 , 2]
    class_freq = []
    for i in lab:
        class_freq.append(np.count_nonzero(torch.Tensor(Y_train) == i))
    class_freq = torch.FloatTensor(class_freq)
    class_weights = 1 / class_freq
    class_weights /= class_weights.sum()

    train_loader = DataLoader(X_train, batch_size=16, shuffle=True)
    valid_loader = DataLoader(X_valid, batch_size=len(X_valid), shuffle=True)
    test_loader = DataLoader(X_test, batch_size=len(X_test), shuffle=False)

    model = GCN(hidden_channels=8, dropout=0.5, input_dim=dataset.num_node_features, output_dim=dataset.num_classes)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
    criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

    model.train()
    for epoch in range(200):
        total_loss = 0
        for data in train_loader:
            optimizer.zero_grad()
            out = model(data.x, data.edge_index, data.batch)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * data.num_graphs
            model.eval()
            correct = 0
            for data_v in valid_loader:
                out = model(data_v.x, data_v.edge_index, data_v.batch)
                pred = out.argmax(dim=1)
                print(pred)
                correct += int((pred == data_v.y).sum())
            acc = correct / len(valid_loader.dataset)
        print(f'Epoch {epoch}, Loss {total_loss / len(dataset)}, Valid Accuracy: {acc:.4f}')


    model.eval()
    correct = 0
    for data in test_loader:
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum())
    acc = correct / len(test_loader.dataset)
    print(f'Test Accuracy: {acc:.4f}')
    tot_test_acc.append(acc)

print(f'Average Test Accuracy: {sum(tot_test_acc) / len(tot_test_acc)}')
print(f'Max test accuracy: {max(tot_test_acc)}')
print(f'Standard Deviation: {np.std(tot_test_acc)}')



Number of training graphs: 119
Number of validation graphs: 18
Number of test graphs: 35
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Epoch 0, Loss 0.7808080738367036, Valid Accuracy: 0.2778
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 



tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Epoch 0, Loss 0.7628925394180209, Valid Accuracy: 0.2778
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Epoch 1, Loss 0.760597419600154, Valid Accuracy: 0.2778
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tenso

KeyboardInterrupt: 

In [16]:
def filter_group(group):
    df = pd.read_csv('/Users/georgepulickal/Documents/ADNI_FULL/patient_info.csv')
    labels = df['Research Group']
    label_idx_list = [i for i in range(len(labels)) if labels[i] == group]
    return label_idx_list



Number of training graphs: 54
Number of validation graphs: 8
Number of test graphs: 16
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 0, Loss 0.49248994772250837, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 1, Loss 0.4892475070097508, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 2, Loss 0.4872228579643445, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 3, Loss 0.48414854819958025, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 4, Loss 0.4840450



tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 0, Loss 0.4833725644991948, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 1, Loss 0.4811198543279599, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 2, Loss 0.4797940697425451, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 3, Loss 0.4801004712398236, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 4, Loss 0.4788733399831332, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0,



tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 1, Loss 0.48368711807788944, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 2, Loss 0.48224825583971465, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 3, Loss 0.4825303630951123, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 4, Loss 0.48192053727614576, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 5, Loss 0.48187375679994243, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0



tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 0, Loss 0.4924958424690442, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 1, Loss 0.49063305518566036, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 2, Loss 0.4902943044136732, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 3, Loss 0.4904091511017237, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 4, Loss 0.4899148757641132, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 5, Loss 0.49014



tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 0, Loss 0.5012639050300305, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 1, Loss 0.4964125591974992, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 2, Loss 0.49446819455195695, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 3, Loss 0.4952301772741171, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Epoch 4, Loss 0.4932329081572019, Valid Accuracy: 0.6250
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0

In [132]:
cn = filter_group('CN')
ad = filter_group('AD')
print(len(ad))

31
