In [1]:
import os
import timm
import logging
import argparse
import pandas as pd
from typing import Optional
from datetime import datetime

import sys
sys.path.append('../')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torchmetrics import Accuracy, F1Score, Specificity

from pytorch_lightning import LightningModule
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.strategies import ParallelStrategy
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning import Trainer
# from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.plugins import DDPPlugin

from utils.dataset import PapsDetDataset, train_transforms, val_transforms, test_transforms, MAX_IMAGE_SIZE
from utils.collate import collate_fn
# from utils.sampler_by_group import GroupedBatchSampler, create_area_groups
# from utils.losses import SupConLoss, FocalLoss

from cls_utils.block import Bottleneck, TwoMLPHead, RoIPool

In [2]:
num_classes = 6
model = timm.create_model('resnet50', pretrained=False, num_classes=0, global_pool='')

In [3]:
in_features = timm.create_model('resnet50', pretrained=False).get_classifier().in_features 
in_features

2048

In [4]:
# model

In [5]:
interlayer = Bottleneck(in_features, in_features//4)

In [6]:
roipool = RoIPool(( 1,1), float(1/32)) # float(1/32), boxes is not normalized

In [7]:
intermdeiate_channels = in_features//4
mlp = TwoMLPHead(in_features, intermdeiate_channels, num_classes)

In [13]:
x = torch.randn(2,3,2048, 2048)

In [14]:
x = model(x)
x.shape

torch.Size([2, 2048, 64, 64])

In [15]:
x = interlayer(x)
x.shape

torch.Size([2, 2048, 64, 64])

In [18]:
bbox = torch.tensor([[3, 100.,200., 200., 250.], [3, 130.,250., 200., 250.]], dtype=torch.float32)
print(bbox)
print(bbox.shape)
roi = roipool(x, bbox)
roi.shape

tensor([[  3., 100., 200., 200., 250.],
        [  3., 130., 250., 200., 250.]])
torch.Size([2, 5])


torch.Size([2, 2048, 1, 1])

In [19]:
out = mlp(roi)
out.shape

torch.Size([2, 6])

In [26]:
bbox = torch.tensor([[ 100.,200., 200., 250.], [130.,250., 200., 250.]])
bbox.shape

torch.Size([2, 4])

In [27]:
labels = torch.ones(len(bbox)).unsqueeze(dim=1)
labels.shape

torch.Size([2, 1])

In [29]:
torch.cat([labels, bbox ], dim=1)

tensor([[  1., 100., 200., 200., 250.],
        [  1., 130., 250., 200., 250.]])