In [1]:
import torch
from pytorch3d.ops.subdivide_meshes import SubdivideMeshes
from pytorch3d.structures.meshes import Meshes

import igl
from igl import doublearea,grad
import numpy as np
import pickle
import os

import scipy.sparse as sparse
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [None]:

def save_checkpoint(epoch, model, optimizer, scheduler, save_cur=False):
    logger.info('==> Saving...')
    state = {
        'save_path': '',
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'epoch': epoch,
    }

    if save_cur:
        state['save_path'] = os.path.join('.', f'ckpt_best_val.pth')
        torch.save(state, os.path.join('.', f'ckpt_best_val.pth'))
        logger.info("Saved in {}".format(os.path.join('.', f'ckpt_best_val.pth')))
    elif epoch % 100 == 0:
        state['save_path'] = os.path.join('.', f'ckpt_epoch_{epoch}.pth')
        torch.save(state, os.path.join('.', f'ckpt_epoch_{epoch}.pth'))
        logger.info("Saved in {}".format(os.path.join('.', f'ckpt_epoch_{epoch}.pth')))
    else:
        # state['save_path'] = 'current.pth'
        # torch.save(state, os.path.join(args.log_dir, 'current.pth'))
        print("not saving checkpoint")
        pass


In [3]:
import math
import pickle, gzip
import os

import torch
from torch import nn
from torch.nn.parameter import Parameter


# from utils import sparse2tensor, spmatmul
def sparse2tensor(m):
    """
    Convert sparse matrix (scipy.sparse) to tensor (torch.sparse)
    """
    assert(isinstance(m, sparse.coo.coo_matrix))
    i = torch.LongTensor(np.array([m.row, m.col]))
    v = torch.FloatTensor(m.data)
    return torch.sparse.FloatTensor(i, v, torch.Size(m.shape))

def spmatmul(den, sp):
    """
    den: Dense tensor of shape batch_size x in_chan x #V
    1, 40961, 4
    sp : Sparse tensor of shape newlen x #V
    """
    # grad_face = spmatmul(input, self.G)
    batch_size, in_chan, nv = list(den.size())
    new_len = sp.size()[0]
    den = den.permute(2, 1, 0).contiguous().view(nv, -1)
    res = torch.spmm(sp, den.double()).view(new_len, in_chan, batch_size).contiguous().permute(2, 1, 0)
    return res

class _MeshConv(nn.Module):
    def __init__(self, in_channels, out_channels, l_mesh_file,r_mesh_file, stride=1, bias=True):
        assert stride in [1, 2]
        super(_MeshConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.l_ncoeff = 2
        self.l_coeffs = Parameter(torch.Tensor(out_channels//2, in_channels//2, self.l_ncoeff))

        self.r_ncoeff = 2
        self.r_coeffs = Parameter(torch.Tensor(out_channels//2, in_channels//2, self.r_ncoeff))
        
        self.l_l_coeffs = Parameter(torch.Tensor(in_channels//2, in_channels//2))
        self.r_l_coeffs = Parameter(torch.Tensor(in_channels//2, in_channels//2))

        self.set_coeffs()
        # load mesh file
        l_pkl = pickle.load(open(l_mesh_file, "rb"))
        r_pkl = pickle.load(open(r_mesh_file, "rb"))
        # l_pkl = pickle.load(open(l_mesh_file, "rb"))
        
        self.l_pkl = l_pkl
        self.r_pkl = r_pkl


        self.l_nv = self.l_pkl['V'].shape[0]
        l_G = sparse2tensor(pickle.load(open(l_mesh_file, "rb"))['G'].tocoo())  # gradient matrix V->F, 3#F x #V
        # G = torch.tensor(pkl['G'])
        l_NS = torch.tensor(l_pkl['NS'], dtype=torch.float32)  # north-south vector field, #F x 3
        l_EW = torch.tensor(l_pkl['EW'], dtype=torch.float32)  # east-west vector field, #F x 3
        self.register_buffer("l_G", l_G)
        self.register_buffer("l_NS", l_NS)
        self.register_buffer("l_EW", l_EW)
        
        self.r_nv = self.r_pkl['V'].shape[0]
        r_G = sparse2tensor(pickle.load(open(r_mesh_file, "rb"))['G'].tocoo())  # gradient matrix V->F, 3#F x #V
        # G = torch.tensor(pkl['G'])
        r_NS = torch.tensor(r_pkl['NS'], dtype=torch.float32)  # north-south vector field, #F x 3
        r_EW = torch.tensor(r_pkl['EW'], dtype=torch.float32)  # east-west vector field, #F x 3
        self.register_buffer("r_G", r_G)
        self.register_buffer("r_NS", r_NS)
        self.register_buffer("r_EW", r_EW)
        

        
    def set_coeffs(self):
        n = self.in_channels * self.l_ncoeff

        stdv = 1. / math.sqrt(n)
        self.l_coeffs.data.uniform_(-stdv, stdv)
        self.r_coeffs.data.uniform_(-stdv, stdv)
        self.l_l_coeffs.data.uniform_(-stdv, stdv)
        self.r_l_coeffs.data.uniform_(-stdv, stdv)

        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

class MeshConv(_MeshConv):
    def __init__(self, in_channels, out_channels, l_mesh_file,r_mesh_file, stride=1, bias=True, agg = False):
        super(MeshConv, self).__init__(in_channels, out_channels, l_mesh_file,r_mesh_file, stride, bias)
        l_pkl = self.l_pkl
        r_pkl = self.r_pkl

        if stride == 2:
            self.nv_prev = pkl['nv_prev']
            L = sparse2tensor(pkl['L'].tocsr()[:self.nv_prev].tocoo()) # laplacian matrix V->V
            F2V = sparse2tensor(pkl['F2V'].tocsr()[:self.nv_prev].tocoo())  # F->V, #V x #F
        else: # stride == 1

            self.l_nv_prev = l_pkl['V'].shape[0]
            # L = sparse2tensor(pkl['L'].tocoo())
            l_L = l_pkl['L']
            l_F2V = sparse2tensor(l_pkl['F2V'].tocoo())

            self.r_nv_prev = r_pkl['V'].shape[0]
            r_L = r_pkl['L']
            # L = sparse2tensor(pkl['L'].tocoo())
            r_F2V = sparse2tensor(r_pkl['F2V'].tocoo())
        
        self.register_buffer("l_L", l_L)
        self.register_buffer("l_F2V", l_F2V)
        
        self.register_buffer("r_L", r_L)
        self.register_buffer("r_F2V", r_F2V)
        # self.embedding = ConvBNReLU1D(out_channels, out_channels)
        # self.C0_1_r = SimplicialConvolution(2, in_channels//2,in_channels//2, variance= 0.01)
        # self.C0_1_l = SimplicialConvolution(2, in_channels//2,in_channels//2, variance= 0.01)
        self.agg = agg
    def forward(self, input):
        # compute gradient
        self.l_G = self.l_G.double()
        # self.l_L = self.l_L.float()
        self.l_L = self.l_L.double()

        self.l_EW =  self.l_EW.double()
        self.l_NS = self.l_NS.double()
        self.l_F2V = self.l_F2V.double()
        self.r_G = self.r_G.double()
        self.r_L = self.r_L.double()
        self.r_EW =  self.r_EW.double()
        self.r_NS = self.r_NS.double()
        self.r_F2V = self.r_F2V.double()

        half_dim = input.shape[1]//2
        l_grad_face = spmatmul(input[:,:half_dim ,:], self.l_G)
        l_grad_face = l_grad_face.view(*(input[:,:half_dim ,:].size()[:2]), 3, -1).permute(0, 1, 3, 2) # gradient, 3 component per face
        l_laplacian = spmatmul(input[:,:half_dim ,:], self.l_L)
        l_identity = input[:,:half_dim ,:][..., :self.l_nv_prev]
        l_feat = [l_identity, l_laplacian]

        r_grad_face = spmatmul(input[:,half_dim: ,:], self.r_G)
        r_grad_face = r_grad_face.view(*(input[:,half_dim: ,:].size()[:2]), 3, -1).permute(0, 1, 3, 2) # gradient, 3 component per face
        r_laplacian = spmatmul(input[:,half_dim: ,:], self.r_L)
        r_identity = input[:,half_dim: ,:][..., :self.r_nv_prev]
        r_feat = [r_identity, r_laplacian]


        out = torch.stack(l_feat, dim=-1)
        out2 = torch.stack(r_feat, dim=-1)
        out = torch.sum(torch.sum(torch.mul(out.unsqueeze(1), self.l_coeffs.unsqueeze(2)), dim=2), dim=-1)
        out2 = torch.sum(torch.sum(torch.mul(out2.unsqueeze(1), self.r_coeffs.unsqueeze(2)), dim=2), dim=-1)
        if self.agg:
            out = spmatmul(out, self.l_L)
            out2 = spmatmul(out2, self.r_L)
        
        out = torch.cat([out, out2],dim= 1)
        out += self.bias.unsqueeze(-1)
        return out.float()


class MeshConv_transpose(_MeshConv):
    def __init__(self, in_channels, out_channels, mesh_file, stride=2, bias=True):
        assert(stride == 2)
        super(MeshConv_transpose, self).__init__(in_channels, out_channels, mesh_file, stride, bias)
        pkl = self.pkl
        self.nv_prev = self.pkl['nv_prev']
        self.nv_pad = self.nv - self.nv_prev
        L = sparse2tensor(pkl['L'].tocoo()) # laplacian matrix V->V
        F2V = sparse2tensor(pkl['F2V'].tocoo()) # F->V, #V x #F
        self.register_buffer("L", L)
        self.register_buffer("F2V", F2V)
        
    def forward(self, input):
        # pad input with zeros up to next mesh resolution
        ones_pad = torch.ones(*input.size()[:2], self.nv_pad).to(input.device)
        input = torch.cat((input, ones_pad), dim=-1)
        # compute gradient
        grad_face = spmatmul(input, self.G)
        grad_face = grad_face.view(*(input.size()[:2]), 3, -1).permute(0, 1, 3, 2) # gradient, 3 component per face
        laplacian = spmatmul(input, self.L)
        identity = input
        grad_face_ew = torch.sum(torch.mul(grad_face, self.EW), keepdim=False, dim=-1)
        grad_face_ns = torch.sum(torch.mul(grad_face, self.NS), keepdim=False, dim=-1)
        grad_vert_ew = spmatmul(grad_face_ew, self.F2V)
        grad_vert_ns = spmatmul(grad_face_ns, self.F2V)

        feat = [identity, laplacian, grad_vert_ew, grad_vert_ns]
        out = torch.stack(feat, dim=-1)
        out = torch.sum(torch.sum(torch.mul(out.unsqueeze(1), self.coeffs.unsqueeze(2)), dim=2), dim=-1)
        out += self.bias.unsqueeze(-1)
        return out


In [None]:

from sklearn.metrics import confusion_matrix
import logging
def train(model, device, train_loader, optimizer, epoch, logger):
    model.train()
    train_loss = 0
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        data= data.permute(0,2,1)
        output = model(data)
        pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        # loss = F.nll_loss(output, torch.flatten(target))
        loss = F.cross_entropy(output, target)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
    train_loss /= len(train_loader.dataset)
    logger.info('Train set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%) \r'.format(
                train_loss, correct, len(train_loader.dataset)-5,
                100. * correct / (len(train_loader.dataset)-5)))
    

def val(model, device, test_loader, logger, best):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data= data.permute(0,2,1)
            output = model(data)
            test_loss += F.cross_entropy(output, target).item()
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

            
    test_loss /= len(test_loader.dataset)
    logger.info('Val set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%) \r'.format(
                test_loss, correct, len(test_loader.dataset),
                100. * correct / len(test_loader.dataset)))

    best.append(correct / len(test_loader.dataset))
    logger.info('Best Val Accuracy:({:.4f}%) \r'.format(100. * max(best)))
    acc = correct / len(test_loader.dataset)
    return best, acc, test_loss

def test(model, device, test_loader, logger, best, append = False):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data= data.permute(0,2,1)
            output = model(data)
            test_loss += F.cross_entropy(output, target).item()
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            
            
            # Calculate the confusion matrix
            conf_matrix = confusion_matrix(torch.flatten(target).cpu().numpy(), torch.flatten(pred).cpu().numpy())
            # Extract values from the confusion matrix
            true_negatives, false_positives, false_negatives, true_positives = conf_matrix.ravel()
            # Calculate sensitivity (recall)
            sensitivity = true_positives / (true_positives + false_negatives)
            specifity = true_negatives / (true_negatives + false_positives)

            # Print the confusion matrix and sensitivity
            print("Sensitivity (Recall):", sensitivity)
            print("Sensitivity (Recall):", specifity)

    test_loss /= len(test_loader.dataset)
    acc = correct / len(test_loader.dataset)
    
    # if acc >= 0.89: 
    logger.info('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%) \r'.format(
                    test_loss, correct, len(test_loader.dataset),
                    100. * correct / len(test_loader.dataset)))

    if append:
            best.append(correct / len(test_loader.dataset))
            logger.info('Best Test Accuracy:({:.4f}%) \r'.format(100. * max(best)))
            
    return best, acc

In [None]:
 

import torch
import torch.nn as nn
import torch.nn.functional as F
import sys; 
import os
import numpy as np

class MaxPool(nn.Module):
    def __init__(self, level, nv_prev):
        super().__init__()
        self.level = level
        self.nv_prev = nv_prev

        if self.level > 0:

            neihboring_patches_file = os.path.join("/home/justin/Mesh/Matrices_out/l_%d_patches.npy" % level)
            self.neihboring_patches = np.load(neihboring_patches_file)

    def forward(self, x):
        tmp = x[...,:self.nv_prev]
        out, _ = torch.max(tmp[:, :, self.neihboring_patches], -1)

        return out

class DownSamp(nn.Module):
    def __init__(self, nv_prev, lvl, in_chan,hami,e2n = False, down = True):
        super().__init__()
        self.nv_prev = nv_prev
        self.lvl = lvl + 1
        self.e2n = e2n
        if e2n:
            if down:
                #face-lvel 
                self.e2f = nn.Linear(in_chan, 2*in_chan, bias=True)
                self.ff = nn.Linear(in_chan, 2*in_chan, bias=True)
                self.f = nn.Linear(in_chan, 2*in_chan, bias=True)

                #edge-level
                self.ee1 = nn.Linear(in_chan, 2*in_chan, bias=True)
                self.ee2 = nn.Linear(in_chan, 2*in_chan, bias=True)
                self.f2e = nn.Linear(2*in_chan, 2*in_chan, bias=True)
                self.n2e = nn.Linear(in_chan, 2*in_chan, bias=True)
                self.e = nn.Linear(in_chan, 2*in_chan, bias=True)

                #node-level
                self.nn = nn.Linear(in_chan, 2*in_chan, bias=True)
                self.e22n = nn.Linear(2*in_chan, 2*in_chan, bias=True)
                self.n = nn.Linear(in_chan, 2*in_chan, bias=True)


                self.f_bn1 = nn.BatchNorm1d(2*in_chan)
                self.e_bn1 = nn.BatchNorm1d(2*in_chan)
                self.n_bn1 = nn.BatchNorm1d(2*in_chan)
            else:
                self.e2f = nn.Linear(in_chan, in_chan, bias=True)
                self.ff = nn.Linear(in_chan, in_chan, bias=True)
                self.f = nn.Linear(in_chan, in_chan, bias=True)

                #edge-level
                self.ee1 = nn.Linear(in_chan, in_chan, bias=True)
                self.ee2 = nn.Linear(in_chan, in_chan, bias=True)
                self.f2e = nn.Linear(in_chan, in_chan, bias=True)
                self.n2e = nn.Linear(in_chan, in_chan, bias=True)
                self.e = nn.Linear(in_chan, in_chan, bias=True)

                #node-level
                self.nn = nn.Linear(in_chan, in_chan, bias=True)
                self.e22n = nn.Linear(in_chan, in_chan, bias=True)
                self.n = nn.Linear(in_chan, in_chan, bias=True)

                self.affine_a = nn.Parameter(torch.ones([1,in_chan, 1]))
                self.affine_b = nn.Parameter(torch.zeros([1,in_chan, 1]))
                
                self.f_bn1 = nn.BatchNorm1d(in_chan)
                self.e_bn1 = nn.BatchNorm1d(in_chan)
                self.n_bn1 = nn.BatchNorm1d(in_chan)
        
        self.hami = hami
        self.ch = in_chan
        self.down = down
    def forward(self, x):
        if self.e2n is True:
            # l = pickle.load(open('/home/justin/Mesh/surfaces/matrix/{}_lvl_{}.pkl'.format(self.hami,self.lvl), "rb"))
            l = pickle.load(open('/home/justin/Mesh/Matrices_out/{}_lvl_{}.pkl'.format(self.hami,self.lvl), "rb"))
            E = l['E']
            Fa = l['F']
            B1 = l["B1"].cuda()
            B1t = l["B1t"].cuda()
            B2 = l["B2"].cuda()
            B2t = l["B2t"].cuda()
            L1d = l['L1d'].cuda()
            L1u = l['L1u'].cuda()
            L2d = l['L2d'].cuda()
            L0u = l['L0u'].cuda()

            #PE node
            batch_size = x.shape[0]
            edge_features_list = []
            for b in range(batch_size):

                concatenated_features = ( 0.5 * x[b,:,E[:,0]].unsqueeze(0) ) +  ( 0.5 * x[b,:,E[:,1]].unsqueeze(0) )
                concatenated_features = concatenated_features.squeeze()
                edge_features_list.append(concatenated_features)

            edge_features = torch.stack(edge_features_list, dim=0)
            # edge_PE = self.pe3(torch.stack(PE_list, dim = 0).permute(0,2,1)).permute(0,2,1)

            face_features_list = []
            PE_list = []
            Fa= torch.tensor(Fa)
            for b in range(batch_size):

                concatenated_features =( 0.3 * x[b,:,Fa[:,0]].unsqueeze(0) ) + (0.3 * x[b,:,Fa[:,1]].unsqueeze(0)) + (0.3 * x[b,:,Fa[:,2]].unsqueeze(0))
                concatenated_features = concatenated_features.squeeze()
                face_features_list.append(concatenated_features)

            face_features = torch.stack(face_features_list, dim=0)
            # face_PE = self.pe4(torch.stack(PE_list, dim = 0).permute(0,2,1)).permute(0,2,1)            #face-level
            FF  =self.ff(face_features.permute(0,2,1)).permute(0,2,1)
            FF = spmatmul(FF, L2d.double())

            E2F = self.e2f(edge_features.permute(0,2,1)).permute(0,2,1)
            E2F = spmatmul(E2F, B2t.double())
        
            face_features = F.relu(self.f_bn1((FF + E2F +self.f(face_features.permute(0,2,1)).permute(0,2,1)).float()))
            #edge-level
            EE1  =self.ee1(edge_features.permute(0,2,1)).permute(0,2,1)
            EE1 = spmatmul(EE1, L1d.double())
            
            EE2  =self.ee2(edge_features.permute(0,2,1)).permute(0,2,1)
            EE2 = spmatmul(EE2, L1u.double())
            
            F2E = self.f2e(face_features.permute(0,2,1)).permute(0,2,1)
            F2E= spmatmul(F2E, B2.double())

            N2E = self.n2e(x.permute(0,2,1)).permute(0,2,1)
            N2E= spmatmul(N2E, B1t.double())

            edge_features =F.relu(self.e_bn1((EE1 + EE2 + F2E + N2E +self.e(edge_features.permute(0,2,1)).permute(0,2,1)).float()))

            #node-level
            NN  =self.nn(x.permute(0,2,1)).permute(0,2,1)
            NN = spmatmul(NN, L0u.double())

            E2N = self.e22n(edge_features.permute(0,2,1)).permute(0,2,1)
            E2N = spmatmul(E2N, B1.double())

            if self.down:
                final_node = F.relu(self.n_bn1((NN + E2N + self.n(x.permute(0,2,1)).permute(0,2,1)).float()) )
            else:
                final_node = F.relu(self.n_bn1((NN + E2N + self.n(x.permute(0,2,1)).permute(0,2,1)).float())  + self.affine_a * x + self.affine_b)

            return final_node[...,:self.nv_prev].float() 

        return x[..., :self.nv_prev] 

class ResBlock(nn.Module):
    def __init__(self, in_chan, neck_chan, out_chan, level, coarsen, mesh_folder, e2n = False):
        super().__init__()

        in_chan = int(in_chan//2)
        neck_chan = int(neck_chan//2)
        out_chan = int(out_chan//2)

        l = level-1 if coarsen else level
        self.coarsen = coarsen
        mesh_file = os.path.join(mesh_folder, "left_icosphere_{}.pkl".format(l))
        mesh_file2 = os.path.join(mesh_folder, "right_icosphere_{}.pkl".format(l))
        self.conv2 = MeshConv(neck_chan, neck_chan, mesh_file,mesh_file2, stride=1)
        self.relu = nn.ReLU(inplace=True)
        self.nv_prev = self.conv2.l_nv_prev

        self.down1 = DownSamp(self.nv_prev,l,in_chan,'l',e2n, down = False)
        self.down11 = DownSamp(self.nv_prev,l-1,in_chan,'l',e2n, down = False)
        self.down111 = DownSamp(self.nv_prev,l-1,in_chan,'l',e2n)
        self.down3 = DownSamp(self.nv_prev,l,in_chan,'r',e2n, down = False)
        self.down33 = DownSamp(self.nv_prev,l-1,in_chan,'r',e2n, down = False)
        self.down333 = DownSamp(self.nv_prev,l-1,in_chan,'r',e2n)

        self.diff_chan = (in_chan != out_chan)

        if coarsen:
            self.seq1 = nn.Sequential(self.down1, self.down11,self.down111)
        else:
            self.seq1 = nn.Sequential(self.conv1, self.bn1, self.relu, 
                                      self.conv2, self.bn2, self.relu, 
                                      self.conv3, self.bn3)

        if coarsen:
            self.r_seq1 = nn.Sequential(self.down3, self.down33,self.down333)
        else:
            self.r_seq1 = nn.Sequential(self.r_conv1, self.r_bn1, self.r_relu, 
                                      self.r_conv2, self.r_bn2, self.r_relu, 
                                      self.r_conv3, self.r_bn3)


    def forward(self, x):
        
        dim= x.shape[1]
        r_x, l_x  = x[:,int(dim/2):,:],x[:,:int(dim/2),:]
        cut = l_x.shape[-1]
        x1 = self.seq1(l_x)
        r_x1 = self.r_seq1(r_x)

        x1 = torch.cat([x1,r_x1],axis = 1)

        out = x1 
        out = self.relu(out)
        
        return out
class ConvBNReLU1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'):
        super(ConvBNReLU1D, self).__init__()
        # self.act = nn.ReLU(inplace=True)
        self.net = nn.Sequential(
            nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
            # nn.BatchNorm1d(out_channels),
            # self.act
        )

    def forward(self, x):
        return self.net(x)


def make_weights(labels, nclasses):
    labels = np.array(labels) 
    weight_arr = np.zeros_like(labels) 
    
    _, counts = np.unique(labels, return_counts=True) 
    for cls in range(nclasses):
        weight_arr = np.where(labels == cls, 1/counts[cls], weight_arr) 

 
    return weight_arr

class Model(nn.Module):
    def __init__(self, mesh_folder, feat=16, nclasses=10):
        super().__init__()
        mf_l = os.path.join(mesh_folder, "left_icosphere_6.pkl")
        mf_r = os.path.join(mesh_folder, "right_icosphere_6.pkl")

        self.in_conv = MeshConv(2, 2*feat, l_mesh_file=mf_l,r_mesh_file=mf_r, stride=1, agg = False)
        self.in_bn = nn.BatchNorm1d(2*feat)
        self.relu = nn.LeakyReLU(inplace=True)
        self.in_block = nn.Sequential(self.in_conv, self.in_bn, self.relu)

        self.maxp = MaxPool(5, 10242)

        L_relu = nn.LeakyReLU()
        self.block2 = ResBlock(in_chan=2*feat, neck_chan= 4*feat, out_chan=4*feat, level=5, coarsen=True, mesh_folder=mesh_folder, e2n = True)
        self.block3 = ResBlock(in_chan=4*feat, neck_chan= 8*feat, out_chan=8*feat, level=4, coarsen=True, mesh_folder=mesh_folder, e2n = True)
        self.block4 = ResBlock(in_chan=8*feat, neck_chan=16*feat, out_chan=16*feat, level=3, coarsen=True, mesh_folder=mesh_folder, e2n = True)
        self.block5 = ResBlock(in_chan=16*feat, neck_chan=32*feat, out_chan=32*feat, level=2, coarsen=True, mesh_folder=mesh_folder, e2n = True)

        self.avg = nn.AvgPool1d(kernel_size=self.block5.nv_prev) # output shape batch x channels x 1
        # self.out_layer = nn.Linear(64*feat, 16*feat)
        # self.out_layer2 = nn.Linear(16*feat, 2)
        self.D = nn.Sequential(nn.Linear(32*feat,32*feat),L_relu,nn.Linear(32*feat,8*feat),L_relu,nn.Linear(8*feat,2)) #nn.Softmax(dim=0) for multi-class

    def forward(self, x):

        x = self.in_block(x)
        # x = x[...,:10242]
        x = self.maxp(x).float() 

        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        # x = self.block6(x)

        x_l = torch.squeeze(self.avg(x[:, :x.shape[1]//2,: ]))
        x_r = torch.squeeze(self.avg(x[:, x.shape[1]//2:,: ]))
        x= torch.cat([x_l,x_r], dim = 1)
        # x=x_l+x_r
        x = self.D(x)
        # x = F.dropout(x,training=self.training)
        # x = self.out_layer(x)
        # x = F.dropout(x,training=self.training)
        # x = self.out_layer2(x)
        # return F.log_softmax(x, dim=1)
        return torch.sigmoid(x)
        # return x



import torch
import random
from sklearn.model_selection import train_test_split, StratifiedKFold

random_seed = 54

torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)



# X_trainval, X_test, y_trainval, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
# KFold(n_splits = k, shuffle = True, random_state = 1)

k = 5     # k-cross validation에서 k의 값이 5
test_bests = []
val_test= []
# skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=65)
skf = StratifiedShuffleSplit(n_splits=k, test_size = 0.2, train_size = 0.8, random_state = 42)
for i, (trainval_idx, test_idx) in enumerate(skf.split(X, y)):
    # Split data into training and test sets for the current fold
    X_trainval, X_test = X[trainval_idx], X[test_idx]
    y_trainval, y_test = y[trainval_idx], y[test_idx]
    
    # Further split the training set into training and validation sets
    print(len(y_test), len(X))
    val_po = (len(X) * 0.1) / (len(X) * 0.8)
    train_po = 1 - val_po
    
    skf = StratifiedShuffleSplit(n_splits=1, test_size = val_po, train_size = train_po)

    for i, (train_idx, val_idx) in enumerate(skf.split(X_trainval, y_trainval)):
        X_train, X_val = X_trainval[train_idx], X_trainval[val_idx]
        y_train, y_val = y_trainval[train_idx], y_trainval[val_idx]


    from torch.utils.data import DataLoader
    train_list = []
    for i,j in enumerate(list(X_train)):
        train_list.append([j,list(y_train)[i]])

    val_list = []
    for i,j in enumerate(list(X_val)):
        val_list.append([j,list(y_val)[i]])


    test_list = []
    for i,j in enumerate(list(X_test)):
        test_list.append([j,list(y_test)[i]])

    print(len(train_list), len(val_list), len(test_list))

    weights = make_weights(y_train, 2)
    weights = torch.DoubleTensor(weights)
    print(weights)
    sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))

    train_loader = DataLoader(train_list, batch_size = 32, shuffle = False, sampler = sampler)
    val_loader = DataLoader(val_list, batch_size = len(val_list))
    test_loader = DataLoader(test_list, batch_size = len(test_list))


    # train_loader
    train_features, train_labels = next(iter(train_loader))
    print(f"Feature batch shape: {train_features.size()}")
    print(f"Labels batch shape: {train_labels.size()}")

    val_features, val_labels = next(iter(val_loader))
    print(f"Feature batch shape: {val_features.size()}")
    print(f"Labels batch shape: {val_labels.size()}")

    test_features, test_labels = next(iter(test_loader))
    print(f"Feature batch shape: {test_features.size()}")
    print(f"Labels batch shape: {test_labels.size()}")

    
    from torch.utils.data import DataLoader
    import torch.optim as optim
    from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
    from torch.optim.lr_scheduler import CosineAnnealingLR

    %cd /home/justin/Mesh
    #Model parameters: template path, feature dimension, output class number
    model = Model(".",16, 2)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    # optimizer = optim.SGD(model.parameters(), lr = 0.05, momentum = 0.9)
    # scheduler = CosineAnnealingLR(optimizer,
                                    # T_max = 100, # Maximum number of iterations.
                                    # eta_min = 0.01 / 50) # Minimum learning rate.
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.7)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3,9,12,15,18], gamma=0.6)
    logger = logging.getLogger("train")
    logger.setLevel(logging.DEBUG)
    logger.handlers = []
    ch = logging.StreamHandler()
    logger.addHandler(ch)
    fh = logging.FileHandler(os.path.join('.', "0911last_decay.txt"))
    logger.addHandler(fh)
    test_acc = None
    logger.info("{} paramerters in total".format(sum(x.numel() for x in model.parameters())))
    best = []
    tbest = []
    best_loss = 0
    tolerance = 0

    for epoch in range(1, 50 + 1):
            logger.info("[Epoch {}]".format(epoch))
            train(model, device, train_loader, optimizer, epoch, logger)
            #validate every 100 epoch
            # if epoch % 50 == 0:
            best, acc, val_loss= val( model, device, val_loader, logger,best)
            # if test_acc is not None:
                    #   print(test_acc)
            if acc >= best_loss:
                best_loss = acc
                tolerance = 0
                # save_checkpoint(epoch,model, optimizer, scheduler, save_cur=True)

            else:
                tolerance += 1
                # save_checkpoint(epoch,model, optimizer, scheduler, save_cur=True)
            print(tolerance)
            if acc == max(best):
                    # save_checkpoint(epoch,model, optimizer, scheduler, save_cur=True)
                    t_best, test_acc= test( model, device, test_loader, logger, tbest, append = True)
                    vals = [acc, test_acc]


            else:
                    t_best, test_acc= test( model, device, test_loader, logger, tbest, append = False)
            if epoch <= 8:
                tolerance = 0
                
            if tolerance <= 8:
                # break
                save_checkpoint(epoch,model, optimizer, scheduler, save_cur=True)
            else:
                break
            scheduler.step()
            # if epoch == 15:
            #     break
    val_test.append(vals)
    test_bests.append(tbest)


In [None]:
import pickle
import torch
def sparse_diag_identity(n):
    i = [i for i in range(n)]
    return torch.sparse_coo_tensor(torch.tensor([i, i]), torch.ones(n))
def sparse_diag(tensor):
    i = [i for i in range(tensor.shape[0])]
    return torch.sparse_coo_tensor(torch.tensor([i, i]), tensor)

l = pickle.load(open('/home/justin/Mesh/Matrices_out/{}_lvl_{}.pkl'.format('l',6), "rb"))

L1 = torch.abs(l['L1d'] + l['L1u'])
L2 = torch.abs(l['L2d'])
L0 = torch.abs(l['L0u'])
# E = l['E']
# Fa = l['F']
B1 = l["B1"]
B1t = l["B1t"]
B2 = l["B2"]
B2t = l["B2t"]

x0 = 40962
x1 = 122880
x2 = 81920

B1_v_abs, B1_i = torch.abs(B1.coalesce().values()), B1.coalesce().indices()
B1_sum = torch.sparse.sum(torch.sparse_coo_tensor(B1_i, B1_v_abs, (x0, x1)), dim=1)
B1_sum_values = B1_sum.to_dense()
B1_sum_indices = torch.tensor([i for i in range(x0)])
d0_diag_indices = torch.stack([B1_sum_indices, B1_sum_indices], dim=0)
d0 = torch.sparse_coo_tensor(d0_diag_indices, B1_sum_values, (x0, x0))
B1_sum_inv_values = torch.nan_to_num(1. / B1_sum_values, nan=0., posinf=0., neginf=0.)
d0_inv = torch.sparse_coo_tensor(d0_diag_indices, B1_sum_inv_values, (x0, x0))
L0 = torch.sparse.mm(L0, d0_inv)
L0_factor_values = -1 / (B1_sum_inv_values + 1)
L0_factor = torch.sparse_coo_tensor(d0_diag_indices, L0_factor_values, (x0, x0))
L0_bias_values = torch.ones(d0.shape[0])
L0bias = torch.sparse_coo_tensor(d0_diag_indices, L0_bias_values, (x0, x0))
L0 = L0bias + torch.sparse.mm(L0, L0_factor)

D1_inv = torch.sparse_coo_tensor(d0_diag_indices, 0.5 * B1_sum_inv_values, (x0, x0))
B2_v_abs, B2_i = torch.abs(B2.coalesce().values()), B2.coalesce().indices()
D2diag_1 = torch.sparse.sum(torch.sparse_coo_tensor(B2_i, B2_v_abs, (x1, x1)), dim=1).to_dense()
D2diag = torch.maximum(D2diag_1, torch.ones(D2diag_1.shape[0]))
D2_indices = [i for i in range(D2diag.shape[0])]
D2_indices = torch.tensor([D2_indices, D2_indices])
D2 = torch.sparse_coo_tensor(D2_indices, D2diag, (x1, x1))
D2_inv = torch.sparse_coo_tensor(D2_indices, 1 / D2diag, (x1, x1))
D3_values = (1 / 3.) * torch.ones(B2.shape[1])
D3_indices = [i for i in range(B2.shape[1])]
D3_indices = torch.tensor([D3_indices, D3_indices])
D3 = torch.sparse_coo_tensor(D3_indices, D3_values, (x2, x2))
A_1u = D2 - torch.sparse.mm(torch.sparse.mm(B2, D3), B2.t())
A_1d = D2_inv - torch.sparse.mm(torch.sparse.mm(B1.t(), D1_inv), B1)
A_1u_norm = torch.sparse.mm((sparse_diag_identity(A_1u.shape[0]) + A_1u),
                                    (torch.sparse_coo_tensor(D2_indices, 1 / (D2diag + 1), (x1, x1))))
A_1d_norm = torch.sparse.mm((D2 + sparse_diag_identity(D2.shape[0])),
                                    (A_1d + sparse_diag_identity(A_1d.shape[0])))
L1 = A_1u_norm + A_1d_norm

B2_sum = D2diag_1
B2_sum_inv = 1 / (B2_sum + 1)
D5inv = sparse_diag(B2_sum_inv)
A_2d = sparse_diag_identity(B2.shape[1]) + torch.sparse.mm(torch.sparse.mm(B2.t(), D5inv), B2)
A_2d_norm = torch.sparse.mm((2 * sparse_diag_identity(B2.shape[1])),
                                    (A_2d + sparse_diag_identity(A_2d.shape[0])))
L2 = A_2d_norm

B2D3 = torch.sparse.mm(B2, D3)
D2B1TD1inv = (1 / np.sqrt(2.)) * torch.sparse.mm(torch.sparse.mm(D2, B1.t()), D1_inv)
D1invB1 = (1 / np.sqrt(2.)) * torch.sparse.mm(D1_inv, B1)
B2TD2inv = torch.sparse.mm(B2.t(), D5inv)

In [None]:
 

import torch
import torch.nn as nn
import torch.nn.functional as F
import sys; 
import os
import pickle
import torch
def sparse_diag_identity(n):
    i = [i for i in range(n)]
    return torch.sparse_coo_tensor(torch.tensor([i, i]), torch.ones(n))
def sparse_diag(tensor):
    i = [i for i in range(tensor.shape[0])]
    return torch.sparse_coo_tensor(torch.tensor([i, i]), tensor)
class MaxPool(nn.Module):
    def __init__(self, level, nv_prev):
        super().__init__()
        self.level = level
        self.nv_prev = nv_prev

        if self.level > 0:

            neihboring_patches_file = os.path.join("./Matrices_out/l_%d_patches.npy" % level)
            self.neihboring_patches = np.load(neihboring_patches_file)

    def forward(self, x):
        out = 0.5 * torch.max(x[:, :, self.neihboring_patches], -1)[0] + x[...,:self.nv_prev] * 0.5

        return out


class DownSamp(nn.Module):
    def __init__(self, nv_prev, lvl, in_chan,hami,e2n = False, down = True):
        super().__init__()
        self.nv_prev = nv_prev
        self.lvl = lvl + 1
        self.e2n = e2n
        self.down = down
        if e2n:
            if down:
                #face-lvel 
                self.e2f = nn.Linear(in_chan, 2*in_chan, bias=True)
                self.ff = nn.Linear(in_chan, 2*in_chan, bias=True)
                self.f = nn.Linear(in_chan, 2*in_chan, bias=True)

                #edge-level
                self.ee1 = nn.Linear(in_chan, 2*in_chan, bias=True)
                self.ee2 = nn.Linear(in_chan, 2*in_chan, bias=True)
                self.f2e = nn.Linear(2*in_chan, 2*in_chan, bias=True)
                self.n2e = nn.Linear(in_chan, 2*in_chan, bias=True)
                self.e = nn.Linear(in_chan, 2*in_chan, bias=True)

                #node-level
                self.nn = nn.Linear(in_chan, 2*in_chan, bias=True)
                self.e22n = nn.Linear(2*in_chan, 2*in_chan, bias=True)
                self.n = nn.Linear(in_chan, 2*in_chan, bias=True)

                self.affine_a = nn.Parameter(torch.ones([1,2*in_chan, 1]))
                self.affine_b = nn.Parameter(torch.zeros([1,2*in_chan, 1]))
                
                self.f_bn1 = nn.BatchNorm1d(2*in_chan)
                self.e_bn1 = nn.BatchNorm1d(2*in_chan)
                self.n_bn1 = nn.BatchNorm1d(2*in_chan)

                self.pool = MaxPool(lvl,(30 * (4 ** lvl) - 20 * (4**lvl) + 2))

            else:
                self.e2f = nn.Linear(in_chan, in_chan, bias=True)
                self.ff = nn.Linear(in_chan, in_chan, bias=True)
                self.f = nn.Linear(in_chan, in_chan, bias=True)

                #edge-level
                self.ee1 = nn.Linear(in_chan, in_chan, bias=True)
                self.ee2 = nn.Linear(in_chan, in_chan, bias=True)
                self.f2e = nn.Linear(in_chan, in_chan, bias=True)
                self.n2e = nn.Linear(in_chan, in_chan, bias=True)
                self.e = nn.Linear(in_chan, in_chan, bias=True)

                #node-level
                self.nn = nn.Linear(in_chan, in_chan, bias=True)
                self.e22n = nn.Linear(in_chan, in_chan, bias=True)
                self.n = nn.Linear(in_chan, in_chan, bias=True)

                self.affine_a = nn.Parameter(torch.ones([1,in_chan, 1]))
                self.affine_b = nn.Parameter(torch.zeros([1,in_chan, 1]))

                self.f_bn1 = nn.BatchNorm1d(in_chan)
                self.e_bn1 = nn.BatchNorm1d(in_chan)
                self.n_bn1 = nn.BatchNorm1d(in_chan)
        
        self.hami = hami
        self.ch = in_chan
        
    def forward(self, x):
        if self.e2n is True:

            l = pickle.load(open('/home/justin/Mesh/Mesh/Norm_M/{}_lvl_{}.pkl'.format(self.hami,self.lvl), "rb"))

            E = l['E']
            Fa = l['F']
           
            L1d = torch.abs(l['L1dn']).cuda()
            L1u = torch.abs(l['L1un']).cuda()
            L2d = torch.abs(l['L2dn']).cuda()
            L0u = torch.abs(l['L0un']).cuda()
            B2D3 = torch.abs(l['B2D3']).cuda()
            D2B1TD1inv= torch.abs(l[ 'D2B1TD1inv']).cuda()
            D1invB1= torch.abs(l['D1invB1']).cuda()
            B2TD2inv= torch.abs(l['B2TD2inv']).cuda()
            L0 = torch.abs(L0u.cuda())
            L2 = torch.abs(L2d.cuda())

            batch_size = x.shape[0]
            edge_features_list = []
            # PE_list = []
            for b in range(batch_size):
                concatenated_features = ( 0.5 * x[b,:,E[:,0]].unsqueeze(0) ) +  ( 0.5 * x[b,:,E[:,1]].unsqueeze(0) )
                concatenated_features = concatenated_features.squeeze()
                edge_features_list.append(concatenated_features)

            edge_features = torch.stack(edge_features_list, dim=0)

            face_features_list = []
            PE_list = []
            Fa= torch.tensor(Fa)
            for b in range(batch_size):

                concatenated_features =( 0.3 * x[b,:,Fa[:,0]].unsqueeze(0) ) + (0.3 * x[b,:,Fa[:,1]].unsqueeze(0)) + (0.3 * x[b,:,Fa[:,2]].unsqueeze(0))
                concatenated_features = concatenated_features.squeeze()
                face_features_list.append(concatenated_features)

            face_features = torch.stack(face_features_list, dim=0)
            FF  =self.ff(face_features.permute(0,2,1)).permute(0,2,1)
            FF = spmatmul(FF, L2.double())

            E2F = self.e2f(edge_features.permute(0,2,1)).permute(0,2,1)
            E2F = spmatmul(E2F, B2TD2inv.double())
        
            face_features = (1/2.) * F.leaky_relu(self.f_bn1((FF + E2F +self.f(face_features.permute(0,2,1)).permute(0,2,1)).float()))
            #edge-level
            EE1  =self.ee1(edge_features.permute(0,2,1)).permute(0,2,1)
            EE1 = spmatmul(EE1, L1d.double())
            
            EE2  =self.ee2(edge_features.permute(0,2,1)).permute(0,2,1)
            EE2 = spmatmul(EE2, L1u.double())
            
            F2E = self.f2e(face_features.permute(0,2,1)).permute(0,2,1)
            F2E= spmatmul(F2E, B2D3.double())

            N2E = self.n2e(x.permute(0,2,1)).permute(0,2,1)
            N2E= spmatmul(N2E, D2B1TD1inv.double())

            edge_features = (1/3.)*F.leaky_relu(self.e_bn1((EE1 + EE2 + F2E + N2E +self.e(edge_features.permute(0,2,1)).permute(0,2,1)).float()))

            #node-level
            NN  =self.nn(x.permute(0,2,1)).permute(0,2,1)
            NN = spmatmul(NN, L0.double())

            E2N = self.e22n(edge_features.permute(0,2,1)).permute(0,2,1)
            E2N = spmatmul(E2N, D1invB1.double())


            final_node = (1/2.) * F.leaky_relu(self.n_bn1(self.affine_a*(NN + E2N + self.n(x.permute(0,2,1)).permute(0,2,1)).float()+self.affine_b) )
            if self.down:
                return self.pool(final_node)

            else:
                return final_node.float()


        return x[..., :self.nv_prev] 

class ResBlock(nn.Module):
    def __init__(self, in_chan, neck_chan, out_chan, level, coarsen, mesh_folder, e2n = False):
        super().__init__()

        in_chan = int(in_chan//2)
        neck_chan = int(neck_chan//2)
        out_chan = int(out_chan//2)

        l = level-1 if coarsen else level
        self.coarsen = coarsen
        mesh_file = os.path.join(mesh_folder, "left_icosphere_{}.pkl".format(l))
        mesh_file2 = os.path.join(mesh_folder, "right_icosphere_{}.pkl".format(l))
        self.conv2 = MeshConv(neck_chan, neck_chan, mesh_file,mesh_file2, stride=1)
        self.relu = nn.ReLU(inplace=True)

        self.nv_prev = self.conv2.l_nv_prev
        self.down1 = DownSamp(self.nv_prev,l,in_chan,'l',e2n, down = False)
        self.down11 = DownSamp(self.nv_prev,l,in_chan,'l',e2n, down = False)
        self.down111 = DownSamp(self.nv_prev,l,in_chan,'l',e2n)
        self.down3 = DownSamp(self.nv_prev,l,in_chan,'r',e2n, down = False)
        self.down33 = DownSamp(self.nv_prev,l,in_chan,'r',e2n, down = False)
        self.down333 = DownSamp(self.nv_prev,l,in_chan,'r',e2n)

        self.diff_chan = (in_chan != out_chan)

        if coarsen:
            self.seq1 = nn.Sequential(self.down1, self.down11, self.down111)

        if coarsen:
            self.r_seq1 = nn.Sequential(self.down3,self.down33,self.down333)

    def forward(self, x):
        
        dim= x.shape[1]
        l_x,r_x  = x[:,:int(dim/2),:],x[:,int(dim/2):,:]

        x1 = self.seq1(l_x)
        r_x1 = self.r_seq1(r_x)

        x1 = torch.cat([x1,r_x1],axis = 1)

        out = x1 
        
        return out
class ConvBNReLU1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'):
        super(ConvBNReLU1D, self).__init__()
        # self.act = nn.ReLU(inplace=True)
        self.net = nn.Sequential(
            nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
            # nn.BatchNorm1d(out_channels),
            # self.act
        )

    def forward(self, x):
        return self.net(x)
class Model(nn.Module):
    def __init__(self, mesh_folder, feat=16, nclasses=10):
        super().__init__()
        mf_l = os.path.join(mesh_folder, "left_icosphere_6.pkl")
        mf_r = os.path.join(mesh_folder, "right_icosphere_6.pkl")

        self.in_conv = MeshConv(8, 2*feat, l_mesh_file=mf_l,r_mesh_file=mf_r, stride=1, agg = False)
        self.in_bn = nn.BatchNorm1d(2*feat)
        self.relu = nn.LeakyReLU(inplace=True)
        self.in_block = nn.Sequential(self.in_conv, self.in_bn, self.relu)

        self.maxp = MaxPool(5, 10242)
        self.block2 = ResBlock(in_chan=2*feat, neck_chan= 4*feat, out_chan=4*feat, level=5, coarsen=True, mesh_folder=mesh_folder, e2n = True)
        self.block3 = ResBlock(in_chan=4*feat, neck_chan= 8*feat, out_chan=8*feat, level=4, coarsen=True, mesh_folder=mesh_folder, e2n = True)
        self.block4 = ResBlock(in_chan=8*feat, neck_chan=16*feat, out_chan=16*feat, level=3, coarsen=True, mesh_folder=mesh_folder, e2n = True)
        self.block5 = ResBlock(in_chan=16*feat, neck_chan=32*feat, out_chan=32*feat, level=2, coarsen=True, mesh_folder=mesh_folder, e2n = True)

        self.avg = nn.AvgPool1d(kernel_size=self.block5.nv_prev) # output shape batch x channels x 1
        L_relu = nn.LeakyReLU()
        self.out_layer= nn.Sequential(nn.Linear(32*feat,16*feat),L_relu,nn.Linear(16*feat,8*feat),L_relu,nn.Linear(8*feat,nclasses)) #nn.Softmax(dim=0) for multi-class

    def forward(self, x):
        x = self.in_block(x)
        x = self.maxp(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)

        x_l = torch.squeeze(self.avg(x[:, :x.shape[1]//2,: ]))
        x_r = torch.squeeze(self.avg(x[:, x.shape[1]//2:,: ]))
        x= torch.cat([x_l,x_r], dim = 1)
        x = self.out_layer(x)

        return torch.sigmoid(x)


import torch
import random
from sklearn.model_selection import train_test_split, StratifiedKFold

random_seed = 55

torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)



#Load data set X and y
#X is the cortical surface thickness data [N, 40962, 2] where first channel is the left/right hemisphere, second channel is the other hemisphere.
#y is the label (AD/NC or AGE) [N,1] 

k = 5     # k-cross validation에서 k의 값이 5
test_bests = []
val_test= []
skf = StratifiedShuffleSplit(n_splits=k, test_size = 0.2, train_size = 0.8, random_state = 42)
for i, (trainval_idx, test_idx) in enumerate(skf.split(X, y)):
    # Split data into training and test sets for the current fold
    X_trainval, X_test = X[trainval_idx], X[test_idx]
    y_trainval, y_test = y[trainval_idx], y[test_idx]
    
    # Further split the training set into training and validation sets
    print(len(y_test), len(X))
    skf = StratifiedShuffleSplit(n_splits=1, test_size = 0.1, train_size = 0.9)
    for i, (train_idx, val_idx) in enumerate(skf.split(X_trainval, y_trainval)):
        X_train, X_val = X_trainval[train_idx], X_trainval[val_idx]
        y_train, y_val = y_trainval[train_idx], y_trainval[val_idx]


    from torch.utils.data import DataLoader
    train_list = []
    for i,j in enumerate(list(X_train)):
        train_list.append([j,list(y_train)[i]])

    val_list = []
    for i,j in enumerate(list(X_val)):
        val_list.append([j,list(y_val)[i]])


    test_list = []
    for i,j in enumerate(list(X_test)):
        test_list.append([j,list(y_test)[i]])

    print(len(train_list), len(val_list), len(test_list))
    train_loader = DataLoader(train_list, batch_size = 32, shuffle = True, drop_last = True)
    val_loader = DataLoader(val_list, batch_size = len(val_list))
    test_loader = DataLoader(test_list, batch_size = len(test_list))


    # train_loader
    train_features, train_labels = next(iter(train_loader))
    print(f"Feature batch shape: {train_features.size()}")
    print(f"Labels batch shape: {train_labels.size()}")

    val_features, val_labels = next(iter(val_loader))
    print(f"Feature batch shape: {val_features.size()}")
    print(f"Labels batch shape: {val_labels.size()}")

    test_features, test_labels = next(iter(test_loader))
    print(f"Feature batch shape: {test_features.size()}")
    print(f"Labels batch shape: {test_labels.size()}")

    
    from torch.utils.data import DataLoader
    import torch.optim as optim
    from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
    from torch.optim.lr_scheduler import CosineAnnealingLR

    %cd /home/justin/Mesh
    #Model parameters: template path, feature dimension, output class number
    model = Model(".",16, 2)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    logger = logging.getLogger("train")
    logger.setLevel(logging.DEBUG)
    logger.handlers = []
    ch = logging.StreamHandler()
    logger.addHandler(ch)
    fh = logging.FileHandler(os.path.join('.', "absolute_aggregation_16.txt"))
    logger.addHandler(fh)
    test_acc = None
    logger.info("{} paramerters in total".format(sum(x.numel() for x in model.parameters())))
    best = []
    tbest = []
    best_loss = 100
    tolerance = 0

    for epoch in range(1, 40 + 1):
            logger.info("[Epoch {}]".format(epoch))
            train(model, device, train_loader, optimizer, epoch, logger)
            best, acc, val_loss= val( model, device, val_loader, logger,best)
            if val_loss <= best_loss:
                best_loss = val_loss
                tolerance = 0
            else:
                tolerance += 1
            print(tolerance)
            if acc == max(best):
                    t_best, test_acc= test( model, device, test_loader, logger, tbest, append = True)
                    vals = [acc, test_acc]
            else:
                    t_best, test_acc= test( model, device, test_loader, logger, tbest, append = False)
            if epoch <= 10:
                tolerance = 0
                
            if tolerance <= 8:
                # break
                save_checkpoint(epoch,model, optimizer, scheduler, save_cur=True)
            else:
                break
            scheduler.step()
    val_test.append(vals)
    test_bests.append(tbest)


#Ends at 35, stops at 9:  91.4634 + 89.0244 + 93.9024 + 89.0244 + 92.6829 = 91.2195


0.8378378378378378 + 0.7837837837837838 + 0.8648648648648649 + 0.8648648648648649 + 0.8918918918918919 = 0.8486486486486486


0.9777777777777777+ 0.9777777777777777 + 1.0 + 0.9111111111111111 + 0.9555555555555556 = 0.9644444444444444