In [None]:
import torch
import torchvision
from torchvision import transforms
from torchvision.transforms import v2

In [None]:
DATASET_PATH = './knee-osteoarthritis'

In [None]:
TRAIN_PATH = f'{DATASET_PATH}/train'
VAL_PATH = f'{DATASET_PATH}/val'
TEST_PATH = f'{DATASET_PATH}/test'
AUTO_TEST_PATH = f'{DATASET_PATH}/auto_test'

In [None]:
transform_toTensor = transforms.Compose([transforms.ToTensor()])

In [None]:
train = torchvision.datasets.ImageFolder(TRAIN_PATH, transform_toTensor)
# val = torchvision.datasets.ImageFolder(VAL_PATH, transform_toTensor)
# test = torchvision.datasets.ImageFolder(TEST_PATH, transform_toTensor)
# auto_test = torchvision.datasets.ImageFolder(AUTO_TEST_PATH, transform_toTensor)

In [None]:
print(len(train))
# print(len(val))
# print(len(test))
# print(len(auto_test))

### Augmentation 

In [None]:
import cv2
import numpy as np

#### Edges

In [None]:
sobel_y_1 = np.array([
    [-1, -1, -1],
    [0, 0, 0],
    [1, 1, 1],
])
sobel_y_2 = sobel_y_1 * -1

def getAugmentationEdges(image):
    grayscaled = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    
    # Apply two filters, upper and lower for different bones
    edges_1 = cv2.filter2D(grayscaled, -1, sobel_y_1)
    edges_2 = cv2.filter2D(grayscaled, -1, sobel_y_2)

    # Cutout noisy background 
    _, edges_1 = cv2.threshold(edges_1, 10, 255, cv2.THRESH_TOZERO)
    _, edges_2 = cv2.threshold(edges_2, 10, 255, cv2.THRESH_TOZERO)

    edges = edges_1 + edges_2

    max_brightness = edges.max()

    # Normalize color
    edges = cv2.convertScaleAbs(edges, alpha = 255/max_brightness, beta = 0)

    return edges

### Building Dataset

In [None]:
example_img = train[0][0]
# print(example_img)
print(example_img.shape)

In [None]:
transform_baseImage = v2.Compose([
    v2.Resize(256),
    v2.CenterCrop(256),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225] )
])

transform_edgesImage = v2.Compose([
    v2.ToDtype(torch.float),
    # v2.Resize(256),
    # v2.CenterCrop(256),
    v2.Normalize(mean=[0.449], std=[0.226])
])

In [None]:

from torch.utils.data import Dataset, DataLoader

class KneeOsteoarthritis(Dataset):
    def __init__(self, dataset):
        self.images = []
        self.edges_images = []
        self.labels = []
        
        for data in dataset:
            image = data[0]
            image_agmentation = image.numpy()*255
            image_agmentation = np.moveaxis(image_agmentation, 0, -1)
            edges_image = getAugmentationEdges(image_agmentation)
            # print(image.shape, image_agmentation.shape, edges_image.shape)
            edges_image = torch.tensor(edges_image)
            label = data[1]
            
            image = transform_baseImage(image)
            # print(image.shape, edges_image.shape)
            edges_image = transform_edgesImage(edges_image.unsqueeze(0))
            
            self.images.append(image)
            self.edges_images.append(edges_image)            
            self.labels.append(label)
            
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        edges_image = self.edges_images[idx]
        label = self.labels[idx]
        
        return image, edges_image, label

In [None]:
train_dataset = KneeOsteoarthritis(train)
# val_dataset = KneeOsteoarthritis(val)
# test_dataset = KneeOsteoarthritis(test)
# auto_test_dataset = KneeOsteoarthritis(auto_test)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap='gray')
    plt.show()

In [None]:

row = train_dataset[4]
normal_ex = row[0]
augmented_ex = row[1]
print(normal_ex.shape, augmented_ex.shape)


imshow(normal_ex)
imshow(augmented_ex)

### Configuring loader

In [None]:
from collections import Counter

freq_table = dict(Counter(train_dataset.labels))
least_class_frequency = min(freq_table.values())

print(freq_table, least_class_frequency, list(freq_table.values()))

In [None]:
class_sample_count = np.array(freq_table.values())
print(class_sample_count)
weights = np.zeros(len(train_dataset.labels))
for i, weight in enumerate(weights):
    label = train_dataset.labels[i]
    weights[i] = 1 / freq_table[label]
    
print(weights)
samples_weight = torch.from_numpy(weights)
samples_weigth = samples_weight.double()
sampler = torch.utils.data.WeightedRandomSampler(samples_weight, len(samples_weight))

In [None]:
print(samples_weigth)

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, sampler=sampler)

### Building Model

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(device)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models import resnet18, ResNet18_Weights

class AugmentedModel(nn.Module):
    def __init__(self, num_classes: int = 5, dropout: float = 0.5) -> None:
        super().__init__()
        
        weights = ResNet18_Weights.DEFAULT
        self.resnet18 = resnet18(weights=weights, progress=False)
        
        self.edgesClassifier = nn.Sequential(
            # nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2),
            # nn.ReLU(inplace=True),
            # nn.MaxPool2d(kernel_size=3, stride=2),
            # nn.Conv2d(64, 128, kernel_size=5, padding=2),
            # nn.ReLU(inplace=True),
            # nn.Conv2d(128, 128, kernel_size=5, stride=2, padding=2),
            # nn.ReLU(inplace=True),
            # nn.Conv2d(128, 128, kernel_size=3, padding=2),
            # nn.ReLU(inplace=True),
            # nn.AdaptiveAvgPool2d((6, 6)),
            # nn.Flatten()
            
            nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.AdaptiveAvgPool2d((6, 6)),
            nn.Flatten(),
            nn.Dropout(p=dropout),
            nn.Linear(256 * 6 * 6, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, num_classes),

            
            # nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2),
            # nn.ReLU(inplace=True),
            # nn.MaxPool2d(kernel_size=3, stride=2),
            # nn.Conv2d(64, 192, kernel_size=5, padding=2),
            # nn.ReLU(inplace=True),
            # nn.MaxPool2d(kernel_size=3, stride=2),
            # nn.Conv2d(192, 384, kernel_size=3, padding=1),
            # nn.ReLU(inplace=True),
            # nn.Conv2d(384, 256, kernel_size=3, padding=1),
            # nn.ReLU(inplace=True),
            # nn.Conv2d(256, 256, kernel_size=3, padding=1),
            # nn.ReLU(inplace=True),
            # nn.MaxPool2d(kernel_size=3, stride=2),
            # nn.AdaptiveAvgPool2d((6, 6)),
            # nn.Dropout(p=dropout),
            # nn.Linear(256 * 6 * 6, 128),
            # nn.ReLU(inplace=True),
            # nn.Dropout(p=dropout),
            # nn.Linear(128, 128),
            # nn.ReLU(inplace=True),
            # nn.Linear(128, num_classes),
        )
        
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(1000, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )
        
        self.outputCombiner = nn.Sequential(
            nn.Linear(2 * num_classes, num_classes),
        )

    def forward(self, image: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
        
        out_edges = self.edgesClassifier(edges)
        return out_edges
        # out_edges = self.edgesClassifier(edges)
      
        # out_image = self.resnet18(image)
        # out_image = self.classifier(out_image)
        
        # concated = torch.cat((out_image, out_edges), 1)
        
        # res = self.outputCombiner(concated)
        # return res
      
net = AugmentedModel(3)
net = net.to(device)

In [None]:
print(sum(p.numel() for p in net.classifier.parameters()) ,sum(p.numel() for p in net.edgesClassifier.parameters()) )
print(sum(p.numel() for p in net.parameters()))

### Training Model

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.classifier.parameters(), lr=0.001)

In [None]:
epoch_correct = 0
epoch_samples = 0
running_loss = 0.0
    
for epoch in range(50):  # loop over the dataset multiple times
    
    for i, data in enumerate(train_loader, 0):

        # get the inputs; data is a list of [inputs, labels]
        images, edges, labels = data
        images = images.to(device)
        edges = edges.to(device)
        labels = labels.to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        outputs = net(images, edges)
        loss = criterion(outputs, labels)
        
        loss.backward()
        
        optimizer.step()
        
        # Changing outputs (logits) to labels
        outputs_clear = outputs.max(1).indices
        
        epoch_correct += (outputs_clear == labels).float().sum()
        epoch_samples += len(outputs)
        
        running_loss += loss.item()
            
    accuracy = epoch_correct / epoch_samples * 100
    print(f'Epoch {epoch + 1}: loss: {running_loss / epoch_samples:.3f}, accuracy: {accuracy}%')
    
    epoch_correct = 0
    epoch_samples = 0
    running_loss = 0.0

print('Finished Training')