Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RELTR1.0 #64

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions examples/RELTR.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from openks.models import OpenKSModel

# 列出已加载模型
OpenKSModel.list_modules()

# 算法模型选择配置
args = {
"MODEL.DEVICE": 'cpu'
}
platform = 'PyTorch'
executor = 'RELTRExtract'
model = 'pytorch-RELTRExtractor'
print("根据配置,使用 {} 框架,{} 执行器训练 {} 模型。".format(platform, executor, model))
print("-----------------------------------------------")
# 模型训练
executor = OpenKSModel.get_module(platform, executor)

text_ner = executor(args=args)
text_ner.run(mode="train")
print("-----------------------------------------------")
392 changes: 392 additions & 0 deletions openks/models/pytorch/RELTR.py

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions openks/models/pytorch/visual_entity_modules/datasets1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch.utils.data
import torchvision

from .coco import build as build_coco


def get_coco_api_from_dataset(dataset):
for _ in range(10):
# if isinstance(dataset, torchvision.datasets.CocoDetection):
# break
if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection):
return dataset.coco


def build_dataset(image_set, args):
if args.dataset == 'vg' or args.dataset_file == 'oi':
return build_coco(image_set, args)
raise ValueError(f'dataset {args.dataset} not supported')
182 changes: 182 additions & 0 deletions openks/models/pytorch/visual_entity_modules/datasets1/coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Institute of Information Processing, Leibniz University Hannover.

"""
dataset (COCO-like) which returns image_id for evaluation.

Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
"""
from pathlib import Path
import json
import torch
import torch.utils.data
import torchvision
from pycocotools import mask as coco_mask

import datasets.transforms as T

class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms, return_masks):
super(CocoDetection, self).__init__(img_folder, ann_file)
self._transforms = transforms
self.prepare = ConvertCocoPolysToMask(return_masks)

#TODO load relationship
with open('/'.join(ann_file.split('/')[:-1])+'/rel.json', 'r') as f:
all_rels = json.load(f)
if 'train' in ann_file:
self.rel_annotations = all_rels['train']
elif 'val' in ann_file:
self.rel_annotations = all_rels['val']
else:
self.rel_annotations = all_rels['test']

self.rel_categories = all_rels['rel_categories']

def __getitem__(self, idx):
img, target = super(CocoDetection, self).__getitem__(idx)
image_id = self.ids[idx]
rel_target = self.rel_annotations[str(image_id)]

target = {'image_id': image_id, 'annotations': target, 'rel_annotations': rel_target}

img, target = self.prepare(img, target)
if self._transforms is not None:
img, target = self._transforms(img, target)
return img, target


def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks


class ConvertCocoPolysToMask(object):
def __init__(self, return_masks=False):
self.return_masks = return_masks

def __call__(self, image, target):
w, h = image.size

image_id = target["image_id"]
image_id = torch.tensor([image_id])

anno = target["annotations"]

anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]

boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
boxes[:, 2:] += boxes[:, :2]
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)

classes = [obj["category_id"] for obj in anno]
classes = torch.tensor(classes, dtype=torch.int64)

if self.return_masks:
segmentations = [obj["segmentation"] for obj in anno]
masks = convert_coco_poly_to_mask(segmentations, h, w)

keypoints = None
if anno and "keypoints" in anno[0]:
keypoints = [obj["keypoints"] for obj in anno]
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
num_keypoints = keypoints.shape[0]
if num_keypoints:
keypoints = keypoints.view(num_keypoints, -1, 3)

keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = classes[keep]
if self.return_masks:
masks = masks[keep]
if keypoints is not None:
keypoints = keypoints[keep]

# TODO add relation gt in the target
rel_annotations = target['rel_annotations']

target = {}
target["boxes"] = boxes
target["labels"] = classes
if self.return_masks:
target["masks"] = masks
target["image_id"] = image_id
if keypoints is not None:
target["keypoints"] = keypoints

# for conversion to coco api
area = torch.tensor([obj["area"] for obj in anno])
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
target["area"] = area[keep]
target["iscrowd"] = iscrowd[keep]

target["orig_size"] = torch.as_tensor([int(h), int(w)])
target["size"] = torch.as_tensor([int(h), int(w)])
# TODO add relation gt in the target
target['rel_annotations'] = torch.tensor(rel_annotations)

return image, target


def make_coco_transforms(image_set):

normalize = T.Compose([
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]

if image_set == 'train':
return T.Compose([
T.RandomHorizontalFlip(),
T.RandomSelect(
T.RandomResize(scales, max_size=1333),
T.Compose([
T.RandomResize([400, 500, 600]),
#T.RandomSizeCrop(384, 600), # TODO: cropping causes that some boxes are dropped then no tensor in the relation part! What should we do?
T.RandomResize(scales, max_size=1333),
])
),
normalize])

if image_set == 'val':
return T.Compose([
T.RandomResize([800], max_size=1333),
normalize,
])

raise ValueError(f'unknown {image_set}')


def build(image_set, args):

ann_path = args.ann_path
img_folder = args.img_folder

#TODO: adapt vg as coco
if image_set == 'train':
ann_file = ann_path + 'train.json'
elif image_set == 'val':
if args.eval:
ann_file = ann_path + 'test.json'
else:
ann_file = ann_path + 'val.json'

dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=False)
return dataset
Loading