In [None]:
import warnings
warnings.filterwarnings('ignore')
import os
import random

import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import Dataset
from torchmetrics import Accuracy
from torchvision.transforms import v2
from torch.utils.data.sampler import WeightedRandomSampler
import torchvision as tv



import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar, EarlyStopping, StochasticWeightAveraging,LearningRateFinder, Timer
from pytorch_lightning.loggers import CSVLogger

from tqdm import tqdm
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import confusion_matrix, classification_report




%matplotlib inline 



torch.manual_seed(21)
np.random.seed(21)
random.seed(21)

In [None]:
url_images = []
for (root,dirs,files) in os.walk('/kaggle/input/blood-cells-image-dataset/bloodcells_dataset'):
    for dir_ in dirs:
        cell_data = os.path.join('/kaggle/input/blood-cells-image-dataset/bloodcells_dataset', dir_)
        for (root,dirs,files) in os.walk(cell_data):
            new_files = ['/kaggle/input/blood-cells-image-dataset/bloodcells_dataset/' + dir_ + "/" + file_name for file_name in files]
            url_images += (new_files)

In [None]:
from sklearn.model_selection import StratifiedKFold

class BloodData(Dataset):
    def __init__(self, url_images, k_folds=5, fold=0, use="train", augment=False, reshape=True):
        self.url_images = url_images
        
        self.target_mappings = {
            "basophil": 0,
            "eosinophil": 1,
            "erythroblast": 2,
            "ig": 3,
            "lymphocyte": 4,
            "monocyte": 5,
            "neutrophil": 6,
            "platelet": 7,
        }
        
        self.targets = np.array([self.target_mappings[url.split("/")[5]] for url in self.url_images])
        self.augment = augment
        self.reshape = reshape
        self.use = use
        if self.use == "test":
            self.indices = np.arange(len(self.url_images))
        else:
            self.k_folds = k_folds
            self.fold = fold
            self.kf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=21)


            all_indices = np.arange(len(url_images))
            train_indices, valid_indices = list(self.kf.split(self.url_images, self.targets))[fold]
            if use=='train':
                self.indices = train_indices
            else:
                self.indices = valid_indices

        
        self.reshape = v2.Compose([
            v2.ToPILImage(),
            v2.Resize(size=(360,360)),
            #v2.Resize(size=(224,224)),
            v2.ToTensor(),
        ])
        
        self.transform_train = v2.Compose([
            v2.ToPILImage(),
            v2.RandomAffine(degrees=360, translate=(0.15,0.15), shear=15),
            v2.ColorJitter(brightness=0.25,hue=0.1,contrast=0.1,saturation=0.1),
            v2.RandomHorizontalFlip(p=0.5),
            v2.RandomVerticalFlip(p=0.5),
            v2.ToTensor(),
        ])
        
    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        idx = self.indices[idx]
        array = np.asarray(mpimg.imread(self.url_images[idx]))
        array = self._augment(array)
        return  torch.Tensor(array) / 255, self.targets[idx]

    def _augment(self, img):
        if self.use=='train' or self.augment==True:
            img = self.transform_train(img)
        if self.reshape:
            img = self.reshape(img)
        return img
    
    def change_fold(self, new_fold):
        self.fold = new_fold
        train_indices, valid_indices = list(self.kf.split(self.url_images))[new_fold]
        if self.train:
            self.indices = train_indices
        else:
            self.indices = valid_indices

In [None]:
def random_indexes_by_category(lst, count):
    index_dict = {}
    
    # Create dictionary of indexes for each element
    for i, item in enumerate(lst):
        if item not in index_dict:
            index_dict[item] = []
        index_dict[item].append(i)
    
    # Randomly select an index for each category
    rand_indx = []
    for key, indexes in index_dict.items():
        for i in range(count):
            random_index = random.choice(indexes)
            rand_indx.append(random_index)
    
    return rand_indx

In [None]:
data = BloodData(url_images, k_folds=10, use="test")
print(data[0][0].shape)
random_indices = random_indexes_by_category(data.targets,2)

fig, axes = plt.subplots(4, 4, figsize=(10, 10))

for i, idx in enumerate(random_indices):
    row = i // 4
    col = i % 4
    image, target = data[idx]
    axes[row, col].imshow(torch.transpose(image, 2, 0).numpy() * 256, interpolation='nearest')
    axes[row, col].set_title(f'Class: {list(data.target_mappings.keys())[int(target)]}')
    axes[row, col].axis('off')

plt.tight_layout()
plt.savefig('normal.png', dpi=500)
plt.show()

In [None]:
data = BloodData(url_images, k_folds=10, use="test", augment=True)
print(data[0][0].shape)
#random_indices = random.sample(range(len(data)), 16)


fig, axes = plt.subplots(4, 4, figsize=(10, 10))

for i, idx in enumerate(random_indices):
    row = i // 4
    col = i % 4
    image, target = data[idx]
    axes[row, col].imshow(torch.transpose(image, 2, 0).numpy() * 256, interpolation='nearest')
    axes[row, col].set_title(f'Class: {list(data.target_mappings.keys())[int(target)]}')
    axes[row, col].axis('off')

plt.tight_layout()
plt.savefig('aug.png', dpi=500)
plt.show()