In [None]:
!pip install torchinfo -q

In [None]:
# Импорт библиотек
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import os
# import gc
import sys
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torchvision
import torchmetrics
from torchinfo import summary
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid
import torchvision.transforms as T
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme

from sklearn.model_selection import train_test_split
import cv2
from PIL import Image

import warnings
warnings.filterwarnings('ignore')
pd.set_option("display.max_colwidth", None)
import logging
logger = logging.getLogger("lightning")
pl.seed_everything(42)

In [None]:
# Проверка версий пакетов
print(os.listdir("../input"))
print('Python       :', sys.version.split('\n')[0])
print('Numpy        :', np.__version__)
print('Pandas       :', pd.__version__)
print('PyTorch      :', torch.__version__)
print('Lightning    :', pl.__version__)
# Проверка работы GPU и CPU
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')

In [None]:
# Задаваемые параметры для задачи
BATCH_SIZE             = 16
LR                     = 1e-4
VAL_SPLIT              = 0.15

In [None]:
# Извлечение данных из архива
!tar zxf '../input/mds-misis-dl-flower-photos-classificationn/flower_photos.tgz'
!rm /kaggle/working/flower_photos/LICENSE.txt
!cp ../input/mds-misis-dl-flower-photos-classificationn/sample_submission.csv ./

In [None]:
# Определение директорий с исходными данными
DATA_PATH = 'flower_photos/'
PATH = "../working/flowers/"

In [None]:
def full_path_files(dir: str) -> list:
    files_list: list = []
    path_list: list = os.listdir(dir)
    for name_path in path_list:
        full_name = os.path.join(dir, name_path)
        if os.path.isdir(full_name):
            files_list = files_list + full_path_files(full_name)
        else:
            files_list.append(full_name)
    return files_list

def collect_data(names: list) -> pd.DataFrame:
    name_class_list: list = []
    for file_name in names:
        class_name = file_name.split('/')[-2].upper()
        if class_name[-1] == "S":
            name_class_list.append(class_name[:-1])
        else:
            name_class_list.append(class_name)
    data = pd.DataFrame({'Id': names,  'Category': name_class_list})
    return data

In [None]:
train_df = collect_data(full_path_files(DATA_PATH))
train_df.head()

In [None]:
# Представим распределение примеров по классам
sns.set_style("darkgrid")
sns.set_palette('tab10', n_colors=3)
sns.barplot(x = train_df.Category.value_counts().index, y = train_df.Category.value_counts());

In [None]:
# Преобразуем данные из колонки 'Category' в числовой формат
classes = sorted(train_df['Category'].unique())
class_label_to_index: dict = {k: v for v, k in enumerate(classes)}
class_index_to_label: dict = {value: key for key, value in class_label_to_index.items()}
train_df['Category'] = train_df['Category'].apply(lambda x: class_label_to_index[x])
train_df.head()

In [None]:
# Получившиеся словари для преобразования типов
class_label_to_index, class_index_to_label

In [None]:
train_df, val_df = train_test_split(train_df, test_size=VAL_SPLIT)
test_df = pd.read_csv('./sample_submission.csv')

In [None]:
train_df.shape,val_df.shape,test_df.shape

In [None]:
train_df.head()

In [None]:
# Создаем класс с исходными данными
class FlowersDataset(Dataset):
    def __init__(self, df, augments, is_test = False):
        super().__init__()
        self.df         = df.reset_index()
        self.augs       = augments
        self.is_test    = is_test
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index):
        image_filepath = self.df.loc[index, 'Id']
        image = cv2.imread(image_filepath)
        image = Image.fromarray(image)
        image = self.augs(image)
        
        if not self.is_test:
            label = self.df.loc[index, 'Category']
            return image, label 
        elif self.is_test:
            return image
    
    
class FlowerDataModule(pl.LightningDataModule):
    def __init__(self, train_df, val_df, test_df, batch_size:int = 8, input_dims:int = 224):
        super().__init__()
        self.train_df = train_df
        self.val_df   = val_df
        self.test_df  = test_df
        
        # Imagenet means and stds
        mean = [0.485, 0.456, 0.406]
        std  = [0.229, 0.224, 0.225]
        
        self.batch_size = batch_size
        
        self.train_augs = T.Compose([
            T.Resize(size=(input_dims, input_dims)),
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])
        
        self.valid_augs = T.Compose([
            T.Resize(size=(input_dims, input_dims)),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])
        
        self.test_augs = T.Compose([
            T.Resize(size=(input_dims, input_dims)),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])
    
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.flowers_train = FlowersDataset(self.train_df, self.train_augs)
            self.flowers_valid = FlowersDataset(self.val_df, self.valid_augs)
        
        if stage == 'test' or stage is None:
            self.flowers_test  = FlowersDataset(self.test_df, self.test_augs, is_test = True)
         
        
    def train_dataloader(self):
        return DataLoader(self.flowers_train, shuffle=True, batch_size=self.batch_size, pin_memory=True)
    
    def val_dataloader(self):
        return DataLoader(self.flowers_valid, shuffle=True, batch_size=self.batch_size, pin_memory=True)
    
    def test_dataloader(self):
        return DataLoader(self.flowers_test,  shuffle=False, batch_size=self.batch_size, pin_memory=True)

In [None]:
class FlowerClassifier(pl.LightningModule):
    def __init__(self, output_dims: int, learning_rate:float, weight_decay:float):
        super().__init__()
        self.save_hyperparameters()
        
        self.classifier  = torchvision.models.resnet34(pretrained=True, progress=True)
        base_output_dims = self.classifier.fc.out_features
        
        self.lin1   = nn.Sequential(nn.BatchNorm1d(base_output_dims),  nn.Dropout(0.2), nn.ReLU(inplace=True))
        self.lin2   = nn.Sequential(nn.Linear(base_output_dims, 1024), nn.BatchNorm1d(1024), nn.Dropout(0.5), nn.ReLU())
        self.lin3   = nn.Sequential(nn.Linear(1024, 512),  nn.BatchNorm1d(512),  nn.Dropout(0.5), nn.ReLU())
        self.output = nn.Sequential(nn.Linear(512, self.hparams.output_dims))
        
        self.accuracy = torchmetrics.Accuracy()

        self.results    = pd.DataFrame()
        self.test_preds = []
        
    def forward(self, x):
        out = self.classifier(x)
        out = self.lin3(self.lin2(self.lin1(out)))
        out = self.output(out)
        return out
    
    def training_step(self, batch, batch_idx, *args, **kwargs):
        image, clas = batch
        y_hat       = self(image)
        loss        = F.cross_entropy(y_hat, clas)
        self.log('loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss
        
    def validation_step(self, batch, batch_idx, *args, **kwargs):
        image, clas = batch
        logits = self(image)
        loss   = F.cross_entropy(logits, clas)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        metric = self.accuracy(logits, clas)
        self.log("accuracy", metric, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
    def test_step(self, batch, batch_idx, *args, **kwargs):
        image  = batch
        logits = self(batch)
        # compute the output from the logits
        preds  = logits.max(dim=1).indices.tolist()
        self.test_preds = self.test_preds + preds
    
    def test_epoch_end(self, *args, **kwargs):
        self.results['Category'] = self.test_preds
        
    def configure_optimizers(self, *args, **kwargs):
        opt = optim.AdamW(self.parameters(), lr = self.hparams.learning_rate, weight_decay = self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min")
        
        pl_scheduler = {
            "scheduler": lr_scheduler, 
            "interval": "epoch", 
            "frequency": 1, 
            "reduce_on_plateau": True, 
            "monitor": "val_loss", 
            "strict": True
        }
        
        return [opt], [pl_scheduler] 

In [None]:
a = torch.randn(4, 4)
print(a)
torch.max(a, 1).indices.tolist()

In [None]:
logger = pl.loggers.CSVLogger(save_dir="/kaggle/working/", name="kaggle_misis_flowers", version="001")

# create your own theme!
progress_bar = [RichProgressBar(
    theme=RichProgressBarTheme(
        description="green_yellow",
        progress_bar="green1",
        progress_bar_finished="green1",
        progress_bar_pulse="#6206E0",
        batch_progress="green_yellow",
        time="grey82",
        processing_speed="grey82",
        metrics="grey82",)), 
        pl.callbacks.EarlyStopping(monitor="val_loss", patience=20),
        pl.callbacks.LearningRateMonitor("step")]


trainer = pl.Trainer(devices="auto", accelerator="auto", callbacks=progress_bar, logger=logger, max_epochs=40, gradient_clip_val=0.5,)

In [None]:
dataModule = FlowerDataModule(train_df, val_df, test_df, batch_size = BATCH_SIZE, input_dims = 224)
model = FlowerClassifier(len(class_label_to_index), learning_rate = LR, weight_decay = 0.01)

In [None]:
summary(model, (BATCH_SIZE, 3, 224, 224))

In [None]:
trainer.fit(model, datamodule = dataModule)

In [None]:
trainer.test(model, datamodule=dataModule)

In [None]:
model.results.head()

In [None]:
submission = test_df
submission['Category'] = model.results
submission['Category'] = submission['Category'].apply(lambda x: class_index_to_label[x])
submission.to_csv('submission.csv', index=False)