In [1]:
!pip install timm

Collecting timm
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
[K     |████████████████████████████████| 376 kB 808 kB/s 
Installing collected packages: timm
Successfully installed timm-0.4.12


In [2]:
import os
import sys
from glob import glob

import cv2
import warnings
import argparse
import random
import gc
import pandas as pd
from glob import glob
from tqdm import tqdm
import numpy as np
import torchvision
from torchvision import datasets, transforms
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
import matplotlib
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score

from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader, random_split, Dataset

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau

from sklearn.model_selection import KFold, GroupKFold, train_test_split, StratifiedKFold
warnings.simplefilter('ignore')

In [3]:
cat_path = glob("../input/cat-dataset/*/*.jpg")
dog_path = glob("../input/stanford-dogs-dataset/images/Images/*/*.jpg")
print(len(cat_path), len(dog_path))

9997 20580


In [4]:
dog_df = pd.DataFrame(columns=["file_path", "label"])
dog_df["file_path"] = dog_path
dog_df["label"] = 0

cat_df = pd.DataFrame(columns=["file_path", "label"])
cat_df["file_path"] = cat_path
cat_df["label"] = 1

df = pd.concat([dog_df, cat_df], axis=0)
df = df.reset_index(drop=True)

In [5]:
class CatDogDataset(Dataset):

    def __init__(self, df, transforms, type_):
        self.type = type_
        self.df = df
        self.transforms = transforms
        self.file_names = df['file_path'].values
        self.labels = df["label"].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        path = self.file_names[index]
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transforms(image=image)['image'].float()
        label = torch.tensor(self.labels[index]).long()
        return image, label

In [6]:
image_size = 384
def train_transforms():
    train_transform = A.Compose(
        [
            A.RandomResizedCrop(int(image_size), int(
                image_size), scale=(0.90, 1.0)),
            A.RandomBrightnessContrast(p=0.2),
            A.ShiftScaleRotate(p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(),
        ]
    )
    return train_transform

def valid_transforms():
    valid_transform = A.Compose(
        [
            A.Resize(image_size, image_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(),
        ]
    )
    return valid_transform

In [7]:
df["fold"] = 0
Fold = StratifiedKFold(n_splits=5, shuffle=True, random_state=241)
for n, (train_index, val_index) in enumerate(Fold.split(df, df.label)):
    df.loc[val_index, 'fold'] = int(n)

In [8]:
df.head()

Unnamed: 0,file_path,label,fold
0,../input/stanford-dogs-dataset/images/Images/n...,0,3
1,../input/stanford-dogs-dataset/images/Images/n...,0,4
2,../input/stanford-dogs-dataset/images/Images/n...,0,3
3,../input/stanford-dogs-dataset/images/Images/n...,0,4
4,../input/stanford-dogs-dataset/images/Images/n...,0,0


In [9]:
class Custom2DCNN(nn.Module):
    def __init__(self):
        super(Custom2DCNN, self).__init__()
        model_name = "tf_efficientnet_b1"
        backborn = timm.create_model(model_name, pretrained=True, in_chans=3)
        if 'efficientnet' in model_name:
            n_features = backborn.classifier.in_features
            backborn.classifier = nn.Identity()
            self.backborn = backborn
        else:
            n_features = list(backborn.children())[-1].in_features
            backborn = list(backborn.children())[:-1]
            self.backborn = torch.nn.Sequential(*backborn)
        self.fc = nn.Linear(n_features, 1)

    def forward(self, x):
        x = self.backborn(x)
        x = self.fc(x)
        return x

In [10]:
def train_fn(epoch, model, loss_fn, optimizer, train_loader, scaler, device, scheduler=None):
    model.train()
    losses = AverageMeter()
    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).float()

        with autocast():
            y_preds = model(imgs).squeeze(1)
            loss = loss_fn(y_preds, image_labels)
        scaler.scale(loss).backward()
        if ((step + 1) % 2 == 0) or ((step + 1) == len(train_loader)):
            # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        losses.update(loss.item(), 32)

    if ((step + 1) % 1 == 0) or ((step + 1) == len(train_loader)):
        description = f'epoch {epoch} loss: {losses.avg:.4f}'
        pbar.set_description(description)
    scheduler.step()

    return losses.avg

In [11]:
def valid_fn(epoch, model, loss_fn, val_loader, device, scheduler=None):
    model.eval()
    losses = AverageMeter()
    image_preds_all = []
    image_targets_all = []
    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).float()

        image_preds = model(imgs).squeeze(1)  # output = model(input)
        image_preds_all += [np.where(image_preds.detach().cpu().numpy() < 0.5, 0, 1)]
        image_targets_all += [image_labels.detach().cpu().numpy()]

        loss = loss_fn(image_preds, image_labels)
        losses.update(loss.item(), 32)


        if ((step + 1) % 1 == 0) or ((step + 1) == len(val_loader)):
            description = f'epoch {epoch} loss: {losses.avg:.4f}'
            pbar.set_description(description)

    image_preds_all = np.concatenate(image_preds_all)
    image_targets_all = np.concatenate(image_targets_all)
    print('validation class accuracy = {:.4f}'.format((image_preds_all == image_targets_all).mean()))

    return losses.avg, image_preds_all, image_targets_all

In [12]:
def get_score(y_true, y_pred):
    return accuracy_score(y_true, y_pred)

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [13]:
for fold in range(5):
    if fold >1:
        continue
    train = df[df.fold != fold]
    val = df[df.fold == fold]
    train_dataset = CatDogDataset(train, train_transforms(), 'train')
    val_dataset = CatDogDataset(val, valid_transforms(), 'train')
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=4,
        pin_memory=False,
        drop_last=True)
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=4,
        pin_memory=False,
        drop_last=True)
    
    loss = torch.nn.BCEWithLogitsLoss()
    model = Custom2DCNN().to(device)

    best_score = 0.
    
    scaler = GradScaler()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)    
    scheduler = CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-5, last_epoch=-1)

    for epoch in range(20):
        train_loss = train_fn(
            epoch, model, loss, optimizer, train_loader, scaler, device, scheduler=scheduler)
        with torch.no_grad():
            valid_loss, valid_preds, valid_labels = valid_fn(
                epoch, model, loss, val_loader, device, scheduler=None)
        score = get_score(valid_labels, valid_preds)
        print(f'Epoch {epoch+1} - avg_train_loss: {train_loss:.4f}  avg_val_loss: {valid_loss:.4f}')
        print(f'Epoch {epoch+1} - Accuracy: {score}')
        if score > best_score:
            best_score = score
            print(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'model': model.state_dict(), 'preds': valid_preds}, f'fold_{fold}_best.pth')
        torch.save(model.state_dict(), f'fold_{fold}_{epoch}.pth')


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_b1_aa-ea7a6ee0.pth
100%|██████████| 764/764 [06:33<00:00,  1.94it/s]
epoch 0 loss: 0.0039: 100%|██████████| 191/191 [01:04<00:00,  2.95it/s]


validation class accuracy = 0.9989
Epoch 1 - avg_train_loss: 0.0489  avg_val_loss: 0.0039
Epoch 1 - Accuracy: 0.9988547120418848
Epoch 1 - Save Best Score: 0.9989 Model


100%|██████████| 764/764 [06:23<00:00,  1.99it/s]
epoch 1 loss: 0.0025: 100%|██████████| 191/191 [01:01<00:00,  3.12it/s]


validation class accuracy = 0.9990
Epoch 2 - avg_train_loss: 0.0044  avg_val_loss: 0.0025
Epoch 2 - Accuracy: 0.9990183246073299
Epoch 2 - Save Best Score: 0.9990 Model


100%|██████████| 764/764 [06:23<00:00,  1.99it/s]
epoch 2 loss: 0.0038: 100%|██████████| 191/191 [01:03<00:00,  3.03it/s]

validation class accuracy = 0.9989
Epoch 3 - avg_train_loss: 0.0026  avg_val_loss: 0.0038
Epoch 3 - Accuracy: 0.9988547120418848



100%|██████████| 764/764 [06:27<00:00,  1.97it/s]
epoch 3 loss: 0.0011: 100%|██████████| 191/191 [01:04<00:00,  2.96it/s]


validation class accuracy = 0.9997
Epoch 4 - avg_train_loss: 0.0022  avg_val_loss: 0.0011
Epoch 4 - Accuracy: 0.99967277486911
Epoch 4 - Save Best Score: 0.9997 Model


100%|██████████| 764/764 [06:26<00:00,  1.97it/s]
epoch 4 loss: 0.0022: 100%|██████████| 191/191 [01:02<00:00,  3.04it/s]

validation class accuracy = 0.9992
Epoch 5 - avg_train_loss: 0.0027  avg_val_loss: 0.0022
Epoch 5 - Accuracy: 0.9991819371727748



100%|██████████| 764/764 [06:26<00:00,  1.97it/s]
epoch 5 loss: 0.0045: 100%|██████████| 191/191 [01:03<00:00,  3.00it/s]

validation class accuracy = 0.9990
Epoch 6 - avg_train_loss: 0.0012  avg_val_loss: 0.0045
Epoch 6 - Accuracy: 0.9990183246073299



100%|██████████| 764/764 [06:27<00:00,  1.97it/s]
epoch 6 loss: 0.0023: 100%|██████████| 191/191 [01:03<00:00,  3.03it/s]

validation class accuracy = 0.9993
Epoch 7 - avg_train_loss: 0.0013  avg_val_loss: 0.0023
Epoch 7 - Accuracy: 0.9993455497382199



100%|██████████| 764/764 [06:25<00:00,  1.98it/s]
epoch 7 loss: 0.0013: 100%|██████████| 191/191 [01:02<00:00,  3.04it/s]

validation class accuracy = 0.9993
Epoch 8 - avg_train_loss: 0.0009  avg_val_loss: 0.0013
Epoch 8 - Accuracy: 0.9993455497382199



100%|██████████| 764/764 [06:24<00:00,  1.99it/s]
epoch 8 loss: 0.0023: 100%|██████████| 191/191 [01:01<00:00,  3.12it/s]

validation class accuracy = 0.9989
Epoch 9 - avg_train_loss: 0.0015  avg_val_loss: 0.0023
Epoch 9 - Accuracy: 0.9988547120418848



100%|██████████| 764/764 [06:24<00:00,  1.99it/s]
epoch 9 loss: 0.0028: 100%|██████████| 191/191 [01:02<00:00,  3.03it/s]

validation class accuracy = 0.9990
Epoch 10 - avg_train_loss: 0.0009  avg_val_loss: 0.0028
Epoch 10 - Accuracy: 0.9990183246073299



100%|██████████| 764/764 [06:23<00:00,  1.99it/s]
epoch 10 loss: 0.0018: 100%|██████████| 191/191 [01:00<00:00,  3.17it/s]

validation class accuracy = 0.9995
Epoch 11 - avg_train_loss: 0.0003  avg_val_loss: 0.0018
Epoch 11 - Accuracy: 0.9995091623036649



100%|██████████| 764/764 [06:25<00:00,  1.98it/s]
epoch 11 loss: 0.0028: 100%|██████████| 191/191 [01:00<00:00,  3.14it/s]

validation class accuracy = 0.9990
Epoch 12 - avg_train_loss: 0.0004  avg_val_loss: 0.0028
Epoch 12 - Accuracy: 0.9990183246073299



100%|██████████| 764/764 [06:25<00:00,  1.98it/s]
epoch 12 loss: 0.0024: 100%|██████████| 191/191 [01:04<00:00,  2.96it/s]

validation class accuracy = 0.9990
Epoch 13 - avg_train_loss: 0.0002  avg_val_loss: 0.0024
Epoch 13 - Accuracy: 0.9990183246073299



100%|██████████| 764/764 [06:23<00:00,  1.99it/s]
epoch 13 loss: 0.0020: 100%|██████████| 191/191 [01:02<00:00,  3.06it/s]

validation class accuracy = 0.9995
Epoch 14 - avg_train_loss: 0.0003  avg_val_loss: 0.0020
Epoch 14 - Accuracy: 0.9995091623036649



100%|██████████| 764/764 [06:27<00:00,  1.97it/s]
epoch 14 loss: 0.0017: 100%|██████████| 191/191 [01:00<00:00,  3.16it/s]

validation class accuracy = 0.9995
Epoch 15 - avg_train_loss: 0.0000  avg_val_loss: 0.0017
Epoch 15 - Accuracy: 0.9995091623036649



100%|██████████| 764/764 [06:29<00:00,  1.96it/s]
epoch 15 loss: 0.0027: 100%|██████████| 191/191 [01:01<00:00,  3.13it/s]


validation class accuracy = 0.9993
Epoch 16 - avg_train_loss: 0.0001  avg_val_loss: 0.0027
Epoch 16 - Accuracy: 0.9993455497382199


100%|██████████| 764/764 [06:29<00:00,  1.96it/s]
epoch 16 loss: 0.0023: 100%|██████████| 191/191 [01:05<00:00,  2.90it/s]

validation class accuracy = 0.9992
Epoch 17 - avg_train_loss: 0.0002  avg_val_loss: 0.0023
Epoch 17 - Accuracy: 0.9991819371727748



100%|██████████| 764/764 [06:23<00:00,  1.99it/s]
epoch 17 loss: 0.0023: 100%|██████████| 191/191 [01:00<00:00,  3.16it/s]

validation class accuracy = 0.9992
Epoch 18 - avg_train_loss: 0.0001  avg_val_loss: 0.0023
Epoch 18 - Accuracy: 0.9991819371727748



100%|██████████| 764/764 [06:26<00:00,  1.97it/s]
epoch 18 loss: 0.0022: 100%|██████████| 191/191 [01:00<00:00,  3.18it/s]

validation class accuracy = 0.9992
Epoch 19 - avg_train_loss: 0.0000  avg_val_loss: 0.0022
Epoch 19 - Accuracy: 0.9991819371727748



100%|██████████| 764/764 [06:21<00:00,  2.00it/s]
epoch 19 loss: 0.0025: 100%|██████████| 191/191 [01:05<00:00,  2.91it/s]


validation class accuracy = 0.9990
Epoch 20 - avg_train_loss: 0.0001  avg_val_loss: 0.0025
Epoch 20 - Accuracy: 0.9990183246073299


100%|██████████| 764/764 [06:22<00:00,  2.00it/s]
epoch 0 loss: 0.0027: 100%|██████████| 191/191 [01:00<00:00,  3.15it/s]


validation class accuracy = 0.9998
Epoch 1 - avg_train_loss: 0.0507  avg_val_loss: 0.0027
Epoch 1 - Accuracy: 0.9998363874345549
Epoch 1 - Save Best Score: 0.9998 Model


100%|██████████| 764/764 [06:27<00:00,  1.97it/s]
epoch 1 loss: 0.0039: 100%|██████████| 191/191 [01:00<00:00,  3.15it/s]

validation class accuracy = 0.9992
Epoch 2 - avg_train_loss: 0.0049  avg_val_loss: 0.0039
Epoch 2 - Accuracy: 0.9991819371727748



100%|██████████| 764/764 [06:23<00:00,  1.99it/s]
epoch 2 loss: 0.0022: 100%|██████████| 191/191 [01:03<00:00,  3.00it/s]


validation class accuracy = 0.9995
Epoch 3 - avg_train_loss: 0.0036  avg_val_loss: 0.0022
Epoch 3 - Accuracy: 0.9995091623036649


100%|██████████| 764/764 [06:24<00:00,  1.99it/s]
epoch 3 loss: 0.0078: 100%|██████████| 191/191 [01:00<00:00,  3.13it/s]

validation class accuracy = 0.9974
Epoch 4 - avg_train_loss: 0.0017  avg_val_loss: 0.0078
Epoch 4 - Accuracy: 0.9973821989528796



100%|██████████| 764/764 [06:26<00:00,  1.98it/s]
epoch 4 loss: 0.0027: 100%|██████████| 191/191 [01:00<00:00,  3.16it/s]

validation class accuracy = 0.9992
Epoch 5 - avg_train_loss: 0.0021  avg_val_loss: 0.0027
Epoch 5 - Accuracy: 0.9991819371727748



100%|██████████| 764/764 [06:23<00:00,  1.99it/s]
epoch 5 loss: 0.0023: 100%|██████████| 191/191 [01:04<00:00,  2.95it/s]

validation class accuracy = 0.9993
Epoch 6 - avg_train_loss: 0.0012  avg_val_loss: 0.0023
Epoch 6 - Accuracy: 0.9993455497382199



100%|██████████| 764/764 [06:24<00:00,  1.99it/s]
epoch 6 loss: 0.0024: 100%|██████████| 191/191 [01:00<00:00,  3.16it/s]

validation class accuracy = 0.9993
Epoch 7 - avg_train_loss: 0.0004  avg_val_loss: 0.0024
Epoch 7 - Accuracy: 0.9993455497382199



100%|██████████| 764/764 [06:28<00:00,  1.97it/s]
epoch 7 loss: 0.0019: 100%|██████████| 191/191 [01:00<00:00,  3.17it/s]

validation class accuracy = 0.9992
Epoch 8 - avg_train_loss: 0.0011  avg_val_loss: 0.0019
Epoch 8 - Accuracy: 0.9991819371727748



100%|██████████| 764/764 [06:23<00:00,  1.99it/s]
epoch 8 loss: 0.0019: 100%|██████████| 191/191 [01:08<00:00,  2.81it/s]


validation class accuracy = 0.9995
Epoch 9 - avg_train_loss: 0.0007  avg_val_loss: 0.0019
Epoch 9 - Accuracy: 0.9995091623036649


100%|██████████| 764/764 [06:25<00:00,  1.98it/s]
epoch 9 loss: 0.0030: 100%|██████████| 191/191 [01:00<00:00,  3.16it/s]

validation class accuracy = 0.9990
Epoch 10 - avg_train_loss: 0.0006  avg_val_loss: 0.0030
Epoch 10 - Accuracy: 0.9990183246073299



100%|██████████| 764/764 [06:30<00:00,  1.96it/s]
epoch 10 loss: 0.0017: 100%|██████████| 191/191 [01:00<00:00,  3.15it/s]

validation class accuracy = 0.9993
Epoch 11 - avg_train_loss: 0.0002  avg_val_loss: 0.0017
Epoch 11 - Accuracy: 0.9993455497382199



100%|██████████| 764/764 [06:23<00:00,  1.99it/s]
epoch 11 loss: 0.0022: 100%|██████████| 191/191 [01:03<00:00,  3.03it/s]


validation class accuracy = 0.9990
Epoch 12 - avg_train_loss: 0.0006  avg_val_loss: 0.0022
Epoch 12 - Accuracy: 0.9990183246073299


100%|██████████| 764/764 [06:29<00:00,  1.96it/s]
epoch 12 loss: 0.0018: 100%|██████████| 191/191 [01:01<00:00,  3.08it/s]

validation class accuracy = 0.9995
Epoch 13 - avg_train_loss: 0.0008  avg_val_loss: 0.0018
Epoch 13 - Accuracy: 0.9995091623036649



100%|██████████| 764/764 [06:28<00:00,  1.96it/s]
epoch 13 loss: 0.0017: 100%|██████████| 191/191 [01:03<00:00,  2.99it/s]

validation class accuracy = 0.9997
Epoch 14 - avg_train_loss: 0.0002  avg_val_loss: 0.0017
Epoch 14 - Accuracy: 0.99967277486911



100%|██████████| 764/764 [06:25<00:00,  1.98it/s]
epoch 14 loss: 0.0071: 100%|██████████| 191/191 [00:59<00:00,  3.19it/s]

validation class accuracy = 0.9984
Epoch 15 - avg_train_loss: 0.0004  avg_val_loss: 0.0071
Epoch 15 - Accuracy: 0.9983638743455497



100%|██████████| 764/764 [06:30<00:00,  1.96it/s]
epoch 15 loss: 0.0015: 100%|██████████| 191/191 [01:00<00:00,  3.17it/s]

validation class accuracy = 0.9995
Epoch 16 - avg_train_loss: 0.0002  avg_val_loss: 0.0015
Epoch 16 - Accuracy: 0.9995091623036649



100%|██████████| 764/764 [06:24<00:00,  1.99it/s]
epoch 16 loss: 0.0010: 100%|██████████| 191/191 [01:04<00:00,  2.96it/s]


validation class accuracy = 0.9997
Epoch 17 - avg_train_loss: 0.0002  avg_val_loss: 0.0010
Epoch 17 - Accuracy: 0.99967277486911


100%|██████████| 764/764 [06:29<00:00,  1.96it/s]
epoch 17 loss: 0.0012: 100%|██████████| 191/191 [01:00<00:00,  3.15it/s]

validation class accuracy = 0.9998
Epoch 18 - avg_train_loss: 0.0002  avg_val_loss: 0.0012
Epoch 18 - Accuracy: 0.9998363874345549



100%|██████████| 764/764 [06:30<00:00,  1.95it/s]
epoch 18 loss: 0.0011: 100%|██████████| 191/191 [01:12<00:00,  2.62it/s]


validation class accuracy = 0.9997
Epoch 19 - avg_train_loss: 0.0001  avg_val_loss: 0.0011
Epoch 19 - Accuracy: 0.99967277486911


100%|██████████| 764/764 [06:27<00:00,  1.97it/s]
epoch 19 loss: 0.0012: 100%|██████████| 191/191 [01:04<00:00,  2.97it/s]

validation class accuracy = 0.9998
Epoch 20 - avg_train_loss: 0.0001  avg_val_loss: 0.0012
Epoch 20 - Accuracy: 0.9998363874345549



