pip install split-folders

In [1]:
pip install split-folders

Collecting split-folders
  Downloading split_folders-0.5.1-py3-none-any.whl (8.4 kB)
Installing collected packages: split-folders
Successfully installed split-folders-0.5.1
Note: you may need to restart the kernel to use updated packages.


In [4]:
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import transforms
import os
from PIL import Image
import torch
import ssl
import torchvision
import torchvision.models as models
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import splitfolders
import csv
from sklearn.metrics import accuracy_score, precision_score, recall_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
import pandas as pd
import numpy as np
import warnings

warnings.filterwarnings("ignore")


def calculate_accuracy(y_pred, y):
    top_pred = y_pred.argmax(1, keepdim=True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

ssl._create_default_https_context = ssl._create_unverified_context

if __name__ == '__main__':
    # Define the device to be used for training
    device = torch.device("cuda")

    # Set up the transform to resize and normalize the images
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],
            std=[0.5, 0.5, 0.5]),
    ])

    # Update input folder and output folder paths
    input_folder = r"/kaggle/input/base-dataset-without-preprocessing/Without_any_preprocessing_small"
    output_folder = r"/kaggle/working/Images"

    ### Uncomment only for first time. once data is splitted into train and validation, comment it out
    #splitfolders.ratio(input_folder, output_folder, seed=42, ratio=(0.8, 0.2), group_prefix=None)

    # Check for empty folders in the training and validation data
    train_path = os.path.join(output_folder, 'train')
    val_path = os.path.join(output_folder, 'val')

    # Create datasets for the training and testing sets
    train_dataset = torchvision.datasets.ImageFolder(output_folder + '/train', transform=transform)
    val_dataset = torchvision.datasets.ImageFolder(output_folder + '/val', transform=transform)
    train_size = len(train_dataset)
    val_size = len(val_dataset)

    # Create the data loaders for training and validation
    train_loader = DataLoader(train_dataset, batch_size=25, shuffle=True,num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=25, shuffle=True,num_workers=4)
    list_of_classes = os.listdir(r"/kaggle/working/Images/train")
    print(list_of_classes)
    classes = list(train_dataset.class_to_idx.keys())
    classes.sort()

from sklearn.metrics import accuracy_score, precision_score, recall_score







# Define the ResNet model
model = models.resnet50(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_features, 4096),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(4096, 4096),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(4096, len(classes))
)
model = model.to(device)  # move the model to the specified device

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
START_LR = 0.0001
optimizer = optim.Adam(model.parameters(), lr=START_LR)
criterion = criterion.to(device)

# Train and Validate the model
for epoch in range(10):
    torch.cuda.empty_cache()
    print('Epoch {}/{}'.format(epoch + 1, 10))
    print('-' * 10)

    # Training phase
    model.train()
    running_loss = 0
    running_corrects = 0
    predictions = []
    true_labels = []
    all_labels = []
    all_predictions = []

    for inputs, labels in tqdm(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        true_labels.extend(labels.cpu().numpy())

        optimizer.zero_grad()
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        predictions.extend(preds.cpu().numpy())

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
        
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(preds.cpu().numpy())
    
    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)

     # Calculate accuracy, precision, and recall for all classes
    acc = accuracy_score(all_labels, all_predictions)
    precision = precision_score(all_labels, all_predictions, average=None)
    recall = recall_score(all_labels, all_predictions, average=None)

    # Calculate accuracy for each class
    acc_per_class = {}
    for i, class_name in enumerate(classes):
        class_indices = [idx for idx, lbl in enumerate(true_labels) if lbl == i]
        class_true_labels = [true_labels[idx] for idx in class_indices]
        class_preds = [predictions[idx] for idx in class_indices]

        class_acc = accuracy_score(class_true_labels, class_preds)
        acc_per_class[class_name] = class_acc

    print('Accuracy for all classes: ', acc_per_class)
    print("")
        
    print('Precision for all classes: ', precision)
    print("")
    print('Recall for all classes: ', recall)
    print("")
    
    report_dict = classification_report(true_labels, predictions, target_names=classes, output_dict=True)
    report_pd = pd.DataFrame(report_dict)
    report_pd.to_csv('training-classification-epoch' + str(epoch + 1) + '.csv')
    
    cnf_matrix = confusion_matrix(true_labels, predictions)
    df_cm = pd.DataFrame(cnf_matrix / np.sum(cnf_matrix, axis=1)[:, None], index=classes, columns=classes)
    df_cm.to_csv('confusion-matrix-train-epoch' + str(epoch + 1) + '.csv')
    
    print('Train Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))

    # Validation phase
    running_loss = 0
    running_corrects = 0
    predictions = []
    true_labels = []

    model.eval()
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            true_labels.extend(labels.cpu().numpy())

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            predictions.extend(preds.cpu().numpy())

            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(val_dataset)
    epoch_acc = running_corrects.double() / len(val_dataset)

    report_dict = classification_report(true_labels, predictions, target_names=classes, output_dict=True)
    report_pd = pd.DataFrame(report_dict)
    report_pd.to_csv('val-classification-epoch' + str(epoch + 1) + '.csv')

    cnf_matrix = confusion_matrix(true_labels, predictions)
    df_cm = pd.DataFrame(cnf_matrix / np.sum(cnf_matrix, axis=1)[:, None], index=classes, columns=classes)
    df_cm.to_csv('confusion-matrix-val-epoch' + str(epoch + 1) + '.csv')

    print('Val Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))

    torch.cuda.empty_cache()

['grbcam1', 'chibat1', 'gobwea1', 'abhori1', 'reccuc1', 'slbgre1', 'somgre1', 'chespa1', 'yebsto1', 'brican1', 'afghor1', 'spwlap1', 'blakit1', 'combul2', 'pygbat1', 'whbtit5', 'grewoo2', 'whctur2', 'vilwea1', 'whbcan1', 'palfly2', 'nubwoo1', 'gobbun1', 'rebfir2', 'norpuf1', 'chtapa3', 'afrjac1', 'colsun2', 'carcha1', 'spewea1']


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:02<00:00, 50.1MB/s]


Epoch 1/10
----------


100%|██████████| 71/71 [00:19<00:00,  3.56it/s]


Accuracy for all classes:  {'abhori1': 0.03, 'afghor1': 0.0, 'afrjac1': 0.0, 'blakit1': 0.5454545454545454, 'brican1': 0.0, 'carcha1': 0.11475409836065574, 'chespa1': 0.0, 'chibat1': 0.0, 'chtapa3': 0.0, 'colsun2': 0.2569444444444444, 'combul2': 0.5427350427350427, 'gobbun1': 0.0, 'gobwea1': 0.0, 'grbcam1': 0.0, 'grewoo2': 0.012195121951219513, 'norpuf1': 0.0, 'nubwoo1': 0.0, 'palfly2': 0.0, 'pygbat1': 0.09090909090909091, 'rebfir2': 0.0, 'reccuc1': 0.18556701030927836, 'slbgre1': 0.0, 'somgre1': 0.23270440251572327, 'spewea1': 0.0, 'spwlap1': 0.0, 'vilwea1': 0.0, 'whbcan1': 0.0, 'whbtit5': 0.0, 'whctur2': 0.0, 'yebsto1': 0.0}

Precision for all classes:  [0.16666667 0.         0.         0.2472885  0.         0.14583333
 0.         0.         0.         0.15744681 0.18594436 0.
 0.         0.         0.5        0.         0.         0.
 0.05882353 0.         0.3        0.         0.21022727 0.
 0.         0.         0.         0.         0.         0.        ]

Recall for all classes:

100%|██████████| 19/19 [00:02<00:00,  6.41it/s]


Val Loss: 2.4363 Acc: 0.3654
Epoch 2/10
----------


100%|██████████| 71/71 [00:14<00:00,  4.85it/s]


Accuracy for all classes:  {'abhori1': 0.28, 'afghor1': 0.03508771929824561, 'afrjac1': 0.0, 'blakit1': 0.8708133971291866, 'brican1': 0.0, 'carcha1': 0.6639344262295082, 'chespa1': 0.0, 'chibat1': 0.0, 'chtapa3': 0.0, 'colsun2': 0.8194444444444444, 'combul2': 0.7307692307692307, 'gobbun1': 0.015625, 'gobwea1': 0.0, 'grbcam1': 0.02666666666666667, 'grewoo2': 0.2804878048780488, 'norpuf1': 0.0, 'nubwoo1': 0.0, 'palfly2': 0.0, 'pygbat1': 0.0, 'rebfir2': 0.0, 'reccuc1': 0.8041237113402062, 'slbgre1': 0.0, 'somgre1': 0.7358490566037735, 'spewea1': 0.0, 'spwlap1': 0.0, 'vilwea1': 0.15714285714285714, 'whbcan1': 0.0, 'whbtit5': 0.0, 'whctur2': 0.0, 'yebsto1': 0.0}

Precision for all classes:  [0.41176471 0.28571429 0.         0.60264901 0.         0.57857143
 0.         0.         0.         0.31299735 0.51661631 0.33333333
 0.         0.0625     0.17164179 0.         0.         0.
 0.         0.         0.624      0.         0.61904762 0.
 0.         0.18644068 0.         0.         0.     

100%|██████████| 19/19 [00:02<00:00,  6.88it/s]


Val Loss: 1.7333 Acc: 0.5230
Epoch 3/10
----------


100%|██████████| 71/71 [00:14<00:00,  4.85it/s]


Accuracy for all classes:  {'abhori1': 0.76, 'afghor1': 0.43859649122807015, 'afrjac1': 0.0, 'blakit1': 0.9330143540669856, 'brican1': 0.0, 'carcha1': 0.8032786885245902, 'chespa1': 0.0, 'chibat1': 0.047619047619047616, 'chtapa3': 0.0, 'colsun2': 0.8680555555555556, 'combul2': 0.7777777777777778, 'gobbun1': 0.453125, 'gobwea1': 0.0, 'grbcam1': 0.10666666666666667, 'grewoo2': 0.6829268292682927, 'norpuf1': 0.0, 'nubwoo1': 0.0, 'palfly2': 0.0, 'pygbat1': 0.0, 'rebfir2': 0.0, 'reccuc1': 0.8762886597938144, 'slbgre1': 0.0, 'somgre1': 0.8490566037735849, 'spewea1': 0.0, 'spwlap1': 0.2978723404255319, 'vilwea1': 0.08571428571428572, 'whbcan1': 0.0, 'whbtit5': 0.0, 'whctur2': 0.0, 'yebsto1': 0.0}

Precision for all classes:  [0.62295082 0.41666667 0.         0.75581395 0.         0.73134328
 0.         0.17647059 0.         0.6127451  0.66181818 0.42647059
 0.         0.13333333 0.29946524 0.         0.         0.
 0.         0.         0.80188679 0.         0.80357143 0.
 0.20289855 0.15    

100%|██████████| 19/19 [00:02<00:00,  7.13it/s]


Val Loss: 1.8534 Acc: 0.5711
Epoch 4/10
----------


100%|██████████| 71/71 [00:14<00:00,  4.82it/s]


Accuracy for all classes:  {'abhori1': 0.83, 'afghor1': 0.5964912280701754, 'afrjac1': 0.0, 'blakit1': 0.8947368421052632, 'brican1': 0.0, 'carcha1': 0.8278688524590164, 'chespa1': 0.0, 'chibat1': 0.2222222222222222, 'chtapa3': 0.0, 'colsun2': 0.8402777777777778, 'combul2': 0.8376068376068376, 'gobbun1': 0.578125, 'gobwea1': 0.0, 'grbcam1': 0.4, 'grewoo2': 0.6829268292682927, 'norpuf1': 0.0, 'nubwoo1': 0.0, 'palfly2': 0.0, 'pygbat1': 0.0, 'rebfir2': 0.02702702702702703, 'reccuc1': 0.9484536082474226, 'slbgre1': 0.0, 'somgre1': 0.8238993710691824, 'spewea1': 0.0, 'spwlap1': 0.40425531914893614, 'vilwea1': 0.2857142857142857, 'whbcan1': 0.0, 'whbtit5': 0.0, 'whctur2': 0.0, 'yebsto1': 0.0}

Precision for all classes:  [0.77570093 0.49275362 0.         0.83856502 0.         0.7890625
 0.         0.31111111 0.         0.57619048 0.75675676 0.48684211
 0.         0.36585366 0.42748092 0.         0.         0.
 0.         0.5        0.85981308 0.         0.75287356 0.
 0.3877551  0.18867925 0

100%|██████████| 19/19 [00:02<00:00,  6.96it/s]


Val Loss: 1.3613 Acc: 0.6280
Epoch 5/10
----------


100%|██████████| 71/71 [00:14<00:00,  4.84it/s]


Accuracy for all classes:  {'abhori1': 0.87, 'afghor1': 0.6491228070175439, 'afrjac1': 0.0, 'blakit1': 0.9617224880382775, 'brican1': 0.08695652173913043, 'carcha1': 0.9098360655737705, 'chespa1': 0.0, 'chibat1': 0.4126984126984127, 'chtapa3': 0.0, 'colsun2': 0.9375, 'combul2': 0.8675213675213675, 'gobbun1': 0.78125, 'gobwea1': 0.0, 'grbcam1': 0.6133333333333333, 'grewoo2': 0.6951219512195121, 'norpuf1': 0.0, 'nubwoo1': 0.0, 'palfly2': 0.0, 'pygbat1': 0.0, 'rebfir2': 0.02702702702702703, 'reccuc1': 0.9484536082474226, 'slbgre1': 0.0, 'somgre1': 0.9182389937106918, 'spewea1': 0.08, 'spwlap1': 0.723404255319149, 'vilwea1': 0.6142857142857143, 'whbcan1': 0.0, 'whbtit5': 0.0, 'whctur2': 0.0, 'yebsto1': 0.0}

Precision for all classes:  [0.71900826 0.62711864 0.         0.88546256 0.66666667 0.86046512
 0.         0.41935484 0.         0.80357143 0.79607843 0.58823529
 0.         0.56790123 0.55339806 0.         0.         0.
 0.         0.2        0.88461538 0.         0.88484848 0.6666666

100%|██████████| 19/19 [00:02<00:00,  6.76it/s]


Val Loss: 1.4374 Acc: 0.6018
Epoch 6/10
----------


100%|██████████| 71/71 [00:14<00:00,  4.83it/s]


Accuracy for all classes:  {'abhori1': 0.92, 'afghor1': 0.8771929824561403, 'afrjac1': 0.125, 'blakit1': 0.9617224880382775, 'brican1': 0.43478260869565216, 'carcha1': 0.9508196721311475, 'chespa1': 0.0, 'chibat1': 0.6190476190476191, 'chtapa3': 0.38095238095238093, 'colsun2': 0.9236111111111112, 'combul2': 0.9145299145299145, 'gobbun1': 0.828125, 'gobwea1': 0.0, 'grbcam1': 0.7333333333333333, 'grewoo2': 0.7439024390243902, 'norpuf1': 0.0, 'nubwoo1': 0.0, 'palfly2': 0.0, 'pygbat1': 0.0, 'rebfir2': 0.10810810810810811, 'reccuc1': 0.9690721649484536, 'slbgre1': 0.0, 'somgre1': 0.9811320754716981, 'spewea1': 0.24, 'spwlap1': 0.8085106382978723, 'vilwea1': 0.7142857142857143, 'whbcan1': 0.0, 'whbtit5': 0.0, 'whctur2': 0.0, 'yebsto1': 0.0}

Precision for all classes:  [0.85981308 0.76923077 0.15       0.93488372 0.5        0.94308943
 0.         0.45348837 0.5        0.93006993 0.85943775 0.6091954
 0.         0.73333333 0.72619048 0.         0.         0.
 0.         0.12903226 0.96907216 

100%|██████████| 19/19 [00:02<00:00,  6.94it/s]


Val Loss: 1.2178 Acc: 0.6674
Epoch 7/10
----------


100%|██████████| 71/71 [00:14<00:00,  4.85it/s]


Accuracy for all classes:  {'abhori1': 0.97, 'afghor1': 0.8070175438596491, 'afrjac1': 0.4166666666666667, 'blakit1': 0.9665071770334929, 'brican1': 0.43478260869565216, 'carcha1': 0.9344262295081968, 'chespa1': 0.0, 'chibat1': 0.8095238095238095, 'chtapa3': 0.2857142857142857, 'colsun2': 0.9305555555555556, 'combul2': 0.9444444444444444, 'gobbun1': 0.890625, 'gobwea1': 0.0, 'grbcam1': 0.8, 'grewoo2': 0.8780487804878049, 'norpuf1': 0.0, 'nubwoo1': 0.11764705882352941, 'palfly2': 0.0, 'pygbat1': 0.0, 'rebfir2': 0.2972972972972973, 'reccuc1': 0.9690721649484536, 'slbgre1': 0.0, 'somgre1': 0.9433962264150944, 'spewea1': 0.32, 'spwlap1': 0.8936170212765957, 'vilwea1': 0.8, 'whbcan1': 0.0, 'whbtit5': 0.0, 'whctur2': 0.0, 'yebsto1': 0.0}

Precision for all classes:  [0.92380952 0.82142857 0.52631579 0.93953488 0.43478261 0.95798319
 0.         0.47222222 0.42857143 0.88157895 0.8875502  0.7125
 0.         0.83333333 0.76595745 0.         1.         0.
 0.         0.32352941 0.97916667 0.    

100%|██████████| 19/19 [00:02<00:00,  6.94it/s]


Val Loss: 1.6580 Acc: 0.6324
Epoch 8/10
----------


100%|██████████| 71/71 [00:14<00:00,  4.81it/s]


Accuracy for all classes:  {'abhori1': 0.9, 'afghor1': 0.9122807017543859, 'afrjac1': 0.4166666666666667, 'blakit1': 0.9712918660287081, 'brican1': 0.391304347826087, 'carcha1': 0.9672131147540983, 'chespa1': 0.0, 'chibat1': 0.7777777777777778, 'chtapa3': 0.6666666666666666, 'colsun2': 0.9375, 'combul2': 0.9017094017094017, 'gobbun1': 0.875, 'gobwea1': 0.0, 'grbcam1': 0.7466666666666667, 'grewoo2': 0.8536585365853658, 'norpuf1': 0.0625, 'nubwoo1': 0.17647058823529413, 'palfly2': 0.0, 'pygbat1': 0.09090909090909091, 'rebfir2': 0.1891891891891892, 'reccuc1': 0.979381443298969, 'slbgre1': 0.0, 'somgre1': 0.9622641509433962, 'spewea1': 0.28, 'spwlap1': 0.7872340425531915, 'vilwea1': 0.7571428571428571, 'whbcan1': 0.0, 'whbtit5': 0.05555555555555555, 'whctur2': 0.0, 'yebsto1': 0.0}

Precision for all classes:  [0.87378641 0.83870968 0.37037037 0.96208531 0.27272727 0.90076336
 0.         0.58333333 0.58333333 0.90604027 0.87190083 0.7
 0.         0.73684211 0.76923077 0.5        0.27272727 

100%|██████████| 19/19 [00:02<00:00,  7.07it/s]


Val Loss: 1.2894 Acc: 0.6674
Epoch 9/10
----------


100%|██████████| 71/71 [00:14<00:00,  4.86it/s]


Accuracy for all classes:  {'abhori1': 0.94, 'afghor1': 0.9122807017543859, 'afrjac1': 0.625, 'blakit1': 0.9617224880382775, 'brican1': 0.30434782608695654, 'carcha1': 0.9508196721311475, 'chespa1': 0.0, 'chibat1': 0.8571428571428571, 'chtapa3': 0.5714285714285714, 'colsun2': 0.9722222222222222, 'combul2': 0.9786324786324786, 'gobbun1': 0.921875, 'gobwea1': 0.0, 'grbcam1': 0.84, 'grewoo2': 0.8780487804878049, 'norpuf1': 0.0, 'nubwoo1': 0.29411764705882354, 'palfly2': 0.0, 'pygbat1': 0.09090909090909091, 'rebfir2': 0.6216216216216216, 'reccuc1': 0.9690721649484536, 'slbgre1': 0.2222222222222222, 'somgre1': 0.949685534591195, 'spewea1': 0.32, 'spwlap1': 0.9148936170212766, 'vilwea1': 0.9428571428571428, 'whbcan1': 0.0, 'whbtit5': 0.0, 'whctur2': 0.0, 'yebsto1': 0.0}

Precision for all classes:  [0.94       0.78787879 0.51724138 0.96172249 0.38888889 0.95081967
 0.         0.65060241 0.6        0.93959732 0.94238683 0.89393939
 0.         0.85135135 0.81818182 0.         0.45454545 0.
 0.

100%|██████████| 19/19 [00:02<00:00,  7.05it/s]


Val Loss: 1.1724 Acc: 0.6783
Epoch 10/10
----------


100%|██████████| 71/71 [00:14<00:00,  4.83it/s]


Accuracy for all classes:  {'abhori1': 0.94, 'afghor1': 0.8596491228070176, 'afrjac1': 0.6666666666666666, 'blakit1': 0.9808612440191388, 'brican1': 0.7391304347826086, 'carcha1': 0.9672131147540983, 'chespa1': 0.0, 'chibat1': 0.9365079365079365, 'chtapa3': 0.6190476190476191, 'colsun2': 0.9930555555555556, 'combul2': 0.9401709401709402, 'gobbun1': 0.96875, 'gobwea1': 0.0, 'grbcam1': 0.88, 'grewoo2': 0.9146341463414634, 'norpuf1': 0.25, 'nubwoo1': 0.29411764705882354, 'palfly2': 0.0, 'pygbat1': 0.36363636363636365, 'rebfir2': 0.7027027027027027, 'reccuc1': 0.9587628865979382, 'slbgre1': 0.05555555555555555, 'somgre1': 0.949685534591195, 'spewea1': 0.4, 'spwlap1': 0.9787234042553191, 'vilwea1': 0.9142857142857143, 'whbcan1': 0.0, 'whbtit5': 0.2222222222222222, 'whctur2': 0.0, 'yebsto1': 0.0}

Precision for all classes:  [0.93069307 0.875      0.57142857 0.98557692 0.56666667 0.93650794
 0.         0.67816092 0.59090909 0.9862069  0.93220339 0.87323944
 0.         0.89189189 0.91463415 0

100%|██████████| 19/19 [00:02<00:00,  6.70it/s]


Val Loss: 1.3739 Acc: 0.6849
