# Check GPUs

In [None]:
!nvidia-smi

# Load libraries

In [None]:
import warnings
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore", category=DeprecationWarning) 
warnings.filterwarnings("ignore", category=FutureWarning)
import shutup; shutup.please()

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns  # for heatmaps
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
%matplotlib inline

import os
import time
import pathlib
import random
from tqdm import tqdm
import umap

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter

In [None]:
from data.load_data import *
from data import *
from data.transform.utils import *

from simmim.vision_transformer import ViT
from simmim.simmim import SimMIM
from pretrain import *


from simmim.optimizer import build_pretrain_optimizer, build_finetune_optimizer
from simmim.lr_scheduler import build_scheduler

# Fix seed

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seed = 42

seed_everything(seed)

# Load data

In [None]:
data_type =  'spectrogram'            # 'spectrogram', 'time-series'
num_workers = 4
y_sampling = None                # option: None,'oversampling','undersampling'
activities = []
sampling = 'weight'
batch_size = 64
namings = ['exp_15_pwr_spectrograms', 'exp_10_amp_spec_only_STFT', 'exp_11_phdiff_spec_only_STFT']#, 'MarkovTransitionField', 'exp_7_amp_spec_only']#, 'exp_9_phdiff_spec_only']

In [None]:
multimodal_data = import_multiple_modalities(data_type = data_type, namings = namings)

In [None]:
data_type = 'multimodal_spectrogram'
views = 'associated'
axis = 3

if data_type == 'multimodal_spectrogram':
    X_train, X_test, y_train, y_test = split_multimodal_data(multimodal_data, views = views, axis = axis)

    X_train, X_test, y_train, y_test, lb = filtering_activities_and_label_encoding(X_train, X_test, y_train, y_test, 
                                                                                       activities)
del multimodal_data

In [None]:
_, valid_loader, _ = combine1(X_train, X_test, y_train, y_test, 
                                                    sampling, lb, batch_size, num_workers,
                                                    y_sampling='None')
pretrain_set = DataLoader(
    X_train,
    batch_size       = 64,
    shuffle          = True,
    drop_last        = True,
    num_workers      = 4
)

# Visualise modalities

In [None]:
if (data_type == 'multimodal_spectrogram'):
    for i in range(int(pretrain_set.dataset.shape[-1]/224)):
        plt.figure()
        plt.imshow(pretrain_set.dataset[0][0][0:224, 224*i:(224*(i+1))], cmap = 'jet', origin='lower')

# Build model

In [None]:
from models.hybridvit import *
from simmim.simmim_cnn import *

In [None]:
img_size = (224, pretrain_set.dataset[0].shape[2])
patch_size = 224 # [32,32]  [16,16]
in_channels = 1
num_classes = 6
dim = 512
depth = 3
n_heads = 4
mlp_dim = 512
dropout = 0.1
emb_dropout = 0.1
n_filter_list = [1, 16, 32, 64]
seq_pool = False
positional_embedding = True

# Training settings
epochs = 500
lr = 5e-4
multi_gpus = False
weight_decay = 0.05
network = 'hyb'
exp_name = 'SiMMiM_STFT_PWR_0.6_masking'

def get_model(img_size, patch_size, in_channels, num_classes, dim, depth, 
              n_heads, mlp_dim, dropout, emb_dropout, n_filter_list, 
              seq_pool, positional_embedding, network):
    
    if network == 'hyb':
        model  = HybridViT(
                image_size = img_size, 
                patch_size = patch_size, 
                num_classes = num_classes, 
                dim = dim, 
                depth = depth, 
                heads = n_heads,
                mlp_dim = mlp_dim,
                channels = in_channels,
                dropout = dropout,
                n_filter_list = n_filter_list,
                emb_dropout = emb_dropout,
                seq_pool = seq_pool, 
                positional_embedding = positional_embedding
            )
    if network == 'vit':
        model = ViT(
            image_size = img_size, 
            patch_size = patch_size, 
            num_classes = num_classes, 
            dim = dim, 
            depth = depth, 
            heads = n_heads,
            mlp_dim = mlp_dim,
            channels = in_channels,
            dropout = dropout
        )

    mim = SimMIM(
        encoder = model,
        masking_ratio = 0.6  # they found 50% to yield the best results
    )
    
    return mim

In [None]:
simmim = get_model(img_size, patch_size, in_channels, num_classes, dim, depth, 
                   n_heads, mlp_dim, dropout, emb_dropout, n_filter_list, 
                   seq_pool, positional_embedding, network)

# Pre-training phase

In [None]:
seed_everything(seed)

simmim = get_model(img_size, patch_size, in_channels, num_classes, dim, depth, 
                   n_heads, mlp_dim, dropout, emb_dropout, n_filter_list, 
                   seq_pool, positional_embedding, network)

#Run or multiple GPUs
if multi_gpus == True:
    simmim = nn.DataParallel(model, list(range(torch.cuda.device_count())), output_device = 0)
    simmim = simmim.to(f'cuda:{model.device_ids[0]}')
else:
    simmim = simmim.to(device)

optimizer = build_pretrain_optimizer(1e-8, (0.9, 0.999), lr, weight_decay, simmim)

lr_scheduler = build_scheduler(scheduler = 'multistep', num_epochs = epochs, warmup_epochs = 10, optimizer = optimizer, 
                               num_batches = len(pretrain_set), decay_rate = 0.1, decay_epochs = 30)

simmim, record = pretrain(simmim, optimizer, lr_scheduler, epochs, pretrain_set, valid_loader, device, exp_name, lb, embedding = 'no')

In [None]:
import shutil
shutil.rmtree(f'{source_dir}/results/saved_models/pretrain/' + exp_name)
shutil.rmtree(f'{source_dir}/logs/pretrain/' + exp_name)

# Unsupervised deep clustering

# Evaluation of the pre-trained model without fine-tuning

In [None]:
model_parameters = filter(lambda p: p.requires_grad, simmim.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)

In [None]:
plt.figure(figsize=(10,5))
plt.plot(np.arange(1,epochs+1),record['train_loss'])
plt.plot(np.arange(1,epochs+1),record['val_loss'])
plt.legend(['Training Loss' , 'Validation Loss'])
plt.ylabel('Loss')
plt.xlabel('Epochs')
plt.rcParams.update({'font.size': 16})
plt.show()

In [None]:
seed_everything(seed)

model = simmim.encoder
model_dir = f'{source_dir}/results/saved_models/pretrain/' + exp_name + '/'
model.load_state_dict(torch.load(model_dir + os.listdir(model_dir)[0]), strict = False)

if multi_gpus == True:
    cmtx,cls = evaluation(model.module.cpu(), valid_loader, label_encoder = lb)
else:
    cmtx,cls = evaluation(model.cpu(), valid_loader, label_encoder = lb)
    
df = ( cmtx.div( cmtx.sum(1).tolist(),axis=0)).round(2)
df.columns = df.columns.str.replace(r'predict :', '')
df.index    = df.index.str.replace(r'actual:', '')

CMAP = 'Blues'
FMT = 'g'
plt.figure(figsize=(20,10))
sns.heatmap(df,cmap=CMAP,annot=True, fmt=FMT)
plt.title('')
plt.rcParams.update({'font.size': 22})
plt.xlabel('Predicted')
plt.ylabel('True labels')
plt.show()

# Deep Clustering

In [None]:
simmim = get_model(img_size, patch_size, in_channels, num_classes, dim, depth, 
                   n_heads, mlp_dim, dropout, emb_dropout, n_filter_list, 
                   seq_pool, positional_embedding, network)

model = simmim.encoder

model_dir = f'{source_dir}/results/saved_models/pretrain/' + exp_name + '/'

model.load_state_dict(torch.load(model_dir + os.listdir(model_dir)[0]), strict = False)

model.mlp_head = nn.Identity()

from sklearn.cluster import KMeans, SpectralClustering
from scipy.optimize import linear_sum_assignment as linear_assignment

embed = model(valid_loader.dataset[:][0].cpu()).cpu()

hle = umap.UMAP(
    random_state=0,
    metric= 'euclidean',
    n_components = 25,
    n_neighbors = 10,
    min_dist= 0.0).fit_transform(embed.detach().numpy())

sc = SpectralClustering(
            n_clusters= 6,
            random_state=42,
            affinity='nearest_neighbors')
y_pred = sc.fit_predict(hle)

y_true = valid_loader.dataset[:][1].detach().numpy()

def acc(y_true, y_pred):
    """
    Calculate clustering accuracy.

    # Arguments
        y: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`

    # Return
        accuracy, in [0,1]
    """
    y_true = y_true.astype(np.int64)
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    ind = linear_assignment(w.max() - w)
    ind = np.transpose(np.asarray(ind))
    return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 

acc(y_true, y_pred)

# Fine-tuning

In [None]:
# Training settings
simmim = get_model(img_size, patch_size, in_channels, num_classes, dim, depth, 
                   n_heads, mlp_dim, dropout, emb_dropout, n_filter_list, 
                   seq_pool, positional_embedding, network)

model_parameters = filter(lambda p: p.requires_grad, simmim.encoder.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)

model_parameters = pd.DataFrame({
    'img_size'      : [img_size],
    'patch_size'    : [patch_size], 
    'in_channels'   : in_channels,
    'num_classes'   : num_classes,
    'dim'           : dim,
    'depth'         : depth,
    'n_heads'       : n_heads,
    'mlp_dim'       : mlp_dim
})

parameters = {
    'num_parameters' : [params],
    'learning rate' : [lr],
    'optimizer' : ['AdamW'],
    'Weight decay': [str(0.01)],
    'Scheduler' : ['StepLR'],
    }

num_parameters = pd.DataFrame(parameters)


my_yticks = ['PWR1', 'PWR2', 'PWR3', 'AMP_STFT_NUC1', 'AMP_STFT_NUC2', 'PHDIFF_STFT_NUC1', \
             'PHDIFF_STFT_NUC2', 'MTF_DIFF_N1', 'MTF_DIFF_N2', 'MTF_DWT_N1', 'MTF_DWT_N2', \
             'AMP_SCAL_N1', 'AMP_SCAL_N2', 'AMP_DIFF_N1', 'AMP_DIFF_N2']

modalities = {}
for i, modality in enumerate(my_yticks[0:15 ]):
    modalities['modality ' + str(i)] = [modality]
data = pd.DataFrame(modalities)

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
from finetune import finetune, evaluation, cmtx_table, save_model, record_log

In [None]:
simmim = get_model(img_size, patch_size, in_channels, num_classes, dim, depth, 
                   n_heads, mlp_dim, dropout, emb_dropout, n_filter_list, 
                   seq_pool, positional_embedding, network).to(device)
model = simmim.encoder
del simmim
model_dir = f'{source_dir}/results/saved_models/pretrain/' + exp_name + '/'

model.load_state_dict(torch.load(model_dir + os.listdir(model_dir)[0]), strict = False)

In [None]:
##################################################### Lab-Finetuning-phase #####################################################

epochs = 50 # 200
layer_decay = 0.1
base_lr = 1e-3
eps = 1e-8
betas = (0.9, 0.999)
weight_decay = 0.05
depth = depth
warmup_epochs = 10 # 10
decay_rate = None
decay_epochs = 20 # 10
dropout = 0.1
emb_dropout = 0.1

# sampling condition
samplings = [1,
             int(0.025 * len(X_train) // 6),
             int(0.05 * len(X_train) // 6),
             int(0.10 * len(X_train) // 6),
             int(0.15 * len(X_train) // 6),
             int(0.20 * len(X_train) // 6),
             #'weight'  ## weight = full training labels used  
             ]

size_train_exp_name = ['1_img_per_class', '0.025', '0.05', '0.10', '0.15', '0.20'] # , 'all'

if not os.path.isdir(f'{source_dir}/results/saved_models/finetune/' + exp_name):
    os.mkdir(f'{source_dir}/results/saved_models/finetune/' + exp_name)
    os.mkdir(f'{source_dir}/results/records/finetune/' + exp_name)
    os.mkdir(f'{source_dir}/logs/finetune/' + exp_name)

for i, sampling in enumerate(samplings):
    seed_everything(seed)
    
    simmim = get_model(img_size, patch_size, in_channels, num_classes, dim, depth, 
                       n_heads, mlp_dim, dropout, emb_dropout, n_filter_list, 
                       seq_pool, positional_embedding, network).to(device)
    
    model = simmim.encoder
    del simmim
    model_dir = f'{source_dir}/results/saved_models/pretrain/' + exp_name + '/'
    
    model.load_state_dict(torch.load(model_dir + os.listdir(model_dir)[0]), strict = False)

    print('\n\nsampling: ', sampling, '\n')
    
    exp_name_ft = exp_name + '/' + size_train_exp_name[i]
    
    if sampling != 'weight':
        if sampling < 20:
            batch_size = 16
        if sampling > 20:
            batch_size = 32
    else:
        batch_size = 64
        
    # create dataloader class 
    lab_finetune_loader, lab_validatn_loader, class_weight = combine1(X_train, X_test,
                                                                      y_train, y_test, 
                                                                      sampling, lb, batch_size, num_workers, 
                                                                      y_sampling = y_sampling)
    
    print("class: ", lb.classes_)
    print("class_size: ", 1 - class_weight)

        
    if len(lab_finetune_loader.dataset) // batch_size >=1:
        n_batches = len(lab_finetune_loader.dataset) // batch_size
    else:
        n_batches = 1
        
    # criterion
    criterion = nn.CrossEntropyLoss().to(device)      

    # optimizer
    optimizer = build_finetune_optimizer(layer_decay = layer_decay, base_lr = base_lr, epsilon = eps, 
                             betas = betas, depth = depth, weight_decay = weight_decay, model = model)
    # lr_scheduler
    lr_scheduler = build_scheduler(scheduler = 'cosinelr', num_epochs = epochs, warmup_epochs = warmup_epochs, 
                                   optimizer = optimizer, num_batches = n_batches, decay_rate = decay_rate, 
                                   decay_epochs = decay_epochs)

    model, record = finetune(model, criterion, lr_scheduler, optimizer, epochs, lab_finetune_loader, lab_validatn_loader, device, exp_name_ft, lb, embedding = 'no')
    
    ################################### SAVE RESULTS ################################################
    
    model_dir = f'{source_dir}/results/saved_models/finetune/' + exp_name_ft + '/'
    model.load_state_dict(torch.load(model_dir + os.listdir(model_dir)[0]), strict = False)
    
    # finetuning 
    
    cmtx, cls = evaluation(model, lab_validatn_loader, label_encoder = lb)        
    
    metrics = pd.DataFrame({'accuracy': [cls['accuracy'][0]], 'precision': [cls['macro avg']['precision']], 
                        'recall': [cls['macro avg']['recall']], 'f1-score': [cls['macro avg']['f1-score']]})
    
    record_log(exp_name_ft, metrics, model_parameters, data, num_parameters)
    
    ######################################################################################################
    
    del model, criterion, optimizer, record, cmtx, cls, lr_scheduler
    del lab_finetune_loader, lab_validatn_loader
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
del model, criterion, optimizer, lr_scheduler
del lab_finetune_loader, lab_validatn_loader

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
import shutil
shutil.rmtree(f'{source_dir}/results/saved_models/finetune/' + exp_name)
shutil.rmtree(f'{source_dir}/logs/finetune/' + exp_name)
shutil.rmtree(f'{source_dir}/results/records/finetune/' + exp_name)