In [1]:
import os, sys
from glob import glob
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from torchvision.models import alexnet

from tqdm import tqdm
import numpy as np

import cv2 as cv

from fast_glcm import *

In [2]:
print(os.getcwd())

D:\GitCloneProject\CE-AB-CLS\OtherSOTA


In [3]:
os.chdir("..")

In [4]:
print(os.getcwd())

D:\GitCloneProject\CE-AB-CLS


In [5]:
data_dir = os.getcwd() + "\\Data set\\Original Form"
print(os.listdir(data_dir))

['Test', 'Train']


In [24]:
class CEDataset(Dataset):
    def __init__(self, root_dir = data_dir, subset = "Train"):
        self.root_dir = root_dir + f"\\{subset}"
        self.img_paths = glob(self.root_dir + "\\*\\*")
        self.transform = A.Compose([
            # A.RandomCrop(height = 256, width = 256),
            A.ColorJitter(p=0.5),
            A.GaussianBlur(p=0.5),
            A.RandomShadow(p=0.5),
            A.RandomBrightnessContrast(p=0.2),
            A.Resize(height = 256, width = 256, p = 1),
        ])
        super().__init__()
    
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        path = self.img_paths[idx]
        img = cv.imread(path)
        
        if "Normal" in path:
            label = torch.tensor(0)
        elif "Abnormal" in path:
            label = torch.tensor(1)
        else:
            raise Exception("not shown type of label in path")
            
        img_torch = torch.from_numpy(np.array([fast_glcm_mean(img = img, levels = 8)/255]))        
        
        return img_torch, label  

In [25]:
train_dataset = CEDataset(subset = "Train")
test_dataset = CEDataset(subset = "Test")

print(len(train_dataset))
print(len(test_dataset))

176
176


In [26]:
sample_img, sample_label = train_dataset[0]
print(sample_img.shape)
print(sample_img.dtype)
print(sample_label.shape)

torch.Size([1, 481, 518])
torch.float32
torch.Size([])


In [27]:
print(train_dataset.img_paths[0])
print(test_dataset.img_paths[0])

D:\GitCloneProject\CE-AB-CLS\Data set\Original Form\Train\Group 1 - Normal\CHGastro_Normal_001.png
D:\GitCloneProject\CE-AB-CLS\Data set\Original Form\Test\Group 1 - Normal\CHGastro_Normal_002.png


In [28]:
train_dataloader = DataLoader(train_dataset, shuffle = True, batch_size = 32)
test_dataloader = DataLoader(test_dataset, shuffle = False, batch_size = 1)

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = alexnet(weights = 'IMAGENET1K_V1').to(device)
model.fc = nn.Linear(1000, 2)

In [30]:
model.features[0] = nn.Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)).to(device)

In [31]:
print(model)

AlexNet(
  (features): Sequential(
    (0): Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [32]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [33]:
size = len(train_dataloader)
model.train()
epochs = 100
for t in range(epochs):
    for batch, (X, y) in enumerate(train_dataloader):
        X, y = X.to(device, dtype = torch.float), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # accuracy
        correct = (pred.argmax(1) == y).type(torch.float).sum().item()
        acc = correct / y.shape[0]

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}] acc: {acc}")

loss: 7.471893  [   32/    6] acc: 0.0
loss: 0.707207  [   32/    6] acc: 0.6875
loss: 0.684324  [   32/    6] acc: 0.65625
loss: 0.675471  [   32/    6] acc: 0.53125
loss: 0.663816  [   32/    6] acc: 0.6875
loss: 0.713014  [   32/    6] acc: 0.4375
loss: 0.731688  [   32/    6] acc: 0.6875
loss: 0.618406  [   32/    6] acc: 0.71875
loss: 1.015067  [   32/    6] acc: 0.28125
loss: 0.803113  [   32/    6] acc: 0.625
loss: 0.692917  [   32/    6] acc: 0.625
loss: 0.710083  [   32/    6] acc: 0.65625
loss: 0.665321  [   32/    6] acc: 0.59375
loss: 0.600201  [   32/    6] acc: 0.75
loss: 0.678337  [   32/    6] acc: 0.625
loss: 0.641597  [   32/    6] acc: 0.6875
loss: 0.581192  [   32/    6] acc: 0.6875
loss: 0.649853  [   32/    6] acc: 0.65625
loss: 0.691753  [   32/    6] acc: 0.59375
loss: 0.609048  [   32/    6] acc: 0.71875
loss: 0.512113  [   32/    6] acc: 0.8125
loss: 0.665615  [   32/    6] acc: 0.65625
loss: 0.604881  [   32/    6] acc: 0.6875
loss: 0.651468  [   32/    6] ac

In [34]:
pred = []
true = []
for batch, (X, y) in enumerate(test_dataloader):
    X, y = X.to(device, dtype = torch.float), y.to(device)
    
    pred.append(model(X).argmax(1).item())
    true.append(y.item())

In [35]:
print(len(pred))
print(len(true))

176
176


In [36]:
from sklearn.metrics import classification_report

print(classification_report(true, pred, target_names=["Normal", "Abnormal"]))

              precision    recall  f1-score   support

      Normal       0.00      0.00      0.00        56
    Abnormal       0.68      1.00      0.81       120

    accuracy                           0.68       176
   macro avg       0.34      0.50      0.41       176
weighted avg       0.46      0.68      0.55       176



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
