In [8]:
import warnings
import os
import gc
import sys
import random
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from PIL import ImageFile
from tqdm import tqdm
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import StratifiedKFold, train_test_split

import torch
import torch.nn as nn
from torchvision import models
import torchvision.transforms as transforms
import albumentations as A

import wandb


warnings.filterwarnings("ignore")

In [9]:
config = {
    'TRAIN_PATH': "D:/Documents/GitHub/image_pipeline/data/aerial-cactus-identification/train/",
    'TRAIN_FILE': "D:/Documents/GitHub/image_pipeline/data/aerial-cactus-identification/train.csv",
    'DEVICE': torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
    'TARGET_VAR': "has_cactus",
    'FOLD_NUMBER': 5,
    'IMAGE_ID': "id",
    'IMAGE_EXT': ".jpg",
    'IMAGE_SIZE': (32, 32),
    'EPOCHS': 5,
    'TRAIN_BS': 32,
    'VALID_BS': 16,
}


In [10]:
Augmentations = {
    'train':
        transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [
                                     0.229, 0.224, 0.225])
            ]
        ),
    'valid':
        transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [
                                     0.229, 0.224, 0.225])
            ]
        )
}

In [11]:
class PETFINDER_DATASET:
    '''
    Pytorch class to define an image dataset
    image_path : must be a list of path to individual images like "data/image_001.png"
    resize : if not None, image will be resized to this size, MUST BE A TUPLE
    label : labels for each image of their class
    transforms : if not None, transform will be applied on images
    '''

    def __init__(self, image_path, resize=None, label=None, transforms=None):
        self.image_path = image_path
        self.resize = resize
        self.label = label
        self.transforms = transforms

    #RETURN THE LENGHT OF THE DATASET
    def __len__(self):
        return len(self.image_path)

    def __getitem__(self, item):
        #LOADING IMAGES
        image = Image.open(self.image_path[item])

        #RESIZING IMAGES
        if self.resize is not None:
            image = image.resize(
                (self.resize[1], self.resize[0]), resample=Image.BILINEAR
            )

        #APPLYING DATA AUGMENTATIONS TO IMAGE DATA
        if self.transforms:
            image = self.transforms(image)

        if self.label is not None:
            label = self.label[item]
            return {
                "images": image,
                "labels": torch.tensor(label, dtype=torch.float32),
            }
        else:
            return {
                "images": image,
            }

In [12]:
class RESNET18(nn.Module):
    def __init__(self, n_class=2, pretrain=True):
        super(RESNET18, self).__init__()

        self.base_model = models.resnet18(pretrained=pretrain)
        in_features = self.base_model.fc.out_features
        #self.nb_features = self.base_model.fc.in_features
        self.l0 = nn.Linear(in_features, n_class)

    def forward(self, image):
        x = self.base_model(image)
        out = self.l0(x)
        return out

In [13]:
df = pd.read_csv("D:/Documents/GitHub/image_pipeline/data/aerial-cactus-identification/train.csv")

train_img = df[config["IMAGE_ID"]].values.tolist()
train_img = [os.path.join(config["TRAIN_PATH"], os.path.splitext(i)[0] + config["IMAGE_EXT"]) for i in train_img]
train_labels = df[config["TARGET_VAR"]].values


In [14]:
# TRAINING DATASET
train_dataset = PETFINDER_DATASET(
    image_path=train_img,
    resize=config["IMAGE_SIZE"],
    label=train_labels,
    transforms=Augmentations["train"]
)

# TRAINING DATALOADER
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=config["TRAIN_BS"], shuffle=True, num_workers=0
)

In [18]:
train_loader

array([1, 1, 1, ..., 1, 0, 1], dtype=int64)