In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.layers.mesh_pool import MeshPool
from models.layers.mesh_conv import MeshConv

In [6]:
def get_norm_args(norm_layer, nfeats_list):
    if hasattr(norm_layer, '__name__') and norm_layer.__name__ == 'NoNorm':
        norm_args = [{'fake': True} for f in nfeats_list]
    elif norm_layer.func.__name__ == 'GroupNorm':
        norm_args = [{'num_channels': f} for f in nfeats_list]
    elif norm_layer.func.__name__ == 'BatchNorm':
        norm_args = [{'num_features': f} for f in nfeats_list]
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_layer.func.__name__)
    return norm_args

In [7]:
class MeshConvNet(nn.Module):
    """Network for learning a global shape descriptor (classification)
    """
    def __init__(self, norm_layer, nf0, conv_res, nclasses, input_res, pool_res, fc_n, nresblocks=3):
        super(MeshConvNet, self).__init__()
        self.k = [nf0] + conv_res
        self.res = [input_res] + pool_res
        norm_args = get_norm_args(norm_layer, self.k[1:])

        # for i, ki in enumerate(self.k[:-1]):
        #     setattr(self, 'conv{}'.format(i), MResConv(ki, self.k[i + 1], nresblocks))
        #     setattr(self, 'norm{}'.format(i), norm_layer(**norm_args[i]))
        #     setattr(self, 'pool{}'.format(i), MeshPool(self.res[i + 1]))
    
        self.conv0 = MResConv(self.k[0], self.k[1], nresblocks) 
        self.norm0 = norm_layer(**norm_args[0])
        self.pool0 = MeshPool(self.res[1])
        
        self.conv1 = MResConv(self.k[1], self.k[2], nresblocks)
        self.norm1 = norm_layer(**norm_args[1])
        self.pool1 = MeshPool(self.res[2])
        
        self.conv2 = MResConv(self.k[2], self.k[3], nresblocks)
        self.norm2 = norm_layer(**norm_args[2])
        self.pool2 = MeshPool(self.res[3])


        self.gp = torch.nn.AvgPool1d(self.res[-1])
        self.fc1 = nn.Linear(self.k[-1], fc_n)
        self.fc2 = nn.Linear(fc_n, nclasses)

    def forward(self, x, mesh):

        for i in range(len(self.k) - 1):
            x = getattr(self, 'conv{}'.format(i))(x, mesh)
            x = F.relu(getattr(self, 'norm{}'.format(i))(x))
            x = getattr(self, 'pool{}'.format(i))(x, mesh)

        x = self.gp(x)
        x = x.view(-1, self.k[-1])

        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class MResConv(nn.Module):
    def __init__(self, in_channels, out_channels, skips=1):
        super(MResConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.skips = skips
        self.conv0 = MeshConv(self.in_channels, self.out_channels, bias=False)
        for i in range(self.skips):
            setattr(self, 'bn{}'.format(i + 1), nn.BatchNorm2d(self.out_channels))
            setattr(self, 'conv{}'.format(i + 1), MeshConv(self.out_channels, self.out_channels, bias=False))

    def forward(self, x, mesh):
        x = self.conv0(x, mesh)
        x1 = x
        for i in range(self.skips):
            x = getattr(self, 'bn{}'.format(i + 1))(F.relu(x))
            x = getattr(self, 'conv{}'.format(i + 1))(x, mesh)
        x += x1
        x = F.relu(x)
        return x

In [8]:
from torch.utils.data import DataLoader
from data.classification_data import ClassificationData  # Custom dataset
from data.base_dataset import collate_fn  # Custom collate function

class Options:
    def __init__(self):
        # Define all the necessary options with default values
        self.dataroot = "datasets/human_class"  # Dataset root directory
        self.name = "human_class"
        self.phase = "train"  # Phase (train/test)
        self.batch_size = 8  # Batch size
        self.ninput_edges = 40000  # Number of input edges for mesh
        self.num_threads = 4  # Number of workers for data loading
        self.gpu_ids = [0]  # Use GPU (set to [] for CPU)
        self.max_dataset_size = float('inf')  # Maximum dataset size
        self.serial_batches = False  # Shuffle data (False means shuffle)
        self.export_folder = None  # Folder for export results (if any)
        self.save_latest_freq = 1000  # Frequency of saving model
        self.print_freq = 100  # Frequency of printing log messages
        self.epoch_count = 1  # Start counting from which epoch
        self.niter = 50  # Number of iterations at base learning rate
        self.niter_decay = 50  # Number of iterations with learning rate decay
        self.save_epoch_freq = 10  # Frequency of saving model at each epoch
        self.verbose_plot = False  # Verbose plotting (e.g. weight plots)
        self.run_test_freq = 5  # Frequency to run tests
        self.num_aug = 20  # Number of augmentations  
        self.checkpoints_dir = "checkpoints"

# Initialize the Options object
opt = Options()

# Now we can access the options like opt.gpu_ids, opt.batch_size, etc.
print("Using GPU IDs:", opt.gpu_ids)
print("Batch size:", opt.batch_size)

# Then use `opt` just like you would in the original code

# Manually create dataset object using `opt`
dataset = ClassificationData(opt)

# Create DataLoader with all options specified in `opt`
dataloader = DataLoader(
    dataset,
    batch_size=opt.batch_size,
    shuffle=not opt.serial_batches,
    num_workers=opt.num_threads,
    collate_fn=collate_fn
)

Using GPU IDs: [0]
Batch size: 8
loaded mean / std from cache


In [None]:
import torch
from models import networks
from os.path import join
from util.util import seg_accuracy, print_network

class ClassifierModel:
    def __init__(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.is_train = opt.is_train
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
        self.save_dir = join(opt.checkpoints_dir, opt.name)
        self.optimizer = None
        self.edge_features = None
        self.labels = None
        self.mesh = None
        self.soft_label = None
        self.loss = None
        self.nclasses = opt.nclasses

        norm_layer = networks.get_norm_layer(norm_type=opt.norm, num_groups=opt.num_groups)
        net = MeshConvNet(norm_layer, opt.input_nc, opt.ncf, self.nclasses, opt.ninput_edges, opt.pool_res, opt.fc_n, opt.resblocks)
        self.net = networks.init_net(net, opt.init_type, opt.init_gain, self.gpu_ids)
        
        self.net.train(self.is_train)
        self.criterion = torch.nn.CrossEntropyLoss().to(self.device)

        if self.is_train:
            self.optimizer = torch.optim.Adam(self.net.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.scheduler = networks.get_scheduler(self.optimizer, opt)
            print_network(self.net)

        if not self.is_train or opt.continue_train:
            self.load_network(opt.which_epoch)

    def set_input(self, data):
        input_edge_features = torch.from_numpy(data['edge_features']).float()
        labels = torch.from_numpy(data['label']).long()
        # set inputs
        self.edge_features = input_edge_features.to(self.device).requires_grad_(self.is_train)
        self.labels = labels.to(self.device)
        self.mesh = data['mesh']
        if self.opt.dataset_mode == 'segmentation' and not self.is_train:
            self.soft_label = torch.from_numpy(data['soft_label'])


    def forward(self):
        out = self.net(self.edge_features, self.mesh)
        return out

    def backward(self, out):
        self.loss = self.criterion(out, self.labels)
        self.loss.backward()

    def optimize_parameters(self):
        self.optimizer.zero_grad()
        out = self.forward()
        self.backward(out)
        self.optimizer.step()

    def load_network(self, which_epoch):
        """load model from disk"""
        save_filename = '%s_net.pth' % which_epoch
        load_path = join(self.save_dir, save_filename)
        net = self.net
        if isinstance(net, torch.nn.DataParallel):
            net = net.module
        print('loading the model from %s' % load_path)
        state_dict = torch.load(load_path, map_location=str(self.device))
        if hasattr(state_dict, '_metadata'):
            del state_dict._metadata
        net.load_state_dict(state_dict)


    def save_network(self, which_epoch):
        """save model to disk"""
        save_filename = '%s_net.pth' % (which_epoch)
        save_path = join(self.save_dir, save_filename)
        if len(self.gpu_ids) > 0 and torch.cuda.is_available():
            torch.save(self.net.module.cpu().state_dict(), save_path)
            self.net.cuda(self.gpu_ids[0])
        else:
            torch.save(self.net.cpu().state_dict(), save_path)

    def update_learning_rate(self):
        """update learning rate (called once every epoch)"""
        self.scheduler.step()
        lr = self.optimizer.param_groups[0]['lr']
        print('learning rate = %.7f' % lr)

    def test(self):
        """tests model
        returns: number correct and total number
        """
        with torch.no_grad():
            out = self.forward()
            # compute number of correct
            pred_class = out.data.max(1)[1]
            label_class = self.labels
            self.export_segmentation(pred_class.cpu())
            correct = self.get_accuracy(pred_class, label_class)
        return correct, len(label_class)

    def get_accuracy(self, pred, labels):
        """computes accuracy for classification / segmentation """
        if self.opt.dataset_mode == 'classification':
            correct = pred.eq(labels).sum()
        elif self.opt.dataset_mode == 'segmentation':
            correct = seg_accuracy(pred, self.soft_label, self.mesh)
        return correct

    def export_segmentation(self, pred_seg):
        if self.opt.dataset_mode == 'segmentation':
            for meshi, mesh in enumerate(self.mesh):
                mesh.export_segments(pred_seg[meshi, :])
