In [None]:
import sys
!git clone https://github.com/data-sachez-2511/EasyPL.git /kaggle/working/github/easypl
sys.path.insert(1, '/kaggle/working/github/easypl')

!git clone https://github.com/rwightman/pytorch-image-models.git /kaggle/working/github/timm
sys.path.insert(1, '/kaggle/working/github/timm')
!pip uninstall -y torch
!pip uninstall -y pytorch-lightning
!python -m pip install -–upgrade pip
!pip install pytorch-lightning
!pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
!pip install torchtext==0.11.0

In [None]:
import pandas as pd
import numpy as np
import shutil
import cv2
import os
from torch.utils.data import Dataset, DataLoader
import torch
import torch.optim as optim
import torch.nn as nn
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import wandb
from albumentations.augmentations import *
from albumentations.core.composition import *
from albumentations.pytorch.transforms import *
from timm import create_model
import random
from torchmetrics import *
import shutil

from easypl.learners import ClassificatorLearner
from easypl.metrics import TorchMetric
from easypl.optimizers import WrapperOptimizer
from easypl.lr_schedulers import WrapperScheduler
from easypl.datasets import CSVDatasetClassification
from easypl.callbacks.loggers.classification import ClassificationImageLogger
from easypl.callbacks.mixers import Mixup, Cutmix, Mosaic


In [None]:
train_transform = Compose([
    HorizontalFlip(p=0.5),
    Rotate(p=0.5),
    LongestMaxSize(max_size=224),
    PadIfNeeded(min_height=224, min_width=224, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
    Normalize(),
    ToTensorV2(),
])

val_transform = Compose([
    LongestMaxSize(max_size=600),
    PadIfNeeded(min_height=600, min_width=600, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
    Normalize(),
    ToTensorV2(),
])

test_transform = Compose([
    Normalize(),
    ToTensorV2(),
])

In [None]:
train_dataset = CSVDatasetClassification('../input/cat-dog-test/train.csv', image_prefix='../input/cat-dog-test/train', transform=train_transform, return_label=True)
val_dataset = CSVDatasetClassification('../input/cat-dog-test/val.csv', image_prefix='../input/cat-dog-test/val', transform=val_transform, return_label=True)

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, pin_memory=True, num_workers=2)
val_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False, pin_memory=True, num_workers=2)

In [None]:
model = create_model('resnet18', pretrained=True, num_classes=2)

loss_f = nn.CrossEntropyLoss()

optimizer = WrapperOptimizer(optim.Adam, lr=1e-4)
lr_scheduler = WrapperScheduler(optim.lr_scheduler.StepLR, step_size=2, gamma=1e-1, interval='epoch')

train_metrics = []
val_metrics = [TorchMetric(F1(num_classes=2, average='none'), class_names=['cat', 'dog'])]

In [None]:
# finetuner = SequentialFinetune({
#     '0': {'layers': ['gloobal_pool', 'fc']},
#     '1': {'layers': ['layer4', 'layer3', 'layer2', 'layer1', 'conv1', 'maxpool', 'bn1', 'act1']},
# #     '2': {'layers': ['conv_stem', 'bn1', 'act1', 'blocks']}
# })

# checkpoint_callback = ModelCheckpoint(
#     monitor="val/loss",
#     dirpath='/kaggle/working/weigths',
#     filename='best_model',
#     save_top_k=1,
#     mode="min",
# )

image_logger = ClassificationImageLogger(
    phase='train',
    max_samples=10,
    class_names=['cat', 'dog'],
    max_log_classes=2,
    dir_path='images',
    save_on_disk=True,
)
mixup = Mosaic(
    on_batch=True,
    p=1.0,
    n_mosaics=3,
    domen='classification',
)

In [None]:
learner = ClassificatorLearner(
    model=model,
    loss=loss_f,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    train_metrics=train_metrics,
    val_metrics=val_metrics,
    data_keys=['image'],
    target_keys=['target'],
    multilabel=False
)
trainer = Trainer(gpus=1, callbacks=[image_logger, mixup], max_epochs=3, precision=16)
trainer.fit(learner, train_dataloader=train_dataloader, val_dataloaders=[val_dataloader])

In [None]:
if os.path.isdir('/kaggle/working/github'):
    shutil.rmtree('/kaggle/working/github')
if os.path.isdir('/kaggle/working/lightning_logs'):
    shutil.rmtree('/kaggle/working/lightning_logs')
if os.path.isdir('/kaggle/working/wandb'):
    shutil.rmtree('/kaggle/working/wandb')