## Install Dependencies

In [None]:
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.11.0+cu113.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.11.0+cu113.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install --upgrade --no-cache-dir gdown
!gdown --id --no-cookies '1eG-iNJQhK6IXEKTAnAvCGy2Jyz126f5l'
!unzip geomat.zip

## Setup Datset

In [None]:
import torch_geometric.transforms as T
from geomat import GeoMat
from torch_geometric.loader import DataLoader

pre_transform, transform = T.NormalizeScale(), T.FixedPoints(500)
train_dataset = GeoMat("geomat", True, transform, pre_transform)
test_dataset = GeoMat("geomat", False, transform, pre_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=6)


## Create Model

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MLP, DynamicEdgeConv, global_max_pool
import convnext

class Net(torch.nn.Module):
    def __init__(self, out_channels, k=20, aggr="max"):
        super().__init__()
        self.conv1 = DynamicEdgeConv(MLP([2 * (3 + 3), 64], act="LeakyReLU", act_kwargs={"negative_slope": 0.2}, dropout=0.8), k, aggr)
        self.conv2 = DynamicEdgeConv(MLP([2 * 64, 64], act="LeakyReLU", act_kwargs={"negative_slope": 0.2}, dropout=0.8), k, aggr)
        self.fc1 = MLP([64 + 64, 1024], act="LeakyReLU", act_kwargs={"negative_slope": 0.2}, dropout=0.8)
        self.fc2 = MLP([1024 + 2304, 512, 256, out_channels], dropout=0.8)

        self.img_model = timm.create_model("convnext_base", num_classes=2, drop_path_rate=0.8).cuda()
        self.img_model.eval() # Don't finetune layers to reduce computation
        self.filter_conv = nn.Conv2d(1920, 64, 1)  # reduce filter size

    def forward(self, data):
        pos, x, batch = (data.pos.cuda(), data.x.cuda(), data.batch.cuda())

        features = self.img_model.get_features_concat(data.image.cuda().permute(0, 3, 1, 2).float())
        features = self.filter_conv(features)

        x1 = self.conv1(torch.cat((pos, x), dim=1).float(), batch)
        x2 = self.conv2(x1, batch)
        out = self.fc1(torch.cat((x1, x2), dim=1))
        out = global_max_pool(out, batch)
        out = self.fc2(torch.cat((out, features.reshape(features.shape[0], -1)), dim=1))
        return F.log_softmax(out, dim=1)




## Training and test loop

In [None]:
import numpy as np
import sklearn.metrics as metrics
import timm
from tqdm import tqdm
from timm.loss import LabelSmoothingCrossEntropy


def train():
    model.train()
    train_loss, train_pred, train_true = 0, [], []
    for data in tqdm(train_loader):
        optimizer.zero_grad()
        out = model.forward(data)
        loss = criterion(out, data.y.cuda())
        loss.backward()
        optimizer.step()
        preds = out.max(dim=1)[1]
        train_loss += loss.item() * data.num_graphs
        train_true.append(data.y.cpu().numpy())
        train_pred.append(preds.detach().cpu().numpy())

    train_true = np.concatenate(train_true)
    train_pred = np.concatenate(train_pred)
    return (train_loss / len(train_dataset), metrics.accuracy_score(train_true, train_pred), metrics.balanced_accuracy_score(train_true, train_pred))


def test():
    model.eval()
    correct = 0
    for data in test_loader:
        with torch.no_grad():
            pred = model(data).max(dim=1)[1]
        correct += pred.eq(data.y.cuda()).sum().item()

    return correct / len(test_loader.dataset)

## Initialize model and train

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net(out_channels=19, k=20).to(device)
optimizer = torch.optim.RAdam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
criterion = LabelSmoothingCrossEntropy(smoothing=0.1)

for epoch in range(10):
    loss, train_acc, balanced_train_acc = train()
    test_acc = test()

    print(f"Epoch {epoch:03d}, Train Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Balanced Train Acc: {balanced_train_acc:.4f}, Test: {test_acc:.4f}")
    scheduler.step()