In [None]:
!pip install pytorch-metric-learning
!pip install torch-summary

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import os
import json
import zipfile
import subprocess
import shutil
import getpass
import math
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image,ImageReadMode
import matplotlib.pyplot as plt
from pytorch_metric_learning import losses, regularizers
from torchsummary import summary

In [None]:
torch.manual_seed(20)

<torch._C.Generator at 0x7f9d70241b50>

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
print(f"Using {device} device")

Using cuda device


In [None]:
dataset_save_dir = './dataset'

In [None]:
def one_hot_encode(val):
    arr = numpy.zeros((6,), dtype=int)
    arr[val] = 1
    return arr

def get_bucket_id(age):
  age_floor = int(age)
  if age_floor >= 0 and age_floor <= 5: return 0
  elif age_floor >= 6 and age_floor <= 12: return 1
  elif age_floor >= 13 and age_floor <= 19: return 2
  elif age_floor >= 20 and age_floor <= 29: return 3
  elif age_floor >= 30 and age_floor <= 59: return 4
  else: return 5

def get_ground_truth(age):
  return one_hot_encode(get_bucket_id(age))

In [None]:
def get_random_two_different_int(low=0, high=6, size=1):
  num1 = torch.randint(low,high, (size,)).item()
  num2 = torch.randint(low,high, (size,)).item()
  while num1 == num2: num2 = torch.randint(low,high, (size,)).item()
  return num1,num2

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    torch.cuda.empty_cache()
    size = len(dataloader.dataset)
    model.train()
    loss_tot = 0.0
    num = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        embeddings = model.forward_once(X)
        pairwise_distances = torch.cdist(embeddings,embeddings)

        # Find triplets
        triplets = []
        embedding_anchor = []
        embedding_pos = []
        embedding_neg = []
        batch_size = embeddings.size(0)
        embedding_dim = embeddings.size(1)

        for i in range(batch_size):
            label_i = y[i]
            positive_indices = torch.where(y == label_i)[0]
            negative_indices = torch.where(y != label_i)[0]
            farthest_pos_index = positive_indices[torch.argmax(pairwise_distances[i][positive_indices])]
            nearest_neg_index = negative_indices[torch.argmin(pairwise_distances[i][negative_indices])]
            embedding_anchor.append(embeddings[i])
            embedding_pos.append(embeddings[farthest_pos_index])
            embedding_neg.append(embeddings[nearest_neg_index])

        embedding_anchor = torch.stack(embedding_anchor, dim=0)
        embedding_pos = torch.stack(embedding_pos, dim=0)
        embedding_neg = torch.stack(embedding_neg, dim=0)

        # Forward
        optimizer.zero_grad()
        ea, ep, en = model.forward_only_fc(embedding_anchor, embedding_pos, embedding_neg)
    
        loss = loss_fn(ea, ep, en)
        loss.backward()
        optimizer.step()

        loss_tot += loss.item()
        num += 1

        X.cpu()
        y.cpu()

        # Gather data and report
        if batch % 4 == 0:
            current = (batch + 1) * len(X)
            print(f"loss: {loss.item():>7f}  [{current:>5d}/{size:>5d}]")

    # loss_tot /= num
    print(f'training loss: {(loss_tot):>0.5f}')

In [None]:
validation_accuracy = []
current_max_val_acc = 0.0
def validation(dataloader, model, loss_fn):
    global current_max_val_acc
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    correct = 0
    totalsize = 0
    loss_tot = 0.0
    num = 0
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)

            embeddings = model.forward_once(X)
            pairwise_distances = torch.cdist(embeddings,embeddings)

            # Find triplets
            triplets = []
            embedding_anchor = []
            embedding_pos = []
            embedding_neg = []
            batch_size = embeddings.size(0)
            embedding_dim = embeddings.size(1)

            for i in range(batch_size):
                label_i = y[i]
                positive_indices = torch.where(y == label_i)[0]
                negative_indices = torch.where(y != label_i)[0]
                farthest_pos_index = positive_indices[torch.argmax(pairwise_distances[i][positive_indices])]
                nearest_neg_index = negative_indices[torch.argmin(pairwise_distances[i][negative_indices])]
                embedding_anchor.append(embeddings[i])
                embedding_pos.append(embeddings[farthest_pos_index])
                embedding_neg.append(embeddings[nearest_neg_index])

            embedding_anchor = torch.stack(embedding_anchor, dim=0)
            embedding_pos = torch.stack(embedding_pos, dim=0)
            embedding_neg = torch.stack(embedding_neg, dim=0)

            # Forward
            ea, ep, en = model.forward_only_fc(embedding_anchor, embedding_pos, embedding_neg)
            
            distances = torch.nn.functional.pairwise_distance(ea, en) - torch.nn.functional.pairwise_distance(ea, ep)
            predictions = (distances >= 0).float()

            correct += torch.sum(predictions).item()
            totalsize += ea.shape[0]

            loss = loss_fn(ea, ep, en)

            loss_tot += loss.item()

            X.cpu()
            y.cpu()
            
    print(f"Correct/Total: {correct}/{totalsize}")
    correct /= totalsize
    validation_accuracy.append(correct*100)
    print(f"Validation Loss:  {(loss_tot):>0.5f}")
    print(f"Validation Accuracy: {(100*correct):>0.5f}%\n")
    current_max_val_acc = max(current_max_val_acc,100*correct)
    print(f"Current Best Validation Accuracy: {(current_max_val_acc):>0.5f}%\n")
    return loss_tot

In [None]:
data_augmentation_transformations = T.RandomChoice([ # Geometric Transformation
    T.RandomAffine(degrees=0),
    T.Lambda(lambda x: TF.hflip(img=x))
    # T.RandomAffine(degrees=0), # No Transformation
    # Geometric Transformations:
    # T.RandomAffine(degrees=0, scale=(1.3,1.3)), # Scale
    # T.RandomAffine(degrees=0, translate=(0.5,0.5)), # Translate
    # T.RandomAffine(degrees=(-8, 8)), # Rotate
    # T.Lambda(lambda x: TF.hflip(img=x)), # Reflect
    # Skipping Shearing & Skewing as they don't make sense in this context of Teeth X-Ray
    # Occlusion:
    # T.Compose([T.RandomErasing(p=1, scale=(0.0008, 0.0008), ratio=(1,1))]*100), # Occlusion
    # T.Compose([T.RandomErasing(p=1, scale=(0.0008, 0.0008), ratio=(1,1))]*100), # Occlusion
    # T.Compose([T.RandomErasing(p=1, scale=(0.0008, 0.0008), ratio=(1,1))]*100), # Occlusion
    # Intensity Operations
    # T.Lambda(lambda x: TF.adjust_gamma(img=x, gamma=0.5)), # Gamma Contrast
    # T.Lambda(lambda x: TF.adjust_contrast(x, contrast_factor=2.0)), # Linear Contrast
    # Histogram Equalizer skipped as we need to typecast it to uint8 for that
    # Skipping Noise injection as we want to easily normalize it later 
    # Filtering:
    # T.Lambda(lambda x: TF.adjust_sharpness(img=x, sharpness_factor=4)), #Sharpen
    # T.GaussianBlur(kernel_size=(15,15), sigma=(0.01, 1)), # Gaussian Blur
])  

In [None]:
class XRayToothDataset(Dataset):
    def __init__(self, cwd, img_dir, transform=None, target_height=None, target_width=None):
        self.dataset_path = cwd + '/' + img_dir
        self.transform = transform
        self.target_height = target_height
        self.target_width = target_width

    def __len__(self):
        return len(os.listdir(self.dataset_path))

    def __getitem__(self, idx):
        if idx  >= len(os.listdir(self.dataset_path)):
            print("No datafile/image at index : "+ str(idx))
            return None
        img_filename = os.listdir(self.dataset_path)[idx]
        age = float(img_filename.split("_")[1][:-4])
        age_gt = get_bucket_id(age)
        image_tensor = read_image(path=self.dataset_path + '/' + img_filename)
        image_tensor = image_tensor.reshape(1, 3, image_tensor.shape[-2], image_tensor.shape[-1])
        if self.target_height and self.target_width: # Resize the image 
            image_tensor = torch.nn.functional.interpolate(image_tensor, (self.target_height,self.target_width))
        if self.transform: image_tensor = self.transform(image_tensor) # Apply transformations
        image_tensor = (image_tensor-image_tensor.min())/(image_tensor.max()-image_tensor.min())
        return image_tensor.reshape(-1,image_tensor.shape[-2],image_tensor.shape[-1]).to(torch.float32), torch.tensor(age_gt)

In [None]:
training_data = XRayToothDataset(os.getcwd(), img_dir=dataset_save_dir+'/training', transform=data_augmentation_transformations, target_height=224, target_width=224)
validation_data = XRayToothDataset(os.getcwd(), img_dir=dataset_save_dir+'/validation', transform=None, target_height=224, target_width=224)

In [None]:
from torchvision.models import vit_l_32, ViT_L_32_Weights

pretrained_vit = vit_l_32(weights=ViT_L_32_Weights.IMAGENET1K_V1)

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = pretrained_vit

        for param in self.backbone.parameters():
            param.requires_grad = False

        self.fc = nn.Sequential(
            nn.Dropout(0.6),
            nn.Linear(1000,512),
            nn.Dropout(0.7),
            nn.Linear(512,6)
        )

    def forward_once(self, x):
        x = self.backbone(x)
        return x
    
    def forward_only_fc(self, ea, ep, en):
        ea = self.fc(ea)
        ep = self.fc(ep)
        en = self.fc(en)
        
        return ea, ep, en
    
    def forward(self, xA, xP, xN):
        ea = self.fc(self.forward_once(xA))
        ep = self.fc(self.forward_once(xP))
        en = self.fc(self.forward_once(xN))
        
        return ea, ep, en

In [None]:
model = NeuralNetwork().to(device)

In [None]:
# Training Hyperparameters
epochs = 600
batch_size = 50
learning_rate = 1e-2
momentum=0.9
weight_decay=0.05

In [None]:
training_data_loader = DataLoader(training_data, batch_size, shuffle = True)
validation_data_loader = DataLoader(validation_data, batch_size, shuffle = False)

In [None]:
loss_function=nn.TripletMarginLoss(margin=5.0)
optimizer=torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=5, min_lr=1e-4,verbose=True)

In [65]:
# Training
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(training_data_loader, model, loss_function, optimizer)
    val_loss = validation(validation_data_loader, model, loss_function)
    scheduler.step(val_loss)
    # torch.save(model, 'model.pth')
print("Done!")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m

Current Best Validation Accuracy: 32.55814%

Epoch 149
-------------------------------
loss: 4.818784  [   50/  296]
loss: 4.944316  [  250/  296]
training loss: 25.18595
Correct/Total: 26.0/129
Validation Loss:  17.10046
Validation Accuracy: 20.15504%

Current Best Validation Accuracy: 32.55814%

Epoch 150
-------------------------------
loss: 5.155196  [   50/  296]
loss: 5.029342  [  250/  296]
training loss: 27.26602
Correct/Total: 30.0/129
Validation Loss:  17.07662
Validation Accuracy: 23.25581%

Current Best Validation Accuracy: 32.55814%

Epoch 00150: reducing learning rate of group 0 to 1.0942e-03.
Epoch 151
-------------------------------
loss: 2.581809  [   50/  296]
loss: 3.566315  [  250/  296]
training loss: 24.41674
Correct/Total: 35.0/129
Validation Loss:  16.82063
Validation Accuracy: 27.13178%

Current Best Validation Accuracy: 32.55814%

Epoch 152
-------------------------------
loss: 5.687599  [   50/