In [1]:
import os, sys
import argparse
import cv2
import logging
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import json
import pandas as pd
CODE_SPACE=os.path.abspath('Test_Minist/')
sys.path.append(CODE_SPACE)
os.chdir(CODE_SPACE)

import utils.config as config
from utils.config import CfgNode
from utils.transforms_utils import get_imagenet_mean_std, normalize_img, pad_to_crop_sz, resize_by_scaled_short_side
import matplotlib.pyplot as plt

import glob
from PIL import Image
import torch.multiprocessing as mp 
from utils.segformer import get_configured_segformer, SegModel
from tqdm import tqdm 

In [2]:
from torch.utils.data import Dataset as BaseDataset
from torch.utils.data import DataLoader
from torchvision import transforms, utils, datasets

In [3]:
classes = {
    'facade' : 1,
    'molding' : 6,
    'cornice' : 10,
    'pillar': 11,
    'window' : 9,
    'door' : 4,
    'sill' : 2,
    'blind' : 5,
    'balcony' : 3,
    'shop': 8,
    'deco': 7,
    'background' : 0,
}

In [4]:
def make_palette(num_classes=12):
    """
    Inputs:
        num_classes: the number of classes
    Outputs:
        palette: the colormap as a k x 3 array of RGB colors
    """
    palette = np.zeros((num_classes, 3), dtype=np.uint8)
    for k in range(0, num_classes):
        label = k
        i = 0
        while label:
            palette[k, 0] |= (((label >> 0) & 1) << (7 - i))
            palette[k, 1] |= (((label >> 1) & 1) << (7 - i))
            palette[k, 2] |= (((label >> 2) & 1) << (7 - i))
            label >>= 3
            i += 1
    idx1 = np.arange(0, num_classes, 2)[::-1]
    idx2 = np.arange(1, num_classes, 2)
    idx = np.concatenate([idx1[:, None], idx2[:, None]], axis=1).flatten()
    palette = palette[idx]
    palette[num_classes - 1, :] = [255, 255, 255]
    return palette

PALETTE = make_palette(12)

In [5]:
def color_seg(seg, palette):
    color_out = palette[seg.reshape(-1)].reshape(seg.shape + (3,))
    return color_out

def color_map_list(class_num):
    map1 =  np.asarray([
        [0, 0, 170], #background
        [0, 0 ,255], #facade
        [0, 85, 255], #window
        [0, 170, 255], #door
        [85, 255, 170], #cornice
        [255, 170, 0], #sill
        [170, 255, 85], #balcony
        [255, 255, 0], #blind
        [0, 255, 255], #deco
        [255, 85, 0], #molding
        [255, 0, 0], #pillar
        [170, 0, 0], #shop
    ])
    idx1 = np.arange(0, map1.shape[0], 2)[::-1]
    idx2 = np.arange(1, map1.shape[0], 2)
    idx = np.concatenate([idx1[:, None], idx2[:, None]], axis=1).flatten()
    map1 = map1[idx]

    pa = np.ones((class_num, 3), dtype=np.uint8) * 255
    pa[:map1.shape[0], :] = map1
    return pa

In [6]:
class CustomDataset(torch.utils.data.Dataset):
    
    def __init__(
        self,
        images_dir,
        classes,
        augmentation=None, 
        preprocessing=None
    ):    
        self.images = [f'{images_dir}{i}' for i in os.listdir(images_dir) if i[-3:]=='jpg']
        self.masks = [f'{i[:-3]}png' for i in self.images]
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.classes = classes

    def __getitem__(self, i):
        image = np.array(Image.open(self.images[i]))
        mask = np.array(Image.open(self.masks[i]))
        
        mask = self.get_mask(mask)
        
        
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            

        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask

    def __len__(self):
        return len(self.images)
    
    def get_mask(self, rgb_mask):
        w, h = rgb_mask.shape
        mask = np.zeros((w, h, len(self.classes)))
        for i in self.classes.values():
            idx = rgb_mask==i
            mask[idx,i] = 1.
        return mask

In [7]:
dataset = CustomDataset('../data/base/', classes)

In [15]:
from torch.nn.modules.loss import _Loss

class TDiceLoss(_Loss):
    __name__ = 'TDiceLoss'
    
    def __init__(self, gt_embs_list, dice_weight=1, cross_weight=3):
        super().__init__()
        self.nll_loss = utils.losses.CrossEntropyLoss()
        self.dice_weight = dice_weight
        self.cross_weight = cross_weight
        self.gt_embs_list = gt_embs_list
        self.softmax = nn.Softmax2d()
        

    def forward(self, outputs, targets):
        
        outputs = self.get_prediction(outputs)
        outputs = F.softmax(outputs, dim=1)
        loss = self.nll_loss(outputs, targets) * self.cross_weight
        if self.dice_weight:
            eps = 1e-15
            dice_target = (targets == 1).float()
            dice_output = outputs
            intersection = (dice_output * dice_target).sum()
            union = dice_output.sum() + dice_target.sum() + eps
            loss += (1 - torch.log(2 * intersection / union)) * self.dice_weight
        return loss
    
    def get_prediction(self, embs):
        prediction = []
        logits = []
        B = embs.shape[0]
        for b in range(B):
            score = embs[b,...]
            score = score.unsqueeze(0)
            emb = self.gt_embs_list
            emb = emb / emb.norm(dim=1, keepdim=True)
            score = score / score.norm(dim=1, keepdim=True)
            score = score.permute(0, 2, 3, 1) @ emb.t()
            # [N, H, W, num_cls] You maybe need to remove the .t() based on the shape of your saved .npy
            score = score.permute(0, 3, 1, 2)  # [N, num_cls, H, W]
            prediction.append(score.max(1)[1])
            logits.append(score)
        if len(prediction) == 1:
            prediction = prediction[0]
            logit = logits[0]
        else:
            prediction = torch.cat(prediction, dim=0)
            logit = torch.cat(logits, dim=0)
        return logit

In [16]:
class Fscore(_Loss):
    __name__ = 'Fscore'
    
    def __init__(self, gt_embs_list):
        super().__init__()
        self.fscore = utils.metrics.Fscore(threshold=0.05)
        self.gt_embs_list = gt_embs_list
        self.softmax = nn.Softmax2d()
        

    def forward(self, outputs, targets):
        
        outputs = self.get_prediction(outputs)
        outputs = self.softmax(outputs)
        loss = self.fscore(outputs, targets)
        return loss
    
    def get_prediction(self, embs):
        prediction = []
        logits = []
        B = embs.shape[0]
        for b in range(B):
            score = embs[b,...]
            score = score.unsqueeze(0)
            emb = self.gt_embs_list
            emb = emb / emb.norm(dim=1, keepdim=True)
            score = score / score.norm(dim=1, keepdim=True)
            score = score.permute(0, 2, 3, 1) @ emb.t()
            # [N, H, W, num_cls] You maybe need to remove the .t() based on the shape of your saved .npy
            score = score.permute(0, 3, 1, 2)  # [N, num_cls, H, W]
            prediction.append(score.max(1)[1])
            logits.append(score)
        if len(prediction) == 1:
            prediction = prediction[0]
            logit = logits[0]
        else:
            prediction = torch.cat(prediction, dim=0)
            logit = torch.cat(logits, dim=0)
        return logit

In [17]:
model = SegModel(num_classes=512,
                  load_imagenet_model=False,
                  imagenet_ckpt_fpath='')
checkpoint = torch.load('models/segformer_7data.pth', map_location='cpu')
ckpt_filter = {k: v for k, v in checkpoint.items() if 'criterion.0.criterion.weight' not in k}
model.load_state_dict(ckpt_filter, strict=False)

_IncompatibleKeys(missing_keys=['segmodel.encoder.patch_embed1.proj.weight', 'segmodel.encoder.patch_embed1.proj.bias', 'segmodel.encoder.patch_embed1.norm.weight', 'segmodel.encoder.patch_embed1.norm.bias', 'segmodel.encoder.patch_embed2.proj.weight', 'segmodel.encoder.patch_embed2.proj.bias', 'segmodel.encoder.patch_embed2.norm.weight', 'segmodel.encoder.patch_embed2.norm.bias', 'segmodel.encoder.patch_embed3.proj.weight', 'segmodel.encoder.patch_embed3.proj.bias', 'segmodel.encoder.patch_embed3.norm.weight', 'segmodel.encoder.patch_embed3.norm.bias', 'segmodel.encoder.patch_embed4.proj.weight', 'segmodel.encoder.patch_embed4.proj.bias', 'segmodel.encoder.patch_embed4.norm.weight', 'segmodel.encoder.patch_embed4.norm.bias', 'segmodel.encoder.block1.0.norm1.weight', 'segmodel.encoder.block1.0.norm1.bias', 'segmodel.encoder.block1.0.attn.q.weight', 'segmodel.encoder.block1.0.attn.q.bias', 'segmodel.encoder.block1.0.attn.kv.weight', 'segmodel.encoder.block1.0.attn.kv.bias', 'segmodel.en

In [18]:
# from transformers import CLIPTokenizer, XCLIPModel
# text_model = XCLIPModel.from_pretrained("microsoft/xclip-base-patch32")
# tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32")
# text = [i for i,k in sorted(classes.items(), key=lambda x: x[1])]
# inputs = tokenizer(text, padding=True, return_tensors="pt")
# text_features = text_model.get_text_features(**inputs).detach().cuda()
# np.save('emb.npy', text_features.cpu().numpy())

In [19]:
text_features = torch.tensor(np.load('emb.npy')).cuda().float()

In [20]:
print(text_features)

tensor([[-0.0780, -0.0664, -0.1835,  ..., -0.5559,  0.0776,  0.1547],
        [ 0.0210, -0.0333, -0.1442,  ..., -0.7032, -0.0023, -0.4330],
        [-0.1037,  0.0747, -0.1227,  ..., -0.2483,  0.0738,  0.0639],
        ...,
        [ 0.1801,  0.2484, -0.2756,  ..., -0.4271, -0.1999, -0.0577],
        [-0.0171,  0.1701, -0.5017,  ..., -0.1415,  0.1229, -0.0184],
        [-0.1442,  0.3843, -0.2616,  ..., -0.5479,  0.3423, -0.1578]],
       device='cuda:0')


In [21]:
import segmentation_models_pytorch as smp
import albumentations as album
from albumentations.pytorch.transforms import ToTensorV2
from albumentations.augmentations.transforms import ToFloat

def get_training_augmentation():
    train_transform = [    
#         album.Resize(1024,1024),
        album.PadIfNeeded(256,256, always_apply=True, value=0),
        album.RandomCrop(256,256, always_apply=True),
        album.OneOf([
                album.CLAHE(p=1),
                album.RandomBrightnessContrast(p=1),
                album.RandomGamma(p=1),
                album.OpticalDistortion(p=1),
            ],
            p=0.9,
        ),
        album.OneOf([
                album.HorizontalFlip(p=1),
                album.VerticalFlip(p=1),
                album.RandomRotate90(p=1),
                album.Rotate(limit = 45, p=1),
            ],p=0.9
        ),
        album.OneOf([
            album.ElasticTransform(p=1),
            album.Posterize(p=1),
            album.RandomGamma((40,150),p=1),
            album.MultiplicativeNoise((0.8, 1.37),p=1),
            album.HueSaturationValue(p=1),
        ], p=0.9),
        album.Normalize(),
    ]
    return album.Compose(train_transform)


def get_validation_augmentation():   
    test_transform = [
        album.PadIfNeeded(256,256, always_apply=True, value=0),
        album.Resize(512,512),
        album.Normalize(),
    ]
    return album.Compose(test_transform)


def get_preprocessing(preprocessing_fn=None):
    
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
#     _transform.append(ToFloat(max_value=255))
    _transform.append(ToTensorV2(True))
    
        
    return album.Compose(_transform)

In [22]:
from segmentation_models_pytorch import utils

In [23]:
train_dataset = CustomDataset(
    '../data/base/',
    classes=classes,
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(),
)

valid_dataset = CustomDataset(
    '../data/extended/',
    classes=classes,
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(),
)

# Get train and val data loaders
train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True,  drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)

In [24]:
TRAINING = True

# Set num of epochs
EPOCHS = 50

# Set device: `cuda` or `cpu`
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# define loss function
loss = TDiceLoss(text_features)#utils.losses.DiceLoss() + utils.losses.BCELoss()
# define metrics
metrics = [
     Fscore(text_features)
#     utils.metrics.Accuracy(0.9),
#     utils.metrics.IoU(threshold=0.9),
#     utils.metrics.Fscore(threshold=0.9),
]

# define optimizer
optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.7),
])

# define learning rate scheduler (not used in this NB)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, 3, gamma=0.5
)

In [25]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss,
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
%%time
# torch.distributed.init_process_group(backend='nccl')
model = torch.nn.DataParallel(model)

if TRAINING:

    best_score = 0.0
    train_logs_list, valid_logs_list = [], []

    for i in range(0, EPOCHS):

        # Perform training & validation
        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)
        train_logs_list.append(train_logs)
        valid_logs_list.append(valid_logs)
        print(valid_logs['Fscore'])

        # Save model if a better val IoU score is obtained
        if best_score < valid_logs['Fscore']:
            best_score = valid_logs['Fscore']
            torch.save(model, './best_model.pth')
            print('Model saved!')


Epoch: 0
train: 100%|███████████████████████████████████████████████████████████████████████| 126/126 [02:18<00:00,  1.10s/it, TDiceLoss - 10.6, Fscore - 0.1506]
valid: 100%|██████████████████████████████████████████████████████████████████████| 228/228 [01:27<00:00,  2.60it/s, TDiceLoss - 10.58, Fscore - 0.1495]
0.14946113922062768
Model saved!

Epoch: 1
train:  49%|███████████████████████████████████▍                                    | 62/126 [02:13<02:07,  1.99s/it, TDiceLoss - 10.6, Fscore - 0.1509]