In [1]:
!pip install pytorch-metric-learning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-metric-learning
  Downloading pytorch_metric_learning-2.1.2-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pytorch-metric-learning
Successfully installed pytorch-metric-learning-2.1.2


In [3]:
import os
import json
import zipfile
import subprocess
import shutil
import getpass
import math
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image,ImageReadMode
import matplotlib.pyplot as plt
from pytorch_metric_learning import losses, regularizers
from torchsummary import summary
from torchvision.models.vision_transformer import VisionTransformer
import copy

In [4]:
torch.manual_seed(20)

<torch._C.Generator at 0x7f5f5054a710>

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
print(f"Using {device} device")

Using cuda device


In [6]:
random_generator = torch.Generator(device='cpu')
random_generator.manual_seed(2147483647)

<torch._C.Generator at 0x7f5e89f0afb0>

In [7]:
dataset_save_dir = './dataset'

# File Paths
save_current_best_model_path = './triplet_classifier_model.pth'
save_validation_accuracy_path = './triplet_classifier_model_val_acc.npy'

In [11]:
def one_hot_encode(val):
    arr = numpy.zeros((6,), dtype=int)
    arr[val] = 1
    return arr

def get_bucket_id(age):
  age_floor = int(age)
  if age_floor >= 0 and age_floor <= 5: return 0
  elif age_floor >= 6 and age_floor <= 12: return 1
  elif age_floor >= 13 and age_floor <= 19: return 2
  elif age_floor >= 20 and age_floor <= 29: return 3
  elif age_floor >= 30 and age_floor <= 59: return 4
  else: return 5

def get_ground_truth(age):
  return one_hot_encode(get_bucket_id(age))

In [12]:
def get_random_two_different_int(low=0, high=6, size=1):
  num1 = torch.randint(low,high, (size,), generator=random_generator).item()
  num2 = torch.randint(low,high, (size,), generator=random_generator).item()
  while num1 == num2: num2 = torch.randint(low,high, (size,), generator=random_generator).item()
  return num1,num2

In [13]:
def save_model_checkpoint(model, path):
    torch.save({
      'model_state_dict': model.state_dict()
    }, path)

def deep_copy_model(model_to_copy):
    model_copy = copy.deepcopy(model_to_copy)
    return model_copy

In [79]:
def train(dataloader, model, loss_fn, optimizer):
    torch.cuda.empty_cache()
    size = len(dataloader.dataset)
    model.train()
    loss_tot = 0.0
    num = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Find all embeddings of current batch
        embeddings = model.get_embedding(X)
        del X
        torch.cuda.empty_cache()

        # Mining the farthest positive and closest negative
        dist_matrix = torch.cdist(embeddings,embeddings)
        mask_positive = (y.unsqueeze(1) == y.unsqueeze(0)).float()
        dist_for_only_positive = ((1-mask_positive)*-2e8)+mask_positive*dist_matrix
        positives = embeddings[dist_for_only_positive.argmax(dim=1)]
        del dist_for_only_positive
        torch.cuda.empty_cache()
        dist_for_only_negative = (mask_positive*2e8)+(1-mask_positive)*dist_matrix
        del mask_positive
        del dist_matrix
        torch.cuda.empty_cache()
        negatives = embeddings[dist_for_only_negative.argmin(dim=1)]
        del dist_for_only_negative
        torch.cuda.empty_cache()

        # Finding logits of current batch
        logits = model.get_logit(embeddings)

        # Backpropagation
        optimizer.zero_grad()
        loss = loss_fn(logits, embeddings, positives, negatives, y, 1.0, 1.0)
        del logits
        del positives
        del negatives
        del y
        del embeddings
        torch.cuda.empty_cache()
        loss.backward()
        optimizer.step()

        loss_tot += loss.item()
        num += 1
    print(f'training loss: {(loss_tot):>0.5f}')
    return loss_tot

In [80]:
validation_accuracy = []
current_max_val_acc = 0.0
def validation(dataloader, model, loss_fn, dont_log = False):
    global current_max_val_acc
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    correct = 0
    totalsize = 0
    loss_tot = 0.0
    num = 0
    cnt = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            validation_size = X.shape[0]

            # Get all the embeddings; make sure to set batchsize max
            embeddings = model.get_embedding(X)
            anchors = embeddings.unsqueeze(1).expand(-1, validation_size, -1).reshape(-1, embeddings.size(1))
            versus = embeddings.repeat(validation_size,1)
            del embeddings
            torch.cuda.empty_cache()
            # similarities = model.get_similarity(anchors, versus).reshape(validation_size, -1).argmax(dim=1)
            similarities = model.get_similarity(anchors, versus).reshape(validation_size, -1)
            mask = torch.ones((validation_size,validation_size))
            indices = torch.arange(validation_size)
            similarities[indices, indices] = -2e8 # assign -INF for same samples
            similarities = similarities.argmax(dim=1)
            del anchors
            del versus
            torch.cuda.empty_cache()
            predicted_labels = y[similarities]
            del similarities
            torch.cuda.empty_cache()
            num_equal_indices = torch.nonzero(predicted_labels == y).squeeze().numel()
            del predicted_labels
            torch.cuda.empty_cache()

            print(num_equal_indices)

            correct += num_equal_indices
            totalsize += X.shape[0]
            cnt += X.shape[0]
            num += 1
            del X
            del y
            torch.cuda.empty_cache()

    print(f"Correct/Total: {correct}/{totalsize}")
    correct /= totalsize
    validation_accuracy.append(correct*100)
    # print(f"Validation Loss:  {(loss_tot):>0.5f}")
    print(f"Validation Accuracy: {(100*correct):>0.5f}%\n")
    print(f"Total images: {cnt}")
    if 100*correct > current_max_val_acc and dont_log == False:
        current_max_val_acc = 100*correct
        save_model_checkpoint(model, save_current_best_model_path)
        numpy.save(save_validation_accuracy_path, numpy.array(validation_accuracy), allow_pickle=True, fix_imports=True)
        print("Saved")

    print(f"Current Best Validation Accuracy: {(current_max_val_acc):>0.5f}%\n")
    return loss_tot

In [81]:
data_augmentation_transformations = T.RandomChoice([ # Geometric Transformation
    T.RandomAffine(degrees=0),
    T.Lambda(lambda x: TF.hflip(img=x))
])

In [82]:
class XRayToothDataset(Dataset):
    def __init__(self, cwd, img_dir, transform=None, target_height=None, target_width=None):
        self.dataset_path = cwd + '/' + img_dir
        self.transform = transform
        self.target_height = target_height
        self.target_width = target_width

    def __len__(self):
        return len(os.listdir(self.dataset_path))

    def __getitem__(self, idx):
        if idx  >= len(os.listdir(self.dataset_path)):
            print("No datafile/image at index : "+ str(idx))
            return None
        img_filename = os.listdir(self.dataset_path)[idx]
        age = float(img_filename.split("_")[1][:-4])
        age_gt = get_bucket_id(age)
        image_tensor = read_image(path=self.dataset_path + '/' + img_filename)
        image_tensor = image_tensor.reshape(1, 3, image_tensor.shape[-2], image_tensor.shape[-1])
        if self.target_height and self.target_width: # Resize the image
            image_tensor = torch.nn.functional.interpolate(image_tensor, (self.target_height,self.target_width))
        if self.transform: image_tensor = self.transform(image_tensor) # Apply transformations
        image_tensor = (image_tensor-image_tensor.min())/(image_tensor.max()-image_tensor.min())
        return image_tensor.reshape(-1,image_tensor.shape[-2],image_tensor.shape[-1]).to(torch.float32), torch.tensor(age_gt)

In [83]:
training_data = XRayToothDataset(os.getcwd(), img_dir=dataset_save_dir+'/training', transform=data_augmentation_transformations, target_height=224, target_width=224)
validation_data = XRayToothDataset(os.getcwd(), img_dir=dataset_save_dir+'/validation', transform=None, target_height=224, target_width=224)

In [84]:
from torchvision.models import vit_l_32, ViT_L_32_Weights

pretrained_vit = vit_l_32(weights=ViT_L_32_Weights.IMAGENET1K_V1)

In [85]:
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

pretrained_effnet = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)

In [86]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone1 = pretrained_vit
        self.backbone2 = pretrained_effnet
        for param in self.backbone1.parameters():
            param.requires_grad = False

        for param in self.backbone2.parameters():
            param.requires_grad = False

        self.fc = nn.Sequential(
            nn.Dropout(0.6),
            nn.Linear(2000,256)
        )

        self.classifier = nn.Sequential(
            nn.Dropout(0.6),
            nn.Linear(256, 6)
        )

    def get_embedding(self, x): # Embedding of only single image
        x1 = self.backbone1(x)
        x2 = self.backbone2(x)
        x = torch.cat((x1,x2), 1)
        e = self.fc(x)
        return e

    def get_logit(self, ea):
        logit = self.classifier(ea)
        return logit

    def get_softmax(self, logit):
        probabilities = nn.Softmax(dim=1)(logit)
        return probabilities

    def get_similarity(self, e1, e2):
        return nn.CosineSimilarity(dim=1)(e1, e2)

    def forward(self, xA, xP, xN): # Get three embeddings
        ea = self.get_embedding(xA)
        ep = self.get_embedding(xP)
        en = self.get_embedding(xN)
        return ea, ep, en

In [87]:
model = NeuralNetwork().to(device)

In [88]:
# Training Hyperparameters
epochs = 10000
batch_size = 1000
learning_rate = 1e-2
momentum=0.9
weight_decay=0.9

In [89]:
training_data_loader = DataLoader(training_data, batch_size, shuffle = True)
validation_data_loader = DataLoader(validation_data, 129, shuffle = False)

In [90]:
class CustomNLLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, logits, anchors, positives, negatives, labels, scale1, scale2):
        loss_c = torch.nn.functional.cross_entropy(logits, labels)
        loss_t = nn.TripletMarginLoss(margin=10.0)(anchors,positives,negatives)
        loss_total = scale1*loss_c + scale2*loss_t
        return loss_total

In [91]:
loss_function=CustomNLLoss()
optimizer=torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=5, min_lr=1e-4,verbose=True)

In [None]:
# Training
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loss = train(training_data_loader, model, loss_function, optimizer)
    if t >= 100: # 20 warmup epochs
        validation(validation_data_loader, model, loss_function)
        scheduler.step(train_loss)
print("Done!")

Epoch 1
-------------------------------
training loss: 18.57322
tensor([  6,  80, 115, 113,  21,  56,  72, 109,  80, 117,  49,   8,  31,  30,
        122,   4,  21, 102,  96, 124,  55,   2, 115, 126,  34,  20,  39,  95,
         13,  91,  13,  62, 117, 101,  68,  58, 122, 124,  67,  13,  46, 113,
         64,  40,  68, 102,  64,  58,  69, 114, 103, 113, 113,  44,  97,  20,
         97,  74,  35, 115,  74,  35,  31,  31,  46, 100,   9,  63,  34,  31,
        101,  58,   6,  99,  89,  91,  71,  50,  13,  61,   8,  98,   8,  70,
         31,  35, 113,  15,  89, 113, 111,  23,  36,  55, 122,  27,  21,  54,
         21, 102,  30,  70, 126,   8,  23,  35, 113,  13, 114,   0, 115,  90,
        118, 122,  49,   2,  35,  55, 112,  30, 117,  99, 113,  63,  31,  23,
         54,  64,   2], device='cuda:0')
62
Correct/Total: 62/129
Validation Accuracy: 48.06202%

Total images: 129
Saved
Current Best Validation Accuracy: 48.06202%

Epoch 2
-------------------------------
