In [None]:
!pip install pytorch-metric-learning

In [None]:
import os
import json
import zipfile
import subprocess
import shutil
import getpass
import math
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image,ImageReadMode
import matplotlib.pyplot as plt
from pytorch_metric_learning import losses, regularizers
from torchsummary import summary
from torchvision.models.vision_transformer import VisionTransformer
import copy

In [None]:
torch.manual_seed(20)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
print(f"Using {device} device")

In [None]:
random_generator = torch.Generator(device='cpu')
random_generator.manual_seed(2147483647)

In [None]:
dataset_save_dir = './dataset'

# File Paths
save_current_best_model_path = './triplet_classifier_model.pth'
save_validation_accuracy_path = './triplet_classifier_model_val_acc.npy'

In [None]:
def one_hot_encode(val):
    arr = numpy.zeros((6,), dtype=int)
    arr[val] = 1
    return arr

def get_bucket_id(age):
  age_floor = int(age)
  if age_floor >= 0 and age_floor <= 5: return 0
  elif age_floor >= 6 and age_floor <= 12: return 1
  elif age_floor >= 13 and age_floor <= 19: return 2
  elif age_floor >= 20 and age_floor <= 29: return 3
  elif age_floor >= 30 and age_floor <= 59: return 4
  else: return 5

def get_ground_truth(age):
  return one_hot_encode(get_bucket_id(age))

In [None]:
def get_random_two_different_int(low=0, high=6, size=1):
  num1 = torch.randint(low,high, (size,), generator=random_generator).item()
  num2 = torch.randint(low,high, (size,), generator=random_generator).item()
  while num1 == num2: num2 = torch.randint(low,high, (size,), generator=random_generator).item()
  return num1,num2

In [None]:
def save_model_checkpoint(model, path):
    torch.save({
      'model_state_dict': model.state_dict()
    }, path)

def deep_copy_model(model_to_copy):
    model_copy = copy.deepcopy(model_to_copy)
    return model_copy

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    torch.cuda.empty_cache()
    size = len(dataloader.dataset)
    model.train()
    loss_tot = 0.0
    num = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Find all embeddings of current batch
        embeddings = model.get_embedding(X)
        del X
        torch.cuda.empty_cache()

        # Mining the farthest positive and closest negative
        dist_matrix = torch.cdist(embeddings,embeddings)
        mask_positive = (y.unsqueeze(1) == y.unsqueeze(0)).float()
        dist_for_only_positive = ((1-mask_positive)*-2e8)+mask_positive*dist_matrix
        positives = embeddings[dist_for_only_positive.argmax(dim=1)]
        del dist_for_only_positive
        torch.cuda.empty_cache()
        dist_for_only_negative = (mask_positive*2e8)+(1-mask_positive)*dist_matrix
        del mask_positive
        del dist_matrix
        torch.cuda.empty_cache()
        negatives = embeddings[dist_for_only_negative.argmin(dim=1)]
        del dist_for_only_negative
        torch.cuda.empty_cache()

        # Finding logits of current batch
        logits = model.get_logit(embeddings)

        # Backpropagation
        optimizer.zero_grad()
        loss = loss_fn(logits, embeddings, positives, negatives, y, 1.0, 1.0)
        del logits
        del positives
        del negatives
        del y
        del embeddings
        torch.cuda.empty_cache()
        loss.backward()
        optimizer.step()

        loss_tot += loss.item()
        num += 1
    print(f'training loss: {(loss_tot):>0.5f}')

In [None]:
validation_accuracy = []
current_max_val_acc = 0.0
def validation(dataloader, model, loss_fn, dont_log = False):
    global current_max_val_acc
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    correct = 0
    totalsize = 0
    loss_tot = 0.0
    num = 0
    cnt = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            embeddings = model.get_embedding(X)
            logits = model.get_logit(embeddings)
            probabilities = model.get_softmax(logits)
            loss = loss_fn(logits=logits, labels=y)
            correct += (probabilities.argmax(1) == y).sum().item()
            totalsize += X.shape[0]
            cnt += X.shape[0]
            loss_tot += loss.item()
            num += 1
            del X
            del y
            del embeddings
            del logits
            del probabilities
            del loss
            torch.cuda.empty_cache()

    print(f"Correct/Total: {correct}/{totalsize}")
    correct /= totalsize
    validation_accuracy.append(correct*100)
    print(f"Validation Loss:  {(loss_tot):>0.5f}")
    print(f"Validation Accuracy: {(100*correct):>0.5f}%\n")
    print(f"Total images: {cnt}")
    if 100*correct > current_max_val_acc and dont_log == False:
        current_max_val_acc = 100*correct
        save_model_checkpoint(model, save_current_best_model_path)
        numpy.save(save_validation_accuracy_path, numpy.array(validation_accuracy), allow_pickle=True, fix_imports=True)
        print("Saved")

    print(f"Current Best Validation Accuracy: {(current_max_val_acc):>0.5f}%\n")
    return loss_tot

In [None]:
data_augmentation_transformations = T.RandomChoice([ # Geometric Transformation
    T.RandomAffine(degrees=0),
    T.Lambda(lambda x: TF.hflip(img=x))
])

In [None]:
class XRayToothDataset(Dataset):
    def __init__(self, cwd, img_dir, transform=None, target_height=None, target_width=None):
        self.dataset_path = cwd + '/' + img_dir
        self.transform = transform
        self.target_height = target_height
        self.target_width = target_width

    def __len__(self):
        return len(os.listdir(self.dataset_path))

    def __getitem__(self, idx):
        if idx  >= len(os.listdir(self.dataset_path)):
            print("No datafile/image at index : "+ str(idx))
            return None
        img_filename = os.listdir(self.dataset_path)[idx]
        age = float(img_filename.split("_")[1][:-4])
        age_gt = get_bucket_id(age)
        image_tensor = read_image(path=self.dataset_path + '/' + img_filename)
        image_tensor = image_tensor.reshape(1, 3, image_tensor.shape[-2], image_tensor.shape[-1])
        if self.target_height and self.target_width: # Resize the image
            image_tensor = torch.nn.functional.interpolate(image_tensor, (self.target_height,self.target_width))
        if self.transform: image_tensor = self.transform(image_tensor) # Apply transformations
        image_tensor = (image_tensor-image_tensor.min())/(image_tensor.max()-image_tensor.min())
        return image_tensor.reshape(-1,image_tensor.shape[-2],image_tensor.shape[-1]).to(torch.float32), torch.tensor(age_gt)

In [None]:
training_data = XRayToothDataset(os.getcwd(), img_dir=dataset_save_dir+'/training', transform=data_augmentation_transformations, target_height=224, target_width=224)
validation_data = XRayToothDataset(os.getcwd(), img_dir=dataset_save_dir+'/validation', transform=None, target_height=224, target_width=224)

In [None]:
from torchvision.models import vit_l_32, ViT_L_32_Weights

pretrained_vit = vit_l_32(weights=ViT_L_32_Weights.IMAGENET1K_V1)

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = pretrained_vit
        pretrained_vit.heads.head = nn.Linear(1024,256)

        for param in self.backbone.parameters():
            param.requires_grad = False

        pretrained_vit.heads.head.requires_grad = True

        self.classifier = nn.Sequential(
            nn.Dropout(0.6),
            nn.Linear(256, 6)
        )

    def get_embedding(self, x): # Embedding of only single image
        e = self.backbone(x)
        return e

    def get_logit(self, ea):
        logit = self.classifier(ea)
        return logit

    def get_softmax(self, logit):
        probabilities = nn.Softmax(dim=1)(logit)
        return probabilities

    def forward(self, xA, xP, xN): # Get three embeddings
        ea = self.get_embedding(xA)
        ep = self.get_embedding(xP)
        en = self.get_embedding(xN)
        return ea, ep, en

In [None]:
model = NeuralNetwork().to(device)

In [None]:
# Training Hyperparameters
epochs = 10000
batch_size = 300
learning_rate = 1e-2
momentum=0.9
weight_decay=0.9

In [None]:
training_data_loader = DataLoader(training_data, batch_size, shuffle = True)
validation_data_loader = DataLoader(validation_data, batch_size, shuffle = False)

In [None]:
class CustomNLLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, logits, anchors=None, positives=None, negatives=None, labels=None, scale1=None, scale2=None):
        loss_c = torch.nn.functional.cross_entropy(logits, labels)
        if anchors==None: return loss_c
        loss_t = nn.TripletMarginLoss(margin=1.0)(anchors,positives,negatives)
        loss_total = scale1*loss_c + scale2*loss_t
        return loss_total

In [None]:
loss_function=CustomNLLoss()
optimizer=torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=5, min_lr=1e-4,verbose=True)

In [None]:
# Training
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(training_data_loader, model, loss_function, optimizer)
    if t >= 100: # 20 warmup epochs
        val_loss = validation(validation_data_loader, model, loss_function)
        scheduler.step(val_loss)
    # torch.save(model, 'model.pth')
print("Done!")