In [None]:
# Copyright (c) 2021  IBM Corporation
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

import numpy as np
import os
from tqdm import tqdm
import cv2

from sign_dataset import SignDataset
import modules.datautils as datautils
from modules.models import GraphAttentionNet
from modules import losses
from modules.training import train_graph_attention_net
from modules import metrics
import matplotlib.pyplot as plt
import visualization as vis
from tqdm import tqdm

# Load data

In [None]:
data_folder = 'dataset/processed_data/'
dict_categories = np.load(os.path.join(data_folder, 'categories.npy'), allow_pickle=True).item()
n_categories = len(dict_categories['cat_relabel'])

# create SignDataset object
sign_dataset = SignDataset(data_folder, n_categories, 'train', max_data=370, augment_crop=True)
test_dataset = SignDataset(data_folder, n_categories, 'test', augment_crop=False)

# dataloader
sign_dataloader = datautils.GroupDataLoader(sign_dataset, batch_size=4, shuffle=True)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
class SignFeatExtractor(nn.Module):
    
    def __init__(self, dim_input):
        super(SignFeatExtractor, self).__init__()
        
        self.dim_input = dim_input
        
        self.conv1 = nn.Conv2d(7, 64, [5, 5], stride=2, padding=2)
        self.norm1 = nn.GroupNorm(4, 64)
        
        self.conv2 = nn.Conv2d(64, 128, [5, 5], stride=2, padding=2)
        self.norm2 = nn.GroupNorm(4, 128)
        
        self.conv3 = nn.Conv2d(128, 128, [5, 5], stride=2, padding=2)
        self.norm3 = nn.GroupNorm(4, 128)
        
        self.conv4 = nn.Conv2d(128, 128, [5, 5], stride=2, padding=2)
        self.norm4 = nn.GroupNorm(4, 128)
        
        self.conv5 = nn.Conv2d(128, 64, [16, 16], stride=2)
        self.norm5 = nn.GroupNorm(4, 64)
        
        self.leaky_relu_p = 0.2
        
    def forward(self, x):
        
        x = self.conv1(x)
        x = self.norm1(x)
        x = F.leaky_relu(x, self.leaky_relu_p)
        
        x = self.conv2(x)
        x = self.norm2(x)
        x = F.leaky_relu(x, self.leaky_relu_p)
        
        x = self.conv3(x)
        x = self.norm3(x)
        x = F.leaky_relu(x, self.leaky_relu_p)
        
        x = self.conv4(x)
        x = self.norm4(x)
        x = F.leaky_relu(x, self.leaky_relu_p)
        
        x = self.conv5(x)
        x = F.adaptive_avg_pool2d(x, 1).squeeze()
        
        return x
    
node_feature_extractor = SignFeatExtractor(7) # here, 7 is 3 color channels + 4 possible sign categories
with torch.no_grad():
    print(node_feature_extractor(torch.randn(5,7,313,256)).shape)

In [None]:
device = 'cuda'
model = GraphAttentionNet(2, 
                          dim_node_input=64, 
                          dim_edge_input=64,
                          dim_edge_output=1, 
                          n_heads=4, 
                          use_residual=True, 
                          use_norm=True, 
                          node_feature_extractor=node_feature_extractor)
model = model.to(device)


In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.0003)
# scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2], gamma=0.1)

In [None]:
losses_group_affinity = losses.SemisupervisedGroupClusteringLoss(['output_edge', 'gt_aff_mat', 'list_idx_label_group'], balance=True, fl_gamma=2)

# combine the loss into LossCollection object
loss_collection = losses.LossCollection()
loss_collection.add_loss("clustering", losses_group_affinity, 1.)

In [None]:
run_name = 'run_1'

In [None]:
n_epochs = 2000
freq_save_model = 10
it_model = 0

train_graph_attention_net(
        model, 
        sign_dataloader,
        loss_collection,
        optimizer,
        n_epochs,
        'tb_synth_{}/{:04d}/'.format(run_name, it_model),
        'model_save_{}/{:04d}/'.format(run_name, it_model),
        freq_save_model,
        device,
#         scheduler=scheduler
    )


# Test

In [None]:
with torch.no_grad():
    # pick an image to test
    sample = test_dataset[388]
    output_node, output_edge = model(None, sample=sample)
    oe_s = torch.sigmoid(output_edge)
    oe_s = (oe_s+oe_s.T)/2
    
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(oe_s.detach().cpu().numpy(), vmin=0, vmax=1)
axs[1].imshow(oe_s.detach().cpu().numpy() > 0.5, vmin=0, vmax=1)
axs[2].imshow(sample.gt_aff_mat, vmin=0, vmax=1)

In [None]:
output_clustering = model.infer_clusters(None, sample=sample)
output_block_labels = output_clustering['group_id']

In [None]:
image_fileloc = os.path.join('dataset/raw/images/',sample.details['filename'])
print(image_fileloc)
img_bgr = cv2.imread(image_fileloc)
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
bboxes = sample.details['segm']

fig, axs = plt.subplots(1, 1, figsize=(15, 15))
# vis.visualize_groups(axs, np.repeat(im_bw[:,:,None], 3, axis=2) , bboxes, output_block_labels, None)
img_rgb_tmp = (sample.node_feature[0][:3].permute([1,2,0]).detach().cpu().numpy())/2+0.5
vis.visualize_groups(axs, img_rgb_tmp, bboxes, output_block_labels, None)
axs.axis('off')

# Run test on test set

In [None]:
idx_data_test = np.arange(370, 412)

# dict for saving data
dict_clus_acc = {}

# loop thru test data
for it_data in tqdm(idx_data_test):
    with torch.no_grad():
        sample = test_dataset[it_data]
        output_clustering = model.infer_clusters(None, sample=sample)
        output_block_labels = output_clustering['group_id']
        
    clus_acc_it = metrics.compute_cluster_accuracy(sample.node_group_id, output_block_labels)
    dict_clus_acc[it_data] = clus_acc_it
    
avg_clus_acc = np.mean(list(dict_clus_acc.values()))
print("Average clustering acc: {:.4f}".format(avg_clus_acc))