# Initial setup
setup of the Dataset classes

In [None]:
import torch
import torchvision
from pathlib import Path
from torch.utils.data import Dataset
from torchvision.io import decode_image
from datasets import load_dataset
import itertools
import math
class MMRFineTune(Dataset):
    def istrain(path: Path) -> bool:
        return path.parts[-2].contains('train')
    def __init__(self, root: Path, transform=None):
        if isinstance(root, str):
            self.root = Path(root)
        elif isinstance(root, Path):
            self.root = root
        self.transform = transform
        self.other_list = list(self.root.glob('other/*.jpg'))
        self.train_list = list(self.root.glob('train/*.jpg'))
        ds = load_dataset("zh-plus/tiny-imagenet")
        self.contrast = ds['train']
        self.max_contrast = 0 #max num of contrast samples
        self.n_other = len(self.other_list)
        self.n_train = len(self.train_list)
    def __len__(self):
        return self.max_contrast + self.n_other + self.n_train
    def __getitem__(self, idx):
        if idx < self.max_contrast:
            image = torchvision.transforms.functional.pil_to_tensor(self.contrast[idx]['image'])
            is_train = False
        elif idx < self.max_contrast + self.n_other:
            img_path = self.other_list[idx - self.max_contrast]
            image = decode_image(img_path)
            is_train = False
        else:
            img_path = self.train_list[idx - self.max_contrast - self.n_other]
            image = decode_image(img_path)
            is_train = True
        
        num_objs = 1
        _, h,w = image.shape
        boxes = torch.zeros((num_objs, 4), dtype=torch.float)
        boxes[0,0] = math.floor(w * 0.1)
        boxes[0,1] = math.floor(h * 0.1)
        boxes[0,2] = math.floor(w * 0.9)
        boxes[0,3] = math.floor(h * 0.9)
        
        if is_train:
            labels = torch.ones((num_objs,), dtype=torch.int64)
        else:
            labels = torch.zeros((num_objs,), dtype=torch.int64)
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        target = {}
        #print('imgesize: ', type(torchvision.transforms.functional.get_image_size(image)))
        #print('shape: ', image.shape)
        target["boxes"] = torchvision.tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(h,w))
        
        #target["masks"] = tv_tensors.Mask(masks)
        target["labels"] = labels
        target["image_id"] = idx
        target["area"] = area
        
        if self.transform:
            image, target = self.transform(image, target)
        return image, target
class MMRVideos(Dataset):
    def read_files(self):
        frame_store = []
        last_frame = []
        for video in (self.root / Path('train')).iterdir():
            frames = sorted([str(frame) for frame in video.glob('*.png')])
            frame_store += frames
            last_frame += [False]*(len(frames)-1) + [True]
        self.frame_store = frame_store
        self.last_frame = last_frame 
        self.frame_offset = list(itertools.accumulate(last_frame))
        self.max_id = len(self)

    def __init__(self, root: Path, transform=None):
        self.root = root
        self.transform = transform
        self.read_files()

    def __len__(self):
        return len(self.frame_store) - self.frame_offset[-1]
        
    def __getitem__(self, lidx):
        idx = lidx + self.frame_offset[lidx]
        img_path = self.frame_store[idx]
        image = decode_image(img_path)
                
        if not self.last_frame[idx] and idx < self.max_id-1:
            nx_frame = decode_image(self.frame_store[idx+1])
        else:
            nx_frame = None
        if self.transform:
            image = self.transform(image)
            if nx_frame is not None:
                nx_frame = self.transform(nx_frame)
        return image, nx_frame

### Testrun of the model to test if data works

In [None]:
from torchvision.io import decode_image
person_int = decode_image(str(Path("./data/train/woodbridge") / "8.png"))
#weights = torchvision.models.get_weight('FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1')
model = torchvision.models.get_model('mobilenet_v3_large', weights='IMAGENET1K_V2')
weights = torchvision.models.get_weight('MobileNet_V3_Large_Weights.IMAGENET1K_V2')
transforms = weights.transforms()

person_float = transforms(person_int).unsqueeze(0)
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook
#model.avgpool = torch.nn.Identity()
#model.classifier = torch.nn.Identity() 
#model.features.register_forward_hook(get_activation("feats"))
model = model.eval()
outputs = model(person_float)


In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torchvision.transforms.functional as F


plt.rcParams["savefig.bbox"] = 'tight'
def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])



from torchvision.utils import draw_keypoints
kpts = outputs[0]['keypoints']
scores = outputs[0]['scores']

detect_threshold = 0.75
idx = torch.where(scores > detect_threshold)
keypoints = kpts[idx]
res = draw_keypoints(person_int, keypoints, colors="blue", radius=3)
show(res)

# Fine tune backbone for this task
The model requires a image feature encoder. As a backbone Imagenet is used.
To further enhance the extraction features will be extracted via a feature pyramid.
This is also intented to be flexible for future modifications.

In [None]:
from torchvision.models.mobilenetv3 import _mobilenet_v3_conf
def get_model_traindetector(num_classes=2, dropout=0.2):
    model = torchvision.models.mobilenet_v3_large(weights="IMAGENET1K_V2")
    weights = torchvision.models.get_weight('MobileNet_V3_Large_Weights.IMAGENET1K_V2')
    #stop
    # get number of input features for the classifier
    #in_features = model.classifier.in_features
    # replace the pre-trained head with a new one
    inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large")
    #feat_dim = inverted_residual_setting[-1].out_channels
    feat_dim = 960
    model.classifier = torch.nn.Sequential(
            torch.nn.Linear(feat_dim, last_channel),
            torch.nn.Hardswish(inplace=True),
            torch.nn.Dropout(p=dropout, inplace=True),
            torch.nn.Linear(last_channel, num_classes),
        )#train+other
    for m in model.classifier.modules():
        if isinstance(m, torch.nn.Conv2d):
            torch.nn.init.kaiming_normal_(m.weight, mode="fan_out")
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
        elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.GroupNorm)):
            torch.nn.init.ones_(m.weight)
            torch.nn.init.zeros_(m.bias)
        elif isinstance(m, torch.nn.Linear):
            torch.nn.init.normal_(m.weight, 0, 0.01)
            torch.nn.init.zeros_(m.bias)
    #copy first layer of classif from pretrained
    state_dict = weights.get_state_dict()
    with torch.no_grad():
        model.classifier[0].weight.copy_(state_dict['classifier.0.weight'])
        model.classifier[0].bias.copy_(state_dict['classifier.0.bias'])
    # now get the number of input features for the mask classifier
    def nograds(layer_name):
        #no_grads = [f'features.{x}' for x in range(15)]
        no_grads = ['classifier.3']
        return any(map(lambda x: layer_name.startswith(x), no_grads))
    for name, param in model.named_parameters(): 
     
        #if name.startswith('features') and not (name.startswith('features.15.block.2') or name.startswith('features.15.block.3') or name.startswith('features.16')):
        if not nograds(name):
          param.requires_grad = False
    return model

In [None]:
from detector.engine import train_one_epoch
import torchvision.transforms.v2 as v2
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = get_model_traindetector()
model_raw = get_model_traindetector()


transform = v2.Compose([
 v2.Resize(232),
 v2.CenterCrop(224),
 v2.ToDtype(torch.float32, scale=True),
 v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), #taken from imagenet config
])
def my_collate(batch):
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    #target = torch.LongTensor(target)
    return [data, target]

dataset = MMRFineTune('finetune', transform=transform)
criterion = torch.nn.CrossEntropyLoss()
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    #collate_fn=my_collate,
)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=0.001,
    momentum=0.9,
    weight_decay=0.0005
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)
# let's train it just for 2 epochs
num_epochs = 10
for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    #evaluate(model, data_loader_test, device=device)

In [None]:
weights = torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights.COCO_V1
model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(weights=weights)
#model(person_float)
# get number of input features for the classifier
#model.roi_heads.box_predictor.cls_score.in_features
#import copy
#from customssd import SSDLiteHead
#model.head.regression_head
return_nodes = { 'head':'reghead'}
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
create_feature_extractor(model, return_nodes=return_nodes)

In [None]:
####TEST ONLY
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, ToPILImage
from torchvision.transforms import v2
import matplotlib.pyplot as plt
#import utils
transform = v2.Compose([
 v2.Resize(size=(320,320)),
 #v2.CenterCrop((600, 480)),
 #v2.CenterCrop((64, 64)),
 v2.ToDtype(torch.float32, scale=True),
])
def my_collate(batch):
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    #target = torch.LongTensor(target)
    return [data, target]

dataset = MMRFineTune('finetune', transform=transform)
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=my_collate,
)
images, targets = next(iter(data_loader))
images = [image for image in images]
targets = [{k: v for k, v in t.items()} for t in targets]
from importlib import reload
reload(detector)
from detector.ssdlite import custom_ssdlite320_mobilenet_v3_large
weights = torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights.COCO_V1
model = custom_ssdlite320_mobilenet_v3_large(num_classes=2, trainable_backbone_layers=1, weights=weights)
model.eval()
output = model(images, targets)  # Returns losses and detections
print(output)

In [None]:
import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.models.detection.backbone_utils import LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork

train_nodes, eval_nodes = get_graph_node_names(model)



class FrameEncoder(torch.nn.Module):
    
    def __init__(self, outdim=10):
        return_nodes = {
            'features.14.add': 'l14',
            'features.15.add':'l15',
            'features.16':'l16',
        }
        super(FrameEncoder, self).__init__()
        m = torchvision.models.get_model('mobilenet_v3_large', weights='IMAGENET1K_V2')
        self.body = create_feature_extractor(m, return_nodes=return_nodes)
         # Dry run to get number of channels for FPN
        inp = torch.randn(2, 3, 224, 224)
        with torch.no_grad():
            out = self.body(inp)
        in_channels_list = [o.shape[1] for o in out.values()]
        # Build FPN
        self.out_channels = outdim
        self.fpn = FeaturePyramidNetwork(
            in_channels_list, out_channels=self.out_channels,
            extra_blocks=LastLevelMaxPool())

    def forward(self, x):
        x = self.body(x)
        x = self.fpn(x)
        return x
class KeypointExtractor(torch.nn.Module):

    def __init__(self, image_dim: Tuple):
        super(KeypointExtractor, self).__init__()
        max_keypoints = 10
        n_bins = 2
        self.backbone = FrameEncoder(outdim=max_keypoints)
        self.vertical_classifier = torch.nn.Linear(10, n_bins*image_h)
        self.h_classifier = torch.nn.Linear(10, n_bins*image_w)
    def forward(self, x):
        x = F.flatten(axis=-1)
        h_feat = self.vertical_classifier(x)
        w_feat = self.h_classifier(x)
        h_conf, h_feat = torch.max(h_feat, dim=1)
        w_conf, w_feat = torch.max(w_feat, dim=1)
        keypoints = torch.cat(h_feat, w_feat, 0)
        return keypoints, torch.cat(h_conf, w_conf, 0)
m = FrameEncoder()

### Trainingarchitecture

the attempted method follows XXX et al with a motion difference reconstruction loss