In [4]:
import os

if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')
    os.system('git clone git@github.com:LukasMosser/neural_rock_typing.git')
else:
    print('Not running on CoLab')

Not running on CoLab


In [None]:
import sys
import argparse
import albumentations as A
from torch.utils.data import DataLoader, ConcatDataset
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from neural_rock.dataset import ThinSectionDataset
from neural_rock.utils import set_seed
from neural_rock.model import NeuralRockModel
from neural_rock.plot import visualize_batch

In [3]:
labelset = 'Dunham' # 'Lucia' 'DominantPore'

learning_rate = 3e-4
batch_size = 16
weight_decay = 1e-5
dropout = 0.5

train_dataset_mult = 10
val_dataset_mult = 50

epochs = 100
check_val_every = 10

seed = 42

In [None]:
wandb_name = 'lukas-mosser'


In [None]:
set_seed(seed, cudnn=True, benchmark=True)

data_transforms = {
    'train': A.Compose([
            A.HorizontalFlip(p=0.5),
            A.Rotate(360, always_apply=True),
            A.RandomCrop(width=512, height=512),
            A.GaussNoise(),
            A.HueSaturationValue(sat_shift_limit=0, val_shift_limit=50, hue_shift_limit=255, always_apply=True),
            A.Resize(width=224, height=224),
            A.Normalize()
]),
    'val': A.Compose([
    A.RandomCrop(width=512, height=512),
    A.Resize(width=224, height=224),
    A.Normalize(),
    ])
}

train_dataset_base = ThinSectionDataset("./data/Images_PhD_Miami/Leg194", args.labelset,
                                   transform=data_transforms['train'], train=True, seed=args.seed)
val_dataset = ThinSectionDataset("./data/Images_PhD_Miami/Leg194", args.labelset,
                                 transform=data_transforms['val'], train=False, seed=args.seed)

train_dataset = ConcatDataset([train_dataset_base]*train_dataset_mult)
val_dataset = ConcatDataset([val_dataset]*val_dataset_mult)

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, prefetch_factor=10)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, prefetch_factor=10)

if args.plot:
    visualize_batch(train_loader)
    visualize_batch(val_loader)

wandb_logger = WandbLogger(name=wandb_name, project='neural-rock')
tensorboard_logger = TensorBoardLogger("lightning_logs", name=labelset)
checkpointer = ModelCheckpoint(monitor="val/f1", verbose=True, mode="max")
trainer = pl.Trainer(gpus=-1, max_epochs=epochs, benchmark=True,
                     logger=[wandb_logger, tensorboard_logger],
                     callbacks=[checkpointer],
                     check_val_every_n_epoch=10)

model = NeuralRockModel(num_classes=len(train_dataset_base.class_names))

trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
