In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from detectron2.config import get_cfg
from detectron2.data import build_detection_train_loader, MetadataCatalog
import matplotlib.pyplot as plt
import torch.nn.functional as F
from detectron2.engine import DefaultTrainer
from transformers import CLIPProcessor
from detectron2.modeling import build_model
import logging
import os
import torch

from coseg.data.dataset_mappers import TrainMapper
from coseg.model.model import CoSeg_wrapper
from coseg.model.lang_model import CLIPLang



In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [4]:
cfg = get_cfg()
cfg.set_new_allowed(True)
cfg.merge_from_file('./configs/base_config.yaml')

Loading config ./configs/base_config.yaml with yaml.unsafe_load. Your machine may be at risk if the file contains malicious content.


In [6]:
class Trainer(DefaultTrainer):
    @classmethod
    def build_train_loader(self, cfg):
        return build_detection_train_loader(
                cfg,
                mapper=TrainMapper(cfg, True),
            )
    
    # @classmethod
    def build_model(cls, cfg):
        model = build_model(cfg)
        logger = logging.getLogger(__name__)
        logger.info("Model:\n{}".format(model))

        # Register vocabulary
        lang_model = CLIPLang().eval()
        processor = CLIPProcessor.from_pretrained(lang_model.clip_version)
        
        meta = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
        labels = meta.stuff_classes
        labels = [f"a photo of a {label}" for label in labels]
        labels.append('unlabeled')
        inputs = processor(labels, padding=True, return_tensors='pt')
        with torch.no_grad():
            label_embeddings = lang_model(**inputs)['text_embeds']
        label_embeddings.requires_grad_(False)
        model.register_vocabulary(label_embeddings)
        del lang_model
        
        return model
             

In [7]:
trainer = Trainer(cfg)



[32m[06/14 16:16:18 d2.data.datasets.coco]: [0mLoaded 118287 images with semantic segmentation from /scratch/t.tovi/datasets/coco-stuff/COCO_stuff_images/train2017
[32m[06/14 16:16:18 d2.data.build]: [0mUsing training sampler TrainingSampler
[32m[06/14 16:16:18 d2.data.common]: [0mSerializing the dataset using: <class 'detectron2.data.common._TorchSerializedList'>
[32m[06/14 16:16:18 d2.data.common]: [0mSerializing 118287 elements to byte tensors and concatenating them all ...
[32m[06/14 16:16:19 d2.data.common]: [0mSerialized dataset takes 24.70 MiB
[32m[06/14 16:16:19 d2.data.build]: [0mMaking batched data loader with batch_size=16


In [8]:
trainer.train()

[32m[06/14 16:16:19 d2.engine.train_loop]: [0mStarting training from iteration 0


  return F.conv2d(input, weight, bias, self.stride,


[32m[06/14 16:16:29 d2.utils.events]: [0m eta: 5:08:21  iter: 19  total_loss: 1.113    time: 0.4640  last_time: 0.4349  data_time: 0.0986  last_data_time: 0.0877   lr: 1.9981e-05  max_mem: 9676M
[32m[06/14 16:16:38 d2.utils.events]: [0m eta: 5:06:02  iter: 39  total_loss: 0.8919    time: 0.4582  last_time: 0.4612  data_time: 0.0843  last_data_time: 0.0832   lr: 3.9961e-05  max_mem: 9676M
[32m[06/14 16:16:47 d2.utils.events]: [0m eta: 5:05:01  iter: 59  total_loss: 0.8183    time: 0.4550  last_time: 0.4631  data_time: 0.0845  last_data_time: 0.0916   lr: 5.9941e-05  max_mem: 9676M
[32m[06/14 16:16:48 d2.engine.hooks]: [0mOverall training speed: 59 iterations in 0:00:27 (0.4583 s / it)
[32m[06/14 16:16:48 d2.engine.hooks]: [0mTotal training time: 0:00:27 (0:00:00 on hooks)
[32m[06/14 16:16:48 d2.utils.events]: [0m eta: 5:04:58  iter: 61  total_loss: 0.8152    time: 0.4549  last_time: 0.4500  data_time: 0.0841  last_data_time: 0.0785   lr: 6.094e-05  max_mem: 9676M


KeyboardInterrupt: 

In [None]:
plt.imshow(sample['masks'][0])

In [None]:
plt.imshow(
    F.sigmoid(sample['image'].permute(1, 2, 0))
)

In [None]:
from PIL import Image

im = Image.open('/scratch/t.tovi/datasets/coco-stuff/COCO_stuff_images/train2017/000000369973.jpg')

In [None]:
an = Image.open('/scratch/t.tovi/datasets/coco-stuff/annotations_detectron2/train2017/000000369973.png')

In [None]:
an.size

In [None]:
mapper = TrainMapper(cfg, True)

In [None]:
import detectron2.data.transforms as T
import numpy as np

im, an = np.array(im), np.array(an)

aug_input = T.AugInput(im, sem_seg=an)
aug_input, transforms = T.apply_transform_gens(mapper.tf_gens, aug_input)

In [None]:
plt.imshow(aug_input.image)

In [None]:
500 * 224/102