In [1]:
!pip install pytorch-metric-learning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import os
import json
import zipfile
import subprocess
import shutil
import getpass
import math
import numpy
import pandas
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image,ImageReadMode
import matplotlib.pyplot as plt
import plotly.express as px
from pytorch_metric_learning import losses, regularizers
from torchsummary import summary

In [3]:
torch.manual_seed(20)

<torch._C.Generator at 0x7f0f39d99010>

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
print(f"Using {device} device")

Using cuda device


In [5]:
dataset_save_dir = './dataset'

In [9]:
def one_hot_encode(val):
    arr = numpy.zeros((6,), dtype=int)
    arr[val] = 1
    return arr

def get_bucket_id(age):
  age_floor = int(age)
  if age_floor >= 0 and age_floor <= 5: return 0
  elif age_floor >= 6 and age_floor <= 12: return 1
  elif age_floor >= 13 and age_floor <= 19: return 2
  elif age_floor >= 20 and age_floor <= 29: return 3
  elif age_floor >= 30 and age_floor <= 59: return 4
  else: return 5

def get_ground_truth(age):
  return one_hot_encode(get_bucket_id(age))

In [10]:
def train(dataloader, model, optimizer, loss_function, loss_optimizer):
    torch.cuda.empty_cache()
    size = len(dataloader.dataset)
    model.train()
    loss_tot = 0.0
    num = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Zeroing the gradients
        optimizer.zero_grad()
        loss_optimizer.zero_grad()

        # Get Embeddings
        embeddings = model(X)

        # Calculate Loss
        loss = loss_function(embeddings, y.argmax(1))

        # Backpropagation
        loss.backward()

        # Update
        optimizer.step()
        loss_optimizer.step()

        loss_tot += loss.item()
        num += 1
        
        X.cpu()
        y.cpu()

    # loss_tot /= num
    print(f'training loss: {(loss_tot):>0.5f}')

In [11]:
validation_accuracy = []
current_max_val_acc = 0.0
def validation(dataloader, model, loss_function):
    global current_max_val_acc
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    correct = 0
    totalsize = 0
    loss_tot = 0.0
    num = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            # Forward With Logits
            embeddings = model(X)        
            logits = loss_function.get_logits(embeddings)
            predSoftmax = nn.Softmax(dim=1)(logits)
            loss = loss_function(embeddings,y.argmax(1))
            
            correct += (predSoftmax.argmax(1) == y.argmax(1)).sum().item()
            totalsize += predSoftmax.shape[0]
            loss_tot += loss.item()
            num += 1
            X.cpu()
            y.cpu()
            
    print(f"Correct/Total: {correct}/{totalsize}")
    correct /= totalsize
    validation_accuracy.append(correct*100)
    print(f"Validation Loss:  {(loss_tot):>0.5f}")
    print(f"Validation Accuracy: {(100*correct):>0.5f}%\n")
    current_max_val_acc = max(current_max_val_acc,100*correct)
    print(f"Current Best Validation Accuracy: {(current_max_val_acc):>0.5f}%\n")
    return loss_tot

In [12]:
class XRayToothDataset(Dataset):
    def __init__(self, cwd, img_dir, transform=None, target_height=None, target_width=None):
        self.dataset_path = cwd + '/' + img_dir
        self.transform = transform
        self.target_height = target_height
        self.target_width = target_width

    def __len__(self):
        return len(os.listdir(self.dataset_path))

    def __getitem__(self, idx):
        if idx  >= len(os.listdir(self.dataset_path)):
            print("No datafile/image at index : "+ str(idx))
            return None
        img_filename = os.listdir(self.dataset_path)[idx]
        age = float(img_filename.split("_")[1][:-4])
        age_gt = get_ground_truth(age)
        image_tensor = read_image(path=self.dataset_path + '/' + img_filename)
        image_tensor = image_tensor.reshape(1, 3, image_tensor.shape[-2], image_tensor.shape[-1])
        if self.target_height and self.target_width: # Resize the image 
            image_tensor = torch.nn.functional.interpolate(image_tensor, (self.target_height,self.target_width))
        if self.transform: image_tensor = self.transform(image_tensor) # Apply transformations
        image_tensor = (image_tensor-image_tensor.min())/(image_tensor.max()-image_tensor.min())
        return image_tensor.reshape(-1,image_tensor.shape[-2],image_tensor.shape[-1]).to(torch.float32), torch.tensor(age_gt)

In [13]:
# Data Augmentation Transformations 
data_augmentation_transformations = T.RandomChoice([
    T.RandomAffine(degrees=0), # No Augmentation
    T.Lambda(lambda x: TF.hflip(img=x)) # Horizontal Flip

    # Geometric Transformations:
    # T.RandomAffine(degrees=0, scale=(1.3,1.3)), # Scale
    # T.RandomAffine(degrees=0, translate=(0.5,0.5)), # Translate
    # T.RandomAffine(degrees=(-8, 8)), # Rotate
    # T.Lambda(lambda x: TF.hflip(img=x)), # Reflect

    # Occlusion:
    # T.Compose([T.RandomErasing(p=1, scale=(0.0008, 0.0008), ratio=(1,1))]*100), # Occlusion

    # Intensity Operations
    # T.Lambda(lambda x: TF.adjust_gamma(img=x, gamma=0.5)), # Gamma Contrast
    # T.Lambda(lambda x: TF.adjust_contrast(x, contrast_factor=2.0)), # Linear Contrast

    # Filtering:
    # T.Lambda(lambda x: TF.adjust_sharpness(img=x, sharpness_factor=4)), #Sharpen
    # T.GaussianBlur(kernel_size=(15,15), sigma=(0.01, 1)), # Gaussian Blur
])  

In [14]:
training_data = XRayToothDataset(os.getcwd(), img_dir=dataset_save_dir+'/training', transform=data_augmentation_transformations, target_height=384, target_width=384)
validation_data = XRayToothDataset(os.getcwd(), img_dir=dataset_save_dir+'/validation', transform=None, target_height=384, target_width=384)

In [15]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3),

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3),
            
            nn.Conv2d(in_channels=256, out_channels=280, kernel_size=2),
            nn.ReLU(),

            nn.Flatten(),
        )

        self.fc = nn.Sequential(
            nn.Linear(2520,1024),
            nn.ReLU(),
            nn.Linear(1024,512)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.fc(x)
        return x

In [16]:
model = NeuralNetwork().to(device)

In [17]:
summary(model, (3, 384, 384))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 383, 383]             416
              ReLU-2         [-1, 32, 383, 383]               0
         MaxPool2d-3         [-1, 32, 127, 127]               0
            Conv2d-4         [-1, 64, 126, 126]           8,256
              ReLU-5         [-1, 64, 126, 126]               0
         MaxPool2d-6           [-1, 64, 42, 42]               0
            Conv2d-7          [-1, 128, 41, 41]          32,896
              ReLU-8          [-1, 128, 41, 41]               0
         MaxPool2d-9          [-1, 128, 13, 13]               0
           Conv2d-10          [-1, 256, 12, 12]         131,328
             ReLU-11          [-1, 256, 12, 12]               0
        MaxPool2d-12            [-1, 256, 4, 4]               0
           Conv2d-13            [-1, 280, 3, 3]         287,000
             ReLU-14            [-1, 28

In [18]:
# Test a forward pass
with torch.no_grad():
    model.eval()
    data = training_data[0][0].reshape(-1,3,384,384).to(device)
    label = training_data[0][1].reshape(-1,6).to(device)
    embed = model(data)
    print(embed)

tensor([[ 1.4981e-02,  1.0995e-02, -2.3925e-02, -1.3316e-03, -2.2172e-02,
         -2.4104e-02, -2.9600e-02,  5.2513e-03,  7.5408e-03, -1.3752e-02,
          4.8989e-04,  4.5686e-02, -1.7035e-03,  6.2442e-03,  1.8329e-02,
          5.0660e-02,  6.4145e-03,  1.0992e-03, -1.7290e-02, -3.1199e-02,
          2.0574e-02,  3.2510e-02, -1.6959e-02,  1.5877e-02, -1.0236e-02,
          7.3419e-03,  1.3815e-02,  1.7355e-02, -2.9637e-03, -6.7780e-03,
         -9.3905e-03, -6.9001e-03, -1.6044e-02,  4.7284e-03, -6.3638e-03,
          3.3319e-02, -1.1262e-02,  3.0464e-02, -2.6038e-02,  2.2370e-02,
         -9.9354e-03,  2.2141e-02,  2.3058e-02, -1.7059e-02,  2.6607e-02,
         -2.8548e-02, -7.6964e-03,  5.6386e-03,  9.3916e-03,  6.8018e-03,
          2.5219e-02, -3.5884e-02,  1.7105e-03, -1.4890e-03, -8.5228e-03,
         -2.2192e-02,  8.3172e-03,  2.8158e-02, -2.4753e-02, -5.3097e-03,
          6.2231e-03, -2.1612e-02, -1.9070e-02,  1.0484e-02,  9.4734e-03,
         -1.1675e-02, -3.3699e-02,  1.

In [19]:
# Training Hyperparameters
epochs = 100
batch_size = 100
learning_rate = 1e-2
margin_loss_learning_rate = 1e-2
momentum=0.9
weight_decay=1e-2

In [20]:
training_data_loader = DataLoader(training_data, batch_size, shuffle = True)
validation_data_loader = DataLoader(validation_data, batch_size, shuffle = False)

In [21]:
# Loss
margin_loss_regularizer = regularizers.RegularFaceRegularizer()
margin_loss_function = losses.ArcFaceLoss(6, 512, margin=34.3, scale=1, weight_regularizer=margin_loss_regularizer).to(device)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_optimizer = torch.optim.SGD(margin_loss_function.parameters(), lr=margin_loss_learning_rate)

# Schedulers
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=5, min_lr=1e-8,verbose=True)
margin_loss_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=5, min_lr=1e-8,verbose=True)

In [22]:
# Training
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(training_data_loader, model, optimizer, margin_loss_function, loss_optimizer)
    val_loss = validation(validation_data_loader, model, margin_loss_function)
    scheduler.step(val_loss)
    # margin_loss_scheduler.step(val_loss)
print("Done!")

Epoch 1
-------------------------------
training loss: 6.38184
Correct/Total: 67/129
Validation Loss:  3.99504
Validation Accuracy: 51.93798%

Current Best Validation Accuracy: 51.93798%

Epoch 2
-------------------------------
training loss: 5.82019
Correct/Total: 67/129
Validation Loss:  3.94047
Validation Accuracy: 51.93798%

Current Best Validation Accuracy: 51.93798%

Epoch 3
-------------------------------
training loss: 5.76879
Correct/Total: 67/129
Validation Loss:  3.92588
Validation Accuracy: 51.93798%

Current Best Validation Accuracy: 51.93798%

Epoch 4
-------------------------------
training loss: 5.74991
Correct/Total: 67/129
Validation Loss:  3.92495
Validation Accuracy: 51.93798%

Current Best Validation Accuracy: 51.93798%

Epoch 5
-------------------------------
training loss: 5.73735
Correct/Total: 67/129
Validation Loss:  3.91809
Validation Accuracy: 51.93798%

Current Best Validation Accuracy: 51.93798%

Epoch 6
-------------------------------
training loss: 5.728