# 📍<span style="font-family:cursive;"> Overview</span>
* In this notebook,we build best model using Gated Recurrent Unit (GRU) keras with some refrences
* Using albumentations 
* Custom Dataset Function
* This notebook using GPU

# 📚<span style="font-family:cursive;"> Libraries </span>

In [None]:
# Libraries
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import os
import glob
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import cv2
from tqdm import tqdm

from plotly.offline import iplot
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import plotly.figure_factory as ff

import tensorflow as tf
from tensorflow.keras import layers

import torch
import torchvision.models as models
import torch.nn as nn

from tensorflow.keras.optimizers import Adam
from tensorflow.keras import models
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

from skimage.io import imshow, imread, imsave
from skimage.transform import rotate, AffineTransform, warp,rescale, resize, downscale_local_mean
from skimage import color,data
from skimage.exposure import adjust_gamma
from skimage.util import random_noise

# 📝<span style="font-family:cursive;"> Data Preparation</span>


In [None]:
data_dir = '../input/seti-breakthrough-listen'
train_merger = os.path.join(data_dir,'train_labels.csv')
train_labels = pd.read_csv(train_merger)
print('train_label_csv : ' +str(train_labels.shape[0]))
#adding the path for each id for easier processing
train_labels['path'] = train_labels['id'].apply(lambda x: f'../input/seti-breakthrough-listen/train/{x[0]}/{x}.npy')
train_labels.head()

# 📗<span style="font-family:cursive;"> Albumentations</span>

In [None]:
import albumentations
from albumentations.pytorch.transforms import ToTensorV2
from typing import *

class Transform:
    def __init__(self, aug_kwargs: Dict):
        albumentations_aug = [getattr(A, name)(**kwargs)
                            for name, kwargs in aug_kwargs.items()]
        albumentations_aug.append(ToTensorV2(p=1))
        self.transform = A.Compose(albumentations_aug)
    
    def __call__(self, image):
        image = self.transform(image = image)['image']
        return image

In [None]:
class ModeTransform():
    def __init__(self, df_frame, config, channel_mode,mode,target,transform):
        self.df_frame = df_frame
        self.channel_mode = channel_mode
        self.config = config
        self.target = target
        self.file_names = df_frame['path'].values
        self.labels = df_frame['target'].values
        self.transform = transform
        self.mode = mode
        
    def __len__(self):
        return len(self.df_frame)

    def __getitem__(self, idx):
        image = np.load(self.file_names[idx])
        # print(image.shape) -> (6, 273, 256)
        if self.channel_mode == 'spatial_6ch':
            image = image.astype(np.float32)
            image = np.vstack(image) # no transpose here (1638, 256) 
            #image = np.vstack(image).transpose((1, 0))
            # print(image.shape) -> (256, 1638)

        elif self.channel_mode == 'spatial_3ch':
            image = image[::2].astype(np.float32)
            image = np.vstack(image).transpose((1, 0))
        elif self.channel_mode == '6_channel':
            image = image.astype(np.float32)
            image = np.transpose(image, (1,2,0))
        elif self.channel_mode == '3_channel':
            image = image[::2].astype(np.float32)
            image = np.transpose(image, (1,2,0))
        
        if self.transform:
            image = self.transform(image)
  
        else:
            image = torch.from_numpy(image).float()

        if self.mode == 'test':
            return image    
        else:
            label = torch.tensor(self.labels[idx]).float()
            return image, label

In [None]:
import albumentations as A
CONFIG = { 
    "TRAIN_TRANSFORMS": {        
        "VerticalFlip": {"p": 0.5},
        "HorizontalFlip": {"p": 0.5},
        "Resize": {"height": 640, "width": 640, "p": 1},
    }}
config = CONFIG

# Parameters
params_train  = {'mode'            :  'train',
                 'channel_mode'    : 'spatial_6ch',
                 'target'          : True}

train_dset = ModeTransform(train_labels,config,
                           **params_train,
                           transform=Transform(config["TRAIN_TRANSFORMS"]))

for i in range(2):
    image, label = train_dset[i]
    plt.imshow(image[0])
    plt.title(f'label: {label}')
    plt.show()
image.shape

# 🎛<span style="font-family:cursive;"> Custom Dataset</span>

In [None]:
class SETIDataset(tf.keras.utils.Sequence):
    def __init__(self,df, directory, batch_size, random_state, shuffle, target):
        np.random.seed(random_state)
        self.directory = directory
        self.df = df
        self.target = target
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.ext = '.npy'
        self.on_epoch_end()
        
    def __len__(self):  
        len_ = np.ceil(self.df.shape[0] / self.batch_size).astype(int)
        return len_
    
    def __getitem__(self, idx):
        start_idx = idx * self.batch_size
        batch = self.df[start_idx: start_idx + self.batch_size]
        
        signals = []

        for fname in batch.id:
            path = os.path.join(self.directory, fname[0], fname + self.ext)
            data = np.load(path)
            signals.append(data)
        
        signals = np.transpose(np.stack(signals), (0, 1, 3, 2)).astype('float32')
        
        if self.target:
            return signals, batch.target.values
        else:
            return signals
    
    def on_epoch_end(self):
        if self.shuffle:
            self.df = self.df.sample(frac=1).reset_index(drop=True)

In [None]:
train = pd.read_csv('../input/seti-breakthrough-listen/train_labels.csv')
sub = pd.read_csv('../input/seti-breakthrough-listen/sample_submission.csv')

In [None]:
sample_df = train.sample(frac=1).reset_index(drop=True)

split = int(sample_df.shape[0] * 0.8)
train_df = sample_df[:split]
valid_df = sample_df[split:]

In [None]:
# Parameters
params_train  = {'batch_size'   : 64,
                'shuffle'       : True,
                'random_state'  : 42,
                'target'        : True}

params_valid  = {'batch_size'   : 64,
                 'shuffle'      : False,
                 'random_state' : 42,
                 'target'       : True}

params_test   = {'batch_size'   : 64,
                'shuffle'       : False,
                'random_state'  : 42,
                'target'        : False}

train_dset = SETIDataset(
    train_df, "../input/seti-breakthrough-listen/train", **params_train )

valid_dset = SETIDataset(
    valid_df, "../input/seti-breakthrough-listen/train", **params_valid)

test_dset = SETIDataset(
    sub, "../input/seti-breakthrough-listen/test", **params_test)

# 🧪<span style="font-family:cursive;"> Build Model</span>

In [None]:
def build_model(unit):
    inputs = layers.Input(shape=(6, 256, 273))

    gru1 = layers.Bidirectional(layers.GRU(unit, return_sequences = True))
    gru2 = layers.Bidirectional(layers.GRU(unit, return_sequences = True))
    pool = layers.GlobalAveragePooling1D()

    model = layers.TimeDistributed(gru1, name="bi_gru_1")(inputs)
    model = layers.TimeDistributed(gru2, name="bi_gru_2")(model)
    model = layers.TimeDistributed(pool, name="pool")(model)
    
    model = layers.Flatten()(model)
    model = layers.Dense(128, activation="relu")(model)
    model = layers.Dense(1, activation="sigmoid", name="sigmoid")(model)

    model = models.Model(inputs = inputs, outputs = model)
    
    model.compile("adam", 
              loss="binary_crossentropy",
              metrics=[tf.keras.metrics.AUC()])
    model.summary()
    
    return model

In [None]:
model = build_model(unit = 128)
model_save = ModelCheckpoint("model_weights.h5", 
                             save_best_only=True, 
                             save_weights_only=True)

history = model.fit(train_dset, 
                    use_multiprocessing=True, 
                    workers=4, 
                    epochs=10,
                    validation_data=valid_dset,
                    callbacks=[model_save])

# 📈<span style="font-family:cursive;"> Visualizations</span>

In [None]:
acc = history.history['auc']
val_acc = history.history['val_auc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(1, len(acc) + 1)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
sns.set_style("white")
plt.suptitle('Train history', size = 15)

ax1.plot(epochs, acc, "bo", label = "Training acc")
ax1.plot(epochs, val_acc, "b", label = "Validation acc")
ax1.set_title("Training and validation acc")
ax1.legend()

ax2.plot(epochs, loss, "bo", label = "Training loss", color = 'red')
ax2.plot(epochs, val_loss, "b", label = "Validation loss", color = 'red')
ax2.set_title("Training and validation loss")
ax2.legend()

plt.show()

# 🏁<span style="font-family:cursive;"> Final Submission</span>

In [None]:
model.load_weights('model_weights.h5')
y_pred = model.predict(
    test_dset, 
    use_multiprocessing=True, 
    workers=4, 
    verbose=1)

In [None]:
sub['target'] = y_pred
sub.to_csv('submission.csv', index=False)
sub.head()

## <span style="font-family:cursive;"> Reference :</span>

Some references that have been used in this book :
* [https://keras.io/api/layers/recurrent_layers/gru/](http://https://keras.io/api/layers/recurrent_layers/gru/)
* [https://www.programcreek.com/python/example/97114/keras.layers.recurrent.GRU](http://https://www.programcreek.com/python/example/97114/keras.layers.recurrent.GRU)
* [https://towardsdatascience.com/illustrated-guide-to-lstms-and-gru-s-a-step-by-step-explanation-44e9eb85bf21](http://https://towardsdatascience.com/illustrated-guide-to-lstms-and-gru-s-a-step-by-step-explanation-44e9eb85bf21)
* [https://www.kaggle.com/xhlulu/openvaccine-gru-with-keras-tuner](http://https://www.kaggle.com/xhlulu/openvaccine-gru-with-keras-tuner)
* [https://www.kaggle.com/c/seti-breakthrough-listen/discussion/239339](http://https://www.kaggle.com/c/seti-breakthrough-listen/discussion/239339)

## <span style="font-family:cursive;"> If it's useful for you, come on upvote and thank you for your attention🙂</span>