## Download required Modules

In [None]:
!pip install astroNN
!pip install timm
!pip install --upgrade wandb

## Import Required Libraries

In [None]:
import os
import glob
import re
import tqdm
import random
import gc
import pickle
import math
import h5py

import timm

import wandb
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

import numpy as np
import seaborn as sns
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

import tensorflow as tf


## Get Dataset

In [None]:
# To get the images and labels from file

with h5py.File('../input/galaxy10/Galaxy10.h5', 'r') as F:
    images = np.array(F['images'])
    labels = np.array(F['ans'])

# To convert the labels to categorical 10 classes
# labels = tf.keras.utils.to_categorical(labels, 10)

# To convert to desirable type
labels = labels.astype(np.float32)
images = images.astype(np.float32)

In [None]:
print(labels.shape)
print(images.shape)

In [None]:
print(labels[0])

In [None]:
type(images[0])

In [None]:
X_train, X_test, y_train, y_test = train_test_split(images, labels, test_size=0.33, random_state=42)

## Init Wandb

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandb_key = user_secrets.get_secret("wandb-key")

In [None]:
wandb.login(key=wandb_key)

## Config and Seeding

In [None]:
CONFIG = {"seed": 2021,
          "img_size": 69,
          "model_name": "tf_efficientnet_b0_ns",
          "embedding_size": 256,
          "train_batch_size": 64,
          "valid_batch_size": 64,
          "num_classes": 10,
          "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
          }

In [None]:
pl.utilities.seed.seed_everything(CONFIG['seed'])

## Wrap in Dataset Class

In [None]:
class galaxy10(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
    
    def __len__(self): 
        return len(self.images)
    
    def __getitem__(self, i):
        im = self.get_image(self.images[i])
        label = self.labels[i]
        return (
            im,
            torch.Tensor([label]).long().squeeze()
        )
        
    def get_image(self, image):
        image = np.array(image)
        if self.transform:
            image = self.transform(image)
        image = image/255.0
#         image = torch.permute(image, (2,0,1))
        return image

## BaseLine Model (Deep CNN)

In [None]:
class baseModel(pl.LightningModule):
    def __init__(self, model_name, embedding_size, pretrained=True):
        super().__init__()
        self.model = nn.Sequential(
            timm.create_model(model_name, pretrained=pretrained,in_chans=3),
            nn.Flatten(),
            nn.LazyLinear(128),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.LazyLinear(256),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.LazyLinear(10),
#             nn.Softmax(),
        )
        
    def forward(self,x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        out = self(x)
        loss = nn.CrossEntropyLoss()(out, y)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        out = self(x)
        val_loss = nn.CrossEntropyLoss()(out, y)
        self.log("val_loss", val_loss)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]
        
    def predict_step(self, batch, batch_idx):
        x = batch
        out = self(x)
        return out

In [None]:
transformations = transforms.Compose([transforms.ToTensor(),
                                      transforms.ConvertImageDtype(torch.float),
                                      transforms.Resize((CONFIG['img_size'],CONFIG['img_size']))])

In [None]:
train_dataset = galaxy10(X_train,y_train, transform = transformations)
valid_dataset = galaxy10(X_test,y_test, transform = transformations)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size = 32, shuffle=True, num_workers=2, pin_memory = True)
valid_dataloader = DataLoader(valid_dataset, batch_size = 32, shuffle = False, num_workers=2, pin_memory = True)

In [None]:
for batch in train_dataloader:
    x,y = batch
    print(x.shape)
    print(y.shape)
    break

In [None]:
model = baseModel('tf_efficientnet_b0_ns',CONFIG['embedding_size'])

In [None]:
temp = torch.rand((2,3,69,69))
model(temp)

In [None]:
wandb_logger = WandbLogger(project="galaxy10", group="EfficientNet", config=CONFIG,  job_type='train')

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    filename="efficientnet_baseline",
    save_top_k=1,
    mode="min",
)

In [None]:
trainer = pl.Trainer(logger=wandb_logger, 
                     gpus=1, max_epochs = 25,
                     callbacks=[checkpoint_callback])

In [None]:
trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders = valid_dataloader)

In [None]:
# wandb.finish()