In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import os
import glob
import cv2
import json
from collections import Counter
import pickle
import numpy as np

from efficientnet_pytorch import EfficientNet

import torch.optim as optim
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

import torchvision
from torchvision.datasets import ImageFolder
from torchvision import datasets, transforms
from torchdiffeq import odeint

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

import albumentations as A
from albumentations.core.composition import Compose
from albumentations.pytorch import ToTensorV2

from sklearn import metrics, model_selection,preprocessing
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix


import optuna
from optuna.integration import PyTorchLightningPruningCallback
os.environ["TORCH_HOME"] = "/media/hdd/Datasets/"

In [2]:
import torchsnooper as sn

# Look at data 
- Create a csv for easy loading

In [3]:
main_path = "/media/hdd/Datasets/flowers/flowers/"

In [4]:
all_ims = glob.glob(main_path+"/*/*.jpg");all_ims[0]

'/media/hdd/Datasets/flowers/flowers/rose/8775267816_726ddc6d92_n.jpg'

In [5]:
len(all_ims)

4323

In [6]:
def create_label(x):
    return x.split("/")[-2]

In [7]:
df = pd.DataFrame.from_dict({x:create_label(x) for x in all_ims} ,orient='index').reset_index()
df.columns = ["image_id","label"]

In [8]:
df.head()

Unnamed: 0,image_id,label
0,/media/hdd/Datasets/flowers/flowers/rose/87752...,rose
1,/media/hdd/Datasets/flowers/flowers/rose/26071...,rose
2,/media/hdd/Datasets/flowers/flowers/rose/56022...,rose
3,/media/hdd/Datasets/flowers/flowers/rose/89266...,rose
4,/media/hdd/Datasets/flowers/flowers/rose/88530...,rose


In [9]:
df_b = df

In [10]:
df.label.unique()

array(['rose', 'dandelion', 'tulip', 'sunflower', 'daisy'], dtype=object)

In [11]:
df.label.nunique()

5

In [12]:
temp = preprocessing.LabelEncoder()
df['label'] = temp.fit_transform(df.label.values)

In [13]:
label_map=  {i: l for i, l in enumerate(temp.classes_)}

In [14]:
df.label.nunique()

5

In [15]:
df.label.value_counts()

1    1052
4     984
2     784
0     769
3     734
Name: label, dtype: int64

In [16]:
df["kfold"] = -1
df = df.sample(frac=1).reset_index(drop=True)
stratify = StratifiedKFold(n_splits=5)
for i, (t_idx, v_idx) in enumerate(
        stratify.split(X=df.image_id.values, y=df.label.values)):
    df.loc[v_idx, "kfold"] = i
    df.to_csv("train_folds_fruits.csv", index=False)

In [17]:
pd.read_csv("train_folds_fruits.csv").head(1)

Unnamed: 0,image_id,label,kfold
0,/media/hdd/Datasets/flowers/flowers/dandelion/...,1,0


# Architecture

In [18]:
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

In [19]:
def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

In [20]:
def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)

In [21]:
class ResBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.norm1 = norm(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = norm(planes)
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        shortcut = x

        out = self.relu(self.norm1(x))

        if self.downsample is not None:
            shortcut = self.downsample(out)

        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + shortcut

In [22]:
class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)

In [23]:
class ODEfunc(nn.Module):

    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out

In [24]:
class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, oui):
        self.integration_time = self.integration_time.type_as(oui)
        out = odeint(self.odefunc, oui, self.integration_time, rtol=1e-3, atol=1e-3)
        return torch.cuda.FloatTensor(oui)

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

In [25]:
class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

# Create model

In [26]:

class LitModel(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=1e-4, weight_decay=0.0001):
        super().__init__()

        # log hyperparameters
        self.save_hyperparameters()
        self.num_classes = num_classes

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

        self.enet = EfficientNet.from_pretrained('efficientnet-b5',num_classes=self.num_classes)
        self.enet._bn1 = nn.Sequential(
                ODEBlock(ODEfunc(2048)),
                nn.BatchNorm2d(2048, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True),
        )
        in_features = self.enet._fc.in_features
        self.enet._fc = nn.Linear(in_features, num_classes)

#     @sn.snoop()
    def forward(self, x):
        out = self.enet(torch.cuda.FloatTensor(x))
        return out

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(),
                                      lr=self.learning_rate,
                                      weight_decay=self.weight_decay)

        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=2,
                                                    gamma=0.1)

        return ([optimizer], [scheduler])



    def training_step(self, train_batch, batch_idx):
        x, y = train_batch["x"], train_batch["y"]
        x =x.type(torch.cuda.FloatTensor)
        preds = self(x)
        loss = F.cross_entropy(preds, y)
        #         loss.requires_grad = True
        acc = accuracy(preds, y)
        self.log('train_acc_step', acc)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch["x"], val_batch["y"]
        x =x.type(torch.cuda.FloatTensor)
        preds = self(x)
        loss = F.cross_entropy(preds, y)
        #         loss.requires_grad = True
        acc = accuracy(preds, y)
        self.log('val_acc_step', acc)
        self.log('val_loss', loss)
#         print(preds.detach().cpu(), y.detach().cpu())
        y2 = torch.argmax(preds, dim = 1)
        conf = confusion_matrix(y2.detach().cpu(), y.detach().cpu())
        fig = plt.figure()
        plt.imshow(conf)
#         plt.show()
        self.logger.experiment.add_figure("Confusion matrix", fig, self.current_epoch)

In [27]:
class ImageClassDs(Dataset):
    def __init__(self,
                 df: pd.DataFrame,
                 imfolder: str,
                 train: bool = True,
                 transforms=None):
        self.df = df
        self.imfolder = imfolder
        self.train = train
        self.transforms = transforms

    def __getitem__(self, index):
        im_path = self.df.iloc[index]['image_id']
        x = cv2.imread(im_path, cv2.IMREAD_COLOR)
        x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)

        if (self.transforms):
            x = self.transforms(image=x)['image']

        y = self.df.iloc[index]['label']
        return {
            "x": x,
            "y": y,
        }

    def __len__(self):
        return len(self.df)

# Load data

In [28]:
class ImDataModule(pl.LightningDataModule):
    def __init__(
            self,
            df,
            batch_size,
            num_classes,
            data_dir: str = "/media/hdd/Datasets/flowers/flowers/",
            img_size=(256, 256)):
        super().__init__()
        self.df = df
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.train_transform = A.Compose([
            A.RandomResizedCrop(img_size, img_size, p=1.0),
            ToTensorV2(p=1.0),
        ],
            p=1.)

        self.valid_transform = A.Compose([
            A.CenterCrop(img_size, img_size, p=1.),
            A.Resize(img_size, img_size),
            A.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225],
                        max_pixel_value=255.0,
                        p=1.0),
            ToTensorV2(p=1.0),
        ],
            p=1.)

    def setup(self, stage=None):
        dfx = pd.read_csv("./train_folds.csv")
        train = dfx.loc[dfx["kfold"] != 1]
        val = dfx.loc[dfx["kfold"] == 1]

        self.train_dataset = ImageClassDs(train,
                                          self.data_dir,
                                          train=True,
                                          transforms=self.train_transform)

        self.valid_dataset = ImageClassDs(val,
                                          self.data_dir,
                                          train=False,
                                          transforms=self.valid_transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.batch_size,
                          num_workers=12,
                          shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset,
                          batch_size=self.batch_size,
                          num_workers=12)

In [29]:
batch_size = 128
num_classes = 131
img_size = 64

In [30]:
dm = ImDataModule(df,
                  batch_size=batch_size,
                  num_classes=num_classes,
                  img_size=img_size)
class_ids = dm.setup()

# Final model

In [31]:
model = LitModel(num_classes);

Loaded pretrained weights for efficientnet-b5


In [32]:
logger = TensorBoardLogger("logs/", name = "flowers")

In [33]:
trainer = pl.Trainer(auto_select_gpus=True,
                     gpus=1,
#                      precision=16,
                     profiler=False,
                     max_epochs=3,
                     callbacks=[pl.callbacks.ProgressBar()],
                     logger=logger)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [34]:
trainer.fit(model, dm)


  | Name | Type         | Params
--------------------------------------
0 | enet | EfficientNet | 104 M 
--------------------------------------
104 M     Trainable params
0         Non-trainable params
104 M     Total params


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

1

In [35]:
trainer.test()

1

In [36]:
trainer.save_checkpoint('model1_flowers.ckpt')