In [None]:
import lightning as L
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets, transforms
from os.path import join
from run import read_csv, save_csv
import random
import glob
import albumentations.pytorch.transforms
import albumentations as A
from PIL import Image
import os
from torchvision import transforms
import numpy as np
from sklearn.model_selection import train_test_split
from collections import Counter
import torchvision
from torch import nn
import torchmetrics
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.models as models
import torchmetrics
import matplotlib.pyplot as plt 
from os.path import basename


def get_device():
    if torch.cuda.is_available():
        print("Using the GPU 😊")
        return torch.device("cuda")
    else:
        print("Using the CPU 😞")
        return torch.device("cpu")
    

NETWORK_SIZE = (480, 480)
BATCH_SIZE = 64
NUM_WORKERS = 0
NUM_CLASSES = 50
BASE_LR = 1e-3
MAX_EPOCHS = 50

path_train = './tests/00_test_img_input/train/'
path_experiment = './experiment/'

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
augmentations = [
    A.Rotate(limit=45, p=0.5),
    #A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.3),
    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.25),
    A.GaussianBlur(p=0.3),
    A.HorizontalFlip(p=0.5),
    #A.GaussNoise(var_limit=(5.0, 20.0), p=0.3),
    #A.ElasticTransform(alpha=3, sigma=2, p=0.5),
    A.Perspective(p=0.25)
    #A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=20, p=0.3),
    #A.CoarseDropout(max_holes=8, max_height=8, max_width=8, p=0.5),
]

common_transforms = [
    A.Resize(*NETWORK_SIZE),
    A.ToFloat(max_value=255),
    A.Normalize(max_pixel_value=1.0, mean=IMAGENET_MEAN, std=IMAGENET_STD),
    A.pytorch.transforms.ToTensorV2(),
]

MyFitTransform = A.Compose(augmentations + common_transforms)
MyPredictTransform = A.Compose(common_transforms)


class ImgDataset(Dataset):
    def __init__(self, img_dir, data, stage, transform=None):
        self._image_dir = img_dir
        self._data = data
        self._stage = stage

        if not transform:
            if stage in ['train', 'fit']:
                self._transform = MyFitTransform
            elif stage in ['validate', 'test', 'predict']:
                self._transform = MyPredictTransform
                
    def __len__(self):
        return len(self._data)

    def __getitem__(self, idx):
        if self._stage in ['fit', 'train', 'validate', 'test']:
            img_path, label = self._data[idx]
            image = Image.open(img_path).convert("RGB")
            if self._transform:
                image = self._transform(image=np.array(image))
            return image['image'], label
        elif self._stage == 'predict':
            img_path = self._data[idx]
            image = Image.open(img_path).convert("RGB")
            if self._transform:
                image = self._transform(image=np.array(image))
            return image['image']

class ImgDataModule(L.LightningDataModule):
    def __init__(
        self, 
        data_dir = '', 
        img_dir = '',
        gt_dict = None, 
        batch_size: int = BATCH_SIZE, 
        num_workers: int = NUM_WORKERS,
        split_seed=42,
        train_share = 0.7,
        valid_share = 0.2,
        test_share = 0.1,
        transform=common_transforms,
        aug_transform=augmentations,
    ):
        super().__init__()
        self._data_dir = data_dir
        self._img_dir = img_dir
        self._gt_dict = gt_dict
        self._batch_size = batch_size
        self._num_workers = num_workers
        self._split_seed = split_seed
        self._train_share = train_share
        self._valid_share = valid_share
        self._test_share = test_share
        self._transform = transform
        self._aug_transform = aug_transform
        
    def setup(self, stage):
        if self._data_dir != '':
            img_dir = join(self._data_dir, 'images')
        elif self._data_dir == '' and self._img_dir != '':
            img_dir = self._img_dir
        paths = sorted(glob.glob(f"{img_dir}/*"))
        
        if stage in ['fit', 'train', 'validate', 'test']:
            if self._data_dir != '':
                gt_dict = read_csv(join(self._data_dir, 'gt.csv'))
            elif self._data_dir == '' and self._gt_dict is not None:
                gt_dict = self._gt_dict
            
            labels = [gt_dict[path.split('/')[-1]] for path in paths]
            
            path_train, path_test, label_train, label_test = train_test_split(paths, labels, test_size=(self._test_share + self._valid_share), random_state=self._split_seed, stratify=labels)
            path_val, path_test, label_val, label_test = train_test_split(path_test, label_test, test_size=(self._test_share / self._valid_share), random_state=self._split_seed, stratify=label_test)
            
            self._train_data = list(zip(path_train, label_train))
            self._val_data = list(zip(path_val, label_val))
            self._test_data = list(zip(path_test, label_test))
            #print(Counter(labels))
            #print(Counter(label_train))
            #print(Counter(label_val))
            #print(Counter(label_test))

        elif stage == 'predict':
            self._pred_data = paths
            
        else:
            raise RuntimeError(f"Invalid stage: {stage!r}")
        
        self._img_dir = img_dir

    def train_dataloader(self):
        ds = ImgDataset(img_dir=self._img_dir, data=self._train_data, stage='train')
        return DataLoader(ds, batch_size=self._batch_size, shuffle=True, drop_last=True, num_workers=self._num_workers)

    def val_dataloader(self):
        ds = ImgDataset(img_dir=self._img_dir, data=self._val_data, stage='validate')
        return DataLoader(ds, batch_size=self._batch_size, shuffle=False, drop_last=False, num_workers=self._num_workers)

    def test_dataloader(self):
        ds = ImgDataset(img_dir=self._img_dir, data=self._test_data, stage='test')
        return DataLoader(ds, batch_size=self._batch_size, shuffle=False, drop_last=False, num_workers=self._num_workers)
    
    def predict_dataloader(self):
        ds = ImgDataset(img_dir=self._img_dir, data=self._pred_data, stage='predict')
        return DataLoader(ds, batch_size=self._batch_size, shuffle=False, drop_last=False, num_workers=self._num_workers)
    
class LightningBirdClassifier(L.LightningModule):

    def __init__(self, *, transfer=False, lr=BASE_LR, model_path='./birds_model.pt', **kwargs):
        super().__init__(**kwargs)
        self.lr = lr
        self.transfer = transfer
        self.num_classes = 50
        self.model = self.get_model()
        self.loss_fn = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.classification.Accuracy(
            task="multiclass",
            num_classes=self.num_classes,
        )
        self.model_path = model_path

    def get_model(self):
        if not self.transfer:
            model = models.efficientnet_v2_m(weights="IMAGENET1K_V1")
            num_features = model.classifier[1].in_features
            model.classifier[1] = nn.Sequential(
                nn.Dropout(p=0.5),
                nn.Linear(num_features, 256),
                nn.ReLU(),
                nn.Linear(256, self.num_classes),
                nn.Softmax(dim=1)
            )
            nn.Linear(num_features, self.num_classes)
            
            
            for param in model.parameters():
                param.requires_grad = False
            
            for param in model.classifier[1].parameters():
                param.requires_grad = True
            return model
        else:
            self.load_model()
            return self.model
    
    def save_model(self):
        torch.save(self.model, 'birds_model.pt')

    def load_model(self):
        self.model = torch.load(self.model_path, map_location=get_device())
        self.model.eval()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=10, gamma=0.1
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "valid_loss",  # Мониторинг значения валидации
                "interval": "epoch",      # Обновление каждую эпоху
                "frequency": 1            # Частота обновления
            },
        }

    def training_step(self, batch):
        return self._step(batch, "train")

    def validation_step(self, batch):
        return self._step(batch, "valid")

    def _step(self, batch, kind):
        x, y = batch
        p = self.model(x)
        loss = self.loss_fn(p, y)
        acc = self.accuracy(p.argmax(dim=-1), y)

        return self._log_metrics(loss, acc, kind)

    def _log_metrics(self, loss, acc, kind):
        metrics = {}
        if loss is not None:
            metrics[f"{kind}_loss"] = loss
        if acc is not None:
            metrics[f"{kind}_acc"] = acc
        self.log_dict(
            metrics,
            prog_bar=True,
            logger=True,
            on_step=(kind == "train"),
            on_epoch=True,
        )
        return loss


def train_model(
    experiment_path,
    data_module,
    model,
    max_epochs=MAX_EPOCHS,
    **trainer_kwargs,
):

    callbacks = [
        L.pytorch.callbacks.TQDMProgressBar(leave=True),
        L.pytorch.callbacks.LearningRateMonitor(),
        L.pytorch.callbacks.ModelCheckpoint(
            filename="{epoch}-{valid_acc:.3f}",
            monitor="valid_acc",
            mode="max",
            save_top_k=3,
            save_last=True,
        ),
    ]

    trainer = L.Trainer(
        callbacks=callbacks,
        max_epochs=max_epochs,
        default_root_dir=experiment_path,
        **trainer_kwargs,
    )
    
    data_module.setup(stage='fit')
    trainer.fit(model, train_dataloaders=data_module.train_dataloader(), val_dataloaders=data_module.val_dataloader())

    
def train_classifier(train_gt, train_img_dir, fast_train=False, num_epochs=MAX_EPOCHS):
    if fast_train:
        data_module = ImgDataModule(
            img_dir=train_img_dir,
            gt_dict=train_gt,
            batch_size=BATCH_SIZE,
            train_share=0.2,
            valid_share=0.2,
            test_share=0.6
        )
        
        torch.set_float32_matmul_precision('medium')
        
        model = LightningBirdClassifier(
            transfer=True,
            lr=BASE_LR,
        )
        
        train_model(
            experiment_path=path_experiment, #УБРАТЬ ДО ЗАСЫЛА В СИСТЕМУУУУ или нетт...
            data_module=data_module,
            model=model,
            max_epochs=num_epochs,      
            accelerator="cpu" if get_device() == torch.device('cpu') else 'gpu',
            devices=1,
            precision=16,
        )
        
    else:
        data_module = ImgDataModule(
            img_dir=train_img_dir,
            gt_dict=train_gt,
            batch_size=BATCH_SIZE,
            train_share=0.8,
            valid_share=0.199,
            test_share=0.001 # если нужно тестирование
        )
        
        #torch.set_float32_matmul_precision('medium')
        
        model = LightningBirdClassifier(
            transfer=False,
            lr=BASE_LR,
        )
        
        train_model(
            experiment_path=path_experiment, #УБРАТЬ ДО ЗАСЫЛА В СИСТЕМУУУУ или нетт...
            data_module=data_module,
            model=model,
            max_epochs=num_epochs,      
            accelerator="cpu" if get_device() == torch.device('cpu') else 'gpu',
            devices=1,
            precision=16,
        )
        model.save_model()
        
    return model


def classify(model_path, test_img_dir):
    data_module = ImgDataModule(
            img_dir=test_img_dir,
            batch_size=BATCH_SIZE,
            train_share=0.7,
            valid_share=0.2,
            test_share=0.1 # если нужно тестирование
        )
    
    model = LightningBirdClassifier(
            transfer=True,
            lr=BASE_LR,
            model_path=model_path,
        )
    
    trainer = L.Trainer()
 
    predictions_batches = trainer.predict(model, data_module.predict_dataloader())
    
    results = {}
    image_paths = data_module._pred_data

    predictions = torch.cat([torch.argmax(batch, dim=1) for batch in predictions_batches]).cpu().numpy()

    results = {
        basename(image_path): int(pred_class)
        for image_path, pred_class in zip(image_paths, predictions)
    }

    return results

In [None]:
def print_image(dl, batches=1):
    count_bat = 0
    for batch in dl:
        images, labels = batch
        #print(images.shape)
        for i in range(images.shape[0]):
            image, label = images[i], labels[i]
            image = (image - image.min()) / (image.max() - image.min())

            plt.figure(figsize=(6, 6))
            plt.imshow(image.numpy().transpose((1, 2, 0)))
            plt.title(label=label)

            plt.show()
        count_bat += 1
        if count_bat == batches:
            break

In [None]:
data_module1 = ImgDataModule(
    data_dir=path_train, 
    batch_size=16,
    train_share=0.7,
    valid_share=0.2,
    test_share=0.1  # если нужно тестирование
)
'''data_module1.setup(stage='train')
train_loader = data_module1.train_dataloader()
print_image(train_loader, 10)

val_loader = data_module1.val_dataloader()
print_image(val_loader)

data_module1.setup(stage='validate')
val_loader = data_module1.val_dataloader()
print_image(val_loader)

data_module1.setup(stage='test')
test_loader = data_module1.test_dataloader()
print_image(test_loader)

data_module1.setup(stage='predict')
pred_loader = data_module1.predict_dataloader()
print_image(pred_loader)'''

In [10]:
model = LightningBirdClassifier(
            transfer=False,
            lr=BASE_LR,
        )
model.get_model()
model.save_model()

In [19]:
def run_single_test(data_dir, output_dir):
    from classification import train_classifier, classify
    from os.path import abspath, dirname, join

    train_dir = join(data_dir, 'train')
    test_dir = join(data_dir, 'test')

    train_gt = read_csv(join(train_dir, 'gt.csv'))
    train_img_dir = join(train_dir, 'images')

    model = train_classifier(train_gt, train_img_dir, fast_train=False)
    MODEL = model
    
    code_dir = './'
    model_path = join(code_dir, 'birds_model.pt')
    test_img_dir = join(test_dir, 'images')
    img_classes = classify(model_path, test_img_dir)
    save_csv(img_classes, join(output_dir, 'output.csv'))

In [None]:
from re import sub
from os import makedirs

for input_dir in sorted(glob.glob(join('./tests/', '[0-9][0-9]_*_input'))):
    output_dir = sub('input$', 'check', input_dir)
    run_output_dir = join(output_dir, 'output')
    makedirs(run_output_dir, exist_ok=True)
    run_single_test(input_dir, run_output_dir)

/Users/user/Documents/programing/Python/CV/cv_env/lib/python3.12/site-packages/lightning/fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
/Users/user/Documents/programing/Python/CV/cv_env/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:512: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name     | Type               | Params | Mode 
--------------------------------------------------------
0 | model    | EfficientNet       | 53.2 M | train
1 | loss_fn  | CrossEntropyLoss   | 0      | train
2 | accuracy | MulticlassAccuracy | 0      | train
-----------------------------------------------

Using the CPU 😞
Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/user/Documents/programing/Python/CV/cv_env/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]