# Imports

In [None]:
import mpose

In [None]:
dataset = mpose.MPOSE()

In [None]:
dataset.get_data(seq_id=True)

In [None]:
# Setup Notebook
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
# Imports
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)

from pathlib import Path

import matplotlib.pyplot as plt
from tqdm.notebook import tqdm as tqdm

import tensorflow as tf
#import tensorflow_addons as tfa


from utils.train import Trainer
from utils.distiller import Distiller
from utils.tools import *
from utils.preprocess import *
from utils.visualize import *
from utils.training_tools import *

IMN_STD = [0.229, 0.224, 0.225]
IMN_MEAN = [0.485, 0.456, 0.406]

In [None]:
config = read_yaml('utils/config.yaml')

config['ID'] = 0
config['SEED'] = 42
config['METHOD'] = 'KD'
config['UNISTYLE'] = True
config['WHITEN_LAYERS'] = []
config['DATA_PATH'] = '../AgriSeg_Dataset/'
model_root= Path('bin/Benchmark/')
label = 'KD_001'
model_path = model_root.joinpath(label)
weights = [model_path.joinpath(f) for f in os.listdir(model_path) if f.endswith('.h5')]
targets = ['pear', 'zucchini', 'vineyard_real', 'misc']

In [None]:
# Stop execution 
class StopExecution(Exception):
    def _render_traceback_(self):
        pass

In [None]:
# GPU setup
physical_devices = tf.config.list_physical_devices('GPU')
print("Num GPUs:", physical_devices)

#select the working GPU
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')

devices = []
for g in [0]:
    tf.config.experimental.set_memory_growth(gpus[g], True)
    devices.append(f'GPU:{g}')
#strategy = tf.distribute.MirroredStrategy(devices=devices)

# Test Function

In [None]:
# Utils
from utils.mobilenet_v3 import MobileNetV3Large
from utils.models import build_model_multi, build_model_binary
from pathlib import Path

def get_single_model(config, weights=None, feats=True, whiten=False, model_dir=None):
        
        whiten_layers = config['WHITEN_LAYERS'] if whiten \
                        and config['UNISTYLE'] \
                        and config['METHOD'] in ['KD'] else []
            
        backbone = MobileNetV3Large(
            input_shape=(config['IMG_SIZE'], config['IMG_SIZE'], 3),
            alpha=1.0,
            minimalistic=False,
            include_top=False,
            weights='imagenet',
            input_tensor=None,
            classes=config['N_CLASSES'],
            pooling='avg',
            dropout_rate=False,
            mode=config['METHOD'], p=config['PADAIN']['P'],
            eps=float(config['PADAIN']['EPS']),
            whiten_layers=whiten_layers,
            backend=tf.keras.backend, layers=tf.keras.layers, models=tf.keras.models, 
            utils=tf.keras.utils)
        
        pre_trained_model = backbone

        # binary segmentation model
        model = build_model_binary(pre_trained_model, False, config['N_CLASSES'], 
                                    sigmoid=config['LOSS']=='iou', mode=config['METHOD'],
                                    p=config['PADAIN']['P'], eps=float(config['PADAIN']['EPS']),
                                    return_feats=feats)
        
        if weights:
            model.load_weights(model_dir.joinpath(weights))
        model.trainable = True
        
        return model
 

def get_teacher(config):
    domains = [w for w in config['SOURCE'] if w != config['TARGET']]
    weights = [f'teachers/teacher_aug_{w}.h5' for w in domains]
    print(f'Loaded Teachers: {domains}')
    
    models = [get_single_model(config, w, feats=False, model_dir=Path('bin/')) for w in weights]
    
    model_input = tf.keras.Input(shape=(config['IMG_SIZE'], config['IMG_SIZE'], 3))
    model_outputs = [model(model_input) for model in models]
    ensemble_output = tf.keras.layers.Average()(model_outputs)
    teacher = tf.keras.Model(inputs=model_input, outputs=ensemble_output)
    return teacher

In [None]:
# Test function
def test_fn(config,
            model_path,
            targets,
            strategy=None, 
            ensemble=False):
    
    res = []
    for model_name in model_path:
        print(str(model_name))
        for t in targets:
            tf.keras.backend.clear_session()
            config['TARGET'] = t
            
            if config['METHOD'] != 'KD':
                trainer = Trainer(config, logger=None, strategy=strategy, test=True)
            else:
                trainer = Distiller(config, logger=None, strategy=strategy, test=True)

            # trainer.model.summary()
            if ensemble:
                trainer.moddel = get_teacher(config)
            else:
                trainer.model.load_weights(str(model_name))
            
            loss, metric = trainer.evaluate(trainer.ds_test, 'test')
            print(metric.numpy())
            print('')
            res.append(metric.numpy())
    return res

In [None]:
# Visualize function
def visualize_fn(config,
                 model_path,
                 targets,
                 strategy=None,
                 n=1,
                 conf=0.0,
                 soft=False,
                 save=False):
    ts = {}
    plt.rcParams['figure.figsize'] = [4, 4]
    
    for model_name in model_path:
        print(str(model_name))
#         for t in targets:
#             if t in str(model_name):
#                 config['TARGET'] = t
        config['TARGET'] = targets[0]
        tf.keras.backend.clear_session()
        trainer = Trainer(config, logger=None, strategy=strategy, test=True)
        trainer.model.load_weights(str(model_name))
        c = n
        for image, y in trainer.ds_test:
            i = tf.cast((image[0] * IMN_STD + IMN_MEAN) * 255.0, tf.uint8)
            plt.imshow(i, alpha=1.0) 
            plt.axis('off')
            if save:
                plt.savefig(f'./demo/Input_{config["TARGET"]}_{c}.pdf',bbox_inches='tight', pad_inches=0)  
            plt.show()
            plt.imshow(y[0], alpha=1.0) 
            plt.axis('off')
            if save:
                plt.savefig(f'./demo/GT_{config["TARGET"]}_{c}.pdf',bbox_inches='tight', pad_inches=0)   
            plt.show()   
            out = trainer.model.predict(image[:1], verbose=0)[0][0]
            out = tf.math.sigmoid(out)

            if conf:
                if soft:
                    plt.imshow(out*tf.cast(out>conf, tf.float32), alpha=1)
                else:
                    plt.imshow(out>conf, alpha=1)
            else:
                plt.imshow(out, alpha=1.)
            plt.axis('off')

            if save:
                plt.savefig(f'./demo/{label}_{config["TARGET"]}_{c}.pdf',bbox_inches='tight', pad_inches=0)
            
            plt.show()

            c -= 1
            if c < 0:
                break

In [None]:
for target in targets:
    print(f'{target}\n')
    res = test_fn(config,
                  weights,
                  [target],
                  strategy=None,
                  ensemble=False)
    print(f'{target}: {np.mean(res)}\n')

In [None]:
# Stop execution
raise StopExecution

In [None]:
# Visualize
ls = visualize_fn(config,
             weights,
             targets,
             strategy=None,
             conf=0.0,
             soft=False,
             save=True)

# Development

In [None]:
config = read_yaml('utils/config.yaml')
config['SEED'] = 0
config['NAME'] = 'test'
config['ID'] = 0
config['BATCH_SIZE'] = 1

config['TARGET'] = 'misc'
config['METHOD'] = 'XDED'
config['WHITEN_LAYERS'] = []
config['ERM_TEACHERS'] = False


model_name = 'bin/Benchmark/XDED_01_new/XDED_01_new_vineyard_real_new_XDED_5.h5'

In [None]:
IMN_STD = [0.229, 0.224, 0.225]
IMN_MEAN = [0.485, 0.456, 0.406]

def predict_some_samples(trainer, n=1, conf=0.0, save=False, mode=None):
    for image, _ in trainer.ds_test:
        if mode == 'KD':
            out = trainer.teacher.predict(image[:1], verbose=0)[0]
        else:
            out = trainer.model.predict(image[:1], verbose=0)[0][0]
        out = tf.math.sigmoid(out)#/trainer.config['KD']['T']) # 
        
        i = tf.cast((image[0] * IMN_STD + IMN_MEAN) * 255.0, tf.uint8)
        plt.imshow(i, alpha=1.) 
        plt.axis('off')
        if save:
            plt.savefig(f'./demo/in_{trainer.config["TARGET"]}_{n-1}.png', bbox_inches='tight')
        plt.show()
        if conf:
            plt.imshow(out>conf, alpha=1)
        else:
            plt.imshow(out, alpha=1)
        plt.axis('off')
        if save:
            plt.savefig(f'./demo/out_{trainer.config["TARGET"]}_{trainer.config["METHOD"]}_{n-1}.png', bbox_inches='tight')
    
        n -= 1
        if n <= 0:
            break

In [None]:
for config['TARGET'] in ['lettuce']:
    if config['METHOD'] != 'KD':
        trainer = Trainer(config, logger=None, strategy=None, test=True)
    else:
        trainer = Distiller(config, logger=None, strategy=None, test=False)  
    trainer.model.load_weights(model_name)
    predict_some_samples(trainer, n=9, conf=0., save=True, mode=None)
    #trainer.evaluate(trainer.ds_test, 'test')

# TFLIite Conversion

In [None]:
raise

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(trainer.model)
converter.experimental_new_converter = True
tflite_model = converter.convert()

name_model_tflite = 'lavanda.tflite'
tflite_model_file = Path(config['MODEL_PATH']).joinpath(name_model_tflite)                          
tflite_model_file.write_bytes(tflite_model)

In [None]:
# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="bin/lavanda.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

In [None]:
output_details[0]

# XDED

In [None]:
import tensorflow as tf

In [None]:
class pixelwise_XDEDLoss(tf.keras.losses.Loss):
    def __init__(self, temp_factor=2.0):
        super(pixelwise_XDEDLoss, self).__init__()
        self.temp_factor = temp_factor
        self.kl_div = tf.keras.losses.KLDivergence(reduction=tf.keras.losses.Reduction.SUM)
        self.CLASS_NUM = 1

    def xded_loss(self, input, target):
        
        loss = self.kl_div(tf.nn.softmax(input/self.temp_factor, axis=-1),
                           tf.nn.softmax(target/self.temp_factor, axis=-1)) * (self.temp_factor**2)/input.shape[0]
        return loss

    def call(self, main_out, gts):
        # main_out.shape : [batch, 1, 768, 768]
        # gts.shape : [batch, 768, 768]

        batch_size = main_out.shape[0]
        print(batch_size)
        flat_gts = tf.reshape(gts,[-1,1]) # [batch*768*768]
        flat_out = tf.reshape(main_out,(-1, self.CLASS_NUM))
        not_flat_out = not flat_out

        flat_targets = tf.reshape(main_out,(-1, self.CLASS_NUM))
        # [batch*768*768, 1]

        cur_gt_idx = flat_gts == 1 # [False, True, ...]
        not_cur_gt_idx = flat_gts == 0 # [True, False, ...]
        print(cur_gt_idx.shape)
        
        x = tf.boolean_mask(flat_out,cur_gt_idx)
        not_x = tf.boolean_mask(flat_out,not_cur_gt_idx)
        
        flat_targets = tf.reduce_mean(x) * tf.cast(cur_gt_idx,tf.float32)
        not_flat_targets = tf.reduce_mean(not_x) * tf.cast(not_cur_gt_idx,tf.float32)
        print(flat_out.shape, not_flat_out.shape)
        print(flat_targets.shape, not_flat_targets.shape)
        
        
        
        return self.xded_loss(flat_out, flat_targets)

In [None]:
ys = tf.random.uniform((64,224,224,1))*10
y  = tf.cast(tf.random.uniform((64,224,224,1),maxval=2,dtype=tf.int32), tf.float32)

In [None]:
print(np.min(ys), np.max(ys))
print(np.min(y), np.max(y))

In [None]:
@tf.function
def loss(y_pred,y):
    l = pixelwise_XDEDLoss()
    return l(y_pred,y)

In [None]:
loss(ys,y)

In [None]:
import torch
from torch.nn import functional as F

In [None]:
class pixelwise_XDEDLoss(torch.nn.Module):
    def __init__(self, temp_factor=2.0):
        super(pixelwise_XDEDLoss, self).__init__()
        self.temp_factor = temp_factor
        self.kl_div = torch.nn.KLDivLoss(reduction="sum")
        self.CLASS_NUM = 1

    def xded_loss(self, input, target):
        log_p = torch.log_softmax(input/self.temp_factor, dim=0)
        q = torch.softmax(target/self.temp_factor, dim=0)
        loss = self.kl_div(log_p, q)*(self.temp_factor**2)/input.size(0)
        return loss

    def forward(self, main_out, gts):
        # main_out.shape : [batch, 19, 768, 768]
        # gts.shape : [batch, 768, 768]

        batch_size = main_out.shape[0]

        flat_gts = gts.reshape(-1) # [batch*768*768]
        flat_out = main_out.reshape(-1, self.CLASS_NUM)

        flat_targets = main_out.clone().detach().reshape(-1, self.CLASS_NUM)
        # [batch*768*768, 19]

        flat_gt_set = flat_gts.unique().tolist()
        ensemble_dict= {}

        for f_gt in flat_gt_set:
            if f_gt == 255:
                continue
            cur_gt_idx = flat_gts == f_gt # [False, True, ...]
            flat_targets[cur_gt_idx, :] = flat_out[cur_gt_idx].mean(0).detach()

        return self.xded_loss(flat_out, flat_targets)

In [None]:
ys = torch.rand((64,1,224,224))
y = torch.randint(size=(64,224,224),high=2)

In [None]:
l = pixelwise_XDEDLoss()

In [None]:
l(ys,y)

# Test*

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import torch
import tensorflow as tf
from utils.instance_norm import _instance_norm_block
from utils.tools import read_yaml
from utils.data import load_multi_dataset, split_data, random_flip, random_resize_crop, random_jitter, random_grayscale

# GPU setup
physical_devices = tf.config.list_physical_devices('GPU')
print("Num GPUs:", physical_devices)
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
devices = []
for g in [0]:
    tf.config.experimental.set_memory_growth(gpus[g], True)
    devices.append(f'GPU:{g}')
    
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

config = read_yaml('utils/config.yaml')
config['TARGET'] = 'lettuce'
config['NAME'] = 'test'
config['ID'] = 0
config['STYLE_AUG'] = False
config['RND_FLIP'] = 0.0
config['RND_CROP'] = 1.0
config['RND_GREY'] = 0.0
config['RND_JITTER'] = 0.0
config['RND_JITTER_RNG'] = 0.0
data_dir = Path(config['DATA_PATH'])

def get_data(config, data_dir):

    target_dataset = data_dir.joinpath(config['TARGET'])
    source_dataset = sorted([data_dir.joinpath(d) for d in config['SOURCE'] if d != config['TARGET']])
    
    ds_source, ds_target = load_multi_dataset(source_dataset, target_dataset, config)
    ds_train, ds_val, ds_test = split_data(ds_source, ds_target, config)
    
    train_len = len(ds_train)
    ds_train = ds_train.cache()
    ds_train = ds_train.shuffle(train_len)
    ds_train = ds_train.map(lambda x, y: random_flip(x, y, p=config['RND_FLIP']), tf.data.experimental.AUTOTUNE)
    ds_train = ds_train.map(lambda x, y: random_resize_crop(x, y, min_p=config['RND_CROP']), tf.data.experimental.AUTOTUNE)
    if config['STYLE_AUG']:
        ds_train = ds_train.map(lambda x, y: random_jitter(x, y, p=config['RND_JITTER'], r=config['RND_JITTER_RNG']), tf.data.experimental.AUTOTUNE)
        ds_train = ds_train.map(lambda x, y: random_grayscale(x, y, p=config['RND_GREY']), tf.data.experimental.AUTOTUNE)
    ds_train = ds_train.batch(8, drop_remainder=True)
    ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

    val_len = 0

    if ds_test is not None:
        test_len = len(ds_test)
        # ds_test = ds_test.cache()
        ds_train = ds_train.shuffle(test_len)
        ds_test = ds_test.batch(8, drop_remainder=False)
        ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)
    else: 
        test_len = 0
        
    print(f'Loaded data: Train {train_len}, Val {val_len}, Test {test_len}')
    return ds_train, ds_val, ds_test

In [None]:
ds_train, ds_val, ds_test = get_data(config, data_dir)

In [None]:
for step, (x, y) in enumerate(ds_test, 1):
    # aug_x = _instance_norm_block(x, mode='KD_WCTA', training=True)
    for i in range(x.shape[0]):
        print(tf.reduce_min(x[i]), tf.reduce_max(x[i]), tf.reduce_mean(x[i]))
        plt.imshow(x[i]*std+mean)
        print(tf.reduce_min(x[i]*std+mean), tf.reduce_max(x[i]*std+mean), tf.reduce_mean(x[i]*std+mean))
        plt.show()
        # plt.imshow(aug_x[i]*std+mean)
        # plt.show()
        break
    break

In [None]:
from PIL import Image
import numpy as np

In [None]:
img = np.array(Image.open('../AgriSeg_Dataset/lettuce/lettuce_1/images/Image10/Image0001.png'))[:,:,:3]
img = img/255.0
img -= [0.485, 0.456, 0.406]
img /= [0.229, 0.224, 0.225]

In [None]:
img.min(), img.max(), img.mean()

In [1]:
from utils.mobilenet_v3_2 import MobileNetV3Large

model = MobileNetV3Large(
    input_shape=(224,224,3),
    alpha=1.0,
    minimalistic=False,
    include_top=False,
    weights="imagenet",
    input_tensor=None,
    classes=1,
    pooling='avg',
    dropout_rate=False,
    classifier_activation="softmax",
    include_preprocessing=True,
)

2024-02-29 05:41:38.542547: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2024-02-29 05:41:38.549656: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2024-02-29 05:41:38.549717: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2024-02-29 05:41:38.550031: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compi

In [3]:
model.summary()

Model: "MobilenetV3large"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 rescaling (Rescaling)          (None, 224, 224, 3)  0           ['input_1[0][0]']                
                                                                                                  
 Conv (Conv2D)                  (None, 112, 112, 16  432         ['rescaling[0][0]']              
                                )                                                                 
                                                                                   