In [1]:
import numpy as np
import os
import pandas as pd
from skimage.io import imread, imsave
from skimage.transform import resize
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn import svm
import torchvision.transforms as T
import torch
torch.manual_seed(42)

import warnings
warnings.filterwarnings('ignore')

from matplotlib import pyplot as plt
%matplotlib inline

In [2]:
er_trdat_path = './data/train/ER/'
nr_trdat_path = './data/train/NR/'
cropdat_path = './data/train/all_cropped/'
balanced_path = './data/train/all_balanced/'

In [3]:
%%time
counts = dict()
i = 0
for filename in os.listdir(er_trdat_path):
    fn_parts = filename.split('_')
    ind =  fn_parts.index('ER')
    regr = fn_parts[ind+1]
    if regr not in {'1', '3', '6', '10', '20', '30'}:
        print(f'error fn_parts[7]')
        print(fn_parts)
        print(filename)
        break    
    img = imread(er_trdat_path + filename)    
    img = img[192:384, 192:384]
    img = resize(img, (64, 64))
    img = img*255
    img = img.astype(np.uint8)
    imsave(cropdat_path + f'{i}-ER-{regr}.png', img)
    counts[f'ER-{regr}'] = counts.get(f'ER-{regr}', 0) + 1
    i += 1
for filename in os.listdir(nr_trdat_path):
    fn_parts = filename.split('_')
    ind =  fn_parts.index('NR')
    regr = fn_parts[ind+1]
    if regr not in {'1', '3', '6', '10', '20', '30'}:
        print(f'error fn_parts[7]')
        print(fn_parts)
        print(filename)
        break    
    img = imread(nr_trdat_path + filename)
    img = img[192:384, 192:384]
    img = resize(img, (64, 64))
    img = img*255
    img = img.astype(np.uint8)
    imsave(cropdat_path + f'{i}-NR-{regr}.png', img)
    counts[f'NR-{regr}'] = counts.get(f'NR-{regr}', 0) + 1
    i += 1

Wall time: 59.6 s


In [4]:
counts

{'ER-30': 2237,
 'ER-3': 2243,
 'ER-10': 2272,
 'ER-20': 2,
 'ER-1': 3,
 'ER-6': 1,
 'NR-6': 2255,
 'NR-20': 2208,
 'NR-1': 2177,
 'NR-10': 2,
 'NR-3': 2,
 'NR-30': 2}

In [12]:
%%time
images = []
labels = []
bi_labels = []
for filename in os.listdir(cropdat_path):
    fn_parts = filename.split('-')
    label = fn_parts[1]+fn_parts[2].split('.')[0]
    img = imread(cropdat_path + filename)
    images.append(img.flatten())
    labels.append(label)
    bi_labels.append(fn_parts[1])
print(len(images))
print(len(labels))

13404
13404
Wall time: 4.18 s


In [7]:
unique, counts = np.unique(labels, return_counts=True)
dict(zip(unique, counts))

{'ER1': 3,
 'ER10': 2272,
 'ER20': 2,
 'ER3': 2243,
 'ER30': 2237,
 'ER6': 1,
 'NR1': 2177,
 'NR10': 2,
 'NR20': 2208,
 'NR3': 2,
 'NR30': 2,
 'NR6': 2255}

In [13]:
unique, counts = np.unique(bi_labels, return_counts=True)
dict(zip(unique, counts))

{'ER': 6758, 'NR': 6646}

### 6-class classification

In [9]:
%%time
train_images, test_images,train_labels, test_labels = train_test_split(images, labels, train_size=0.8, random_state=42)
clf = svm.SVC()
clf.fit(train_images, train_labels)
clf.score(test_images,test_labels)

Wall time: 54.2 s


0.9988810145468109

### binary classification

In [15]:
%%time
train_images, test_images,train_labels, test_labels = train_test_split(images, bi_labels, train_size=0.8, random_state=42)
clf = svm.SVC()
clf.fit(train_images, train_labels)
clf.score(test_images,test_labels)

Wall time: 59.9 s


0.9947780678851175

In [10]:
y_pred = clf.predict(test_images)

In [11]:
print(classification_report(test_labels, y_pred))

              precision    recall  f1-score   support

        ER10       1.00      1.00      1.00       448
         ER3       1.00      1.00      1.00       421
        ER30       1.00      1.00      1.00       459
         NR1       1.00      1.00      1.00       437
        NR20       1.00      1.00      1.00       451
         NR3       0.00      0.00      0.00         1
        NR30       0.00      0.00      0.00         1
         NR6       1.00      1.00      1.00       463

    accuracy                           1.00      2681
   macro avg       0.75      0.75      0.75      2681
weighted avg       1.00      1.00      1.00      2681



### balance classes

In [None]:
for filename in os.listdir('./data/train/small/'):
    img = imread('./data/train/small/' + filename) 
    fn_parts = filename.split('_')
    ind =  fn_parts.index('ER')
    regr = fn_parts[ind+1]       
    img = img[192:384, 192:384]
    img = resize(img, (64, 64))
    img = img*255
    img = img.astype(np.uint8)
    imsave(balanced_path + f'{i}-ER-{regr}.png', img)
    i += 1
    if regr in {'1', '6', '20'}:
        for

In [None]:
%%time
i = 0
for filename in os.listdir(er_trdat_path):
    fn_parts = filename.split('_')
    ind =  fn_parts.index('ER')
    regr = fn_parts[ind+1]
    if regr not in {'1', '3', '6', '10', '20', '30'}:
        print(f'error fn_parts[7]')
        print(fn_parts)
        print(filename)
        break    
    img = imread(er_trdat_path + filename)    
    img = img[192:384, 192:384]
    img = resize(img, (64, 64))
    img = img*255
    img = img.astype(np.uint8)
    imsave(cropdat_path + f'{i}-ER-{regr}.png', img)
    i += 1
for filename in os.listdir(nr_trdat_path):
    fn_parts = filename.split('_')
    ind =  fn_parts.index('NR')
    regr = fn_parts[ind+1]
    if regr not in {'1', '3', '6', '10', '20', '30'}:
        print(f'error fn_parts[7]')
        print(fn_parts)
        print(filename)
        break    
    img = imread(nr_trdat_path + filename)
    img = img[192:384, 192:384]
    img = resize(img, (64, 64))
    img = img*255
    img = img.astype(np.uint8)
    imsave(cropdat_path + f'{i}-NR-{regr}.png', img)
    i += 1