In [None]:
import torch
import torchvision
from torchvision import transforms

In [None]:
DATASET_PATH = './knee-osteoarthritis'

In [None]:
TRAIN_PATH = f'{DATASET_PATH}/train'
VAL_PATH = f'{DATASET_PATH}/val'
TEST_PATH = f'{DATASET_PATH}/test'
AUTO_TEST_PATH = f'{DATASET_PATH}/auto_test'

In [None]:
TRANSFORM_IMG = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225] )
    ])

In [None]:

train = torchvision.datasets.ImageFolder(TRAIN_PATH, TRANSFORM_IMG)
val = torchvision.datasets.ImageFolder(VAL_PATH, TRANSFORM_IMG)
test = torchvision.datasets.ImageFolder(TEST_PATH, TRANSFORM_IMG)
# auto_test = torchvision.datasets.ImageFolder(AUTO_TEST_PATH, TRANSFORM_IMG)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
print(len(train))
print(len(val))
print(len(test))
# print(len(auto_test))

In [None]:
import numpy as np
from torch.utils.data import Dataset, DataLoader

class KneeOsteoarthritis(Dataset):
    def __init__(self, dataset):
        self.images = []
        self.labels = []
        
        for data in dataset:
            self.images.append(data[0])
            self.labels.append(data[1])
            
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        return image, label

In [None]:
train_dataset = KneeOsteoarthritis(train)
# val_dataset = KneeOsteoarthritis(val)
# test_dataset = KneeOsteoarthritis(test)
# auto_test_dataset = KneeOsteoarthritis(auto_test)

In [None]:
# train_dataset[0]

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
imshow(train_dataset[0][0])

In [None]:
from collections import Counter

freq_table = dict(Counter(train_dataset.labels))
least_class_frequency = min(freq_table.values())

print(freq_table, least_class_frequency, list(freq_table.values()))

In [None]:
class_sample_count = np.array(freq_table.values())
print(class_sample_count)
weights = np.zeros(len(train_dataset.labels))
for i, weight in enumerate(weights):
    label = train_dataset.labels[i]
    weights[i] = 1 / freq_table[label]
    
print(weights)
samples_weight = torch.from_numpy(weights)
samples_weigth = samples_weight.double()
sampler = torch.utils.data.WeightedRandomSampler(samples_weight, len(samples_weight))

In [None]:
print(samples_weigth)

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, sampler=sampler)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models import resnet18, ResNet18_Weights

class ResnetTL(nn.Module):
    def __init__(self, num_classes: int = 5, dropout: float = 0.5) -> None:
        super().__init__()
        
        weights = ResNet18_Weights.DEFAULT
        self.resnet18 = resnet18(weights=weights, progress=False)
        
        # self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(1000, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.resnet18(x)
        # print(x.shape)
        x = self.classifier(x)
        return x
      
net = ResnetTL(3)
net = net.to(device)

In [None]:
sum(p.numel() for p in net.parameters())

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.classifier.parameters(), lr=0.001)

In [None]:
LOGGING_FREQ = 100

In [None]:
for epoch in range(20):  # loop over the dataset multiple times

    epoch_correct = 0
    running_correct = 0
    running_loss = 0.0
    samples_epoch = 0
    samples_running = 0
    for i, data in enumerate(train_loader, 0):
      
        # print(i, data)

        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()
        # print(inputs.shape)
        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        outputs_clear = outputs.max(1).indices
        # print(outputs, outputs_clear, labels.shape, (outputs_clear == labels).float().sum())
        
        epoch_correct += (outputs_clear == labels).float().sum()
        running_correct += (outputs_clear == labels).float().sum()
          
        samples_epoch += len(outputs)
        samples_running += len(outputs)
        # print(f'step={i}, labels: {labels}')

        # print statistics
        running_loss += loss.item()
        if i % LOGGING_FREQ == LOGGING_FREQ-1:
            accuracy = LOGGING_FREQ * running_correct / samples_running
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / LOGGING_FREQ:.3f}, accuracy: {accuracy}%')
            # print(outputs)
            
            running_correct = 0
            samples_running = 0
            running_loss = 0.0

print('Finished Training')