In [1]:
from common.data_pipeline.MMCBNU_6000.dataset import DatasetLoader as mmcbnu
from common.data_pipeline.FV_USM.dataset import DatasetLoader as fvusm
from common.util.data_pipeline.dataset_chainer import DatasetChainer
from common.util.enums import EnvironmentType

environment = EnvironmentType.PYTORCH
datasets = DatasetChainer(
    datasets=[
        mmcbnu(included_portion=1, environment_type=environment),
        fvusm(included_portion=0, environment_type=environment),
    ]
)
train, test, validation = datasets.get_dataset(environment)

2023-09-05 12:06:40,441 - root - INFO - Preprocessing MMCBNU6000 dataset for train split.
100%|██████████| 4200/4200 [00:02<00:00, 1779.19it/s]
2023-09-05 12:06:42,927 - root - INFO - Preprocessing MMCBNU6000 dataset for test split.
100%|██████████| 1200/1200 [00:00<00:00, 1845.86it/s]
2023-09-05 12:06:43,613 - root - INFO - Preprocessing MMCBNU6000 dataset for validation split.
100%|██████████| 600/600 [00:00<00:00, 1810.45it/s]
2023-09-05 12:06:43,973 - root - INFO - Preprocessing FV_USM dataset for train split.
0it [00:00, ?it/s]
2023-09-05 12:06:43,977 - root - INFO - Preprocessing FV_USM dataset for test split.
0it [00:00, ?it/s]
2023-09-05 12:06:43,980 - root - INFO - Preprocessing FV_USM dataset for validation split.
0it [00:00, ?it/s]
2023-09-05 12:06:43,984 - root - INFO - Concatenating train set
2023-09-05 12:06:44,057 - root - INFO - Concatenating test set
2023-09-05 12:06:44,080 - root - INFO - Concatenating validation set


In [2]:
for t in train.dataset.data:
    print(t.shape)

(4200, 1, 60, 128)
(4200, 100, 1, 1)


In [3]:
from common.util.enums import DatasetSplitType


# datasets.get_files(DatasetSplitType.TRAIN)

In [4]:
import torch
import torch.nn as nn

import torch.nn as nn
from common.gcn_lib.dense.torch_vertex import DynConv2d


# gcn_lib is downloaded from https://github.com/lightaime/deep_gcns_torch
class GrapherModule(nn.Module):
    """Grapher module with graph conv and FC layers"""

    def __init__(self, in_channels, hidden_channels, k=9, dilation=1, drop_path=0.0):
        super(GrapherModule, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
            nn.BatchNorm2d(in_channels),
        )
        self.graph_conv = nn.Sequential(
            DynConv2d(in_channels, hidden_channels, k, dilation, act=None),
            nn.BatchNorm2d(hidden_channels),
            nn.GELU(),
        )
        self.fc2 = nn.Sequential(
            nn.Conv2d(hidden_channels, in_channels, 1, stride=1, padding=0),
            nn.BatchNorm2d(in_channels),
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.reshape(B, C, -1, 1).contiguous()
        shortcut = x
        x = self.fc1(x)
        x = self.graph_conv(x)
        x = self.fc2(x)
        x = self.drop_path(x) + shortcut
        return x.reshape(B, C, H, W)


class FFNModule(nn.Module):
    """Feed-forward Network"""

    def __init__(self, in_channels, hidden_channels, drop_path=0.0):
        super(FFNModule, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, 1, stride=1, padding=0), nn.BatchNorm2d(hidden_channels), nn.GELU()
        )
        self.fc2 = nn.Sequential(
            nn.Conv2d(hidden_channels, 100, 1, stride=1, padding=0),
            nn.BatchNorm2d(100),
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x):
        shortcut = x
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.drop_path(x) + shortcut
        return x


class ViGBlock(nn.Module):
    """ViG block with Grapher and FFN modules"""

    def __init__(self, channels, k, dilation, drop_path=0.0):
        super(ViGBlock, self).__init__()
        self.grapher = GrapherModule(channels, channels * 2, k, dilation, drop_path)
        self.ffn = FFNModule(channels, channels * 4, drop_path)

    def forward(self, x):
        x = self.grapher(x)
        x = self.ffn(x)
        return x


# Instantiate the model
model = ViGBlock(1, 1, 1, 0)

# Print the model architecture
print(model)

ViGBlock(
  (grapher): GrapherModule(
    (fc1): Sequential(
      (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (graph_conv): Sequential(
      (0): DynConv2d(
        (gconv): EdgeConv2d(
          (nn): BasicConv(
            (0): Conv2d(2, 2, kernel_size=(1, 1), stride=(1, 1))
          )
        )
        (dilated_knn_graph): DenseDilatedKnnGraph(
          (_dilated): DenseDilated()
        )
      )
      (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GELU(approximate='none')
    )
    (fc2): Sequential(
      (0): Conv2d(2, 1, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (drop_path): Identity()
  )
  (ffn): FFNModule(
    (fc1): Sequential(
      (0): Conv2d(1, 4, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(4, eps=1e-05, mo

In [5]:
import torch.optim as optim

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for inputs, labels in train:
        optimizer.zero_grad()
        outputs = model(inputs.float())
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()

    # Validation loop (optional)
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in validation:
            outputs = model(inputs.float())
            val_loss += criterion(outputs, labels.float())
            predicted = (outputs > 0.8).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(
        f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}, Accuracy: {(correct/total)*100:.2f}%"
    )
model.train()

ValueError: Using a target size (torch.Size([10, 100, 1, 1])) that is different to the input size (torch.Size([10, 100, 60, 128])) is deprecated. Please ensure they have the same size.