In [67]:
import torch
import torch.nn as nn

import torch.nn.functional as F
import torch.optim as optim
import torchmetrics
import torchvision

import pytorch_lightning as pl

import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import random_split



In [88]:
from typing import Any


class GraphAttentionModulation_LINEAR(nn.Module):
    def __init__(self, dim_embedding, max_num_boxes) -> None:
        super().__init__()

        self.dim_embedding = dim_embedding
        self.num_boxes = max_num_boxes

        self.W = nn.Bilinear(max_num_boxes, dim_embedding, dim_embedding)

        self.W1 = nn.Linear(dim_embedding, dim_embedding)

        self.gelu = nn.GELU()

        self.layer_norm = nn.LayerNorm(dim_embedding)

    def forward(self, x):

        # compute the clip score
        # x: (batch_size, n_boxes, dim_embedding)
        # c_s: (batch_size, n_boxes, n_boxes)
        # the clip score is the similarity between each pair of boxes
        # c_s[i, j] is the similarity between box i and box j
        # the similarity will be used to compute the attention score
        # considering the similarity between each pair of boxes
        # as the link in between the boxes

        # normalize the embedding
        x /= x.norm(dim = -1, keepdim = True)
        c_s = torch.bmm(x, x.permute(0, 2, 1)) 

        # find a transformation of the embedding
        # based on the similarity between each pair of boxes
        # c_s W^T -> attention_score
        # attention_score: (batch_size, n_boxes, n_boxes)
        # attention_score Embedding -> (batch_size, n_boxes, dim_embedding)
        x = self.W(c_s, x)
        del c_s

        # apply a non-linear transformation
        x = self.gelu(x)

        # apply a layer normalization
        x = self.layer_norm(x)

        # apply a linear transformation 
        x = self.W1(x)

        # apply a non-linear transformation
        x = self.gelu(x)

        # apply a layer normalization
        x = self.layer_norm(x)

        return x

##############################################################################
##############################################################################

class GraphBoxRegressor(nn.Module):
    def __init__(self, dim_embedding = 512, max_num_boxes = 48) -> None:
        super().__init__()

        self.dim_embedding = dim_embedding
        self.max_num_boxes = max_num_boxes

        self.gam = GraphAttentionModulation_LINEAR(dim_embedding, max_num_boxes)

        self.flatten = nn.Flatten(start_dim=1, end_dim=- 1)

        self.regressor = nn.Sequential(
            nn.Linear(dim_embedding * max_num_boxes + max_num_boxes * 4, dim_embedding),
            nn.GELU(),
            nn.LayerNorm(dim_embedding),
            nn.Linear(dim_embedding, dim_embedding//2),
            nn.GELU(),
            nn.LayerNorm(dim_embedding//2),
            nn.Linear(dim_embedding//2, 4)
        )

    def forward(self, x, boxes):

        # apply the graph attention modulation
        # x: (batch_size, n_boxes, dim_embedding)
        x = self.gam(x)

        # concatenate the boxes to the embedding
        # x: (batch_size, n_boxes, dim_embedding + 4)
        x = torch.cat([x, boxes], dim = -1)

        # flatten the embedding
        # x: (batch_size, n_boxes * dim_embedding)
        x = self.flatten(x)

        # apply the regressor
        # x: (batch_size, 4)
        x = self.regressor(x)

        return x
    

##############################################################################
##############################################################################


class GraphBoxRegressorLightning(pl.LightningModule):
    def __init__(self, dim_embedding = 512, max_num_boxes = 48) -> None:
        super().__init__()

        self.dim_embedding = dim_embedding
        self.max_num_boxes = max_num_boxes

        self.gam_l = GraphBoxRegressor(dim_embedding, max_num_boxes)


        self.MSE = nn.MSELoss()
        self.MAE = nn.L1Loss()
        self.HUBER = nn.SmoothL1Loss()  

        # GENERALIZED_BOX_IOU_LOSS https://arxiv.org/abs/1902.09630
        self.generalized_box_iou_loss = torchvision.ops.generalized_box_iou_loss
        # DISTANCE_BOX_IOU_LOSS https://arxiv.org/abs/1911.08287
        self.distance_box_iou_loss = torchvision.ops.distance_box_iou_loss
        # COMPLETE_BOX_IOU_LOSS https://arxiv.org/abs/1911.08287
        self.complete_box_iou_loss = torchvision.ops.complete_box_iou_loss


    def forward(self, x, boxes):
        return self.gam_l(x, boxes)
    
    def training_step(self, batch, batch_idx):

        x, boxes, target = batch

        pred = self(x, boxes)

        # keep track of the losses

        mse_loss = self.MSE(pred, target)

        mae_loss = self.MAE(pred, target)

        huber_loss = self.HUBER(pred, target)

        g_box_iou_loss = self.generalized_box_iou_loss(pred, target, reduction = 'mean')

        d_box_iou_loss = self.distance_box_iou_loss(pred, target, reduction = 'mean')

        c_box_iou_loss = self.complete_box_iou_loss(pred, target, reduction = 'mean')

        self.log('train_mse_loss', mse_loss, on_step = True, on_epoch = True, prog_bar = True, logger = True)
        self.log('train_mae_loss', mae_loss, on_step = True, on_epoch = True, prog_bar = True, logger = True)
        self.log('train_huber_loss', huber_loss, on_step = True, on_epoch = True, prog_bar = True, logger = True)
        self.log('train_g_box_iou_loss', g_box_iou_loss, on_step = True, on_epoch = True, prog_bar = True, logger = True)
        self.log('train_d_box_iou_loss', d_box_iou_loss, on_step = True, on_epoch = True, prog_bar = True, logger = True)
        self.log('train_c_box_iou_loss', c_box_iou_loss, on_step = True, on_epoch = True, prog_bar = True, logger = True)

        return g_box_iou_loss

    def validation_step(self, batch, batch_idx):
            
        x, boxes, target = batch

        pred = self(x, boxes)

        # keep track of the losses

        mse_loss = self.MSE(pred, target)

        mae_loss = self.MAE(pred, target)

        huber_loss = self.HUBER(pred, target)

        g_box_iou_loss = self.generalized_box_iou_loss(pred, target, reduction = 'mean')

        d_box_iou_loss = self.distance_box_iou_loss(pred, target, reduction = 'mean')

        c_box_iou_loss = self.complete_box_iou_loss(pred, target, reduction = 'mean')

        self.log('val_mse_loss', mse_loss, on_step = True, on_epoch = True, prog_bar = True, logger = True)
        self.log('val_mae_loss', mae_loss, on_step = True, on_epoch = True, prog_bar = True, logger = True)
        self.log('val_huber_loss', huber_loss, on_step = True, on_epoch = True, prog_bar = True, logger = True)
        self.log('val_g_box_iou_loss', g_box_iou_loss, on_step = True, on_epoch = True, prog_bar = True, logger = True)
        self.log('val_d_box_iou_loss', d_box_iou_loss, on_step = True, on_epoch = True, prog_bar = True, logger = True)
        self.log('val_c_box_iou_loss', c_box_iou_loss, on_step = True, on_epoch = True, prog_bar = True, logger = True)

        return g_box_iou_loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr = 1e-3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 10)
        return [optimizer], [scheduler]





In [70]:
# set up the dataloaders

X = torch.randn(1000, 48, 512)
boxes = torch.randn(1000, 48, 4)
target = boxes[:, 0,:]

BATCH_SIZE = 64

train_size = int(len(X)*0.8)
val_size = len(X) - train_size

dataset = torch.utils.data.TensorDataset(X, boxes, target)

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = BATCH_SIZE, shuffle = False)

In [89]:
# set up the model

model = GraphBoxRegressorLightning().cuda()

# early stopping
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

# early_stopping = EarlyStopping('val_g_box_iou_loss', patience = 3, mode = 'min')
# checkpoint = ModelCheckpoint('checkpoints', monitor = 'val_g_box_iou_loss', save_top_k = 1, mode = 'min')

# set up the trainer

trainer = pl.Trainer(accelerator='auto', max_epochs = 10)

# train the model
trainer.fit(model, train_loader, val_loader)



GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type              | Params
--------------------------------------------
0 | gam_l | GraphBoxRegressor | 25.7 M
1 | MSE   | MSELoss           | 0     
2 | MAE   | L1Loss            | 0     
3 | HUBER | SmoothL1Loss      | 0     
--------------------------------------------
25.7 M    Trainable params
0         Non-trainable params
25.7 M    Total params
102.651   Total estimated model params size (MB)


                                                                           

  rank_zero_warn(
  rank_zero_warn(


Epoch 9: 100%|██████████| 13/13 [00:01<00:00,  8.24it/s, v_num=5, train_mse_loss_step=8.900, train_mae_loss_step=2.760, train_huber_loss_step=2.270, train_g_box_iou_loss_step=0.993, train_d_box_iou_loss_step=0.995, train_c_box_iou_loss_step=1.320, val_mse_loss_step=9.420, val_mae_loss_step=2.960, val_huber_loss_step=2.460, val_g_box_iou_loss_step=0.994, val_d_box_iou_loss_step=1.010, val_c_box_iou_loss_step=1.520, val_mse_loss_epoch=9.880, val_mae_loss_epoch=2.960, val_huber_loss_epoch=2.460, val_g_box_iou_loss_epoch=0.998, val_d_box_iou_loss_epoch=1.010, val_c_box_iou_loss_epoch=1.300, train_mse_loss_epoch=9.990, train_mae_loss_epoch=2.970, train_huber_loss_epoch=2.480, train_g_box_iou_loss_epoch=0.998, train_d_box_iou_loss_epoch=1.010, train_c_box_iou_loss_epoch=1.320]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 13/13 [00:02<00:00,  6.45it/s, v_num=5, train_mse_loss_step=8.900, train_mae_loss_step=2.760, train_huber_loss_step=2.270, train_g_box_iou_loss_step=0.993, train_d_box_iou_loss_step=0.995, train_c_box_iou_loss_step=1.320, val_mse_loss_step=9.420, val_mae_loss_step=2.960, val_huber_loss_step=2.460, val_g_box_iou_loss_step=0.994, val_d_box_iou_loss_step=1.010, val_c_box_iou_loss_step=1.520, val_mse_loss_epoch=9.880, val_mae_loss_epoch=2.960, val_huber_loss_epoch=2.460, val_g_box_iou_loss_epoch=0.998, val_d_box_iou_loss_epoch=1.010, val_c_box_iou_loss_epoch=1.300, train_mse_loss_epoch=9.990, train_mae_loss_epoch=2.970, train_huber_loss_epoch=2.480, train_g_box_iou_loss_epoch=0.998, train_d_box_iou_loss_epoch=1.010, train_c_box_iou_loss_epoch=1.320]


In [5]:



torch.bmm(torch.randn(64,100, 512), torch.randn(64, 512, 100)).shape

torch.Size([64, 100, 100])