In [None]:
import sys
sys.path.append("../input/tez-lib/")
sys.path.append("../input/timmmaster/")

In [None]:
import tez
import albumentations
import pandas as pd
import cv2
import numpy as np
import timm
import torch.nn as nn
from sklearn import metrics
import torch
from tez.callbacks import EarlyStopping
from tqdm import tqdm

In [None]:
class args:
    batch_size = 256
    image_size = 96
    epochs = 20
    fold = 0

In [None]:
class StarfishDataset:
    def __init__(self, image_paths,  targets, augmentations):
        self.image_paths = image_paths
        # self.dense_features = dense_features
        self.targets = targets
        self.augmentations = augmentations
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, item):
        image = cv2.imread(self.image_paths[item])
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.augmentations is not None:
            augmented = self.augmentations(image=image)
            image = augmented["image"]
            
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        
        # features = self.dense_features[item, :]
        targets = self.targets[item]
        
        return {
            "image": torch.tensor(image, dtype=torch.float),
            # "features": torch.tensor(features, dtype=torch.float),
            "targets": torch.tensor(targets, dtype=torch.float),
        }

In [None]:
class StarfishModel(tez.Model):
    def __init__(self):
        super().__init__()

        self.model = timm.create_model("tf_efficientnet_b0_ns", pretrained=True, in_chans=3)
        self.model.classifier = nn.Linear(self.model.classifier.in_features, 64)
        self.dropout = nn.Dropout(0.1)
        #self.dense1 = nn.Linear(64, 32)
        self.dense2 = nn.Linear(64, 1)

        
        self.step_scheduler_after = "epoch"

    def monitor_metrics(self, outputs, targets):
        outputs = outputs.cpu().detach().numpy()
        targets = targets.cpu().detach().numpy()
        rmse = metrics.mean_squared_error(targets, outputs, squared=False)
        return {"rmse": rmse}

    def fetch_scheduler(self):
        sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1
        )
        return sch

    def fetch_optimizer(self):
        opt = torch.optim.Adam(self.parameters(), lr=1e-4)
        return opt

    def forward(self, image, targets=None):

        x = self.model(image)
        x = self.dropout(x)
        # x = torch.cat([x, features], dim=1)
        #x = self.dense1(x)
        #x = torch.relu(x)
        x = self.dense2(x)

        if targets is not None:
            loss = nn.MSELoss()(x, targets.view(-1, 1))
            metrics = self.monitor_metrics(x, targets)
            return x, loss, metrics
        return x, 0, {}

In [None]:
train_aug = albumentations.Compose(
    [
        albumentations.Resize(args.image_size, args.image_size, p=1),
        albumentations.HueSaturationValue(
            hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5
        ),
        albumentations.RandomBrightnessContrast(
            brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5
        ),
        albumentations.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0,
            p=1.0,
        ),
    ],
    p=1.0,
)

valid_aug = albumentations.Compose(
    [
        albumentations.Resize(args.image_size, args.image_size, p=1),
        albumentations.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0,
            p=1.0,
        ),
    ],
    p=1.0,
)

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
data_list = []
for dirname, _, filenames in os.walk('/kaggle/input/binary-cropped-crown-of-thorns-dataset/notcots_crops'):
    for filename in filenames:
        row = []
        row.append(os.path.join(dirname, filename))
        row.append(0)
        data_list.append(row)
        
for dirname, _, filenames in os.walk('/kaggle/input/binary-cropped-crown-of-thorns-dataset/cots_crops'):
    for filename in filenames:
        row = []
        row.append(os.path.join(dirname, filename))
        row.append(1)
        data_list.append(row)
        
df = pd.DataFrame(data_list)
df.columns = ['image_dir','label']
df.to_csv("train.csv", index=False)

from sklearn import datasets
from sklearn import model_selection

def create_folds(data, num_splits):
    data["kfold"] = -1
    num_bins = int(np.floor(1 + np.log2(len(data))))

    data.loc[:, "bins"] = pd.cut(data["label"], bins=num_bins, labels=False)

    kf = model_selection.StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=42)
    
    for f, (t_, v_) in enumerate(kf.split(X=data, y=data.bins.values)):
        data.loc[v_, 'kfold'] = f
    
    data = data.drop("bins", axis=1)
    return data


df = pd.read_csv("./train.csv")

df_5 = create_folds(df, num_splits=5)
df_5.to_csv("train_5folds.csv", index=False)

In [None]:
df = pd.read_csv("./train_5folds.csv")
df_train = df[df.kfold != args.fold].reset_index(drop=True)
df_valid = df[df.kfold == args.fold].reset_index(drop=True)

In [None]:
train_img_paths = df_train['image_dir']
valid_img_paths = df_valid['image_dir']

In [None]:
train_dataset = StarfishDataset(
    image_paths=train_img_paths,
    #dense_features=df_train[dense_features].values,
    targets=df_train.label.values,
    augmentations=train_aug,
)

valid_dataset = StarfishDataset(
    image_paths=valid_img_paths,
    # dense_features=df_valid[dense_features].values,
    targets=df_valid.label.values,
    augmentations=valid_aug,
)

In [None]:
model = StarfishModel()

es = EarlyStopping(
    monitor="valid_rmse",
    model_path=f"model_f{args.fold}.bin",
    patience=3,
    mode="min",
    save_weights_only=True,
)

model.fit(
    train_dataset,
    valid_dataset=valid_dataset,
    train_bs=args.batch_size,
    valid_bs=2*args.batch_size,
    device="cuda",
    epochs=args.epochs,
    callbacks=[es],
    fp16=True,
)