# Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/Shareddrives/OPML/20210814_Mango_identification/敬倫
%cp -ruv ../data/mango_1st_stage /content/

In [None]:
!pip install pytorch_lightning monai multipledispatch

# Import Packages

In [None]:
import os
from datetime import datetime
from argparse import Namespace
from multipledispatch import dispatch

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from monai.data import Dataset, DataLoader
from monai.transforms import *

In [None]:
from pytorch_lightning import LightningModule, LightningDataModule, Trainer
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint

# Load Data and Construct Dataset, DataLoader

In [None]:
class MangoDataModule(LightningDataModule):
    label_dict = {
        "A": 0,
        "B": 1,
        "C": 2
    }

    @staticmethod
    @dispatch(pd.DataFrame, str, str)
    def prepare_data_list(csv_file, data_root, folder):
        csv_file = csv_file.applymap(
            lambda x: MangoDataModule.label_dict.get(x, x)
        )

        return [{"image_id": s["image_id"],
                "image": os.path.join(
                    data_root,
                    folder,
                    s["image_id"]
                  ),
                  "label": s["label"]}
                for _, s in csv_file.iterrows()]

    @staticmethod
    @dispatch(str, str, str)
    def prepare_data_list(csv_file, data_root, folder):
        csv_file = pd.read_csv(os.path.join(data_root, csv_file))

        return MangoDataModule.prepare_data_list(csv_file, data_root, folder)

    def __init__(self, data_root, batch_size):
        super(MangoDataModule, self).__init__()
        self.data_root  = data_root
        self.batch_size = batch_size

    def prepare_data(self):
        self.data_list = {
            "training": self.prepare_data_list(
                "train.csv",
                self.data_root,
                "train"
            ),
            "validation": self.prepare_data_list(
                "val.csv",
                self.data_root,
                "dev"
            ),
            "test": self.prepare_data_list(
                "test.csv",
                self.data_root,
                "dev"
            )
        }
        
        self.train_transform = Compose([
            LoadImaged(keys="image"),
            AsChannelFirstd(keys="image"),
            Resized(keys="image", spatial_size=(128, 128)),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
            SelectItemsd(keys=["image_id", "image", "label"])
        ])
        self.eval_transform  = Compose([
            LoadImaged(keys="image"),
            AsChannelFirstd(keys="image"),
            Resized(keys="image", spatial_size=(128, 128)),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
            SelectItemsd(keys=["image_id", "image", "label"])
        ])

    def train_dataloader(self):
        dataset = Dataset(
            self.data_list.get("training", []),
            transform=self.train_transform
        )

        return DataLoader(dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        dataset = Dataset(
            self.data_list.get("validation", []),
            transform=self.eval_transform
        )

        return DataLoader(dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        dataset = Dataset(
            self.data_list.get("test", []),
            transform=self.eval_transform
        )

        return DataLoader(dataset, batch_size=self.batch_size)

# Model Construction

In [None]:
class Net(LightningModule):
    def __init__(self, lr):
      super(Net, self).__init__()
      self.conv = nn.Conv2d(in_channels=3, out_channels=16, padding=1, kernel_size=3, stride=2)
      self.bn = nn.BatchNorm2d(16)
      self.prelu = nn.PReLU()
      self.dropout = nn.Dropout(p=0.5)
      self.flatten = nn.Flatten()
      self.layer  = nn.Linear(16*64*64, 3)

      self.save_hyperparameters()
      self.criterion = nn.CrossEntropyLoss()

        
    def forward(self, x):
      x = self.conv(x)
      x = self.bn(x)
      x = self.prelu(x)
      x = self.dropout(x)
      x = self.flatten(x)
      x = self.layer(x)
      x = torch.sigmoid(x)
      
      return x
    
    
    def configure_optimizers(self):
      optimizer = Adam(self.parameters(), lr=self.hparams.lr)
      
      return optimizer
    
    
    def compute_loss(self, x, y):
      loss = self.criterion(x,y)
      
      return loss
    
    
    def compute_acc(self, x, y):
      X = torch.argmax(x, dim=1)
      acc = torch.sum(X == y)/x.shape[0]
      
      return acc
    
    
    def evaluate(self, batch):
      x, y   = batch["image"], batch["label"]
      output = self(x)
      loss = self.compute_loss(output, y)
      acc = self.compute_acc(output, y)
          
      return {'loss': loss, 'acc': acc}
        
        
    def training_step(self, batch, batch_idx):
      output = self.evaluate(batch)

      return output
    
    
    def validation_step(self, batch, batch_idx):
      output = self.evaluate(batch)
      
      return output
    
    
    def epoch_end(self, outputs, prefix):
      meanloss = torch.mean(torch.stack([o['loss'] for o in outputs]))
      meanacc = torch.mean(torch.stack([o['acc'] for o in outputs]))
      self.log('step', self.trainer.current_epoch)
      self.log(f'{prefix}/loss', meanloss, prog_bar=prefix != "training")
      self.log('step', self.trainer.current_epoch)
      self.log(f'{prefix}/accuracy', meanacc, prog_bar=True)
      self.log('step', self.trainer.current_epoch)
      self.log(f'{prefix}/learning rate', self.hparams.lr, prog_bar=False)
      
      return None
    
    
    def training_epoch_end(self, outputs):
      self.epoch_end(outputs, "training")   
      
      return None
    
    
    def validation_epoch_end(self, outputs):
      self.epoch_end(outputs, "validation")    
      
      return None

# Setup Hyperparameters

In [None]:
ckpt = "logs/mango/..."
net = Net.load_from_checkpoint(ckpt) if os.path.exists(ckpt) else Net(lr=5e-4)
datamodule = MangoDataModule("/content/mango_1st_stage", 256)

In [None]:
cur_time = datetime.today().strftime('%Y-%m-%d-%H-%M')
os.makedirs(f"logs/mango/{cur_time}", exist_ok = True)
tensor_board_logger = TensorBoardLogger(save_dir="logs", name="mango", version=cur_time, default_hp_metric=False)
csv_logger = CSVLogger(save_dir="logs", name="mango", version=cur_time)

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor="validation/accuracy",
    dirpath=f"logs/mango/{cur_time}",
    filename="{epoch:0>2}",
    save_top_k=5,
    mode="max",
    save_last=True
)

In [None]:
trainer = Trainer(
    max_epochs=200,
    gpus=1,
    logger=[tensor_board_logger, csv_logger],
    log_every_n_steps=1,
    weights_summary='full',
    # limit_train_batches=50,
    callbacks=[checkpoint_callback],
    progress_bar_refresh_rate=1,
    num_sanity_val_steps=0 
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


# Training

In [None]:
trainer.fit(net, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | conv      | Conv2d           | 448   
1 | bn        | BatchNorm2d      | 32    
2 | prelu     | PReLU            | 1     
3 | dropout   | Dropout          | 0     
4 | flatten   | Flatten          | 0     
5 | layer     | Linear           | 196 K 
6 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
197 K     Trainable params
0         Non-trainable params
197 K     Total params
0.788     Total estimated model params size (MB)


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

# Tensorboard

In [None]:
%reload_ext tensorboard
%tensorboard --logdir logs/