-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset_utils.py
126 lines (102 loc) · 4.67 KB
/
dataset_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import numpy as np
import pandas as pd
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data.sampler import SubsetRandomSampler
import matplotlib.pyplot as plt
import time
import copy
from random import shuffle
import tqdm.notebook as tqdm
import sklearn
from sklearn.metrics import accuracy_score, cohen_kappa_score
from sklearn.metrics import classification_report
from PIL import Image
import cv2
import os
import shutil
import matplotlib.pyplot as plt
import numpy as np
import itertools
COVID_PATH = './COVID19-DATASET/train/covid19'
NORMAL_PATH = './COVID19-DATASET/train/normal'
DATA_PATH = './COVID19-DATASET/train'
def smaple_ploter(dir_Path, title):
fig = plt.figure(figsize=(16,5))
fig.suptitle(title, size=22)
img_paths = os.listdir(dir_Path)
shuffle(img_paths)
for i,image in enumerate(img_paths[:4]):
img = cv2.imread(os.path.join(dir_Path, image))
plt.subplot(1,4, i+1, frameon=False)
plt.imshow(img)
fig.show()
#Statistics Based on ImageNet Data for Normalisation
mean_nums = [0.485, 0.456, 0.406]
std_nums = [0.229, 0.224, 0.225]
data_transforms = {"train":transforms.Compose([
transforms.Resize((150,150)), #Resizes all images into same dimension
transforms.RandomRotation(10), # Rotates the images upto Max of 10 Degrees
transforms.RandomHorizontalFlip(p=0.4), #Performs Horizantal Flip over images
transforms.ToTensor(), # Coدverts into Tensors
transforms.Normalize(mean = mean_nums, std=std_nums)]), # Normalizes
"val": transforms.Compose([
transforms.Resize((150,150)),
transforms.CenterCrop(150), #Performs Crop at Center and resizes it to 150x150
transforms.ToTensor(),
transforms.Normalize(mean=mean_nums, std = std_nums)
])}
def load_split_train_test(datadir, valid_size = .2):
train_data = datasets.ImageFolder(datadir,
transform=data_transforms['train']) #Picks up Image Paths from its respective folders and label them
test_data = datasets.ImageFolder(datadir,
transform=data_transforms['val'])
num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
np.random.shuffle(indices)
train_idx, test_idx = indices[split:], indices[:split]
dataset_size = {"train":len(train_idx), "val":len(test_idx)}
train_sampler = SubsetRandomSampler(train_idx) # Sampler for splitting train and val images
test_sampler = SubsetRandomSampler(test_idx)
trainloader = torch.utils.data.DataLoader(train_data,
sampler=train_sampler, batch_size=8) # DataLoader provides data from traininng and validation in batches
testloader = torch.utils.data.DataLoader(test_data,
sampler=test_sampler, batch_size=8)
return trainloader, testloader, dataset_size
def plot_confusion_matrix(cm,
target_names,
title='Confusion matrix',
cmap=None,
normalize=True):
accuracy = np.trace(cm) / float(np.sum(cm))
misclass = 1 - accuracy
if cmap is None:
cmap = plt.get_cmap('Blues')
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
if target_names is not None:
tick_marks = np.arange(len(target_names))
plt.xticks(tick_marks, target_names, rotation=45)
plt.yticks(tick_marks, target_names)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
thresh = cm.max() / 1.5 if normalize else cm.max() / 2
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
if normalize:
plt.text(j, i, "{:0.4f}".format(cm[i, j]),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
else:
plt.text(j, i, "{:,}".format(cm[i, j]),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
plt.show()