<a href="https://colab.research.google.com/github/aj1365/3DUNetGSFormer/blob/main/3DUNetGSFormer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from tensorflow import keras
import numpy as np
from keras.layers import Conv2D, Conv3D, Flatten, Dense, Reshape, BatchNormalization, MaxPool2D
from keras.layers import Dropout, Input
from keras.models import Model
from tensorflow.keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
from keras.utils import np_utils

from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score

from operator import truediv

from plotly.offline import init_notebook_mode

import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
import os
import spectral

init_notebook_mode(connected=True)
%matplotlib inline

In [None]:
from tensorflow.python.client import device_lib
device_lib.list_local_devices()

In [None]:
def loadData(name):
    
    data_path = os.path.join(os.getcwd(),'Data/')
   
    if name == 'SA1':
        
        data = sio.loadmat(os.path.join(data_path, 'Avalon.mat'))['Avalon']
        labels = sio.loadmat(os.path.join(data_path, 'Avalon_gt.mat'))['Avalon_gt']
    if name == 'SA2':
        
        data = sio.loadmat(os.path.join(data_path, 'GFall.mat'))['GFall']
        labels = sio.loadmat(os.path.join(data_path, 'GFall_gt.mat'))['GFall_gt']
    if name == 'SA3':
        
        data = sio.loadmat(os.path.join(data_path, 'GMorne.mat'))['GMorne']
        labels = sio.loadmat(os.path.join(data_path, 'GMorne_gt.mat'))['GMorne_gt']
    
    return data, labels

In [None]:
## GLOBAL VARIABLES
test_ratio = 0.9
windowSize = 8

In [None]:
def splitTrainTestSet(X, y, testRatio, randomState=345):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=testRatio, random_state=randomState,
                                                        stratify=y)
    return X_train, X_test, y_train, y_test

In [None]:
def applyPCA(X, numComponents=75):
    newX = np.reshape(X, (-1, X.shape[2]))
    pca = PCA(n_components=numComponents, whiten=True)
    newX = pca.fit_transform(newX)
    newX = np.reshape(newX, (X.shape[0],X.shape[1], numComponents))
    return newX, pca

In [None]:
def padWithZeros(X, margin=2):
    newX = np.zeros((X.shape[0] + 2 * margin, X.shape[1] + 2* margin, X.shape[2]))
    x_offset = margin
    y_offset = margin
    newX[x_offset:X.shape[0] + x_offset, y_offset:X.shape[1] + y_offset, :] = X
    return newX

In [None]:
def createImageCubes(X, y, windowSize=8, removeZeroLabels = True):
    margin = int((windowSize) / 2)
    zeroPaddedX = padWithZeros(X, margin=margin)
    # split patches
    patchesData = np.zeros((X.shape[0] * X.shape[1], windowSize, windowSize, X.shape[2]))
    patchesLabels = np.zeros((X.shape[0] * X.shape[1]))
    patchIndex = 0
    for r in range(margin, zeroPaddedX.shape[0] - margin):
        for c in range(margin, zeroPaddedX.shape[1] - margin):
            patch = zeroPaddedX[r - margin:r + margin , c - margin:c + margin ]   
            patchesData[patchIndex, :, :, :] = patch
            patchesLabels[patchIndex] = y[r-margin, c-margin]
            patchIndex = patchIndex + 1
    if removeZeroLabels:
        patchesData = patchesData[patchesLabels>0,:,:,:]
        patchesLabels = patchesLabels[patchesLabels>0]
        patchesLabels -= 1
    return patchesData, patchesLabels

In [None]:
dataset = 'SA1'
X1 , Y1 = loadData(dataset)
#X[X>100000]=-1
X1[np.isnan(X1)]=-1
X1[X1<-1000]=-1


In [None]:
dataset = 'SA2'
X2 , Y2 = loadData(dataset)
#X[X>100000]=-1
X2[np.isnan(X2)]=-1
X2[X2<-1000]=-1

In [None]:
dataset = 'SA3'
X3 , Y3 = loadData(dataset)
#X[X>100000]=-1
X3[np.isnan(X3)]=-1
X3[X3<-1000]=-1

In [None]:
X1, Y1 = createImageCubes(X1, Y1, windowSize=windowSize)
X1.shape, Y1.shape


X2, Y2 = createImageCubes(X2, Y2, windowSize=windowSize)
X2.shape, Y2.shape


X3, Y3 = createImageCubes(X3, Y3, windowSize=windowSize)
X3.shape, Y3.shape

In [None]:
X = np.concatenate((X1 , X2, X3) , axis = 0)
Y = np.concatenate((Y1 , Y2, Y3) , axis = 0)

X.shape,Y.shape

In [None]:
X = X.reshape((X.shape[0],windowSize,windowSize,18,1))
#X=X[:,:,:,0:10]
X.shape

In [None]:
trainS, testS, labelTr, labelTs = splitTrainTestSet(X, Y, test_ratio)


In [None]:
del X
del Y

### ***Generative Adversarial Network***

In [None]:
# For running in python 2.x
from __future__ import print_function, unicode_literals
from __future__ import absolute_import, division
import numpy as np
import tensorflow as tf
from keras import backend as K
from keras.layers import Input, Dropout, Dense, RepeatVector, Lambda, Reshape, Conv3D, Conv2D, Flatten, InputSpec
from keras.layers import BatchNormalization, Concatenate, Multiply, Add, Conv2DTranspose, GlobalAveragePooling2D, MaxPool2D
from keras.layers.advanced_activations import LeakyReLU, Softmax
from keras.models import Model
from tensorflow.keras import layers

In [None]:


def denseGamoGenCreate(latDim, num_class):
    noise = Input(shape=(latDim, ))
    labels = Input(shape=(num_class, ))
    gamoGenInput = Concatenate()([noise, labels])

    x = Dense(7 * 7 * 128, use_bias=False)(gamoGenInput)
    x = BatchNormalization(momentum=0.9)(x)
    x = LeakyReLU()(x)

    x = Reshape((7, 7, 128))(x)

    x = Conv2DTranspose(64, (5, 5), strides=(1, 1), padding='same', use_bias=False)(x)
    x = BatchNormalization(momentum=0.9)(x)
    x = LeakyReLU()(x)


    x = Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
    x = BatchNormalization(momentum=0.9)(x)

    gamoGenFinal = Flatten()(x)

    gamoGen = Model([noise, labels], gamoGenFinal)
    return gamoGen



def denseGenProcessCreate(numMinor, dataMinor,sh,mul):
    ip1=Input(shape=(196,))
    x=Dense(numMinor, activation='softmax')(ip1)
    x=RepeatVector(mul)(x)
    z = np.reshape(dataMinor,(numMinor,mul))
    genProcessFinal=Lambda(lambda x: K.sum(x*K.transpose(K.constant(z)), axis=2))(x)
    genProcessReshape = Reshape(sh)(genProcessFinal)
    genProcess=Model(ip1, genProcessReshape)
    return genProcess

def denseDisCreate(sh, num_class):
    imIn=Input(shape=sh)

    x = Conv3D(filters=16, kernel_size=(1, 1, 7), activation='relu', padding='same')(imIn)
   
    conv3d_shape = x.shape
    x = Reshape((conv3d_shape[1], conv3d_shape[2], conv3d_shape[3]*conv3d_shape[4]))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
   
    x = layers.Activation("relu")(x)
    x = layers.SeparableConv2D(filters=32,kernel_size=(3, 3), padding="same")(x)
    #x = layers.BatchNormalization()(x)

    x = layers.Activation("relu")(x)
    x = layers.SeparableConv2D(filters=32,kernel_size=(3, 3), padding="same")(x)
    #x = layers.BatchNormalization()(x)

    x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
    residual = layers.Conv2D(filters=32, kernel_size=(3, 3),strides=2, padding="same")(
            previous_block_activation
        )
    x = layers.add([x, residual])  # Add back residual
    previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    
    x = layers.Activation("relu")(x)
    x = layers.Conv2DTranspose(filters=32, kernel_size=(3, 3), padding="same")(x)
    x = layers.BatchNormalization()(x)

    x = layers.Activation("relu")(x)
    x = layers.Conv2DTranspose(filters=32, kernel_size=(3, 3), padding="same")(x)
    x = layers.BatchNormalization()(x)

    x = layers.UpSampling2D(2)(x)

        # Project residual
    residual = layers.UpSampling2D(2)(previous_block_activation)
    residual = layers.Conv2D(filters=32, kernel_size=(3, 3), padding="same")(residual)
    x = layers.add([x, residual])  # Add back residual
    previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    
    flatten_layer = Flatten()(x)
    
    
    labels=Input(shape=(num_class,))
    disInput=Concatenate()([flatten_layer, labels])
    x=Dropout(0.5)(disInput)
    
    disFinal1=Dense(20, activation='relu')(x)
    disFinal2=Dense(10, activation='relu')(disFinal1)
    disFinal=Dense(1, activation='sigmoid', kernel_initializer="he_normal")(disFinal2)
    
    dis=Model([imIn, labels], disFinal)
    return dis

def denseMlpCreate(sh, num_class):
    imIn=Input(shape=sh)

    conv_layer1 = Conv3D(filters=16, kernel_size=(1, 1, 7), activation='relu', padding='same')(imIn)
    conv_layer2 = Conv3D(filters=32, kernel_size=(3, 3, 5), activation='relu',padding='same')(conv_layer1)
    conv_layer3 = Conv3D(filters=32, kernel_size=(5, 5, 7), activation='relu',padding='same')(conv_layer2)
    conv3d_shape = conv_layer3.shape
    conv_layer3 = Reshape((conv3d_shape[1], conv3d_shape[2], conv3d_shape[3]*conv3d_shape[4]))(conv_layer3)
    conv_layer4 = Conv2D(filters=64, kernel_size=(3,3), activation='relu',padding='same')(conv_layer3)
    conv_layer5 = Conv2D(filters=64, kernel_size=(3,3), activation='relu',padding='same')(conv_layer4)
    conv_layer5 = GlobalAveragePooling2D()(conv_layer5)
    flatten_layer = Flatten()(conv_layer5)
    
    x=Dropout(0.5)(flatten_layer)
    
    mlpFinal1 = Dense(20, activation='relu')(x)
    mlpFinal2 = Dense(10, activation='relu')(mlpFinal1)
    mlpFinal = Dense(num_class, activation="softmax", kernel_initializer="he_normal")(mlpFinal2)
    
    mlp=Model(imIn, mlpFinal)
    return mlp

In [None]:
# For running in python 2.x
from __future__ import print_function, unicode_literals
from __future__ import absolute_import, division

import sys
import numpy as np
from sklearn.metrics import confusion_matrix
from scipy.spatial.distance import cdist
from keras.utils.np_utils import to_categorical

def relabel(labelTr, labelTs):
    unqLab, pInClass=np.unique(labelTr, return_counts=True)
    sortedUnqLab=np.argsort(pInClass, kind='mergesort')
    c=sortedUnqLab.shape[0]
    labelsNewTr=np.zeros((labelTr.shape[0],))-1
    labelsNewTs=np.zeros((labelTs.shape[0],))-1
    pInClass=np.sort(pInClass)
    classMap=list()
    for i in range(c):
        labelsNewTr[labelTr==unqLab[sortedUnqLab[i]]]=i
        labelsNewTs[labelTs==unqLab[sortedUnqLab[i]]]=i
        classMap.append(np.where(labelsNewTr==i)[0])
    return labelsNewTr, labelsNewTs, c, pInClass, classMap, sortedUnqLab

def irFind(pInClass, c, irIgnore=1):
    ir=pInClass[-1]/pInClass
    imbalancedCls=np.arange(c)[ir>irIgnore]
    toBalance=np.subtract(pInClass[-1], pInClass[imbalancedCls])
    imbClsNum=toBalance.shape[0]
    if imbClsNum==0: sys.exit('No imbalanced classes found, exiting ...')
    return imbalancedCls, toBalance, imbClsNum, ir

def fileRead(fileName):
    dataTotal=np.loadtxt(fileName, delimiter=',')
    data=dataTotal[:, :-1]
    labels=dataTotal[:, -1]
    return data, labels

def indices(pLabel, tLabel):
    confMat=confusion_matrix(tLabel, pLabel)
    nc=np.sum(confMat, axis=1)
    tp=np.diagonal(confMat)
    tpr=tp/nc
    acsa=np.mean(tpr)
    gm=np.prod(tpr)**(1/confMat.shape[0])
    acc=np.sum(tp)/np.sum(nc)
    return acsa, gm, tpr, confMat, acc

def randomLabelGen(toBalance, batchSize, c):
    cumProb=np.cumsum(toBalance/np.sum(toBalance))
    bins=np.insert(cumProb, 0, 0)
    randomValue=np.random.rand(batchSize,)
    randLabel=np.digitize(randomValue, bins)-1
    randLabel_cat=to_categorical(randLabel)
    labelPadding=np.zeros((batchSize, c-randLabel_cat.shape[1]))
    randLabel_cat=np.hstack((randLabel_cat, labelPadding))
    return randLabel_cat

def batchDivision(n, batchSize):
    numBatches, residual=int(np.ceil(n/batchSize)), int(n%batchSize)
    if residual==0:
        residual=batchSize
    batchDiv=np.zeros((numBatches+1,1), dtype='int64')
    batchSizeStore=np.ones((numBatches, 1), dtype='int64')
    batchSizeStore[0:-1, 0]=batchSize
    batchSizeStore[-1, 0]=residual
    for i in range(numBatches):
        batchDiv[i]=i*batchSize
    batchDiv[numBatches]=batchDiv[numBatches-1]+residual
    return batchDiv, numBatches, batchSizeStore

def rearrange(labelsCat, numImbCls):
    labels=np.argmax(labelsCat, axis=1)
    arrangeMap=list()
    for i in range(numImbCls):
        arrangeMap.append(np.where(labels==i)[0])
    return arrangeMap

In [None]:
from __future__ import print_function, unicode_literals
from __future__ import absolute_import, division

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from keras.preprocessing import image
from keras.layers import Input
from keras.models import Model
from tensorflow.keras.optimizers import Adam
from keras.utils.np_utils import to_categorical

In [None]:
data_path = 'Data/'
fileName=['SA_trainData.csv', 'SA_testData.csv']
fileStart=data_path + 'SavedModel/'+'Salinas_GAMO_90'
fileEnd, savePath='_Model.h5', fileStart+'/'
adamOpt=Adam(0.0002, 0.5)
latDim, modelSamplePd, resSamplePd=100,1000, 500
plt.ion()

In [None]:
batchSize, max_step=32,30000 #30000

In [None]:
n, m = trainS.shape[0], testS.shape[0]
#trainS, testS=(trainS-5)/5, (testS-5)/5


In [None]:
labelTr, labelTs, c, pInClass, classMap, sortedUnqLab=relabel(labelTr, labelTs)

In [None]:
imbalancedCls, toBalance, imbClsNum, ir=irFind(pInClass, c)

In [None]:
labelsCat=to_categorical(labelTr)

In [None]:
shuffleIndex=np.random.choice(np.arange(n), size=(n,), replace=False)
trainS=trainS[shuffleIndex]
labelTr=labelTr[shuffleIndex]
labelsCat=labelsCat[shuffleIndex]
classMap=list()
for i in range(c):
    classMap.append(np.where(labelTr==i)[0])

In [None]:
# model initialization
sh=(windowSize,windowSize,18,1,)
mlp=denseMlpCreate(sh,8)
mlp.compile(loss='mean_squared_error', optimizer=adamOpt)
mlp.trainable=False

dis=denseDisCreate(sh,8)
dis.compile(loss='mean_squared_error', optimizer=adamOpt)
dis.trainable=False

In [None]:
gen=denseGamoGenCreate(latDim,8)

In [None]:
gen_processed, genP_mlp, genP_dis=list(), list(), list()
for i in range(imbClsNum):
    dataMinor=trainS[classMap[i], :]
    numMinor=dataMinor.shape[0]
    print(dataMinor.shape)
    print(numMinor)
    gen_processed.append(denseGenProcessCreate(numMinor, dataMinor,sh = (windowSize,windowSize,18,1),mul = windowSize*windowSize*18 ))

    ip1=Input(shape=(latDim,))
    ip2=Input(shape=(c,))
    op1=gen([ip1, ip2])
    op2=gen_processed[i](op1)
    op3=mlp(op2)
    genP_mlp.append(Model(inputs=[ip1, ip2], outputs=op3))
    genP_mlp[i].compile(loss='mean_squared_error', optimizer=adamOpt)

    ip1=Input(shape=(latDim,))
    ip2=Input(shape=(c,))
    ip3=Input(shape=(c,))
    op1=gen([ip1, ip2])
    op2=gen_processed[i](op1)
    op3=dis([op2, ip3])
    genP_dis.append(Model(inputs=[ip1, ip2, ip3], outputs=op3))
    genP_dis[i].compile(loss='mean_squared_error', optimizer=adamOpt)

In [None]:
batchDiv, numBatches, bSStore=batchDivision(n, batchSize)
genClassPoints=int(np.ceil(batchSize/c))
#fig, axs=plt.subplots(imbClsNum, 3)

In [None]:
if not os.path.exists(fileStart):
    os.makedirs(fileStart)
picPath=savePath+'Pictures'
if not os.path.exists(picPath):
    os.makedirs(picPath)

In [None]:
iter=int(np.ceil(max_step/resSamplePd)+1)
acsaSaveTr, gmSaveTr, accSaveTr=np.zeros((iter,)), np.zeros((iter,)), np.zeros((iter,))
acsaSaveTs, gmSaveTs, accSaveTs=np.zeros((iter,)), np.zeros((iter,)), np.zeros((iter,))
confMatSaveTr, confMatSaveTs=np.zeros((iter, c, c)), np.zeros((iter, c, c))
tprSaveTr, tprSaveTs=np.zeros((iter, c)), np.zeros((iter, c))

In [None]:
step=0
while step<max_step:
    for j in range(numBatches):
        x1, x2=batchDiv[j, 0], batchDiv[j+1, 0]
        validR=np.ones((bSStore[j, 0],1))-np.random.uniform(0,0.1, size=(bSStore[j, 0], 1))
        mlp.train_on_batch(trainS[x1:x2], labelsCat[x1:x2])
        dis.train_on_batch([trainS[x1:x2], labelsCat[x1:x2]], validR)

        invalid=np.zeros((bSStore[j, 0], 1))+np.random.uniform(0, 0.1, size=(bSStore[j, 0], 1))
        randNoise=np.random.normal(0, 1, (bSStore[j, 0], latDim))
        fakeLabel=randomLabelGen(toBalance, bSStore[j, 0], c)
        rLPerClass=rearrange(fakeLabel, imbClsNum)
        fakePoints=np.zeros((bSStore[j, 0],windowSize,windowSize,18,1))
        genFinal=gen.predict([randNoise, fakeLabel])
        for i1 in range(imbClsNum):
            if rLPerClass[i1].shape[0]!=0:
                temp=genFinal[rLPerClass[i1]]
                fakePoints[rLPerClass[i1]]=gen_processed[i1].predict(temp)

        mlp.train_on_batch(fakePoints, fakeLabel)
        dis.train_on_batch([fakePoints, fakeLabel], invalid)

        for i1 in range(imbClsNum):
            validA=np.ones((genClassPoints, 1))
            randomLabel=np.zeros((genClassPoints, c))
            randomLabel[:, i1]=1
            randNoise=np.random.normal(0, 1, (genClassPoints, latDim))
            oppositeLabel=np.ones((genClassPoints, c))-randomLabel
            genP_mlp[i1].train_on_batch([randNoise, randomLabel], oppositeLabel)
            genP_dis[i1].train_on_batch([randNoise, randomLabel, randomLabel], validA)

        if step%resSamplePd==0:
            saveStep=int(step//resSamplePd)

            pLabel=np.argmax(mlp.predict(trainS), axis=1)
            acsa, gm, tpr, confMat, acc=indices(pLabel, labelTr)
            print('Train: Step: ', step, 'ACSA: ', np.round(acsa, 4), 'GM: ', np.round(gm, 4))
            print('TPR: ', np.round(tpr, 2))
            acsaSaveTr[saveStep], gmSaveTr[saveStep], accSaveTr[saveStep]=acsa, gm, acc
            confMatSaveTr[saveStep]=confMat
            tprSaveTr[saveStep]=tpr

            pLabel=np.argmax(mlp.predict(testS), axis=1)
            acsa, gm, tpr, confMat, acc=indices(pLabel, labelTs)
            print('Test: Step: ', step, 'ACSA: ', np.round(acsa, 4), 'GM: ', np.round(gm, 4))
            print('TPR: ', np.round(tpr, 2))
            acsaSaveTs[saveStep], gmSaveTs[saveStep], accSaveTs[saveStep]=acsa, gm, acc
            confMatSaveTs[saveStep]=confMat
            tprSaveTs[saveStep]=tpr


        if step%modelSamplePd==0 and step!=0:
            direcPath=savePath+'gamo_models_'+str(step)
            if not os.path.exists(direcPath):
                os.makedirs(direcPath)
            gen.save(direcPath+'/GEN_'+str(step)+fileEnd)
            mlp.save(direcPath+'/MLP_'+str(step)+fileEnd)
            dis.save(direcPath+'/DIS_'+str(step)+fileEnd)
            for i in range(imbClsNum):
                gen_processed[i].save(direcPath+'/GenP_'+str(i)+'_'+str(step)+fileEnd)

        step=step+2
        if step>=max_step: break

In [None]:
pLabel=np.argmax(mlp.predict(testS), axis=1)
acsa, gm, tpr, confMat, acc=indices(pLabel, labelTs)
print('Performance on Test Set: Step: ', step, 'ACSA: ', np.round(acsa, 4), 'GM: ', np.round(gm, 4))
print('TPR: ', np.round(tpr, 2))
acsaSaveTs[-1], gmSaveTs[-1], accSaveTs[-1]=acsa, gm, acc
confMatSaveTs[-1]=confMat
tprSaveTs[-1]=tpr

In [None]:
direcPath=savePath+'gamo_models_'+str(step)
if not os.path.exists(direcPath):
    os.makedirs(direcPath)
gen.save(direcPath+'/GEN_'+str(step)+fileEnd)
mlp.save(direcPath+'/MLP_'+str(step)+fileEnd)
dis.save(direcPath+'/DIS_'+str(step)+fileEnd)
for i in range(imbClsNum):
    gen_processed[i].save(direcPath+'/GenP_'+str(i)+'_'+str(step)+fileEnd)

In [None]:
resSave=savePath+'Results.txt'
np.savez(resSave, acsa=acsa, gm=gm, tpr=tpr, confMat=confMat, acc=acc)
recordSave=savePath+'Record.txt'
np.savez(recordSave, acsaSaveTr=acsaSaveTr, gmSaveTr=gmSaveTr, accSaveTr=accSaveTr, acsaSaveTs=acsaSaveTs, gmSaveTs=gmSaveTs, accSaveTs=accSaveTs, confMatSaveTr=confMatSaveTr, confMatSaveTs=confMatSaveTs, tprSaveTr=tprSaveTr, tprSaveTs=tprSaveTs)

In [None]:
X = np.concatenate((trainS , testS) , axis = 0)
print(X.shape)
Y = np.concatenate((labelTr , labelTs) , axis = 0)
Y.shape

In [None]:
unqLab, pInClass=np.unique(labelTr, return_counts=True)
print(unqLab,pInClass)
pInClass = 2917 - pInClass
print(pInClass)

In [None]:
#y = []
for i1 in range(pInClass.shape[0] - 1):
  testNoise=np.random.normal(0, 1, (pInClass[i1], latDim))
  testLabel=np.zeros((pInClass[i1], c))
  testLabel[:, i1]=1
  genFinal=gen.predict([testNoise, testLabel])
  genImages=gen_processed[i1].predict(genFinal)
  genImages=np.reshape(genImages, (pInClass[i1], 12,12,18,1))
  X =  np.concatenate((X , genImages) , axis = 0)
  Y =  np.concatenate((Y , np.argmax(testLabel, axis=1)) , axis = 0)
print(X.shape)
print(np.array(Y).shape)

In [None]:
import scipy.io as sio
sio.savemat('X_Avalon90.mat', {'X':X})

sio.savemat('Y_Avalon90.mat', {'Y':Y})

### ***Swin Transformer***

In [None]:
################################           Reading mat file

X=sio.loadmat('Data/X_Avalon.mat')
Y=sio.loadmat('Data/Y_Avalon.mat')
X = X['X']
Y = Y['Y']
Y=Y.reshape(Y[0].shape)
Y.shape, X.shape

In [None]:
X = X.reshape((X.shape[0],windowSize,windowSize,18))

In [None]:
test_ratio=0.30

In [None]:
Xtrain, Xtest, ytrain, ytest = splitTrainTestSet(X, Y, test_ratio)

np.min(ytrain), np.max(ytrain)


In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
input_shape = (12, 12, 18)
patch_size = (2, 2)  # 2-by-2 sized patches
dropout_rate = 0.04  # Dropout rate
num_heads = 4  # Attention heads
embed_dim = 16  # Embedding dimension
num_mlp = 16  # MLP layer size
qkv_bias = True  # Convert embedded patches to query, key, and values with a learnable additive value
window_size = 2  # Size of attention window
shift_size = 1  # Size of shifting window
image_dimension = 12  # Initial image size

num_patch_x = input_shape[0] // patch_size[0]
num_patch_y = input_shape[1] // patch_size[1]


In [None]:
def window_partition(x, window_size):
    _, height, width, channels = x.shape
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = tf.reshape(
        x, shape=(-1, patch_num_y, window_size, patch_num_x, window_size, channels)
    )
    x = tf.transpose(x, (0, 1, 3, 2, 4, 5))
    windows = tf.reshape(x, shape=(-1, window_size, window_size, channels))
    return windows


def window_reverse(windows, window_size, height, width, channels):
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = tf.reshape(
        windows,
        shape=(-1, patch_num_y, patch_num_x, window_size, window_size, channels),
    )
    x = tf.transpose(x, perm=(0, 1, 3, 2, 4, 5))
    x = tf.reshape(x, shape=(-1, height, width, channels))
    return x


class DropPath(layers.Layer):
    def __init__(self, drop_prob=None, **kwargs):
        super(DropPath, self).__init__(**kwargs)
        self.drop_prob = drop_prob

    def call(self, x):
        input_shape = tf.shape(x)
        batch_size = input_shape[0]
        rank = x.shape.rank
        shape = (batch_size,) + (1,) * (rank - 1)
        random_tensor = (1 - self.drop_prob) + tf.random.uniform(shape, dtype=x.dtype)
        path_mask = tf.floor(random_tensor)
        output = tf.math.divide(x, 1 - self.drop_prob) * path_mask
        return output

In [None]:
class WindowAttention(layers.Layer):
    def __init__(
        self, dim, window_size, num_heads, qkv_bias=True, dropout_rate=0.0, **kwargs
    ):
        super(WindowAttention, self).__init__(**kwargs)
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias)
        self.dropout = layers.Dropout(dropout_rate)
        self.proj = layers.Dense(dim)

    def build(self, input_shape):
        num_window_elements = (2 * self.window_size[0] - 1) * (
            2 * self.window_size[1] - 1
        )
        self.relative_position_bias_table = self.add_weight(
            shape=(num_window_elements, self.num_heads),
            initializer=tf.initializers.Zeros(),
            trainable=True,
        )
        coords_h = np.arange(self.window_size[0])
        coords_w = np.arange(self.window_size[1])
        coords_matrix = np.meshgrid(coords_h, coords_w, indexing="ij")
        coords = np.stack(coords_matrix)
        coords_flatten = coords.reshape(2, -1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.transpose([1, 2, 0])
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)

        self.relative_position_index = tf.Variable(
            initial_value=tf.convert_to_tensor(relative_position_index), trainable=False
        )

    def call(self, x, mask=None):
        _, size, channels = x.shape
        head_dim = channels // self.num_heads
        x_qkv = self.qkv(x)
        x_qkv = tf.reshape(x_qkv, shape=(-1, size, 3, self.num_heads, head_dim))
        x_qkv = tf.transpose(x_qkv, perm=(2, 0, 3, 1, 4))
        q, k, v = x_qkv[0], x_qkv[1], x_qkv[2]
        q = q * self.scale
        k = tf.transpose(k, perm=(0, 1, 3, 2))
        attn = q @ k

        num_window_elements = self.window_size[0] * self.window_size[1]
        relative_position_index_flat = tf.reshape(
            self.relative_position_index, shape=(-1,)
        )
        relative_position_bias = tf.gather(
            self.relative_position_bias_table, relative_position_index_flat
        )
        relative_position_bias = tf.reshape(
            relative_position_bias, shape=(num_window_elements, num_window_elements, -1)
        )
        relative_position_bias = tf.transpose(relative_position_bias, perm=(2, 0, 1))
        attn = attn + tf.expand_dims(relative_position_bias, axis=0)

        if mask is not None:
            nW = mask.get_shape()[0]
            mask_float = tf.cast(
                tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), tf.float32
            )
            attn = (
                tf.reshape(attn, shape=(-1, nW, self.num_heads, size, size))
                + mask_float
            )
            attn = tf.reshape(attn, shape=(-1, self.num_heads, size, size))
            attn = keras.activations.softmax(attn, axis=-1)
        else:
            attn = keras.activations.softmax(attn, axis=-1)
        attn = self.dropout(attn)

        x_qkv = attn @ v
        x_qkv = tf.transpose(x_qkv, perm=(0, 2, 1, 3))
        x_qkv = tf.reshape(x_qkv, shape=(-1, size, channels))
        x_qkv = self.proj(x_qkv)
        x_qkv = self.dropout(x_qkv)
        return x_qkv

In [None]:
class SwinTransformer(layers.Layer):
    def __init__(
        self,
        dim,
        num_patch,
        num_heads,
        window_size=8,
        shift_size=0,
        num_mlp=1024,
        qkv_bias=True,
        dropout_rate=0.0,
        **kwargs,
    ):
        super(SwinTransformer, self).__init__(**kwargs)

        self.dim = dim  # number of input dimensions
        self.num_patch = num_patch  # number of embedded patches
        self.num_heads = num_heads  # number of attention heads
        self.window_size = window_size  # size of window
        self.shift_size = shift_size  # size of window shift
        self.num_mlp = num_mlp  # number of MLP nodes

        self.norm1 = layers.LayerNormalization(epsilon=1e-5)
        self.attn = WindowAttention(
            dim,
            window_size=(self.window_size, self.window_size),
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            dropout_rate=dropout_rate,
        )
        self.drop_path = DropPath(dropout_rate)
        self.norm2 = layers.LayerNormalization(epsilon=1e-5)

        self.mlp = keras.Sequential(
            [
                layers.Dense(num_mlp),
                layers.Activation(keras.activations.gelu),
                layers.Dropout(dropout_rate),
                layers.Dense(dim),
                layers.Dropout(dropout_rate),
            ]
        )

        if min(self.num_patch) < self.window_size:
            self.shift_size = 0
            self.window_size = min(self.num_patch)

    def build(self, input_shape):
        if self.shift_size == 0:
            self.attn_mask = None
        else:
            height, width = self.num_patch
            h_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            w_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            mask_array = np.zeros((1, height, width, 1))
            count = 0
            for h in h_slices:
                for w in w_slices:
                    mask_array[:, h, w, :] = count
                    count += 1
            mask_array = tf.convert_to_tensor(mask_array)

            # mask array to windows
            mask_windows = window_partition(mask_array, self.window_size)
            mask_windows = tf.reshape(
                mask_windows, shape=[-1, self.window_size * self.window_size]
            )
            attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims(
                mask_windows, axis=2
            )
            attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
            attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)
            self.attn_mask = tf.Variable(initial_value=attn_mask, trainable=False)

    def call(self, x):
        height, width = self.num_patch
        _, num_patches_before, channels = x.shape
        x_skip = x
        x = self.norm1(x)
        x = tf.reshape(x, shape=(-1, height, width, channels))
        if self.shift_size > 0:
            shifted_x = tf.roll(
                x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]
            )
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = tf.reshape(
            x_windows, shape=(-1, self.window_size * self.window_size, channels)
        )
        attn_windows = self.attn(x_windows, mask=self.attn_mask)

        attn_windows = tf.reshape(
            attn_windows, shape=(-1, self.window_size, self.window_size, channels)
        )
        shifted_x = window_reverse(
            attn_windows, self.window_size, height, width, channels
        )
        if self.shift_size > 0:
            x = tf.roll(
                shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2]
            )
        else:
            x = shifted_x

        x = tf.reshape(x, shape=(-1, height * width, channels))
        x = self.drop_path(x)
        x = x_skip + x
        x_skip = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop_path(x)
        x = x_skip + x
        return x

In [None]:
class PatchExtract(layers.Layer):
    def __init__(self, patch_size, **kwargs):
        super(PatchExtract, self).__init__(**kwargs)
        self.patch_size_x = patch_size[0]
        self.patch_size_y = patch_size[0]

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=(1, self.patch_size_x, self.patch_size_y, 1),
            strides=(1, self.patch_size_x, self.patch_size_y, 1),
            rates=(1, 1, 1, 1),
            padding="VALID",
        )
        patch_dim = patches.shape[-1]
        patch_num = patches.shape[1]
        return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))


class PatchEmbedding(layers.Layer):
    def __init__(self, num_patch, embed_dim, **kwargs):
        super(PatchEmbedding, self).__init__(**kwargs)
        self.num_patch = num_patch
        self.proj = layers.Dense(embed_dim)
        self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)

    def call(self, patch):
        pos = tf.range(start=0, limit=self.num_patch, delta=1)
        return self.proj(patch) + self.pos_embed(pos)


class PatchMerging(tf.keras.layers.Layer):
    def __init__(self, num_patch, embed_dim):
        super(PatchMerging, self).__init__()
        self.num_patch = num_patch
        self.embed_dim = embed_dim
        self.linear_trans = layers.Dense(2 * embed_dim, use_bias=False)

    def call(self, x):
        height, width = self.num_patch
        _, _, C = x.get_shape().as_list()
        x = tf.reshape(x, shape=(-1, height, width, C))
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = tf.concat((x0, x1, x2, x3), axis=-1)
        x = tf.reshape(x, shape=(-1, (height // 2) * (width // 2), 4 * C))
        return self.linear_trans(x)

In [None]:
def get_ST_model():
 
    input_shape1 =  8, 8, 18
    output_units=8
    
 
    input1_ = Input(shape=input_shape1)


############################ Feature extractor
    
    conv_b1 = Conv2D(filters=128, kernel_size=(3, 3), padding='same',activation='relu', name='conv_b1')(input1_)
    max_b_1 = MaxPool2D((1,1), strides=(1,1), padding='same')(conv_b1)
    conv_b2 = Conv2D(filters=256, kernel_size=(3, 3), padding='same',activation='relu', name='conv_b2')(max_b_1)
    conv_b3 = Conv2D(filters=256, kernel_size=(3, 3), padding='same',activation='relu', name='conv_b3')(conv_b2)
    norm_b = BatchNormalization(name='norm_a')(conv_b3)
    
  
 
    ######################################## Swin Transformers 
    
    x = layers.RandomCrop(image_dimension, image_dimension)(norm_b)
    x = layers.RandomFlip("horizontal")(x)
    x = PatchExtract(patch_size)(x)
    x = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)(x)
    x = SwinTransformer(
    dim=embed_dim,
    num_patch=(num_patch_x, num_patch_y),
    num_heads=num_heads,
    window_size=window_size,
    shift_size=0,
    num_mlp=num_mlp,
    qkv_bias=qkv_bias,
    dropout_rate=dropout_rate,
     )(x)
    x = SwinTransformer(
    dim=embed_dim,
    num_patch=(num_patch_x, num_patch_y),
    num_heads=num_heads,
    window_size=window_size,
    shift_size=shift_size,
    num_mlp=num_mlp,
    qkv_bias=qkv_bias,
    dropout_rate=dropout_rate,
     )(x)
    x = PatchMerging((num_patch_x, num_patch_y), embed_dim=embed_dim)(x)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dense(50, activation="relu")(x)
    
    #############################################
    

    output_layer = Dense(units=output_units, activation='softmax')(x)
 
    model = Model(inputs=input1_, outputs=output_layer)
    model.summary()
    
    
   

    return model

In [None]:
from tensorflow.keras.utils import plot_model

In [None]:
model = get_ST_model()

In [None]:
import tensorflow
ytrain = tensorflow.keras.utils.to_categorical(ytrain)
ytrain.shape

In [None]:

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model_checkpoint_callback = keras.callbacks.ModelCheckpoint("Data/ST.h5",save_best_only=True)
history = model.fit(x=Xtrain, y=ytrain, batch_size = 256, epochs=100,callbacks=model_checkpoint_callback)
