<a href="https://colab.research.google.com/github/LukasMosser/neural_rock_typing/blob/main/notebooks/Neural%20Rock%20Typing%20-%20Train%20Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Neural Rock Train Model Notebook

In [None]:
import os

if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')
    import os
    from getpass import getpass
    import urllib

    user = input('User name: ')
    password = getpass('Password: ')
    password = urllib.parse.quote(password) # your password is converted into url format

    cmd_string = 'git clone https://{0}:{1}@github.com/LukasMosser/neural_rock_typing.git'.format(user, password)

    os.system(cmd_string)
    cmd_string, password = "", "" # removing the password from the variable
    os.chdir("./neural_rock_typing")
    os.system('pip install -r requirements.txt')
    os.system('pip install -e .')

    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
else:
    print('Not running on CoLab')

## Login to Weights & Biases for Logging

In [None]:
!wandb login

## Basic Imports

In [None]:
import sys
import argparse
import albumentations as A
from torch.utils.data import DataLoader, ConcatDataset
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from neural_rock.dataset import ThinSectionDataset
from neural_rock.utils import set_seed
from neural_rock.model import NeuralRockModel
from neural_rock.plot import visualize_batch

## Hyperparameters

In [None]:
os.symlink("../drive/MyDrive/deeprock/data", "./data")

labelset = 'Dunham' # 'Lucia' 'DominantPore'
wandb_name = 'lukas-mosser'
learning_rate = 3e-4
batch_size = 16
weight_decay = 1e-5
dropout = 0.5

train_dataset_mult = 10
val_dataset_mult = 50

epochs = 100
check_val_every = 10

seed = 42

## Load Dataset and Transforms

In [None]:
set_seed(seed, cudnn=True, benchmark=True)

data_transforms = {
    'train': A.Compose([
            A.HorizontalFlip(p=0.5),
            A.Rotate(360, always_apply=True),
            A.RandomCrop(width=512, height=512),
            A.GaussNoise(),
            A.HueSaturationValue(sat_shift_limit=0, val_shift_limit=50, hue_shift_limit=255, always_apply=True),
            A.Resize(width=224, height=224),
            A.Normalize()
]),
    'val': A.Compose([
    A.RandomCrop(width=512, height=512),
    A.Resize(width=224, height=224),
    A.Normalize(),
    ])
}

train_dataset_base = ThinSectionDataset("./data/Images_PhD_Miami/Leg194", labelset,
                                   transform=data_transforms['train'], train=True, seed=seed)
val_dataset = ThinSectionDataset("./data/Images_PhD_Miami/Leg194", labelset,
                                 transform=data_transforms['val'], train=False, seed=seed)
train_dataset = ConcatDataset([train_dataset_base]*train_dataset_mult)
val_dataset = ConcatDataset([val_dataset]*val_dataset_mult)

## Initialize Model & Prepare for Training

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True, prefetch_factor=10)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True, prefetch_factor=10)

wandb_logger = WandbLogger(name=wandb_name, project='neural-rock', entity='ccg')
wandb_logger.experiment.config.update({"labelset": labelset, "model": 'VGG_FC'})

tensorboard_logger = TensorBoardLogger("lightning_logs", name=labelset)
checkpointer = ModelCheckpoint(monitor="val/f1", verbose=True, mode="max")
trainer = pl.Trainer(gpus=-1, max_epochs=epochs, benchmark=True,
                     logger=[wandb_logger, tensorboard_logger],
                     callbacks=[checkpointer],
                     check_val_every_n_epoch=10, 
                     progress_bar_refresh_rate=20)

model = NeuralRockModel(num_classes=len(train_dataset_base.class_names))

## Train Model

In [None]:
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)