In [1]:
from google.colab import drive
drive.mount('/content/drive/')
%cd /content/drive/MyDrive/spinal-bone-feature-detection
!unzip -q ./datasets/PACS.zip -d /tmp/PACS

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
/content/drive/MyDrive/spinal-bone-feature-detection
replace /tmp/PACS/photo/dog/056_0001.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: N


In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.models.mobilenetv3 import MobileNet_V3_Small_Weights

from collections import OrderedDict
from tqdm.notebook import tqdm

# PACS Datasets

In [3]:
# means and standard deviations ImageNet because the network is pretrained
means, stds = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)

# Define transforms to apply to each image
transf = transforms.Compose ([
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(means, stds),
])

# Define datasets root
DIR_PHOTO = "/tmp/PACS/photo"
DIR_ART = "/tmp/PACS/art_painting"
DIR_CARTOON = "/tmp/PACS/cartoon"
DIR_SKETCH = "/tmp/PACS/sketch"

# Prepare Pytorch train / test Datasets
photo_dataset = torchvision.datasets.ImageFolder (DIR_PHOTO, transform=transf)
art_dataset = torchvision.datasets.ImageFolder (DIR_ART, transform=transf)
cartoon_dataset = torchvision.datasets.ImageFolder (DIR_CARTOON, transform=transf)
sketch_dataset = torchvision.datasets.ImageFolder (DIR_SKETCH, transform=transf)
all_datasets = torch.utils.data.ConcatDataset([photo_dataset, art_dataset, cartoon_dataset, sketch_dataset])

# GNN for Image Classification

In [5]:
class GNN(nn.Module):
    def __init__(self, in_features, edge_features, out_feature, device, ratio=(1,)):
        super(GNN, self).__init__()
        self.edge_net = EdgeNet(in_features=in_features, num_features=edge_features, device=device, ratio=ratio)
        self.node_net = NodeNet(in_features=in_features, num_features=out_feature, device=device, ratio=ratio) # Set edge to node
        self.mask_val = -1 # mask value for no-gradient edges

    def label2edge(self, targets): # Convert node labels to affinity mask for backprop
        num_sample = targets.size()[1]
        label_i = targets.unsqueeze(-1).repeat(1, 1, num_sample)
        label_j = label_i.transpose(1, 2)
        edge = torch.eq(label_i, label_j).float()
        target_edge_mask = (torch.eq(label_i, self.mask_val) + torch.eq(label_j, self.mask_val)).type(torch.bool)
        source_edge_mask = ~target_edge_mask
        edge *= source_edge_mask.float()
        return edge[0], source_edge_mask

    def forward(self, init_node_feat):
        edge_feat, edge_sim = self.edge_net(init_node_feat) # Compute normalized and not normalized affinity matrix
        logits_gnn = self.node_net(init_node_feat, edge_feat) # Get edge feature and class logits
        return logits_gnn, edge_sim

In [6]:
class EdgeNet(nn.Module):
    def __init__(self, in_features, num_features, device, ratio=(1,)):
        super(EdgeNet, self).__init__()
        num_features_list = [num_features * r for r in ratio]
        layer_list = OrderedDict()
        self.device = device

        for l in range(len(num_features_list)):
            layer_list['conv%d' % l] = nn.Conv2d(
                in_channels = num_features_list[l - 1] if l > 0 else in_features,
                out_channels = num_features_list[l],
                kernel_size = 1, bias = False
            )
            layer_list['norm%d' % l] = nn.BatchNorm2d(num_features=num_features_list[l])
            layer_list['relu%d' % l] = nn.LeakyReLU()

        # Add final similarity kernel
        layer_list['conv_out'] = nn.Conv2d(in_channels=num_features_list[-1], out_channels=1, kernel_size=1)
        self.sim_network = nn.Sequential(layer_list).to(device)


    def forward(self, node_feat):
        node_feat = node_feat.unsqueeze(dim=0) # (1, bs, dim)
        num_tasks = node_feat.size(0) # 1
        num_data = node_feat.size(1) # bs

        x_i = node_feat.unsqueeze (2) # (1, bs, 1, dim)
        x_j = torch.transpose (x_i, 1, 2) # (1, 1, bs, dim)
        x_ij = torch.abs (x_i - x_j) # (1, bs, bs, dim)
        x_ij = torch.transpose (x_ij, 1, 3) # (1, dim, bs, bs)

        # Compute similarity / dissimilarity (batch_size x feat_size x num_samples x num_samples)
        sim_val = (torch.sigmoid(self.sim_network(x_ij)).squeeze(1).squeeze(0).to(self.device)) # (bs, bs)

        # Normalize affinity matrix
        force_edge_feat = (torch.eye(num_data).unsqueeze (0).repeat(num_tasks, 1, 1).to(self.device)) # (1, bs, bs)
        edge_feat = sim_val + force_edge_feat # (bs, bs)
        edge_feat = edge_feat + 1e-6 # Add small value to avoid nan
        edge_feat = edge_feat / torch.sum( edge_feat, dim =1).unsqueeze (1) # Normalize
        return edge_feat, sim_val # (bs, bs), (bs, bs)

In [7]:
class NodeNet(nn.Module):
    def __init__(self, in_features, num_features, device, ratio =(1,)):
        super(NodeNet, self).__init__()
        num_features_list = [num_features * r for r in ratio]
        layer_list = OrderedDict()
        self.device = device

        for l in range(len(num_features_list)):
            layer_list['conv%d' % l] = nn.Conv2d(
                in_channels = num_features_list[l - 1] if l > 0 else in_features * 2,
                out_channels = num_features_list[l],
                kernel_size =1,
                bias =False,
            )
            layer_list['norm%d' % l] = nn.BatchNorm2d(num_features=num_features_list[l])

            if l < len(num_features_list) - 1: layer_list['relu%d' % l] = nn.LeakyReLU()
        self.network = nn.Sequential(layer_list).to(device)


    def forward(self, node_feat, edge_feat):
        # node_feat: (bs, dim), edge_feat: (bs, bs)
        node_feat = node_feat.unsqueeze(dim=0) # (1, bs, dim)
        num_tasks = node_feat.size(0) # 1
        num_data = node_feat.size(1) # bs

        # Get eye matrix(batch_size x node_size x node_size) only use inter dist.
        diag_mask = 1.0 - torch.eye(num_data).unsqueeze(0).repeat(num_tasks, 1, 1).to(self.device) # (1, bs, bs)

        # Set diagonal as zero and normalize
        edge_feat = F.normalize(edge_feat * diag_mask, p=1, dim=-1) # (bs, bs)

        # Compute attention and aggregate
        aggr_feat = torch.bmm(edge_feat.squeeze(1), node_feat) # (bs, dim)
        node_feat = torch.cat([node_feat, aggr_feat], -1).transpose(1, 2) # (1, 2 * dim, bs)

        # Non-linear transform
        node_feat = self.network(node_feat.unsqueeze(-1)).transpose(1, 2) # (1, bs, dim)
        return node_feat.squeeze(-1).squeeze(0) # (bs, dim)

In [None]:
class Model(nn.Module):
    def __init__(self, num_classes=7):
        super(Model, self).__init__()
        self.backbone = models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT)
        self.backbone.classifier = nn.Sequential()
        self.gnn = GNN(in_features=576, edge_features=576, out_feature=num_classes, device='cuda', ratio=(1,))

    def forward(self, x):
        x = self.backbone(x)
        x, edge_sim = self.gnn(x)
        return x, edge_sim

# Training Utilities

In [None]:
def train_model(model, train_loader, val_loader, optimizer, num_epochs, device, num_classes):
    epoch_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []
    correct, total, best_val_accuracy = 0, 0, 0
    criterion, criterion_edge = nn.CrossEntropyLoss(), nn.BCELoss()

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0
        total_batches = len(train_loader)
        loop = tqdm(enumerate(train_loader), total=total_batches)

        for i, (inputs, labels) in loop:
            inputs, labels = inputs.to(device), labels.to(device) # Move inputs and labels to the device
            outputs, edge_sim = model(inputs) # Forward pass
            loss_cls = criterion(outputs, labels) # Cls loss

            # Edge loss
            edge_gt, edge_mask = model.gcn.label2edge(labels.unsqueeze(dim=0))
            loss_edge = criterion_edge(
                edge_sim.masked_select(edge_mask), edge_gt.masked_select(edge_mask)
            )
            loss = 0.3 * loss_cls + loss_edge # Total loss
            running_loss += loss.item()

            # Calculate accuracy
            _, predicted = outputs.max(1)
            correct += (predicted == labels).sum()
            total += labels.size(0)

            # Backward pass and optimization
            optimizer.zero_grad() # Zero the parameter gradients
            loss.backward()
            optimizer.step()

            # Update the progress bar
            loop.set_description(f"[EPOCH {epoch+1}/{num_epochs}] {i + 1}/{total_batches}")
            loop.set_postfix(loss=loss.item(), accuracy=(correct / total).item())

        epoch_loss = running_loss / total_batches
        epoch_losses.append(epoch_loss)
        train_accuracy = (correct / total).item()
        train_accuracies.append(train_accuracy)

        # Validation
        val_loss, val_accuracy = evaluate_model(model, val_loader, device, num_classes)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)
        print(f"=> Loss: {epoch_loss:.4f} - Accuracy: {train_accuracy:.4f} - Val Loss: {val_loss:.4f} - Val Accuracy: {val_accuracy:.4f}")

        # Save the best model based on validation accuracy
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), "best_model.pth")
    return epoch_losses, val_losses, train_accuracies, val_accuracies


def evaluate_model(model, val_loader, device, num_classes):
    model.eval()
    correct, total, running_loss = 0, 0, 0
    criterion, criterion_edge = nn.CrossEntropyLoss(), nn.BCELoss()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(val_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs, edge_sim = model(inputs)
            loss_cls = criterion(outputs, labels)
            edge_gt, edge_mask = model.gcn.label2edge(labels.unsqueeze(dim=0))
            loss_edge = criterion_edge(
                edge_sim.masked_select(edge_mask), edge_gt.masked_select(edge_mask)
            )
            loss = 0.3 * loss_cls + loss_edge
            running_loss += loss.item()

            _, predicted = outputs.max(1)
            correct += (predicted == labels).sum()
            total += labels.size(0)
    return running_loss / len(val_loader), (correct / total).item() # val_loss, val_accuracy

# Experiments

In [None]:
train_size = int(0.8 * len(all_datasets))
val_size = len(all_datasets) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(all_datasets, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=4, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=4)

In [43]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)



In [44]:
%%time
train_model(
    model, train_loader, val_loader,
    optimizer, num_epochs=10, device=device,
    num_classes=len(all_datasets.datasets[0].classes)
)

  0%|          | 0/15 [00:00<?, ?it/s]

=> Loss: 0.6753 - Accuracy: 0.7150 - Val Loss: 0.6787 - Val Accuracy: 0.6713


  0%|          | 0/15 [00:00<?, ?it/s]

=> Loss: 0.3563 - Accuracy: 0.8057 - Val Loss: 0.5559 - Val Accuracy: 0.7619


  0%|          | 0/15 [00:00<?, ?it/s]

=> Loss: 0.2402 - Accuracy: 0.8524 - Val Loss: 0.4688 - Val Accuracy: 0.7959


  0%|          | 0/15 [00:00<?, ?it/s]

=> Loss: 0.1625 - Accuracy: 0.8837 - Val Loss: 0.4394 - Val Accuracy: 0.8569


  0%|          | 0/15 [00:00<?, ?it/s]

=> Loss: 0.1160 - Accuracy: 0.9054 - Val Loss: 0.5300 - Val Accuracy: 0.8359


  0%|          | 0/15 [00:00<?, ?it/s]

=> Loss: 0.0960 - Accuracy: 0.9206 - Val Loss: 0.5595 - Val Accuracy: 0.8589


  0%|          | 0/15 [00:00<?, ?it/s]

=> Loss: 0.0897 - Accuracy: 0.9315 - Val Loss: 0.5037 - Val Accuracy: 0.8499


  0%|          | 0/15 [00:00<?, ?it/s]

=> Loss: 0.0882 - Accuracy: 0.9397 - Val Loss: 0.4795 - Val Accuracy: 0.8719


  0%|          | 0/15 [00:00<?, ?it/s]

=> Loss: 0.0779 - Accuracy: 0.9462 - Val Loss: 0.3876 - Val Accuracy: 0.8959


  0%|          | 0/15 [00:00<?, ?it/s]

=> Loss: 0.0715 - Accuracy: 0.9515 - Val Loss: 0.4185 - Val Accuracy: 0.8914
CPU times: user 2min 5s, sys: 21.6 s, total: 2min 27s
Wall time: 5min 8s


([0.6752504547437032,
  0.35634762843449913,
  0.24016335109869638,
  0.1625174840291341,
  0.11602991024653117,
  0.09602185040712356,
  0.08967399150133133,
  0.08816636949777604,
  0.07785323709249496,
  0.07151216119527817],
 [0.6787283271551132,
  0.5558868050575256,
  0.46878352761268616,
  0.4394068643450737,
  0.5299849808216095,
  0.5595251172780991,
  0.5036904066801071,
  0.47954872250556946,
  0.3876422420144081,
  0.41850610822439194],
 [0.7149739861488342,
  0.8057292103767395,
  0.8523871898651123,
  0.8836914300918579,
  0.9053646326065063,
  0.9205729365348816,
  0.9314732551574707,
  0.9396647810935974,
  0.9462239146232605,
  0.9515364766120911],
 [0.6713356971740723,
  0.7618809938430786,
  0.7958979606628418,
  0.8569284677505493,
  0.8359180092811584,
  0.8589295148849487,
  0.8499249815940857,
  0.871936023235321,
  0.8959479928016663,
  0.8914457559585571])

In [38]:
import gc
gc.collect()
torch.cuda.empty_cache()