In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

# Will use pytorcone_channelightning
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

import matplotlib.pyplot as plt
import numpy as np
import os

from networks import torNet, FashionCNN, efficientNet
from utils import get_data

In [None]:
def matplotlib_imshow(img, one_channel=False):
    """Function for converting image to plt subplot
    
    Args:
        img: The image to plot
        one_channel: Optional, Bool for one channel input image
    """
    
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.cpu().numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        
    return

## Lightning model

In [None]:
class Net(pl.LightningModule):
    """Pytorch Lightning Module for defining and training the model
    
    This is a Pytorch Lightning clas which helps us to rewrite the pytorch code without the boilerplate 
    codes
    """
    
    def __init__(self, model, val_image_count=6):
        """Init the model
        
        Args:
            model: Model to train
            val_image_count: No of images to log during validation
        """
        
        super().__init__()
        self.model = model
        
        self.val_image_count = val_image_count
        
        # Saving the best model
        self.lowest_val_acc = 0
        self.best_epoch = None
        
    def forward(self, x):
        """Forward method for the net
        Args:
            x: Any tensor to input in the model
        """
        
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        """Training step
        
        Args:
            batch: Batches of input data from train_dataloader
            batch_idx: Index of the Batch or the step count
        
        Returns:
            A dict containing loss and logs for tensorboard
        """
        
        x, y = batch
        yt = self(x)
        loss = F.cross_entropy(yt, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}
    
    def validation_step(self, batch, batch_idx):
        """Validation step
        
        Args:
            batch: Batches of input data from val_dataloader
            batch_idx: Index of the Batch or the step count
        
        Returns:
            A dict containing validation loss and validation accuracy
        """
        
        x, y = batch

        yt = self(x)
        loss = F.cross_entropy(yt, y)
        
        preds = torch.argmax(yt, dim=1)
        val_acc = torch.sum(y == preds).item() / (len(y) * 1.0)
#         self.logger.experiment.add_graph(self.model, x)
        if batch_idx == 0:
            probs = [F.softmax(el, dim=0)[i].item() for i, el in zip(preds[:self.val_image_count], 
                                                                     yt[:self.val_image_count])]
            self.logger.experiment.add_figure('predictions vs. actuals',
                                              self.plot_classes_preds(preds[:self.val_image_count], 
                                                                      probs, x[:self.val_image_count], 
                                                                      y[:self.val_image_count]),
                                              global_step=self.current_epoch)
        
        output = {
            'val_loss': loss,
            'val_acc': torch.tensor(val_acc),
        }

        return output
    
    def validation_epoch_end(self, outputs):
        """After all the Validation steps are completed
        
        Args:
            outputs: All the outputs of validation steps
        Returns:
            results: A dict containing progress bar update and tensorboard logs
        """
        
        val_acc_mean = 0
        for output in outputs:
            val_acc_mean += output['val_acc']

        val_acc_mean /= len(outputs)
        tqdm_dict = {'val_acc': val_acc_mean.item()}
        
        # Saving_best_epoch
        if val_acc_mean.item() > self.lowest_val_acc:
            self.lowest_val_acc = val_acc_mean.item()
            self.best_epoch = self.current_epoch
            torch.save(self.model.state_dict(), './../Models/best_model.pth')

        results = {
            'progress_bar': tqdm_dict,
            'log': {'val_acc': val_acc_mean, 'epoch': self.current_epoch, 'best_epoch': self.best_epoch}
        }
        return results
    
    def plot_classes_preds(self, preds, probs, images, labels):
        """Plot the images in a matplotlib figure
        
        Args:
            preds: Predicted label ids of images
            probs: Prediction probabilities of the class
            images: Images to be plotted
            labels: Labels corresponding to preds
        
        Returns:
            fig: The matplotlip figure containing subplots of image and labels
        """
        
        fig = plt.figure(figsize=(12, 12))
        for idx in range(images.shape[0]):
            ax = fig.add_subplot(1, images.shape[0], idx+1, xticks=[], yticks=[])
            matplotlib_imshow(images[idx], one_channel=True)
            ax.set_title("{0}, {1:.1f}%\n(label: {2})".format(
                self.classes[preds[idx]],
                probs[idx] * 100.0,
                self.classes[labels[idx]]),
                        color=("green" if preds[idx]==labels[idx].item() else "red"))
        return fig

    def configure_optimizers(self):
        """Configuring the optimizers to use for training"""
        
        # return torch.optim.Adam(self.parameters(), lr=0.001)
        return torch.optim.RMSprop(self.parameters(), lr=0.0001, weight_decay=1e-6)

    def train_dataloader(self):
        """Dataloader for training"""
        
        train = get_data(train=True)
        loader = DataLoader(train, batch_size=32, num_workers=4, shuffle=True, pin_memory=True)
        return loader
    
    def val_dataloader(self):
        """Dataloader for validation"""
        
        val = get_data()
        self.classes = val.classes
        loader = DataLoader(val, batch_size=32, num_workers=4, shuffle=False, pin_memory=True)
        return loader

## Training

In [None]:
# Model to train
model = Net(efficientNet())

# Checkpoint incase there is an interruption
checkpoint_callback = ModelCheckpoint(
    filepath=os.path.join(os.getcwd(), './lightning_logs/checkpoints'),
    save_top_k=True,
    verbose=True,
    monitor='val_acc',
    mode='max',
    prefix=''
)

# Trainer to train the model on all GPUs for 20 epochs
trainer = Trainer(max_epochs=20, gpus=-1, checkpoint_callback=checkpoint_callback)#, resume_from_checkpoint='./lightning_logs/Checkpoints_efficientNet/_ckpt_epoch_19.ckpt')

In [None]:
trainer.fit(model)   # Fit the model