In [1]:
!pip install pytorch-metric-learning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import os
import json
import zipfile
import subprocess
import shutil
import getpass
import math
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image,ImageReadMode
import matplotlib.pyplot as plt
from pytorch_metric_learning import losses, regularizers
from torchsummary import summary

In [3]:
torch.manual_seed(20)

<torch._C.Generator at 0x7fd2ec0b90f0>

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
print(f"Using {device} device")

Using cuda device


In [5]:
dataset_save_dir = './dataset'

In [9]:
def one_hot_encode(val):
    arr = numpy.zeros((6,), dtype=int)
    arr[val] = 1
    return arr

def get_bucket_id(age):
  age_floor = int(age)
  if age_floor >= 0 and age_floor <= 5: return 0
  elif age_floor >= 6 and age_floor <= 12: return 1
  elif age_floor >= 13 and age_floor <= 19: return 2
  elif age_floor >= 20 and age_floor <= 29: return 3
  elif age_floor >= 30 and age_floor <= 59: return 4
  else: return 5

def get_ground_truth(age):
  return one_hot_encode(get_bucket_id(age))

In [10]:
def get_random_two_different_int(low=0, high=6, size=1):
  num1 = torch.randint(low,high, (size,)).item()
  num2 = torch.randint(low,high, (size,)).item()
  while num1 == num2: num2 = torch.randint(low,high, (size,)).item()
  return num1,num2

In [11]:
def train(dataloader, model, loss_fn, optimizer):
    torch.cuda.empty_cache()
    size = len(dataloader.dataset)
    model.train()
    loss_tot = 0.0
    num = 0
    for batch, (X1, y1, X2, y2) in enumerate(dataloader):
        X1, y1, X2, y2 = X1.to(device), y1, X2.to(device), y2
        targets = torch.eq(y1.argmax(dim=1), y2.argmax(dim=1)).to(torch.float32).to(device)

        # Forward
        optimizer.zero_grad()
        outputs = model(X1, X2)
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()

        loss_tot += loss.item()
        num += 1
        
        X1.cpu()
        X2.cpu()
        targets.cpu()

        # Gather data and report
        if batch % 4 == 0:
            current = (batch + 1) * len(X1)
            print(f"loss: {loss.item():>7f}  [{current:>5d}/{size:>5d}]")

    # loss_tot /= num
    print(f'training loss: {(loss_tot):>0.5f}')

In [12]:
validation_accuracy = []
current_max_val_acc = 0.0
def validation(dataset, model, loss_fn):
    model.eval()
    global current_max_val_acc
    size = len(dataset)
    correct = 0
    total = 0
    loss_tot = 0
    with torch.no_grad():
        for i in range(size):
            XQ, yQ = dataset[i]
            XQ, yQ = XQ.reshape(-1,3,224,224).to(device), yQ.to(device)
            best = [-1000,-1000,-1000,-1000,-1000,-1000]
            rem = [1,1,1,1,1,1]
            anchors = []
            for j in range(size):
                if i == j: continue
                classLabel = (dataset[j][1]).argmax(dim=0)
                if rem[classLabel] <= 0: continue
                anchors.append(j)
                rem[classLabel] = rem[classLabel] - 1
                if sum(rem) <= 0: break 

            for j in anchors:
                XR, yR = dataset[j]
                classLabel = yR.argmax(dim=0)
                XR, yR = XR.reshape(-1,3,224,224).to(device), yR.to(device)
                rem[classLabel] = rem[classLabel] - 1
                output = model(XQ,XR)
                targets = torch.eq(yQ.argmax(dim=0), yR.argmax(dim=0)).to(torch.float32).reshape(-1).to(device)
                loss = loss_fn(output, targets)
                best[classLabel] = max(best[classLabel], output.cpu().numpy())
                loss_tot += loss.item() 
            classified_as = numpy.argmax(best)
            correct += (classified_as == yQ.argmax(dim=0).cpu())
            total = total + 1
            
    print(f"Correct/Total: {correct}/{total}")
    correct = correct*1.0 / total
    validation_accuracy.append(correct*100)
    print(f"Validation Accuracy: {(100*correct):>0.5f}%\n")
    current_max_val_acc = max(current_max_val_acc,100*correct)
    print(f"Current Best Validation Accuracy: {(current_max_val_acc):>0.5f}%\n")
    return loss_tot

In [13]:
data_augmentation_transformations = T.RandomChoice([ # Geometric Transformation
    T.RandomAffine(degrees=0),
    T.Lambda(lambda x: TF.hflip(img=x))
    # T.RandomAffine(degrees=0), # No Transformation
    # Geometric Transformations:
    # T.RandomAffine(degrees=0, scale=(1.3,1.3)), # Scale
    # T.RandomAffine(degrees=0, translate=(0.5,0.5)), # Translate
    # T.RandomAffine(degrees=(-8, 8)), # Rotate
    # T.Lambda(lambda x: TF.hflip(img=x)), # Reflect
    # Skipping Shearing & Skewing as they don't make sense in this context of Teeth X-Ray
    # Occlusion:
    # T.Compose([T.RandomErasing(p=1, scale=(0.0008, 0.0008), ratio=(1,1))]*100), # Occlusion
    # T.Compose([T.RandomErasing(p=1, scale=(0.0008, 0.0008), ratio=(1,1))]*100), # Occlusion
    # T.Compose([T.RandomErasing(p=1, scale=(0.0008, 0.0008), ratio=(1,1))]*100), # Occlusion
    # Intensity Operations
    # T.Lambda(lambda x: TF.adjust_gamma(img=x, gamma=0.5)), # Gamma Contrast
    # T.Lambda(lambda x: TF.adjust_contrast(x, contrast_factor=2.0)), # Linear Contrast
    # Histogram Equalizer skipped as we need to typecast it to uint8 for that
    # Skipping Noise injection as we want to easily normalize it later 
    # Filtering:
    # T.Lambda(lambda x: TF.adjust_sharpness(img=x, sharpness_factor=4)), #Sharpen
    # T.GaussianBlur(kernel_size=(15,15), sigma=(0.01, 1)), # Gaussian Blur
])  

In [14]:
class XRayToothDatasetPosNeg(Dataset):
    def __init__(self, cwd, img_dir, transform=None, target_height=None, target_width=None):
        self.dataset_path = cwd + '/' + img_dir
        self.transform = transform
        self.target_height = target_height
        self.target_width = target_width

        # Group the examples based on their label
        self.grouped_examples = {}
        for filename in os.listdir(self.dataset_path):
            age_bracket = get_bucket_id(float(filename.split("_")[1][:-4]))
            if age_bracket not in self.grouped_examples: self.grouped_examples[age_bracket] = []
            self.grouped_examples[age_bracket].append(filename)

    def __len__(self):
        return len(os.listdir(self.dataset_path))

    def get_datasample(self, img_filename):
        age = float(img_filename.split("_")[1][:-4])
        age_gt = get_ground_truth(age)
        image_tensor = read_image(path=self.dataset_path + '/' + img_filename)
        image_tensor = image_tensor.reshape(1, 3, image_tensor.shape[-2], image_tensor.shape[-1])
        if self.target_height and self.target_width: # Resize the image 
            image_tensor = torch.nn.functional.interpolate(image_tensor, (self.target_height,self.target_width))
        if self.transform: image_tensor = self.transform(image_tensor) # Apply transformations
        image_tensor = (image_tensor-image_tensor.min())/(image_tensor.max()-image_tensor.min())
        return image_tensor.reshape(-1,image_tensor.shape[-2],image_tensor.shape[-1]).to(torch.float32), torch.tensor(age_gt)

    def __getitem__(self, idx):
        if idx  >= len(os.listdir(self.dataset_path)):
            print("No datafile/image at index : "+ str(idx))
            return None

        if idx%2 == 0: # Give them a positive example
            selected_class = torch.randint(0,6,(1,)).item()
            index1,index2 = get_random_two_different_int(0,len(self.grouped_examples[selected_class]),1)
            image1, label1 = self.get_datasample(self.grouped_examples[selected_class][index1])
            image2, label2 = self.get_datasample(self.grouped_examples[selected_class][index2])
            return (image1,label1,image2,label2)
        else:
            # Give them a negative example
            selected_class1, selected_class2 = get_random_two_different_int(0,6,1)
            index1 = torch.randint(0,len(self.grouped_examples[selected_class1]),(1,)).item()
            index2 = torch.randint(0,len(self.grouped_examples[selected_class2]),(1,)).item()
            image1, label1 = self.get_datasample(self.grouped_examples[selected_class1][index1])
            image2, label2 = self.get_datasample(self.grouped_examples[selected_class2][index2])
            return (image1,label1,image2,label2)
        

In [15]:
class XRayToothDataset(Dataset):
    def __init__(self, cwd, img_dir, transform=None, target_height=None, target_width=None):
        self.dataset_path = cwd + '/' + img_dir
        self.transform = transform
        self.target_height = target_height
        self.target_width = target_width

    def __len__(self):
        return len(os.listdir(self.dataset_path))

    def __getitem__(self, idx):
        if idx  >= len(os.listdir(self.dataset_path)):
            print("No datafile/image at index : "+ str(idx))
            return None
        img_filename = os.listdir(self.dataset_path)[idx]
        age = float(img_filename.split("_")[1][:-4])
        age_gt = get_ground_truth(age)
        image_tensor = read_image(path=self.dataset_path + '/' + img_filename)
        image_tensor = image_tensor.reshape(1, 3, image_tensor.shape[-2], image_tensor.shape[-1])
        if self.target_height and self.target_width: # Resize the image 
            image_tensor = torch.nn.functional.interpolate(image_tensor, (self.target_height,self.target_width))
        if self.transform: image_tensor = self.transform(image_tensor) # Apply transformations
        image_tensor = (image_tensor-image_tensor.min())/(image_tensor.max()-image_tensor.min())
        return image_tensor.reshape(-1,image_tensor.shape[-2],image_tensor.shape[-1]).to(torch.float32), torch.tensor(age_gt)

In [16]:
training_data = XRayToothDatasetPosNeg(os.getcwd(), img_dir=dataset_save_dir+'/training', transform=data_augmentation_transformations, target_height=224, target_width=224)
validation_data = XRayToothDataset(os.getcwd(), img_dir=dataset_save_dir+'/validation', transform=None, target_height=224, target_width=224)

In [17]:
from torchvision.models import vit_l_32, ViT_L_32_Weights

pretrained_vit = vit_l_32(weights=ViT_L_32_Weights.IMAGENET1K_V1)

Downloading: "https://download.pytorch.org/models/vit_l_32-c7638314.pth" to /root/.cache/torch/hub/checkpoints/vit_l_32-c7638314.pth
100%|██████████| 1.14G/1.14G [00:13<00:00, 89.4MB/s]


In [18]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = pretrained_vit

        for param in self.backbone.parameters():
            param.requires_grad = False

        self.fc = nn.Sequential(
            nn.Dropout(0.6),
            nn.Linear(1000,6)
        )

        self.sigmoid = nn.Sigmoid()
    
    def forward(self, xR, xQ):
        e1 = self.fc(self.backbone(xR))
        e2 = self.fc(self.backbone(xQ))

        similarity = torch.linalg.vecdot(e1, e2)
        output = self.sigmoid(similarity)
        
        return output

In [19]:
model = NeuralNetwork().to(device)

In [20]:
# Test a forward pass
image1,label1,image2,label2 = training_data[0]
with torch.no_grad():
    print(model(image1.reshape(-1,3,224,224).to(device), image2.reshape(-1,3,224,224).to(device)))

tensor([0.3061], device='cuda:0')


In [21]:
# Training Hyperparameters
epochs = 500
batch_size = 10
learning_rate = 1e-4
momentum=0.9
weight_decay=0.05

In [22]:
training_data_loader = DataLoader(training_data, batch_size, shuffle = True)
validation_data_loader = DataLoader(validation_data, batch_size, shuffle = False)

In [23]:
loss_function=nn.BCELoss()
optimizer=torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=5, min_lr=1e-4,verbose=True)

In [None]:
# Training
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(training_data_loader, model, loss_function, optimizer)
    val_loss = validation(validation_data, model, loss_function)
    scheduler.step(val_loss)
    # torch.save(model, 'model.pth')
print("Done!")

Epoch 1
-------------------------------
loss: 1.335649  [   10/  296]
loss: 1.563321  [   50/  296]
loss: 1.619494  [   90/  296]
loss: 0.663186  [  130/  296]
loss: 1.432854  [  170/  296]
loss: 0.483766  [  210/  296]
loss: 0.612344  [  250/  296]
loss: 0.877295  [  290/  296]
training loss: 29.29094
Correct/Total: 16/129
Validation Accuracy: 12.40310%

Current Best Validation Accuracy: 12.40310%

Epoch 2
-------------------------------
loss: 0.783023  [   10/  296]
loss: 0.423807  [   50/  296]
loss: 0.700526  [   90/  296]
loss: 0.942276  [  130/  296]
loss: 0.624152  [  170/  296]
loss: 1.184766  [  210/  296]
loss: 1.323326  [  250/  296]
loss: 0.501082  [  290/  296]
training loss: 28.05611
Correct/Total: 15/129
Validation Accuracy: 11.62791%

Current Best Validation Accuracy: 12.40310%

Epoch 3
-------------------------------
loss: 1.720139  [   10/  296]
loss: 0.952585  [   50/  296]
loss: 1.073041  [   90/  296]
loss: 0.411474  [  130/  296]
loss: 0.515103  [  170/  296]
loss

In [None]:
validation(validation_data, model, loss_function)

In [None]:
plt.plot(validation_accuracy)