In [1]:
import os, sys
from glob import glob
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from torchvision.models import alexnet

from tqdm import tqdm
import numpy as np

import cv2 as cv

from fast_glcm import *

In [2]:
print(os.getcwd())

D:\GitCloneProject\CE-AB-CLS\OtherSOTA


In [3]:
os.chdir("..")

In [4]:
print(os.getcwd())

D:\GitCloneProject\CE-AB-CLS


In [5]:
data_dir = os.getcwd() + "\\Data set\\Original Form"
print(os.listdir(data_dir))

['Test', 'Train']


In [6]:
class CEDataset(Dataset):
    def __init__(self, root_dir = data_dir, subset = "Train"):
        self.root_dir = root_dir + f"\\{subset}"
        self.img_paths = glob(self.root_dir + "\\*\\*")
        self.transform = A.Compose([
            # A.RandomCrop(height = 256, width = 256),
            A.ColorJitter(p=0.5),
            A.GaussianBlur(p=0.5),
            A.RandomShadow(p=0.5),
            A.RandomBrightnessContrast(p=0.2),
            A.Resize(height = 256, width = 256, p = 1),
        ])
        super().__init__()
    
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        path = self.img_paths[idx]
        img = cv.imread(path)
        
        if "Normal" in path:
            label = torch.tensor(0)
        elif "Abnormal" in path:
            label = torch.tensor(1)
        else:
            raise Exception("not shown type of label in path")
            
        img_torch = torch.from_numpy(fast_glcm(img = img, levels = 3)/255)
        zeros = torch.zeros(size = img_torch.shape)
        
        return torch.cat([img_torch[i] for i in range(img_torch.shape[0])], dim = 0), label  

In [7]:
train_dataset = CEDataset(subset = "Train")
test_dataset = CEDataset(subset = "Test")

print(len(train_dataset))
print(len(test_dataset))

176
176


In [8]:
sample_img, sample_label = train_dataset[0]
print(sample_img.shape)
print(sample_img.dtype)
print(sample_label.shape)

torch.Size([64, 481, 518])
torch.float32
torch.Size([])


In [9]:
print(train_dataset.img_paths[0])
print(test_dataset.img_paths[0])

D:\GitCloneProject\CE-AB-CLS\Data set\Original Form\Train\Group 1 - Normal\CHGastro_Normal_001.png
D:\GitCloneProject\CE-AB-CLS\Data set\Original Form\Test\Group 1 - Normal\CHGastro_Normal_002.png


In [10]:
train_dataloader = DataLoader(train_dataset, shuffle = True, batch_size = 32)
test_dataloader = DataLoader(test_dataset, shuffle = False, batch_size = 1)

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = alexnet(weights = 'IMAGENET1K_V1').to(device)
model.fc = nn.Linear(1000, 2)

In [12]:
model.features[0] = nn.Conv2d(64, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)).to(device)

In [13]:
print(model)

AlexNet(
  (features): Sequential(
    (0): Conv2d(64, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)


In [14]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [15]:
size = len(train_dataloader)
model.train()
epochs = 100
for t in range(epochs):
    for batch, (X, y) in enumerate(train_dataloader):
        X, y = X.to(device, dtype = torch.float), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # accuracy
        correct = (pred.argmax(1) == y).type(torch.float).sum().item()
        acc = correct / y.shape[0]

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}] acc: {acc}")

loss: 7.574877  [   32/    6] acc: 0.0
loss: 0.732704  [   32/    6] acc: 0.375
loss: 0.720123  [   32/    6] acc: 0.5
loss: 0.453377  [   32/    6] acc: 0.84375
loss: 0.647923  [   32/    6] acc: 0.6875
loss: 0.619065  [   32/    6] acc: 0.6875
loss: 0.583602  [   32/    6] acc: 0.71875
loss: 0.711852  [   32/    6] acc: 0.65625
loss: 0.579444  [   32/    6] acc: 0.78125
loss: 0.405811  [   32/    6] acc: 0.8125
loss: 0.484352  [   32/    6] acc: 0.71875
loss: 0.632154  [   32/    6] acc: 0.71875
loss: 0.660018  [   32/    6] acc: 0.6875
loss: 0.386452  [   32/    6] acc: 0.875
loss: 0.452078  [   32/    6] acc: 0.75
loss: 0.491462  [   32/    6] acc: 0.6875
loss: 0.466789  [   32/    6] acc: 0.78125
loss: 0.786605  [   32/    6] acc: 0.5625
loss: 0.573017  [   32/    6] acc: 0.6875
loss: 0.365475  [   32/    6] acc: 0.8125
loss: 0.578595  [   32/    6] acc: 0.65625
loss: 0.423116  [   32/    6] acc: 0.84375
loss: 0.500769  [   32/    6] acc: 0.78125
loss: 0.514940  [   32/    6] acc:

In [16]:
pred = []
true = []
for batch, (X, y) in enumerate(test_dataloader):
    X, y = X.to(device, dtype = torch.float), y.to(device)
    
    pred.append(model(X).argmax(1).item())
    true.append(y.item())

In [17]:
print(len(pred))
print(len(true))

176
176


In [18]:
from sklearn.metrics import classification_report

print(classification_report(true, pred, target_names=["Normal", "Abnormal"]))

              precision    recall  f1-score   support

      Normal       0.66      0.48      0.56        56
    Abnormal       0.79      0.88      0.83       120

    accuracy                           0.76       176
   macro avg       0.72      0.68      0.69       176
weighted avg       0.74      0.76      0.74       176

