# Imports

In [None]:
# Setup Notebook
%matplotlib inline
%load_ext autoreload
%autoreload 2
%cd ../

In [None]:
# Imports
import os, gc
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)
from pathlib import Path

import matplotlib.pyplot as plt
import tensorflow as tf
#import tensorflow_addons as tfa

from utils.train import Trainer
from utils.distiller import Distiller
from utils.tools import *
from utils.preprocess import *
from utils.visualize import *
from utils.training_tools import *

In [None]:
#select the working GPU
id = 4
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[id], 'GPU')
devices = []
for g in [id]:
    tf.config.experimental.set_memory_growth(gpus[g], True)
    devices.append(f'GPU:{g}')

In [None]:
cfg = read_yaml('cfg/cfg_2.yaml')
cfg['ID'] = 0
cfg['SEED'] = 42
cfg['METHOD'] = 'None'
cfg['UNISTYLE'] = False
cfg['WHITEN_LAYERS'] = []
cfg['TEST'] = True
cfg['WEIGHTS'] = None
cfg['IMG_SIZE_TEST'] = [224,224]
model_root= Path('bin/Benchmark/Test')
label = 'KD'
model_path = model_root.joinpath(label)
weights = [model_path.joinpath(f) for f in os.listdir(model_path) if f.endswith('.h5')] # new_hilr_chard_KD_1.h5
# targets = ['tree_2', 'chard', 'lettuce', 'vineyard'] #, 'pear', 'zucchini', 'vineyard_real', 'misc']
targets = ['freiburg']

In [None]:
# Stop execution 
class StopExecution(Exception):
    def _render_traceback_(self):
        pass

# Test Function

In [None]:
from utils.mobilenet_v3 import MobileNetV3Large
from utils.models import build_model_binary, build_model_multi
from pathlib import Path

def get_single_model(self, weights=None, feats=True, whiten=False):
    whiten_layers = self.cfg['WHITEN_LAYERS'] if whiten \
                    and self.cfg['UNISTYLE'] \
                    and self.cfg['METHOD'] in ['KD'] else []
    
    backbone = MobileNetV3Large(input_shape=(self.cfg['IMG_SIZE'], self.cfg['IMG_SIZE'], 3),
                                alpha=1.0,
                                minimalistic=False,
                                include_top=False,
                                weights='imagenet',
                                input_tensor=None,
                                classes=self.cfg['N_CLASSES'],
                                pooling='avg',
                                dropout_rate=False,
                                include_preprocessing=self.cfg['NORM']=='tf',
                                mode=self.cfg['METHOD'], p=self.cfg['PADAIN']['P'],
                                eps=float(self.cfg['PADAIN']['EPS']),
                                whiten_layers=whiten_layers,
                                wcta=self.cfg['WCTA'] if feats or 'wcta' in self.cfg['TEACHERS'] else False, 
                                backend=tf.keras.backend, layers=tf.keras.layers, models=tf.keras.models, 
                                utils=tf.keras.utils
                                )

    if self.cfg['CITYSCAPES']:
        pre_trained_model = build_model_multi(backbone, False, 20)
        pre_trained_model.load_weights(self.model_dir.joinpath('lr_aspp_pretrain_cityscapes.h5'))
    else:
        pre_trained_model = backbone
        
    if self.cfg['FREEZE_BACKBONE']:
        pre_trained_model.trainable = False

    # binary segmentation model
    model = build_model_binary(pre_trained_model, False, self.cfg['N_CLASSES'], 
                                sigmoid=self.cfg['LOSS']=='iou', mode=self.cfg['METHOD'],
                                p=self.cfg['PADAIN']['P'], eps=float(self.cfg['PADAIN']['EPS']),
                                fwcta=self.cfg['FWCTA'] if feats or 'fwcta' in self.cfg['TEACHERS'] else False,
                                return_feats=feats)
    
    if weights:
        model.load_weights(self.model_dir.joinpath(weights))
    
    del pre_trained_model
    del backbone
    gc.collect()
    
    return model

def get_teacher(self):
    domains = [w for w in self.cfg['SOURCE'] if w != self.cfg['TARGET']]
    if self.cfg['ERM_TEACHERS']:
        weights = [f'teachers/erm/teacher_{self.cfg["TARGET"]}.h5']
    else:
        weights = [f'teachers/{self.cfg["TEACHERS"]}/teacher_{w}.h5' for w in domains]
    print(f'Loaded Teachers: {domains}')
    
    models = [self.get_single_model(w, feats=False) for w in weights]
    
    model_input = tf.keras.Input(shape=(self.cfg['IMG_SIZE'], self.cfg['IMG_SIZE'], 3))
    model_outputs = [model(model_input) for model in models]
    self.teacher = tf.keras.Model(inputs=model_input, outputs=model_outputs)
    
    del models
    gc.collect()

In [None]:
# Test function
def test_fn(cfg,
            model_path,
            targets,
            strategy=None, 
            ensemble=False):
    
    res = []
    for model_name in model_path:
        print(str(model_name))
        for t in targets:
            # if t not in str(model_name):
            #     continue
            tf.keras.backend.clear_session()
            cfg['TARGET'] = t
            
            if cfg['METHOD'] != 'KD':
                trainer = Trainer(cfg, logger=None, strategy=strategy, test=True)
            else:
                trainer = Distiller(cfg, logger=None, strategy=strategy, test=True)

            # trainer.model.summary()
            if ensemble:
                trainer.model = get_teacher(cfg)
            else:
                trainer.model.load_weights(str(model_name))
            
            loss, metric = trainer.evaluate(trainer.ds_test, 'test')
            print(metric.numpy())
            print('')
            res.append(metric.numpy())
    return res

In [None]:
for target in targets:
    print(f'{target}\n')
    res = test_fn(cfg,
                  weights,
                  [target],
                  strategy=None,
                  ensemble=False)
    print(f'{target}: {np.mean(res)}\n')

In [None]:
# Stop execution
raise StopExecution

In [None]:
weights = list(map(weights.__getitem__, [1, 4, 7, 12]))

In [None]:
# Visualize function
def visualize_fn(cfg,
                 model_path,
                 targets,
                 strategy=None,
                 n=1,
                 conf=0.0,
                 soft=False,
                 save=False):
    ts = {}
    plt.rcParams['figure.figsize'] = [4, 4]
    
    for model_name in model_path:
        print(str(model_name))
        cfg['TARGET'] = None
        for t in targets:
            # if t in str(model_name):
            cfg['TARGET'] = t
            if cfg['TARGET'] is None:
                cfg['TARGET'] = targets
            tf.keras.backend.clear_session()
            trainer = Trainer(cfg, logger=None, strategy=strategy)
            trainer.model.load_weights(str(model_name))
            c = n
            for image, y in trainer.ds_test:
                i = tf.cast((image[0] + 0) * 1.0, tf.uint8)
                plt.imshow(i, alpha=1.0) 
                plt.axis('off')
                if save:
                    plt.savefig(f'./demo/Input_{cfg["TARGET"]}_{c}.pdf',bbox_inches='tight', pad_inches=0)  
                plt.show()
                plt.imshow(y[0], alpha=1.0) 
                plt.axis('off')
                if save:
                    plt.savefig(f'./demo/GT_{cfg["TARGET"]}_{c}.pdf',bbox_inches='tight', pad_inches=0)   
                plt.show()   
                out = trainer.model.predict(image[:1], verbose=0)[0][0]
                out = tf.math.sigmoid(out)

                if conf:
                    if soft:
                        plt.imshow(out*tf.cast(out>conf, tf.float32), alpha=1)
                    else:
                        plt.imshow(out>conf, alpha=1)
                else:
                    plt.imshow(out, alpha=1.)
                plt.axis('off')

                if save:
                    plt.savefig(f'./demo/{label}_{cfg["TARGET"]}_{c}.pdf',bbox_inches='tight', pad_inches=0)
                
                plt.show()

                c -= 1
                if c < 0:
                    break

In [None]:
# Visualize
ls = visualize_fn(
    cfg,
    weights,
    targets,
    strategy=None,
    conf=0.,
    soft=False,
    save=True)

# Development

In [None]:
cfg = read_yaml('utils/cfg.yaml')
cfg['SEED'] = 0
cfg['NAME'] = 'test'
cfg['ID'] = 0
cfg['BATCH_SIZE'] = 1

cfg['TARGET'] = 'misc'
cfg['METHOD'] = 'XDED'
cfg['WHITEN_LAYERS'] = []
cfg['ERM_TEACHERS'] = False


model_name = 'bin/Benchmark/XDED_01_new/XDED_01_new_vineyard_real_new_XDED_5.h5'

In [None]:
def predict_some_samples(trainer, n=1, conf=0.0, save=False, mode=None):
    for image, _ in trainer.ds_test:
        if mode == 'KD':
            out = trainer.teacher.predict(image[:1], verbose=0)[0]
        else:
            out = trainer.model.predict(image[:1], verbose=0)[0][0]
        out = tf.math.sigmoid(out)#/trainer.cfg['KD']['T']) # 
        
        i = tf.cast((image[0] + 1) * 127.5, tf.uint8)
        plt.imshow(i, alpha=1.) 
        plt.axis('off')
        if save:
            plt.savefig(f'./demo/in_{trainer.cfg["TARGET"]}_{n-1}.png', bbox_inches='tight')
        plt.show()
        if conf:
            plt.imshow(out>conf, alpha=1)
        else:
            plt.imshow(out, alpha=1)
        plt.axis('off')
        if save:
            plt.savefig(f'./demo/out_{trainer.cfg["TARGET"]}_{trainer.cfg["METHOD"]}_{n-1}.png', bbox_inches='tight')
    
        n -= 1
        if n <= 0:
            break

In [None]:
for cfg['TARGET'] in ['lettuce']:
    if cfg['METHOD'] != 'KD':
        trainer = Trainer(cfg, logger=None, strategy=None, test=True)
    else:
        trainer = Distiller(cfg, logger=None, strategy=None, test=False)  
    trainer.model.load_weights(model_name)
    predict_some_samples(trainer, n=9, conf=0., save=True, mode=None)
    #trainer.evaluate(trainer.ds_test, 'test')

# TFLIite Conversion

In [None]:
raise

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(trainer.model)
converter.experimental_new_converter = True
tflite_model = converter.convert()

name_model_tflite = 'lavanda.tflite'
tflite_model_file = Path(cfg['MODEL_PATH']).joinpath(name_model_tflite)                          
tflite_model_file.write_bytes(tflite_model)

In [None]:
# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="bin/lavanda.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

In [None]:
output_details[0]

# XDED

In [None]:
import tensorflow as tf

In [None]:
class pixelwise_XDEDLoss(tf.keras.losses.Loss):
    def __init__(self, temp_factor=2.0):
        super(pixelwise_XDEDLoss, self).__init__()
        self.temp_factor = temp_factor
        self.kl_div = tf.keras.losses.KLDivergence(reduction=tf.keras.losses.Reduction.SUM)
        self.CLASS_NUM = 1

    def xded_loss(self, input, target):
        
        loss = self.kl_div(tf.nn.softmax(input/self.temp_factor, axis=-1),
                           tf.nn.softmax(target/self.temp_factor, axis=-1)) * (self.temp_factor**2)/input.shape[0]
        return loss

    def call(self, main_out, gts):
        # main_out.shape : [batch, 1, 768, 768]
        # gts.shape : [batch, 768, 768]

        batch_size = main_out.shape[0]
        print(batch_size)
        flat_gts = tf.reshape(gts,[-1,1]) # [batch*768*768]
        flat_out = tf.reshape(main_out,(-1, self.CLASS_NUM))
        not_flat_out = not flat_out

        flat_targets = tf.reshape(main_out,(-1, self.CLASS_NUM))
        # [batch*768*768, 1]

        cur_gt_idx = flat_gts == 1 # [False, True, ...]
        not_cur_gt_idx = flat_gts == 0 # [True, False, ...]
        print(cur_gt_idx.shape)
        
        x = tf.boolean_mask(flat_out,cur_gt_idx)
        not_x = tf.boolean_mask(flat_out,not_cur_gt_idx)
        
        flat_targets = tf.reduce_mean(x) * tf.cast(cur_gt_idx,tf.float32)
        not_flat_targets = tf.reduce_mean(not_x) * tf.cast(not_cur_gt_idx,tf.float32)
        print(flat_out.shape, not_flat_out.shape)
        print(flat_targets.shape, not_flat_targets.shape)
        
        
        
        return self.xded_loss(flat_out, flat_targets)

In [None]:
ys = tf.random.uniform((64,224,224,1))*10
y  = tf.cast(tf.random.uniform((64,224,224,1),maxval=2,dtype=tf.int32), tf.float32)

In [None]:
print(np.min(ys), np.max(ys))
print(np.min(y), np.max(y))

In [None]:
@tf.function
def loss(y_pred,y):
    l = pixelwise_XDEDLoss()
    return l(y_pred,y)

# Test*

In [None]:
import os
os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda-11.8/extras/CUPTI/lib64:$LD_LIBRARY_PATH'
os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda-11.8/lib64:$LD_LIBRARY_PATH'

import tensorflow as tf

In [None]:
tf.sysconfig.get_build_info() 

# GPU setup
physical_devices = tf.config.list_physical_devices('GPU')
print("Num GPUs:", physical_devices)
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
devices = []
for g in [0]:
    tf.config.experimental.set_memory_growth(gpus[g], True)
    devices.append(f'GPU:{g}')

In [None]:
import os
os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda-11.0/extras/CUPTI/lib64:$LD_LIBRARY_PATH'
os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda-11.0/lib64:$LD_LIBRARY_PATH'

from pathlib import Path
import matplotlib.pyplot as plt
import tensorflow as tf
from utils.tools import read_yaml
from utils.data import load_multi_dataset, split_data, random_flip, random_resize_crop, random_jitter, random_grayscale

# GPU setup
physical_devices = tf.config.list_physical_devices('GPU')
print("Num GPUs:", physical_devices)
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
devices = []
for g in [0]:
    tf.config.experimental.set_memory_growth(gpus[g], True)
    devices.append(f'GPU:{g}')
    
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

cfg = read_yaml('utils/cfg.yaml')
cfg['TARGET'] = 'lettuce'
cfg['NAME'] = 'test'
cfg['ID'] = 0
cfg['STYLE_AUG'] = False
cfg['RND_FLIP'] = 0.0
cfg['RND_CROP'] = 1.0
cfg['RND_GREY'] = 0.0
cfg['RND_JITTER'] = 0.0
cfg['RND_JITTER_RNG'] = 0.0
data_dir = Path(cfg['DATA_PATH'])

def get_data(cfg, data_dir):

    target_dataset = data_dir.joinpath(cfg['TARGET'])
    source_dataset = sorted([data_dir.joinpath(d) for d in cfg['SOURCE'] if d != cfg['TARGET']])
    
    ds_source, ds_target = load_multi_dataset(source_dataset, target_dataset, cfg)
    ds_train, ds_val, ds_test = split_data(ds_source, ds_target, cfg)
    
    train_len = len(ds_train)
    ds_train = ds_train.cache()
    ds_train = ds_train.shuffle(train_len)
    ds_train = ds_train.map(lambda x, y: random_flip(x, y, p=cfg['RND_FLIP']), tf.data.experimental.AUTOTUNE)
    ds_train = ds_train.map(lambda x, y: random_resize_crop(x, y, min_p=cfg['RND_CROP']), tf.data.experimental.AUTOTUNE)
    if cfg['STYLE_AUG']:
        ds_train = ds_train.map(lambda x, y: random_jitter(x, y, p=cfg['RND_JITTER'], r=cfg['RND_JITTER_RNG']), tf.data.experimental.AUTOTUNE)
        ds_train = ds_train.map(lambda x, y: random_grayscale(x, y, p=cfg['RND_GREY']), tf.data.experimental.AUTOTUNE)
    ds_train = ds_train.batch(8, drop_remainder=True)
    ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

    val_len = 0

    if ds_test is not None:
        test_len = len(ds_test)
        # ds_test = ds_test.cache()
        ds_train = ds_train.shuffle(test_len)
        ds_test = ds_test.batch(8, drop_remainder=False)
        ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)
    else: 
        test_len = 0
        
    print(f'Loaded data: Train {train_len}, Val {val_len}, Test {test_len}')
    return ds_train, ds_val, ds_test

In [None]:
ds_train, ds_val, ds_test = get_data(cfg, data_dir)

In [None]:
for step, (x, y) in enumerate(ds_test, 1):
    # aug_x = _instance_norm_block(x, mode='KD_WCTA', training=True)
    for i in range(x.shape[0]):
        print(tf.reduce_min(x[i]), tf.reduce_max(x[i]), tf.reduce_mean(x[i]))
        plt.imshow(x[i]*std+mean)
        print(tf.reduce_min(x[i]*std+mean), tf.reduce_max(x[i]*std+mean), tf.reduce_mean(x[i]*std+mean))
        plt.show()
        # plt.imshow(aug_x[i]*std+mean)
        # plt.show()
        break
    break

In [None]:
from PIL import Image
import numpy as np

In [None]:
img = np.array(Image.open('../AgriSeg_Dataset/lettuce/lettuce_1/images/Image10/Image0001.png'))[:,:,:3]
img = img/255.0
img -= [0.485, 0.456, 0.406]
img /= [0.229, 0.224, 0.225]

In [None]:
img.min(), img.max(), img.mean()

In [None]:
%cd ../
import tensorflow as tf
import keras
from utils.mobilenet_v3 import MobileNetV3Large
from utils.models import build_model_binary

#select the working GPU
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[7], 'GPU')
devices = []
for g in [7]:
    tf.config.experimental.set_memory_growth(gpus[g], True)
    devices.append(f'GPU:{g}')

pre_trained_model = MobileNetV3Large(
    input_shape=(224,224,3),
    alpha=1.0,
    minimalistic=False,
    include_top=False,
    weights="imagenet",
    input_tensor=None,
    classes=1,
    pooling='avg',
    dropout_rate=False,
    include_preprocessing=True,
    mode="KD", p=0, eps=1e-5, whiten_layers=[], wcta=True, fwcta=True, 
    backend=keras.backend, layers=keras.layers, models=keras.models, utils=keras.utils
    )

pre_trained_model.trainable = False

model = build_model_binary(pre_trained_model, False, 1, 
    sigmoid=False, mode="KD",
    p=0, eps=1e-5, return_feats=False)

In [None]:
for l in model.layers:
    print(l.name, l.trainable)

In [None]:
model(tf.random.normal((1,224,224,3)), training=False)

# Gradient Filtering

In [None]:
import tensorflow as tf

#select the working GPU
id = 0
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[id], 'GPU')
devices = []
for g in [id]:
    tf.config.experimental.set_memory_growth(gpus[g], True)
    devices.append(f'GPU:{g}')

In [None]:
y = tf.constant([[[[1, 0, 0], [0, 1, 1], [0, 0, 1]]], 
                 [[[1, 0, 0], [0, 1, 1], [0, 0, 1]]]])
print(y)
pred_t = tf.constant([[[[0.6, 0.8, 0.2], [0.3, 0.1, 0.9991], [0.7, 0.0, 0.8]]], 
                      [[[0.6, 0.8, 0.2], [0.3, 0.1, 0.9999], [0.7, 0.0, 0.8]]]])
print(pred_t)
aux_loss = tf.constant([[[[0.4, 0.8, 0.2], [0.3, 0.9, 0.1], [0.7, 0.0, 0.2]]], 
                        [[[0.4, 0.8, 0.2], [0.3, 0.9, 0.1], [0.7, 0.0, 0.2]]]])
print(aux_loss)

In [None]:
def loss_filter(p, n=0.999):
    a = tf.cast(tf.where(p <= n, 1, 0), bool)
    b = tf.cast(tf.where(p > (1+n)/2, 1, 0), bool)
    c = tf.logical_not(tf.logical_or(a, b))
    # print(a, b, c)
    o = tf.cast(tf.where(a, 1, 0), tf.float32) + tf.cast(tf.where(c, ((n+1-2*p) / (1-n)) ** 2, 0), tf.float32)
    # print(o)
    return o

In [None]:
loss_filter(pred_t)

In [None]:
# Confidence

def loss_filter(p, n=0.999):
    if p <= n:
        return 1
    elif p > (1+n)/2:
        return 0
    else:
        return ((n+1-2*p) / (1-n)) ** 2

old_loss = tf.reduce_mean(aux_loss)
print(old_loss)
shape = tf.shape(pred_t)
w = tf.map_fn(loss_filter, tf.reshape(pred_t, [-1]), dtype=tf.float32)
w = tf.reshape(w, shape)
print(w)
aux_loss = aux_loss * w
print(aux_loss)
aux_loss = tf.reduce_mean(aux_loss)
print(aux_loss)

In [None]:
# Error

pred_t_bin = tf.math.greater(pred_t, tf.constant([0.5]))
print(pred_t_bin)
old_loss = tf.reduce_mean(aux_loss)
print(old_loss)
mask = tf.equal(pred_t_bin, tf.cast(y, tf.bool))
print(mask)
aux_loss = aux_loss * tf.cast(mask, tf.float32)
print(aux_loss)
n = tf.math.count_nonzero(aux_loss)
print(n)
aux_loss = tf.reduce_sum(aux_loss) / tf.cast(n, tf.float32) if n > 0 else 0.0
print(aux_loss)

# Teacher Soup

In [None]:
# Setup Notebook
%matplotlib inline
%load_ext autoreload
%autoreload 2
%cfg Completer.use_jedi = False
%cd ../

In [None]:
import gc
from pathlib import Path
import tensorflow as tf
from utils.tools import read_yaml
from utils.mobilenet_v3 import MobileNetV3Large
from utils.models import build_model_binary
from utils.training_tools import uniform_soup

#select the working GPU
id = 0
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[id], 'GPU')
devices = []
for g in [id]:
    tf.config.experimental.set_memory_growth(gpus[g], True)
    devices.append(f'GPU:{g}')

In [None]:
class Tester():
    def __init__(self, cfg):
        self.cfg = cfg
        self.model_dir = Path(cfg['MODEL_PATH'])
        self.model = None
        self.teacher = None

    def get_student(self):
        if self.model is not None:
            return
        self.model = self.get_single_model(whiten=True)
        
        
    def get_single_model(self, weights=None, feats=True, whiten=False):
        whiten_layers = self.cfg['WHITEN_LAYERS'] if whiten \
                        and self.cfg['UNISTYLE'] \
                        and self.cfg['METHOD'] in ['KD'] else []
            
        backbone = MobileNetV3Large(input_shape=(self.cfg['IMG_SIZE'], self.cfg['IMG_SIZE'], 3),
                                    alpha=1.0,
                                    minimalistic=False,
                                    include_top=False,
                                    weights='imagenet',
                                    input_tensor=None,
                                    classes=self.cfg['N_CLASSES'],
                                    pooling='avg',
                                    dropout_rate=False,
                                    include_preprocessing=self.cfg['NORM']=='tf',
                                    mode=self.cfg['METHOD'], p=self.cfg['PADAIN']['P'],
                                    eps=float(self.cfg['PADAIN']['EPS']),
                                    whiten_layers=whiten_layers,
                                    wcta=self.cfg['WCTA'] if feats or 'wcta' in self.cfg['TEACHERS'] else False, 
                                    backend=tf.keras.backend, layers=tf.keras.layers, models=tf.keras.models, 
                                    utils=tf.keras.utils
                                    )

        if self.cfg['CITYSCAPES']:
            pre_trained_model = build_model_multi(backbone, False, 20)
            pre_trained_model.load_weights(self.model_dir.joinpath('lr_aspp_pretrain_cityscapes.h5'))
        else:
            pre_trained_model = backbone
            
        if self.cfg['FREEZE_BACKBONE']:
            pre_trained_model.trainable = False

        # binary segmentation model
        model = build_model_binary(pre_trained_model, False, self.cfg['N_CLASSES'], 
                                    sigmoid=self.cfg['LOSS']=='iou', mode=self.cfg['METHOD'],
                                    p=self.cfg['PADAIN']['P'], eps=float(self.cfg['PADAIN']['EPS']),
                                    fwcta=self.cfg['FWCTA'] if feats or 'fwcta' in self.cfg['TEACHERS'] else False,
                                    return_feats=feats)
        
        if weights:
            model.load_weights(self.model_dir.joinpath(weights))
        
        del pre_trained_model
        del backbone
        gc.collect()
        
        return model

    def get_teacher(self):
        domains = [w for w in self.cfg['SOURCE'] if w != self.cfg['TARGET']]
        if self.cfg['ERM_TEACHERS']:
            weights = [f'teachers/erm/teacher_{self.cfg["TARGET"]}.h5']
        else:
            weights = [f'teachers/{self.cfg["TEACHERS"]}/teacher_{w}.h5' for w in domains]
        print(f'Loaded Teachers: {domains}')
        
        models = [self.get_single_model(w, feats=False) for w in weights]
        
        if self.cfg['SOUP']:
            # average teacher weights
            self.model = uniform_soup(self.get_single_model(feats=True), [self.model_dir.joinpath(w) for w in weights])

        model_input = tf.keras.Input(shape=(self.cfg['IMG_SIZE'], self.cfg['IMG_SIZE'], 3))
        model_outputs = [model(model_input) for model in models]
        # ensemble_output = tf.keras.layers.Average()(model_outputs)
        # self.teacher = tf.keras.Model(inputs=model_input, outputs=ensemble_output)
        self.teacher = tf.keras.Model(inputs=model_input, outputs=model_outputs)
        
        del models
        gc.collect()

In [None]:
cfg = read_yaml('cfg/cfg_3.yaml')
cfg['TARGET'] = 'tree_2'
cfg['NAME'] = 'test'
cfg['METHOD'] = 'KD'
cfg['ID'] = 0
cfg['ERM_TEACHERS'] = False
cfg['TEST'] = False
cfg['TEACHERS'] = f"{cfg['NORM']}_{'style' if cfg['STYLE_AUG'] else 'geom'}" + \
                     f"{'_wcta' if cfg['WCTA'] else ''}" if cfg['TEACHERS'] is None else cfg['TEACHERS']

trainer = Tester(cfg)

In [None]:
trainer.get_teacher()
trainer.get_student()

In [None]:
import glob

list_1 = glob.glob(r"/ssd1/sa58728/AgriSeg_Dataset/misc/misc_1/images/*.jpg")
list_2 = glob.glob(r"/ssd1/sa58728/AgriSeg_Dataset/misc/misc_1/images/*.png")
list_3 = glob.glob(r"/ssd1/sa58728/AgriSeg_Dataset/misc/misc_1/images/*")
len(list_1), len(list_2), len(list_3)

In [None]:
from PIL import Image

i = Image.open('/ssd1/sa58728/AgriSeg_Dataset/misc/misc_1/images/image_5.jpg')
i.save('/ssd1/sa58728/AgriSeg_Dataset/misc/misc_1/images/image_5.jpg', "jpeg")

In [None]:
import os
for f in list_1:
    os.rename(f, f.replace('.jpeg', '.jpg'))

In [None]:
from pathlib import Path
import imghdr

data_dir = "/ssd1/sa58728/AgriSeg_Dataset/misc/misc_1/images/"
image_extensions = [".png", ".jpg", ".jpeg"]  # add there all your images file extensions

img_type_accepted_by_tf = ["bmp", "gif", "jpeg", "png", "jpg"]
for filepath in Path(data_dir).rglob("*"):
    if filepath.suffix.lower() in image_extensions:
        img_type = imghdr.what(filepath)
        if img_type is None:
            print(f"{filepath} is not an image")
        elif img_type not in img_type_accepted_by_tf:
            print(f"{filepath} is a {img_type}, not accepted by TensorFlow")
        # else:
        #     print(f"{filepath} is a {img_type}, accepted by TensorFlow")