In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, BatchNorm, global_add_pool, ChebConv, global_max_pool, SAGPooling, GATConv, GATv2Conv, TransformerConv, SuperGATConv, global_mean_pool, Linear
from torch.nn import BatchNorm1d
from math import floor
import torch
import torch_geometric
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from pathlib import Path

In [2]:
class MultiLevelConvNet(nn.Module):
    """Same as EEGGraphConvNet but with fewer 
    convolutional layers
    """
    def __init__(self, **kwargs):
        super(MultiLevelConvNet, self).__init__()
        # Layers definition
        # Graph convolutional layers
        self.conv1 = GCNConv(-1, 32, cached=True, normalize=False)
        self.conv2 = GCNConv(32, 32, cached=True, normalize=False)
        self.conv3 = GCNConv(32, 64, cached=True, normalize=False)
        
        
        # Batch normalization
        self.batch_norm1 = BatchNorm(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.batch_norm2 = BatchNorm(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.batch_norm3 = BatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        
        
        self.fc1 = nn.Linear(32, 64)
        self.fc2 = nn.Linear(32, 64)
        self.fc3 = nn.Linear(64, 64)
        
        # Fully connected layers
        self.classifier = nn.Sequential(
            nn.Linear(192, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.LeakyReLU(),
            nn.Linear(32, 2),
        )
        
        # Xavier initializacion for fully connected layers
        self.fc1.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1) if isinstance(x, nn.Linear) else None)
        self.fc2.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1) if isinstance(x, nn.Linear) else None)
        self.fc3.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1) if isinstance(x, nn.Linear) else None)
        
        
    def forward(self, x, edge_index, edge_weigth, batch):
        x1 = F.leaky_relu(self.batch_norm1(self.conv1(x, edge_index, edge_weigth)), negative_slope=0.01)
        x2 = F.leaky_relu(self.batch_norm2(self.conv2(x1, edge_index, edge_weigth)), negative_slope=0.01)
        x3 = F.leaky_relu(self.batch_norm3(self.conv3(x2, edge_index, edge_weigth)), negative_slope=0.01)
        
        add_pool1 = global_add_pool(x1, batch=batch)
        add_pool2 = global_add_pool(x2, batch=batch)
        add_pool3 = global_add_pool(x3, batch=batch)
        
        out1 = F.leaky_relu(self.fc1(add_pool1), negative_slope=0.01)        
        out2 = F.leaky_relu(self.fc2(add_pool2), negative_slope=0.01)        
        out3 = F.leaky_relu(self.fc3(add_pool3), negative_slope=0.01)
        
        out = torch.cat((out1, out2, out3), dim=1)        
        out = self.classifier(out)
        return out

In [3]:
def create_corrected_data_list(path):
    data_list = list()
    for file in path.iterdir():
        data_list.append(torch.load(file))
    corrected_data_list = list()
    for data in data_list:
    # print(data)
        data = torch_geometric.data.Data(
            x=torch.tensor(data.x),
            edge_index=torch.tensor(data.edge_index),
            edge_attr=torch.tensor(data.edge_attr),
            label=torch.tensor(data.label),
        )
        corrected_data_list.append(data)
    
    rm = [
      7,
      14+1,
      14+2,
      17+3,
      17+4,
      26+5,
      38+6,
      54+7,
      65+8,
      69+9
      ]

    dl = list()
    start = 0
    for r in rm:
        dl.extend(corrected_data_list[start:r])
        start = r + 1

    dl.extend(corrected_data_list[start:])
    dl_filterd = list()
    for data in dl:
        if data.label == 2:
            # print(data.label)
            # if data.label == 2:
            data.label = torch.tensor(1)
        dl_filterd.append(data)

    len(dl_filterd)
    return dl_filterd
            
    

In [4]:
path = Path('graphs/moments_pearson/')
dl_filterd = create_corrected_data_list(path)
train_dl, test_dl = train_test_split(dl_filterd, test_size=0.2, random_state=47744)
train_dataloader = torch_geometric.loader.DataLoader(dl_filterd, batch_size=1, shuffle=False, num_workers=0)
test_dataloader = torch_geometric.loader.DataLoader(test_dl, batch_size=10, shuffle=False, num_workers=0)

  edge_index=torch.tensor(data.edge_index),


In [6]:
model = MultiLevelConvNet()

In [7]:
dl_filterd

[Data(x=[19, 6], edge_index=[2, 361], edge_attr=[19, 19], label=1),
 Data(x=[19, 6], edge_index=[2, 361], edge_attr=[19, 19], label=1),
 Data(x=[19, 6], edge_index=[2, 361], edge_attr=[19, 19], label=1),
 Data(x=[19, 6], edge_index=[2, 361], edge_attr=[19, 19], label=1),
 Data(x=[19, 6], edge_index=[2, 361], edge_attr=[19, 19], label=1),
 Data(x=[19, 6], edge_index=[2, 361], edge_attr=[19, 19], label=1),
 Data(x=[19, 6], edge_index=[2, 361], edge_attr=[19, 19], label=1),
 Data(x=[19, 6], edge_index=[2, 361], edge_attr=[19, 19], label=1),
 Data(x=[19, 6], edge_index=[2, 361], edge_attr=[19, 19], label=1),
 Data(x=[19, 6], edge_index=[2, 361], edge_attr=[19, 19], label=1),
 Data(x=[19, 6], edge_index=[2, 361], edge_attr=[19, 19], label=1),
 Data(x=[19, 6], edge_index=[2, 361], edge_attr=[19, 19], label=1),
 Data(x=[19, 6], edge_index=[2, 361], edge_attr=[19, 19], label=1),
 Data(x=[19, 6], edge_index=[2, 361], edge_attr=[19, 19], label=1),
 Data(x=[19, 6], edge_index=[2, 361], edge_attr=

In [8]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0001)
criterion = torch.nn.CrossEntropyLoss()

print(f"Model Params: {sum(p.numel() for p in model.parameters())}")
model.double()

ValueError: Attempted to use an uninitialized parameter in <method 'numel' of 'torch._C._TensorBase' objects>. This error happens when you are using a `LazyModule` or explicitly manipulating `torch.nn.parameter.UninitializedParameter` objects. When using LazyModules Call `forward` with a dummy batch to initialize the parameters before calling torch functions